class TopDownNet(nn.Module): def __init__(self, h_dims=128, n_classes=10, filters=[16, 32, 64, 128, 256], kernel_size=(3, 3), final_pool_size=(2, 2), glimpse_type='gaussian', glimpse_size=(15, 15), cnn='cnn'): from networkx.algorithms.traversal.breadth_first_search import bfs_edges nn.Module.__init__(self) t = nx.balanced_tree(1, 2) self.G = DGLGraph(t) self.root = 0 #self.walk_list = bfs_edges(t, self.root) self.walk_list = [(0, 1), (1, 2)] self.h_dims = h_dims self.n_classes = n_classes self.update_module = UpdateModule( h_dims=h_dims, n_classes=n_classes, filters=filters, kernel_size=kernel_size, final_pool_size=final_pool_size, glimpse_type=glimpse_type, glimpse_size=glimpse_size, cnn='cnn', ) self.message_module = MessageModule( h_dims=h_dims, g_dims=self.update_module.glimpse.att_params) self.readout_module = ReadoutModule( h_dims=h_dims, n_classes=n_classes, ) self.G.register_message_func(self.message_module) self.G.register_update_func(self.update_module) self.G.register_readout_func(self.readout_module) def forward(self, x): batch_size = x.shape[0] g_dims = self.update_module.glimpse.att_params self.update_module.set_image(x) zero_tensor_x = lambda r, c: \ x.new(r, c).zero_() init_states = { 's': zero_tensor_x(batch_size, self.h_dims), 'a': ( zero_tensor_x(batch_size, self.h_dims), zero_tensor_x(batch_size, g_dims), ), 'g': None, 'c': zero_tensor_x(batch_size, 1), } for n in self.G.nodes(): self.G.node[n].update(init_states) self.G.recvfrom(self.root, []) # Update root node self.G.propagate(self.walk_list) return self.G.readout()
class DFSGlimpseSingleObjectClassifier(nn.Module): def __init__( self, h_dims=128, n_classes=10, filters=[16, 32, 64, 128, 256], kernel_size=(3, 3), final_pool_size=(2, 2), glimpse_type='gaussian', glimpse_size=(15, 15), cnn='cnn', cnn_file='cnn.pt', ): nn.Module.__init__(self) #self.T_MAX_RECUR = kwarg['steps'] t = nx.balanced_tree(2, 2) t_uni = nx.bfs_tree(t, 0) self.G = DGLGraph(t) self.root = 0 self.h_dims = h_dims self.n_classes = n_classes self.message_module = MessageModule() self.G.register_message_func(self.message_module) # default: just copy cnnmodule = CNN( cnn=cnn, n_layers=6, h_dims=h_dims, n_classes=n_classes, final_pool_size=final_pool_size, filters=filters, kernel_size=kernel_size, input_size=glimpse_size, ) if cnn_file is not None: cnnmodule.load_state_dict(T.load(cnn_file)) #self.update_module = UpdateModule(h_dims, n_classes, glimpse_size) self.update_module = UpdateModule( glimpse_type=glimpse_type, glimpse_size=glimpse_size, cnn=cnnmodule, max_recur=1, # T_MAX_RECUR n_classes=n_classes, h_dims=h_dims, ) self.G.register_update_func(self.update_module) self.readout_module = ReadoutModule(h_dims=h_dims, n_classes=n_classes) self.G.register_readout_func(self.readout_module) #self.walk_list = [(0, 1), (1, 2), (2, 1), (1, 0)] self.walk_list = [] dfs_walk(t_uni, self.root, self.walk_list) def forward(self, x, pretrain=False): batch_size = x.shape[0] self.update_module.set_image(x) init_states = { 'h': x.new(batch_size, self.h_dims).zero_(), 'b': x.new(batch_size, self.update_module.glimpse.att_params).zero_(), 'b_next': x.new(batch_size, self.update_module.glimpse.att_params).zero_(), 'a': x.new(batch_size, 1).zero_(), 'y': x.new(batch_size, self.n_classes).zero_(), 'g': None, 'b_fix': None, 'db': None, } for n in self.G.nodes(): self.G.node[n].update(init_states) #TODO: the following two lines is needed for single object #TODO: but not useful or wrong for multi-obj self.G.recvfrom(self.root, []) if pretrain: return self.G.readout([self.root], pretrain=True) else: #for u, v in self.walk_list: # self.G.update_by_edge((u, v)) # update local should be inside the update module #for i in self.T_MAX_RECUR: # self.G.update_local(u) self.G.propagate(self.walk_list) return self.G.readout('all', pretrain=False)