Esempio n. 1
0
    def make_param_declaration(self, assign_node):
        lhs = assign_node.targets[0].id
        size = u.param_size(assign_node)

        decl = u.stmt_from_str(
            "self._m_%s = tf.Variable(init_params(%s), name='%s')" %
            (lhs, size, lhs))
        softmax = u.stmt_from_str(
            "%s = tpt.softmax(self._m_%s, scope='%s_softmax')" %
            (lhs, lhs, lhs))
        return [decl, softmax], ("self._m_%s" % lhs, lhs)
Esempio n. 2
0
 def visit_Expr(self_t, node):
     if u.is_set_as_input(node):
         var_name = node.value.func.value.id
         var_size = self_t.var_sizes[var_name]
         node1 = u.stmt_from_str(
             "%s_idxs.set_to(tf.placeholder(tf.int32, shape=[None], name='%s_idxs'))"
             % (var_name, var_name))
         node2 = u.stmt_from_str(
             "%s.set_to(tpt.one_hot(%s_idxs, %s, scope='%s_one_hot'))"
             % (var_name, var_name, var_size, var_name))
         self_t.inputs[var_name] = "%s_idxs" % var_name
         return [node1, node2]
     else:
         return node
Esempio n. 3
0
            def visit_Assign(self_t, node):
                if u.is_input_definition(node):
                    name = node.targets[0].id
                    node.value.func.id = "Var"  # rename to be a var
                    new_node = u.stmt_from_str("%s.set_as_input()" % name)
                    return [node, new_node]

                elif u.is_output_definition(node):
                    name = node.targets[0].id
                    node.value.func.id = "Var"  # rename to be a var
                    new_node = u.stmt_from_str("%s.set_as_output()" % name)
                    return [node, new_node]

                else:
                    return node
Esempio n. 4
0
    def merge_hypers(self, module_nodes):
        """
        Merge together the compilations for different hyperparameter settings
        into a single class.

        Plan: take the first module_node, and then merge in each of the other
        module_nodes.
        """
        hypers_names = module_nodes.keys()
        module_node = module_nodes[hypers_names[0]]
        class_node = u.get_class_node(module_node)

        for hypers_name_i in hypers_names[1:]:
            module_node_i = module_nodes[hypers_name_i]
            self.assert_params_are_same(module_node, hypers_names[0],
                                        module_node_i, hypers_name_i)

            # now we know it's ok to just keep one copy of the params declarations
            model_method_i = u.get_method_by_name(
                module_node_i, "build_model_%s" % hypers_name_i)
            class_node.body.append(model_method_i)

        # finally, rename param declaration method to be hypers-independent
        param_decl_method = u.get_method_by_name(
            module_node, "declare_params_%s" % hypers_names[0])
        param_decl_method.name = "declare_params"
        # give it arguments
        # param_decl_method.args = [ast.Name(s="init_params")]

        module_node.body.insert(0, u.stmt_from_str("import tensorflow as tf"))

        return module_node
Esempio n. 5
0
    def declare_outputs_xform(self, root):
        class Transformer(self.MyTransformer):
            def visit_Module(self_t, node):
                v = self.VarSizesVisitor()
                v.visit(node)
                self_t.var_sizes = v.var_sizes
                self_t.outputs = {}
                self_t.generic_visit(node)
                return node

            def visit_FunctionDef(self_t, node):
                """ Don't recurse into user defined functions """
                return node

            def visit_Expr(self_t, node):
                if u.is_set_as_output(node):
                    var_name = node.value.func.value.id
                    self_t.outputs[var_name] = var_name
                    return []
                else:
                    return node

        t = Transformer()
        root = t.visit(root)
        root.body.append(
            u.stmt_from_str("_outputs = %s" % u.dict_to_ast(t.outputs)))

        return root
Esempio n. 6
0
    def declare_inputs_xform(self, root):
        class Transformer(self.MyTransformer):
            def visit_FunctionDef(self_t, node):
                """ Don't recurse into user defined functions """
                return node

            def visit_Module(self_t, node):
                v = self.VarSizesVisitor()
                v.visit(node)
                self_t.var_sizes = v.var_sizes
                self_t.inputs = {}
                self_t.generic_visit(node)
                return node

            def visit_Expr(self_t, node):
                if u.is_set_as_input(node):
                    var_name = node.value.func.value.id
                    var_size = self_t.var_sizes[var_name]
                    node1 = u.stmt_from_str(
                        "%s_idxs.set_to(tf.placeholder(tf.int32, shape=[None], name='%s_idxs'))"
                        % (var_name, var_name))
                    node2 = u.stmt_from_str(
                        "%s.set_to(tpt.one_hot(%s_idxs, %s, scope='%s_one_hot'))"
                        % (var_name, var_name, var_size, var_name))
                    self_t.inputs[var_name] = "%s_idxs" % var_name
                    return [node1, node2]
                else:
                    return node

        t = Transformer()
        root = t.visit(root)
        root.body.append(
            u.stmt_from_str("_inputs = %s" % u.dict_to_ast(t.inputs)))
        return root
Esempio n. 7
0
    def add_get_hypers_method(self, module_node, hypers_list):
        get_hypers = ast.FunctionDef(
            name="get_hypers",
            args=u.make_args_list([ast.Name(id="self")]),
            decorator_list=[],
            body=u.stmt_from_str("return %s" % json.dumps(hypers_list)))

        u.get_class_node(module_node).body.append(get_hypers)

        return module_node
Esempio n. 8
0
    def class_formatting_xform(self, root, hypers_name):
        root, param_statements, param_dict = self.split_params_from_model(root)

        return_stmt = u.stmt_from_str("return _inputs, _outputs")
        root.body.append(return_stmt)
        model_method = ast.FunctionDef(name="build_model_%s" % hypers_name,
                                       args=[ast.Name(id="self")],
                                       decorator_list=[],
                                       body=root.body)

        args_list = u.make_args_list(
            [ast.Name(id="self"),
             ast.Name(id="init_params")])

        param_method = ast.FunctionDef(name="declare_params_%s" % hypers_name,
                                       args=args_list,
                                       decorator_list=[],
                                       body=param_statements)
        param_method.body.append(param_dict)
        param_method.body

        class_node = ast.ClassDef(name="Model",
                                  bases=[],
                                  body=[param_method, model_method],
                                  decorator_list=[])

        import_block = []

        runtime_assign = u.stmt_from_str(
            "runtime = '%s'" % (self.args.get('--runtime') or 'logspace'))
        condition = u.stmt_from_str("runtime == 'logspace'").value
        log_import_stmt = u.stmt_from_str(
            "import terpret_tf_log_runtime as tpt")
        standard_import_stmt = u.stmt_from_str(
            "import terpret_tf_runtime as tpt")

        if_node = ast.If(test=condition,
                         body=[log_import_stmt],
                         orelse=[standard_import_stmt])

        module_body = [runtime_assign, if_node, class_node]
        module_node = ast.Module(body=module_body)
        return module_node
Esempio n. 9
0
            def visit_Module(self_t, node):
                self_t.declared_vars = set()
                self_t.set_vars = set()
                self_t.param_vars = set()

                self_t.generic_visit(node)

                var_node_names = {
                    name: name
                    for name in self_t.declared_vars if name in self_t.set_vars
                }
                param_node_names = {name: name for name in self_t.param_vars}
                node.body.append(
                    u.stmt_from_str("self.var_nodes = %s" %
                                    u.dict_to_ast(var_node_names)))
                node.body.append(
                    u.stmt_from_str("self.param_var_nodes = %s" %
                                    u.dict_to_ast(param_node_names)))
                return node
Esempio n. 10
0
            def visit_Expr(self_t, node):
                if u.is_set_to(node) and isinstance(node.value.args[0],
                                                    ast.Num):
                    rhs = node.value.args[0].n
                    var_name = node.value.func.value.id
                    var_size = self_t.size_visitor.get_var_size(var_name)
                    expr = u.stmt_from_str("tpt.one_hot(%s, %s)" %
                                           (rhs, var_size))
                    node.value.args[0] = expr.value

                return node
Esempio n. 11
0
 def visit_FunctionDef(self_t, node):
     if len(node.decorator_list
            ) == 1 and node.decorator_list[0].func.id == "Runtime":
         decorator = node.decorator_list[0]
         assert len(decorator.args) == 2, \
             "Decorator must have two args:\n%s" % unparse(decorator)
         input_sizes, output_size = decorator.args[
             0].elts, decorator.args[1]
         new_node = u.stmt_from_str(
             "%s_tensor = tpt.make_tensor(%s, [%s], %s)" %
             (node.name, node.name, ",".join(
                 [str(size.n)
                  for size in input_sizes]), output_size.n))
         return [node, new_node]
     elif len(node.decorator_list
              ) == 1 and node.decorator_list[0].func.id == "Inline":
         return None
     else:
         return node
Esempio n. 12
0
    def declare_params_xform(self, root):
        class Transformer(self.MyTransformer):
            def visit_FunctionDef(self_t, node):
                """ Don't recurse into user defined functions """
                return node

            def visit_Module(self_t, node):
                self_t.params = {}
                self_t.generic_visit(node)
                return node

            def visit_Assign(self_t, node):
                if u.is_param_definition(node):
                    new_nodes, names = self.make_param_declaration(node)
                    self_t.params[names[1]] = names[0]
                else:
                    new_nodes = [node]
                return new_nodes

        t = Transformer()
        root = t.visit(root)
        root.body.append(
            u.stmt_from_str("self.params = %s" % u.dict_to_ast(t.params)))
        return root
Esempio n. 13
0
    def rename_if_else_block(self, if_node):
        vars_to_rename = self.vars_defined_in_all_cases(if_node)
        all_cases = []

        for node in u.ifs_in_elif_block(if_node):
            condition_lhs = u.get_condition_lhs(node)
            n = u.get_condition_rhs_num(node)
            all_cases.append(n)
            rename_dict = {}
            for var in vars_to_rename:
                rename_dict[var] = "%s_case%s" % (var, n)
            node.body = self.subs(node.body, **rename_dict)

        all_cases = sorted(all_cases)

        result_nodes = [if_node]
        for var in vars_to_rename:
            renamed_vars = ["%s_case%s" % (var, n) for n in all_cases]
            stmt = u.stmt_from_str(
                "%s.set_to(tpt.weighted_sum([%s], %s, scope='%s_weighted_sum'))"
                % (var, ",".join(renamed_vars), condition_lhs, var))
            result_nodes.append(stmt)

        return result_nodes