Exemplo n.º 1
0
  def _create_nonlocal_declarations(self, loop_vars):
    results = []
    global_vars = self.state[_Function].scope.globals

    if global_vars:
      results.append(gast.Global([str(v) for v in global_vars]))

    nonlocal_vars = [
        v for v in loop_vars if not v.is_composite() and v not in global_vars]
    if nonlocal_vars:
      results.append(gast.Nonlocal([str(v) for v in nonlocal_vars]))

    return results
Exemplo n.º 2
0
    def _update_name_to_var_shape(self, node):
        def replace_dot(name):
            # replace all '.' into '_'
            return name.replace('.', '_')

        assert isinstance(node, gast.Assign)
        target_node = node.targets[0]
        value_node = node.value

        update_static_shape_var_node = None
        if isinstance(target_node, gast.Tuple):
            update_static_shape_var_node = []
            for idx, element in enumerate(target_node.elts):
                target_id = ast_to_source_code(element).strip()

                if isinstance(value_node, gast.Name):
                    if value_node.id in self.name_to_var_shape:
                        # TODO(zhhsplendid): is context a problem for the result node of gast.parse?
                        static_shape_var_name = unique_name.generate(
                            replace_dot(target_id) +
                            STATIC_CONVERT_VAR_SHAPE_SUFFIX)
                        static_shape_var_node = gast.parse(
                            static_shape_var_name).body[0].value

                        static_shape_value_name = self.name_to_var_shape[
                            value_node.id]

                        sub_node_str = "{}[{}]".format(static_shape_value_name,
                                                       idx)
                        sub_node = gast.parse(sub_node_str).body[0].value

                        update_static_shape_var_node.append(
                            gast.Assign(targets=[static_shape_var_node],
                                        value=sub_node))

                        self.name_to_var_shape[
                            target_id] = static_shape_var_name
                if isinstance(value_node, gast.Attribute):
                    if self._is_var_shape(value_node):  # eg: x.shape
                        static_shape_var_name = unique_name.generate(
                            replace_dot(target_id) +
                            STATIC_CONVERT_VAR_SHAPE_SUFFIX)
                        static_shape_var_node = gast.parse(
                            static_shape_var_name).body[0].value

                        static_shape_value_node = copy.deepcopy(value_node)
                        # x.shape becomes convert_var_shape_simple(x)
                        static_shape_value_node = ShapeAttributeTransformer(
                        ).visit(static_shape_value_node)

                        sub_node_str = "{}[{}]".format(
                            ast_to_source_code(
                                static_shape_value_node).strip(), idx)
                        sub_node = gast.parse(sub_node_str).body[0].value
                        # Note(Aurelius84): Becuase static_shape_var_name is used in
                        # eval_if_exist_else_none() as plain string, so it will not
                        # be pasred as argument in convert_loop/ifelse. We delcare it
                        # as global var because it has unique name.
                        update_static_shape_var_node.append(
                            gast.Global(names=[static_shape_var_name]))

                        update_static_shape_var_node.append(
                            gast.Assign(targets=[static_shape_var_node],
                                        value=sub_node))
                        self.name_to_var_shape[
                            target_id] = static_shape_var_name
            return update_static_shape_var_node
        else:
            target_id = ast_to_source_code(target_node).strip()
            if isinstance(value_node, gast.Name):
                if value_node.id in self.name_to_var_shape:
                    static_shape_var_name = unique_name.generate(
                        replace_dot(target_id) +
                        STATIC_CONVERT_VAR_SHAPE_SUFFIX)
                    static_shape_var_node = gast.parse(
                        static_shape_var_name).body[0].value
                    static_shape_value_name = self.name_to_var_shape[
                        value_node.id]
                    static_shape_value_node = gast.parse(
                        static_shape_value_name).body[0].value

                    update_static_shape_var_node = [
                        gast.Assign(targets=[static_shape_var_node],
                                    value=static_shape_value_node)
                    ]
                    self.name_to_var_shape[target_id] = static_shape_var_name
            elif self._is_var_shape(value_node):  # eg: x.shape or x.shape[0]
                static_shape_var_name = unique_name.generate(
                    replace_dot(target_id) + STATIC_CONVERT_VAR_SHAPE_SUFFIX)
                static_shape_var_node = gast.parse(
                    static_shape_var_name).body[0].value
                static_shape_value_node = copy.deepcopy(value_node)
                # x.shape becomes convert_var_shape_simple(x)
                static_shape_value_node = ShapeAttributeTransformer().visit(
                    static_shape_value_node)
                # Declare static_shape_var_name as global var
                update_static_shape_var_node = [
                    gast.Global(names=[static_shape_var_name])
                ]
                update_static_shape_var_node.append(
                    gast.Assign(targets=[static_shape_var_node],
                                value=static_shape_value_node))
                self.name_to_var_shape[target_id] = static_shape_var_name
        return update_static_shape_var_node