def declare_outputs_xform(self, root): class Transformer(self.MyTransformer): def visit_Module(self_t, node): v = self.VarSizesVisitor() v.visit(node) self_t.var_sizes = v.var_sizes self_t.outputs = {} self_t.generic_visit(node) return node def visit_FunctionDef(self_t, node): """ Don't recurse into user defined functions """ return node def visit_Expr(self_t, node): if u.is_set_as_output(node): var_name = node.value.func.value.id self_t.outputs[var_name] = var_name return [] else: return node t = Transformer() root = t.visit(root) root.body.append( u.stmt_from_str("_outputs = %s" % u.dict_to_ast(t.outputs))) return root
def declare_inputs_xform(self, root): class Transformer(self.MyTransformer): def visit_FunctionDef(self_t, node): """ Don't recurse into user defined functions """ return node def visit_Module(self_t, node): v = self.VarSizesVisitor() v.visit(node) self_t.var_sizes = v.var_sizes self_t.inputs = {} self_t.generic_visit(node) return node def visit_Expr(self_t, node): if u.is_set_as_input(node): var_name = node.value.func.value.id var_size = self_t.var_sizes[var_name] node1 = u.stmt_from_str( "%s_idxs.set_to(tf.placeholder(tf.int32, shape=[None], name='%s_idxs'))" % (var_name, var_name)) node2 = u.stmt_from_str( "%s.set_to(tpt.one_hot(%s_idxs, %s, scope='%s_one_hot'))" % (var_name, var_name, var_size, var_name)) self_t.inputs[var_name] = "%s_idxs" % var_name return [node1, node2] else: return node t = Transformer() root = t.visit(root) root.body.append( u.stmt_from_str("_inputs = %s" % u.dict_to_ast(t.inputs))) return root
def visit_Module(self_t, node): self_t.declared_vars = set() self_t.set_vars = set() self_t.param_vars = set() self_t.generic_visit(node) var_node_names = { name: name for name in self_t.declared_vars if name in self_t.set_vars } param_node_names = {name: name for name in self_t.param_vars} node.body.append( u.stmt_from_str("self.var_nodes = %s" % u.dict_to_ast(var_node_names))) node.body.append( u.stmt_from_str("self.param_var_nodes = %s" % u.dict_to_ast(param_node_names))) return node
def declare_params_xform(self, root): class Transformer(self.MyTransformer): def visit_FunctionDef(self_t, node): """ Don't recurse into user defined functions """ return node def visit_Module(self_t, node): self_t.params = {} self_t.generic_visit(node) return node def visit_Assign(self_t, node): if u.is_param_definition(node): new_nodes, names = self.make_param_declaration(node) self_t.params[names[1]] = names[0] else: new_nodes = [node] return new_nodes t = Transformer() root = t.visit(root) root.body.append( u.stmt_from_str("self.params = %s" % u.dict_to_ast(t.params))) return root