装饰器定义为类

装饰器定义为类#

场景

此装饰器可以用在类/函数中也可以用于类/函数外。

from types import MethodType
from typing import Any
from functools import wraps
class Profiled:
    def __init__(self, func):
        wraps(func)(self)
        self.ncalls = 0

    def __call__(self, *args, **kwds):
        self.ncalls += 1
        return self.__wrapped__(*args, **kwds)

    def __get__(self, obj, objtype=None):
        if obj is None:
            return self
        else:
            return MethodType(self, obj)

可以用于装饰函数:

@Profiled
def add(x, y):
    return x + y
print(add.ncalls) # 函数外使用
a = add(2, 3)
print(add.ncalls)
a = add(4, 5)
print(add.ncalls)

可以用于装饰类:

class Spam:
    @Profiled
    def bar(self, x):
        print(self, x)
s = Spam()
print(Spam.bar.ncalls) # 类外使用
a = s.bar(1)
print(Spam.bar.ncalls)
a = s.bar(2)
print(s.bar.ncalls)

Profiled 记录了函数或类的调用次数。

记录中间结果#

from types import MethodType
from typing import Any
from functools import wraps
from weakref import WeakValueDictionary


class Cached(type):
    def __init__(self, *args, **kwds):
        super().__init__(*args, **kwds)
        self.__cache = WeakValueDictionary()

    def __call__(self, *args, **kwds):
        if args in self.__cache:
            return self.__cache[args]
        else:
            obj = super().__call__(*args, **kwds)
            self.__cache[args] = obj
        print("1", args, kwds)
        return obj

class ProfiledAccess:
    cache = {}
    is_activate = False
    def __init__(self, func, *cargs, **ckwds):
        super().__init__(*cargs, **ckwds)
        self.ncalls = 0 # 调用次数
        self.varname = "temp" 
        wraps(func)(self)

    def __call__(self, *args, **kwds):
        func = self.__wrapped__
        values = func(*args, **kwds)
        if type(self).is_activate:
            type(self).cache[func.__qualname__] = values
        self.ncalls += 1
        return values

    def __get__(self, obj, objtype=None):
        if obj is None:
            return self
        else:
            return MethodType(self, obj)

    @classmethod
    def activate(cls):
        cls.is_activate = True

    @classmethod
    def clear(cls):
        cls.cache.clear()
from bytecode import Bytecode, Instr
class Profiled:
    def __init__(self, func):
        self.varname = "return"
        wraps(self.transform(func))(self, self.varname)
        self.ncalls = 0

    def transform(self, func):
        """修改函数返回值为:(原返回值, varname表示的值)
        """
        c = Bytecode.from_code(func.__code__)
        extra_code = [
            Instr('STORE_FAST', '_res'),
            Instr('LOAD_FAST', self.varname),
            Instr('STORE_FAST', '_value'),
            Instr('LOAD_FAST', '_res'),
            Instr('LOAD_FAST', '_value'),
            Instr('BUILD_TUPLE', 2),
            Instr('STORE_FAST', '_result_tuple'),
            Instr('LOAD_FAST', '_result_tuple'),
        ]
        c[-1:-1] = extra_code
        func.__code__ = c.to_code()
        return func

    def __call__(self, *args, **kwds):
        self.ncalls += 1
        return #self.__wrapped__(*args, **kwds)

    # def __get__(self, obj, objtype=None):
    #     if obj is None:
    #         return self
    #     else:
    #         return MethodType(self, obj)
def func(x, y):
    c = x + y
    return c ** 2

varname = "c"
cache = WeakValueDictionary()
func(2, 3)
(25, 5)

其他#

a = Spam(2)
b = Spam(3)
c = Spam(2)
a is c
exec("b = a + 1")
print(b)
def test():
    a = 13
    loc = locals()
    exec("b = a + 1")
    print(loc["b"])
test()
def test():
    x = 13
    loc = locals()
    exec("x += 1")
    print(x)
    print(loc)
test()
x = 42
eval()
exec("for k in range(10): print(k)")
import ast
ex = ast.parse("2 + 3*4 + x", mode="eval")
print(ast.dump(ex))
import ast

class CodeAnalyzer(ast.NodeVisitor):
    def __init__(self) -> None:
        super().__init__()
        self.loaded = set()
        self.stored = set()
        self.deleted = set()

    def visit_Name(self, node):
        if isinstance(node.ctx, ast.Load):
            self.loaded.add(node.id)
        elif isinstance(node.ctx, ast.Store):
            self.stored.add(node.id)
        elif isinstance(node.ctx, ast.Del):
            self.deleted.add(node.id)
code = """
for k in range(10):
    print(k)
del k
"""
top = ast.parse(code, mode="exec")
c = CodeAnalyzer()
c.visit(top)
c.loaded
c.stored
c.deleted
exec(compile(top, "<stdin>", "exec"))