Esempio n. 1
0
    def declare_outputs_xform(self, root):
        class Transformer(self.MyTransformer):
            def visit_Module(self_t, node):
                v = self.VarSizesVisitor()
                v.visit(node)
                self_t.var_sizes = v.var_sizes
                self_t.outputs = {}
                self_t.generic_visit(node)
                return node

            def visit_FunctionDef(self_t, node):
                """ Don't recurse into user defined functions """
                return node

            def visit_Expr(self_t, node):
                if u.is_set_as_output(node):
                    var_name = node.value.func.value.id
                    self_t.outputs[var_name] = var_name
                    return []
                else:
                    return node

        t = Transformer()
        root = t.visit(root)
        root.body.append(
            u.stmt_from_str("_outputs = %s" % u.dict_to_ast(t.outputs)))

        return root
Esempio n. 2
0
    def declare_inputs_xform(self, root):
        class Transformer(self.MyTransformer):
            def visit_FunctionDef(self_t, node):
                """ Don't recurse into user defined functions """
                return node

            def visit_Module(self_t, node):
                v = self.VarSizesVisitor()
                v.visit(node)
                self_t.var_sizes = v.var_sizes
                self_t.inputs = {}
                self_t.generic_visit(node)
                return node

            def visit_Expr(self_t, node):
                if u.is_set_as_input(node):
                    var_name = node.value.func.value.id
                    var_size = self_t.var_sizes[var_name]
                    node1 = u.stmt_from_str(
                        "%s_idxs.set_to(tf.placeholder(tf.int32, shape=[None], name='%s_idxs'))"
                        % (var_name, var_name))
                    node2 = u.stmt_from_str(
                        "%s.set_to(tpt.one_hot(%s_idxs, %s, scope='%s_one_hot'))"
                        % (var_name, var_name, var_size, var_name))
                    self_t.inputs[var_name] = "%s_idxs" % var_name
                    return [node1, node2]
                else:
                    return node

        t = Transformer()
        root = t.visit(root)
        root.body.append(
            u.stmt_from_str("_inputs = %s" % u.dict_to_ast(t.inputs)))
        return root
Esempio n. 3
0
            def visit_Module(self_t, node):
                self_t.declared_vars = set()
                self_t.set_vars = set()
                self_t.param_vars = set()

                self_t.generic_visit(node)

                var_node_names = {
                    name: name
                    for name in self_t.declared_vars if name in self_t.set_vars
                }
                param_node_names = {name: name for name in self_t.param_vars}
                node.body.append(
                    u.stmt_from_str("self.var_nodes = %s" %
                                    u.dict_to_ast(var_node_names)))
                node.body.append(
                    u.stmt_from_str("self.param_var_nodes = %s" %
                                    u.dict_to_ast(param_node_names)))
                return node
Esempio n. 4
0
    def declare_params_xform(self, root):
        class Transformer(self.MyTransformer):
            def visit_FunctionDef(self_t, node):
                """ Don't recurse into user defined functions """
                return node

            def visit_Module(self_t, node):
                self_t.params = {}
                self_t.generic_visit(node)
                return node

            def visit_Assign(self_t, node):
                if u.is_param_definition(node):
                    new_nodes, names = self.make_param_declaration(node)
                    self_t.params[names[1]] = names[0]
                else:
                    new_nodes = [node]
                return new_nodes

        t = Transformer()
        root = t.visit(root)
        root.body.append(
            u.stmt_from_str("self.params = %s" % u.dict_to_ast(t.params)))
        return root