def _create_nonlocal_declarations(self, loop_vars): results = [] global_vars = self.state[_Function].scope.globals if global_vars: results.append(gast.Global([str(v) for v in global_vars])) nonlocal_vars = [ v for v in loop_vars if not v.is_composite() and v not in global_vars] if nonlocal_vars: results.append(gast.Nonlocal([str(v) for v in nonlocal_vars])) return results
def _update_name_to_var_shape(self, node): def replace_dot(name): # replace all '.' into '_' return name.replace('.', '_') assert isinstance(node, gast.Assign) target_node = node.targets[0] value_node = node.value update_static_shape_var_node = None if isinstance(target_node, gast.Tuple): update_static_shape_var_node = [] for idx, element in enumerate(target_node.elts): target_id = ast_to_source_code(element).strip() if isinstance(value_node, gast.Name): if value_node.id in self.name_to_var_shape: # TODO(zhhsplendid): is context a problem for the result node of gast.parse? static_shape_var_name = unique_name.generate( replace_dot(target_id) + STATIC_CONVERT_VAR_SHAPE_SUFFIX) static_shape_var_node = gast.parse( static_shape_var_name).body[0].value static_shape_value_name = self.name_to_var_shape[ value_node.id] sub_node_str = "{}[{}]".format(static_shape_value_name, idx) sub_node = gast.parse(sub_node_str).body[0].value update_static_shape_var_node.append( gast.Assign(targets=[static_shape_var_node], value=sub_node)) self.name_to_var_shape[ target_id] = static_shape_var_name if isinstance(value_node, gast.Attribute): if self._is_var_shape(value_node): # eg: x.shape static_shape_var_name = unique_name.generate( replace_dot(target_id) + STATIC_CONVERT_VAR_SHAPE_SUFFIX) static_shape_var_node = gast.parse( static_shape_var_name).body[0].value static_shape_value_node = copy.deepcopy(value_node) # x.shape becomes convert_var_shape_simple(x) static_shape_value_node = ShapeAttributeTransformer( ).visit(static_shape_value_node) sub_node_str = "{}[{}]".format( ast_to_source_code( static_shape_value_node).strip(), idx) sub_node = gast.parse(sub_node_str).body[0].value # Note(Aurelius84): Becuase static_shape_var_name is used in # eval_if_exist_else_none() as plain string, so it will not # be pasred as argument in convert_loop/ifelse. We delcare it # as global var because it has unique name. update_static_shape_var_node.append( gast.Global(names=[static_shape_var_name])) update_static_shape_var_node.append( gast.Assign(targets=[static_shape_var_node], value=sub_node)) self.name_to_var_shape[ target_id] = static_shape_var_name return update_static_shape_var_node else: target_id = ast_to_source_code(target_node).strip() if isinstance(value_node, gast.Name): if value_node.id in self.name_to_var_shape: static_shape_var_name = unique_name.generate( replace_dot(target_id) + STATIC_CONVERT_VAR_SHAPE_SUFFIX) static_shape_var_node = gast.parse( static_shape_var_name).body[0].value static_shape_value_name = self.name_to_var_shape[ value_node.id] static_shape_value_node = gast.parse( static_shape_value_name).body[0].value update_static_shape_var_node = [ gast.Assign(targets=[static_shape_var_node], value=static_shape_value_node) ] self.name_to_var_shape[target_id] = static_shape_var_name elif self._is_var_shape(value_node): # eg: x.shape or x.shape[0] static_shape_var_name = unique_name.generate( replace_dot(target_id) + STATIC_CONVERT_VAR_SHAPE_SUFFIX) static_shape_var_node = gast.parse( static_shape_var_name).body[0].value static_shape_value_node = copy.deepcopy(value_node) # x.shape becomes convert_var_shape_simple(x) static_shape_value_node = ShapeAttributeTransformer().visit( static_shape_value_node) # Declare static_shape_var_name as global var update_static_shape_var_node = [ gast.Global(names=[static_shape_var_name]) ] update_static_shape_var_node.append( gast.Assign(targets=[static_shape_var_node], value=static_shape_value_node)) self.name_to_var_shape[target_id] = static_shape_var_name return update_static_shape_var_node