예제 #1
0
 def visit_Expr(self, node):
     value_node = node.value
     for child_node in gast.walk(value_node):
         if isinstance(child_node, gast.Call):
             if is_dygraph_api(child_node):
                 return
             else:
                 self._visit_Call(child_node)
     return node
예제 #2
0
 def visit_Expr(self, node):
     value_node = node.value
     for child_node in gast.walk(value_node):
         if isinstance(child_node, gast.Call):
             # TODO(liym27):
             #  Considers that a dygraph api which modifies the input or has a output.
             if is_dygraph_api(child_node):
                 return
             else:
                 self._visit_Call(child_node)
     return node
예제 #3
0
def is_to_variable(node):
    assert isinstance(node, gast.Call)
    api_name = utils.ast_to_source_code(node.func).strip()

    if utils.is_dygraph_api(node):
        return api_name.endswith("to_variable")

    if utils.is_paddle_api(node):
        return api_name.endswith("to_tensor")

    return False
예제 #4
0
    def _update_class_node_dict(self, node):
        assert isinstance(node, gast.Assign)
        node_value = node.value
        if isinstance(node_value, gast.Call):
            if is_to_variable(node_value):
                return False

            if is_dygraph_api(node_value):
                dygraph_api = node_value.func.attr
                if not dygraph_class_to_static_api.get(dygraph_api):
                    return False

                update_args_of_func(node_value, node_value, "__init__")
                target_str = astor.to_source(gast.gast_to_ast(node.targets[0]))
                self.class_node_dict[target_str] = node_value
                return True
            # TODO: node.value is not dygraph class
        return False
 def test_dygraph_api(self):
     self.assertTrue(is_dygraph_api(self._get_dygraph_ast_node()) is True)
     self.assertTrue(is_dygraph_api(self._get_static_ast_node()) is False)