Fri 03 April 2009

CPython unterstützt ja bekanntermaßen keine TCO von sich aus, weshalb es zahlreiche mehr oder weniger schöne Hacks gibt, die das CPython beibringen. Aber einen wirklich beeindruckenden habe ich eben in #python auf Freenode gesehen, geschrieben von habnabit. Und zwar wird der Bytecode geändert, CALL_FUNCTION wird zu JUMP_ABSOLUTE geändert.

Nachtrag: Das müsste eigentlich recht zuverlässig funktionieren, unter verschiedenen CPython-Versionen. Für das CALL_FUNCTION müssen die Argumente eh auf dem Stack liegen. Diese werden jetzt einfach mit STORE_FAST in die Namen geschrieben, die die Funktion entgegen nimmt, und dann wird wieder zum Anfang der Funktion gesprungen.

import inspect, pprint, types, dis, struct, opcode, array
short = struct.Struct('<H')

class Label(object):
    pass

class Code(object):
    @classmethod
    def from_code(cls, code_obj):
        self = cls()
        self.code_obj = code_obj
        self.names = list(code_obj.co_names)
        self.varnames = list(code_obj.co_varnames)
        self.consts = list(code_obj.co_consts)
        ret = []
        line_starts = dict(dis.findlinestarts(code_obj))
        code = code_obj.co_code
        labels = dict((addr, Label()) for addr in dis.findlabels(code))
        i, l = 0, len(code)
        extended_arg = 0
        while i < l:
            op = ord(code[i])
            if i in labels:
                ret.append(('MARK_LABEL', labels[i]))
            if i in line_starts:
                ret.append(('MARK_LINENO', line_starts[i]))
            i += 1
            if op >= opcode.HAVE_ARGUMENT:
                arg, = short.unpack(code[i:i + 2])
                arg += extended_arg
                extended_arg = 0
                i += 2
                if op == opcode.EXTENDED_ARG:
                    extended_arg = arg << 16
                    continue
                elif op in opcode.hasjabs:
                    arg = labels[arg]
                elif op in opcode.hasjrel:
                    arg = labels[i + arg]
            else:
                arg = None
            ret.append((opcode.opname[op], arg))
        self.code = ret
        return self

    def to_code(self):
        code_obj = self.code_obj
        co_code = array.array('B')
        co_lnotab = array.array('B')
        label_pos = {}
        jumps = []
        lastlineno = code_obj.co_firstlineno
        lastlinepos = 0
        for op, arg in self.code:
            if op == 'MARK_LABEL':
                label_pos[arg] = len(co_code)
            elif op == 'MARK_LINENO':
                incr_lineno = arg - lastlineno
                incr_pos = len(co_code) - lastlinepos
                lastlineno = arg
                lastlinepos = len(co_code)

                if incr_lineno == 0 and incr_pos == 0:
                    co_lnotab.append(0)
                    co_lnotab.append(0)
                else:
                    while incr_pos > 255:
                        co_lnotab.append(255)
                        co_lnotab.append(0)
                        incr_pos -= 255
                    while incr_lineno > 255:
                        co_lnotab.append(incr_pos)
                        co_lnotab.append(255)
                        incr_pos = 0
                        incr_lineno -= 255
                    if incr_pos or incr_lineno:
                        co_lnotab.append(incr_pos)
                        co_lnotab.append(incr_lineno)
            elif arg is not None:
                op = opcode.opmap[op]
                if op in opcode.hasjabs or op in opcode.hasjrel:
                    jumps.append((len(co_code), arg))
                    arg = 0
                if arg > 0xffff:
                    co_code.extend((opcode.EXTENDED_ARG,
                        (arg >> 16) & 0xff, (arg >> 24) & 0xff))
                co_code.extend((op,
                    arg & 0xff, (arg >> 8) & 0xff))
            else:
                co_code.append(opcode.opmap[op])

        for pos, label in jumps:
            jump = label_pos[label]
            if co_code[pos] in opcode.hasjrel:
                jump -= pos + 3
            assert jump <= 0xffff
            co_code[pos + 1] = jump & 0xff
            co_code[pos + 2] = (jump >> 8) & 0xff

        return types.CodeType(code_obj.co_argcount, code_obj.co_nlocals,
            code_obj.co_stacksize, code_obj.co_flags, co_code.tostring(),
            tuple(self.consts), tuple(self.names), tuple(self.varnames),
            code_obj.co_filename, code_obj.co_name, code_obj.co_firstlineno,
            co_lnotab.tostring(), code_obj.co_freevars, code_obj.co_cellvars)

    def const_idx(self, val):
        try:
            return self.consts.index(val)
        except ValueError:
            self.consts.append(val)
            return len(self.consts) - 1

def tail_call(func):
    code = Code.from_code(func.func_code)
    func_name = func.__name__
    if func_name in code.varnames:
        raise SyntaxError('"%s" was found as a local variable in the function' %
            func_name)
    try:
        name_idx = code.names.index(func_name)
    except IndexError:
        raise SyntaxError('"%s" not found in function\'s global names' %
            func_name)
    last_idx = 0
    func_start = Label()
    code.code.insert(0, ('MARK_LABEL', func_start))
    while True:
        try:
            lglobal_idx = code.code.index(('LOAD_GLOBAL', name_idx), last_idx)
        except ValueError:
            break

        if code.code[lglobal_idx - 1][0] != 'MARK_LINENO':
            last_idx = lglobal_idx + 1
            continue

        try:
            return_idx = code.code.index(('RETURN_VALUE', None), lglobal_idx)
        except ValueError:
            raise SyntaxError('"return" not found in function after "%s"' %
                func_name)

        if (return_idx != len(code.code) - 1
                and code.code[return_idx + 1][0] != 'MARK_LINENO'):
            last_idx = return_idx + 1
            continue

        if code.code[return_idx - 1][0] in ('CALL_FUNCTION_VAR',
                'CALL_FUNCTION_KW', 'CALL_FUNCTION_VAR_KW'):
            raise SyntaxError('calling with *a and/or **kw is unsupported')

        if code.code[return_idx - 1][0] != 'CALL_FUNCTION':
            last_idx = return_idx + 1
            continue

        if code.code[return_idx - 1][1] & 0xff00:
            raise SyntaxError('calling with keyword arguments is unsupported')

        arg_names, _, _, defaults = inspect.getargspec(func)
        n_args = code.code[return_idx - 1][1]
        if defaults is None:
            defaults = ()
        if n_args + len(defaults) < len(arg_names):
            raise SyntaxError('not enough arguments provided')

        new_bytecode = []
        if n_args < len(arg_names):
            new_bytecode.extend(
                ('LOAD_CONST', code.const_idx(d))
                for d in defaults[n_args - len(arg_names):])
        new_bytecode.extend(
            ('STORE_FAST', code.varnames.index(arg))
            for arg in reversed(arg_names))
        new_bytecode.append(('JUMP_ABSOLUTE', func_start))
        code.code[return_idx - 1:return_idx + 1] = new_bytecode
        del code.code[lglobal_idx]

    func.func_code = code.to_code()
    return func

def factorial(n, acc=1):
    if n <= 0:
        return acc
    return factorial(n - 1, n * acc)

dis.dis(factorial)
factorial = tail_call(factorial)
print
dis.dis(factorial)

print factorial(10000)