Пример #1
0
def identify_redis_objs(func):
    """Identify objects likely to be used to access Redis in the code"""

    redis_func_objs = []
    nonredis_func_objs = []
    func_ast = sully.get_func_ast(func)
    node_walkers = (ast.walk(func_node) for func_node in func_ast)
    for node in itertools.chain.from_iterable(node_walkers):
        # Skip any nodes which are not function calls on objects
        if not (isinstance(node, ast.Call)
                and isinstance(node.func, ast.Attribute)):
            continue

        # Record all function calls
        if node.func.attr in REDIS_METHODS:
            redis_func_objs.append(node.func.value)
        else:
            nonredis_func_objs.append(node.func.value)

    # Loop through all the found function objects to pick
    # out the ones we deem to represent Redis interfaces
    redis_objs = []
    while len(redis_func_objs) > 0:
        obj = redis_func_objs.pop()

        # Remove all call nodes matching this object
        redis_before = len(redis_func_objs) + 1
        redis_func_objs = [
            obj2 for obj2 in redis_func_objs
            if not sully.nodes_equal(obj, obj2)
        ]

        nonredis_before = len(nonredis_func_objs)
        nonredis_func_objs = [
            obj2 for obj2 in nonredis_func_objs
            if not sully.nodes_equal(obj, obj2)
        ]

        # If the object meets a threshold of calls for the object
        # and a certain percentage of all calls match, record it
        redis_calls = redis_before - len(redis_func_objs)
        nonredis_calls = nonredis_before - len(nonredis_func_objs)
        if redis_calls >= REDIS_METHOD_COUNT and \
           (redis_calls * 1.0 /
               (redis_calls + nonredis_calls)) >= REDIS_METHOD_PCT:
            redis_objs.append(obj)

    return redis_objs
Пример #2
0
    def process_Call(self, node, code, indent, loops):
        """Generate code for a function call"""

        # We don't support positional or keyword arguments
        if node.starargs or node.kwargs:
            raise UntranslatableCodeException(node)

        raw_args = [self.process_node(n) for n in node.args]
        args = ', '.join(arg.code for arg in raw_args)

        # Handle some built-in functions
        if isinstance(node.func, ast.Name):
            if node.func.id in ('int', 'float'):
                line = 'tonumber(%s)' % args
            elif node.func.id == 'str':
                line = 'tostring(%s)' % args
            elif node.func.id in ('range', 'xrange'):
                # Extend to always use three arguments
                if len(node.args) == 1:
                    args = '0, %s - 1, 1' % args
                elif len(node.args) == 2:
                    args += ' - 1, 1'

                line = args
            elif node.func.id == 'len':
                assert len(node.args) == 1
                line = '#' + args
            else:
                # XXX We don't know how to handle this function
                raise UntranslatableCodeException(node)

        # XXX We assume now that the function being called is an Attribute

        # Get the current time for time.time()
        elif isinstance(node.func, ast.Attribute) and \
                isinstance(node.func.value, ast.Name) and \
                node.func.value.id == node.func.attr == 'time':
            line = '((function() local __TIME = redis.call("TIME"); ' \
                   'return __TIME[1] + (__TIME[2] / 1000000) end)())'

        # Perform string replacement
        elif node.func.attr == 'replace':
            line = '((function() local __TEMP, _; ' \
                   '__TEMP, _ = string.gsub(%s, %s); ' \
                   'return __TEMP end)())' \
                   % (self.process_node(node.func.value).code, args)

        # Join a table of strings
        elif node.func.attr == 'join':
            line = 'table.concat(%s, %s)\n' \
                   % (args, self.process_node(node.func.value).code)

        # If we're calling append, add to the end of a list
        elif node.func.attr == 'append':
            line = 'table.insert(%s, %s)' \
                   % (self.process_node(node.func.value).code, args)

        # If we're calling insert, add to the appropriate list position
        elif node.func.attr == 'insert':
            line = 'table.insert(%s, %s + 1, %s)\n' \
                   % (self.process_node(node.func.value).code,
                      raw_args[0].code, raw_args[1].code)

        # Check if we have a method call
        elif node.func.value.id == 'self':
            line = '%s(%s)' % ('self.' + node.func.attr, args)

        # XXX Assume this is a Redis pipeline execution
        elif node.func.attr == 'pipe':
            # Do nothing to start a pipeline
            line = ''
        elif node.func.attr == 'execute':
            expr = self.process_node(node.func.value).code
            line = '__PIPE_GET(\'%s\')' % expr

        # XXX Otherwise, assume this is a redis function call
        elif any(
                sully.nodes_equal(node.func.value, obj)
                for obj in self.redis_objs):
            # Generate the Redis function call expression
            cmd = node.func.attr
            if cmd == 'delete':
                cmd = 'del'
            call = 'redis.call(\'%s\', %s)' % (cmd, args)

            # Wrap the Redis call in a function which stores the
            # result if needed later for pipelining and returns it
            expr = self.process_node(node.func.value).code
            line = '__PIPE_ADD(\'%s\', %s)' % (expr, call)
        else:
            # XXX Something we can't handle
            raise UntranslatableCodeException(node)

        code.append(LuaLine(line, node, indent))
Пример #3
0
def test_names_equal():
    assert nodes_equal(ast.Name(id='foo', ctx=ast.Load()),
                       ast.Name(id='foo', ctx=ast.Load()))
Пример #4
0
def test_subscript_unequal():
    assert not nodes_equal(parse_source('x[0]'), parse_source('x[1]'))
Пример #5
0
def test_attribute_unequal():
    assert not nodes_equal(parse_source('self.x'), parse_source('self.y'))
Пример #6
0
def test_tpcc_fragment():
    objs = identify_redis_objs(redisdriver.RedisDriver.doDelivery)
    assert sully.nodes_equal(ast.Name(id='wtr', ctx=ast.Load()), objs[0])
    assert sully.nodes_equal(ast.Name(id='rdr', ctx=ast.Load()), objs[1])