def identify_redis_objs(func): """Identify objects likely to be used to access Redis in the code""" redis_func_objs = [] nonredis_func_objs = [] func_ast = sully.get_func_ast(func) node_walkers = (ast.walk(func_node) for func_node in func_ast) for node in itertools.chain.from_iterable(node_walkers): # Skip any nodes which are not function calls on objects if not (isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute)): continue # Record all function calls if node.func.attr in REDIS_METHODS: redis_func_objs.append(node.func.value) else: nonredis_func_objs.append(node.func.value) # Loop through all the found function objects to pick # out the ones we deem to represent Redis interfaces redis_objs = [] while len(redis_func_objs) > 0: obj = redis_func_objs.pop() # Remove all call nodes matching this object redis_before = len(redis_func_objs) + 1 redis_func_objs = [ obj2 for obj2 in redis_func_objs if not sully.nodes_equal(obj, obj2) ] nonredis_before = len(nonredis_func_objs) nonredis_func_objs = [ obj2 for obj2 in nonredis_func_objs if not sully.nodes_equal(obj, obj2) ] # If the object meets a threshold of calls for the object # and a certain percentage of all calls match, record it redis_calls = redis_before - len(redis_func_objs) nonredis_calls = nonredis_before - len(nonredis_func_objs) if redis_calls >= REDIS_METHOD_COUNT and \ (redis_calls * 1.0 / (redis_calls + nonredis_calls)) >= REDIS_METHOD_PCT: redis_objs.append(obj) return redis_objs
def process_Call(self, node, code, indent, loops): """Generate code for a function call""" # We don't support positional or keyword arguments if node.starargs or node.kwargs: raise UntranslatableCodeException(node) raw_args = [self.process_node(n) for n in node.args] args = ', '.join(arg.code for arg in raw_args) # Handle some built-in functions if isinstance(node.func, ast.Name): if node.func.id in ('int', 'float'): line = 'tonumber(%s)' % args elif node.func.id == 'str': line = 'tostring(%s)' % args elif node.func.id in ('range', 'xrange'): # Extend to always use three arguments if len(node.args) == 1: args = '0, %s - 1, 1' % args elif len(node.args) == 2: args += ' - 1, 1' line = args elif node.func.id == 'len': assert len(node.args) == 1 line = '#' + args else: # XXX We don't know how to handle this function raise UntranslatableCodeException(node) # XXX We assume now that the function being called is an Attribute # Get the current time for time.time() elif isinstance(node.func, ast.Attribute) and \ isinstance(node.func.value, ast.Name) and \ node.func.value.id == node.func.attr == 'time': line = '((function() local __TIME = redis.call("TIME"); ' \ 'return __TIME[1] + (__TIME[2] / 1000000) end)())' # Perform string replacement elif node.func.attr == 'replace': line = '((function() local __TEMP, _; ' \ '__TEMP, _ = string.gsub(%s, %s); ' \ 'return __TEMP end)())' \ % (self.process_node(node.func.value).code, args) # Join a table of strings elif node.func.attr == 'join': line = 'table.concat(%s, %s)\n' \ % (args, self.process_node(node.func.value).code) # If we're calling append, add to the end of a list elif node.func.attr == 'append': line = 'table.insert(%s, %s)' \ % (self.process_node(node.func.value).code, args) # If we're calling insert, add to the appropriate list position elif node.func.attr == 'insert': line = 'table.insert(%s, %s + 1, %s)\n' \ % (self.process_node(node.func.value).code, raw_args[0].code, raw_args[1].code) # Check if we have a method call elif node.func.value.id == 'self': line = '%s(%s)' % ('self.' + node.func.attr, args) # XXX Assume this is a Redis pipeline execution elif node.func.attr == 'pipe': # Do nothing to start a pipeline line = '' elif node.func.attr == 'execute': expr = self.process_node(node.func.value).code line = '__PIPE_GET(\'%s\')' % expr # XXX Otherwise, assume this is a redis function call elif any( sully.nodes_equal(node.func.value, obj) for obj in self.redis_objs): # Generate the Redis function call expression cmd = node.func.attr if cmd == 'delete': cmd = 'del' call = 'redis.call(\'%s\', %s)' % (cmd, args) # Wrap the Redis call in a function which stores the # result if needed later for pipelining and returns it expr = self.process_node(node.func.value).code line = '__PIPE_ADD(\'%s\', %s)' % (expr, call) else: # XXX Something we can't handle raise UntranslatableCodeException(node) code.append(LuaLine(line, node, indent))
def test_names_equal(): assert nodes_equal(ast.Name(id='foo', ctx=ast.Load()), ast.Name(id='foo', ctx=ast.Load()))
def test_subscript_unequal(): assert not nodes_equal(parse_source('x[0]'), parse_source('x[1]'))
def test_attribute_unequal(): assert not nodes_equal(parse_source('self.x'), parse_source('self.y'))
def test_tpcc_fragment(): objs = identify_redis_objs(redisdriver.RedisDriver.doDelivery) assert sully.nodes_equal(ast.Name(id='wtr', ctx=ast.Load()), objs[0]) assert sully.nodes_equal(ast.Name(id='rdr', ctx=ast.Load()), objs[1])