def test_multi_recv_0deg(): # test recv with 0deg nodes; g = DGLGraph() def _message(edges): return {'m' : edges.src['h']} def _reduce(nodes): return {'h' : nodes.data['h'] + nodes.mailbox['m'].sum(1)} def _apply(nodes): return {'h' : nodes.data['h'] * 2} def _init2(shape, dtype, ctx, ids): return 2 + th.zeros(shape, dtype=dtype, device=ctx) g.register_message_func(_message) g.register_reduce_func(_reduce) g.register_apply_node_func(_apply) g.set_n_initializer(_init2) g.add_nodes(2) g.add_edge(0, 1) # recv both 0deg and non-0deg nodes old = th.randn((2, 5)) g.ndata['h'] = old g.send((0, 1)) g.recv([0, 1]) new = g.ndata['h'] # 0deg check: initialized with the func and got applied assert U.allclose(new[0], th.full((5,), 4)) # non-0deg check assert U.allclose(new[1], th.sum(old, 0) * 2) # recv again on zero degree node g.recv([0]) assert U.allclose(g.nodes[0].data['h'], th.full((5,), 8)) # recv again on node with no incoming message g.recv([1]) assert U.allclose(g.nodes[1].data['h'], th.sum(old, 0) * 4)
def test_dynamic_addition(): N = 3 D = 1 g = DGLGraph() def _init(shape, dtype, ctx, ids): return F.copy_to(F.astype(F.randn(shape), dtype), ctx) g.set_n_initializer(_init) g.set_e_initializer(_init) def _message(edges): return { 'm': edges.src['h1'] + edges.dst['h2'] + edges.data['h1'] + edges.data['h2'] } def _reduce(nodes): return {'h': F.sum(nodes.mailbox['m'], 1)} def _apply(nodes): return {'h': nodes.data['h']} g.register_message_func(_message) g.register_reduce_func(_reduce) g.register_apply_node_func(_apply) g.set_n_initializer(dgl.init.zero_initializer) g.set_e_initializer(dgl.init.zero_initializer) # add nodes and edges g.add_nodes(N) g.ndata.update({'h1': F.randn((N, D)), 'h2': F.randn((N, D))}) g.add_nodes(3) g.add_edge(0, 1) g.add_edge(1, 0) g.edata.update({'h1': F.randn((2, D)), 'h2': F.randn((2, D))}) g.send() expected = F.copy_to(F.ones((g.number_of_edges(), ), dtype=F.int64), F.cpu()) assert F.array_equal(g._get_msg_index().tousertensor(), expected) # add more edges g.add_edges([0, 2], [2, 0], {'h1': F.randn((2, D))}) g.send(([0, 2], [2, 0])) g.recv(0) g.add_edge(1, 2) g.edges[4].data['h1'] = F.randn((1, D)) g.send((1, 2)) g.recv([1, 2]) h = g.ndata.pop('h') # a complete round of send and recv g.send() g.recv() assert F.allclose(h, g.ndata['h'])
def test_recv_0deg_newfld(): # test recv with 0deg nodes; the reducer also creates a new field g = DGLGraph() g.add_nodes(2) g.add_edge(0, 1) def _message(edges): return {'m': edges.src['h']} def _reduce(nodes): return {'h1': nodes.data['h'] + mx.nd.sum(nodes.mailbox['m'], 1)} def _apply(nodes): return {'h1': nodes.data['h1'] * 2} def _init2(shape, dtype, ctx, ids): return 2 + mx.nd.zeros(shape=shape, dtype=dtype, ctx=ctx) g.register_message_func(_message) g.register_reduce_func(_reduce) g.register_apply_node_func(_apply) # test#1: recv both 0deg and non-0deg nodes old = mx.nd.random.normal(shape=(2, 5)) g.set_n_initializer(_init2, 'h1') g.ndata['h'] = old g.send((0, 1)) g.recv([0, 1]) new = g.ndata.pop('h1') # 0deg check: initialized with the func and got applied assert np.allclose(new[0].asnumpy(), np.full((5, ), 4)) # non-0deg check assert np.allclose(new[1].asnumpy(), mx.nd.sum(old, 0).asnumpy() * 2) # test#2: recv only 0deg node old = mx.nd.random.normal(shape=(2, 5)) g.ndata['h'] = old g.ndata['h1'] = mx.nd.full((2, 5), -1) # this is necessary g.send((0, 1)) g.recv(0) new = g.ndata.pop('h1') # 0deg check: fallback to apply assert np.allclose(new[0].asnumpy(), np.full((5, ), -2)) # non-0deg check: not changed assert np.allclose(new[1].asnumpy(), np.full((5, ), -1))
def test_recv_0deg(): # test recv with 0deg nodes; g = DGLGraph() g.add_nodes(2) g.add_edge(0, 1) def _message(edges): return {'m': edges.src['h']} def _reduce(nodes): return {'h': nodes.data['h'] + nodes.mailbox['m'].sum(1)} def _apply(nodes): return {'h': nodes.data['h'] * 2} def _init2(shape, dtype, ctx, ids): return 2 + th.zeros(shape, dtype=dtype, device=ctx) g.register_message_func(_message) g.register_reduce_func(_reduce) g.register_apply_node_func(_apply) g.set_n_initializer(_init2, 'h') # test#1: recv both 0deg and non-0deg nodes old = th.randn((2, 5)) g.ndata['h'] = old g.send((0, 1)) g.recv([0, 1]) new = g.ndata.pop('h') # 0deg check: initialized with the func and got applied assert U.allclose(new[0], th.full((5, ), 4)) # non-0deg check assert U.allclose(new[1], th.sum(old, 0) * 2) # test#2: recv only 0deg node is equal to apply old = th.randn((2, 5)) g.ndata['h'] = old g.send((0, 1)) g.recv(0) new = g.ndata.pop('h') # 0deg check: equal to apply_nodes assert U.allclose(new[0], 2 * old[0]) # non-0deg check: untouched assert U.allclose(new[1], old[1])
class TopDownNet(nn.Module): def __init__(self, h_dims=128, n_classes=10, filters=[16, 32, 64, 128, 256], kernel_size=(3, 3), final_pool_size=(2, 2), glimpse_type='gaussian', glimpse_size=(15, 15), cnn='cnn'): from networkx.algorithms.traversal.breadth_first_search import bfs_edges nn.Module.__init__(self) t = nx.balanced_tree(1, 2) self.G = DGLGraph(t) self.root = 0 #self.walk_list = bfs_edges(t, self.root) self.walk_list = [(0, 1), (1, 2)] self.h_dims = h_dims self.n_classes = n_classes self.update_module = UpdateModule( h_dims=h_dims, n_classes=n_classes, filters=filters, kernel_size=kernel_size, final_pool_size=final_pool_size, glimpse_type=glimpse_type, glimpse_size=glimpse_size, cnn='cnn', ) self.message_module = MessageModule( h_dims=h_dims, g_dims=self.update_module.glimpse.att_params) self.readout_module = ReadoutModule( h_dims=h_dims, n_classes=n_classes, ) self.G.register_message_func(self.message_module) self.G.register_update_func(self.update_module) self.G.register_readout_func(self.readout_module) def forward(self, x): batch_size = x.shape[0] g_dims = self.update_module.glimpse.att_params self.update_module.set_image(x) zero_tensor_x = lambda r, c: \ x.new(r, c).zero_() init_states = { 's': zero_tensor_x(batch_size, self.h_dims), 'a': ( zero_tensor_x(batch_size, self.h_dims), zero_tensor_x(batch_size, g_dims), ), 'g': None, 'c': zero_tensor_x(batch_size, 1), } for n in self.G.nodes(): self.G.node[n].update(init_states) self.G.recvfrom(self.root, []) # Update root node self.G.propagate(self.walk_list) return self.G.readout()
class DFSGlimpseSingleObjectClassifier(nn.Module): def __init__( self, h_dims=128, n_classes=10, filters=[16, 32, 64, 128, 256], kernel_size=(3, 3), final_pool_size=(2, 2), glimpse_type='gaussian', glimpse_size=(15, 15), cnn='cnn', cnn_file='cnn.pt', ): nn.Module.__init__(self) #self.T_MAX_RECUR = kwarg['steps'] t = nx.balanced_tree(2, 2) t_uni = nx.bfs_tree(t, 0) self.G = DGLGraph(t) self.root = 0 self.h_dims = h_dims self.n_classes = n_classes self.message_module = MessageModule() self.G.register_message_func(self.message_module) # default: just copy cnnmodule = CNN( cnn=cnn, n_layers=6, h_dims=h_dims, n_classes=n_classes, final_pool_size=final_pool_size, filters=filters, kernel_size=kernel_size, input_size=glimpse_size, ) if cnn_file is not None: cnnmodule.load_state_dict(T.load(cnn_file)) #self.update_module = UpdateModule(h_dims, n_classes, glimpse_size) self.update_module = UpdateModule( glimpse_type=glimpse_type, glimpse_size=glimpse_size, cnn=cnnmodule, max_recur=1, # T_MAX_RECUR n_classes=n_classes, h_dims=h_dims, ) self.G.register_update_func(self.update_module) self.readout_module = ReadoutModule(h_dims=h_dims, n_classes=n_classes) self.G.register_readout_func(self.readout_module) #self.walk_list = [(0, 1), (1, 2), (2, 1), (1, 0)] self.walk_list = [] dfs_walk(t_uni, self.root, self.walk_list) def forward(self, x, pretrain=False): batch_size = x.shape[0] self.update_module.set_image(x) init_states = { 'h': x.new(batch_size, self.h_dims).zero_(), 'b': x.new(batch_size, self.update_module.glimpse.att_params).zero_(), 'b_next': x.new(batch_size, self.update_module.glimpse.att_params).zero_(), 'a': x.new(batch_size, 1).zero_(), 'y': x.new(batch_size, self.n_classes).zero_(), 'g': None, 'b_fix': None, 'db': None, } for n in self.G.nodes(): self.G.node[n].update(init_states) #TODO: the following two lines is needed for single object #TODO: but not useful or wrong for multi-obj self.G.recvfrom(self.root, []) if pretrain: return self.G.readout([self.root], pretrain=True) else: #for u, v in self.walk_list: # self.G.update_by_edge((u, v)) # update local should be inside the update module #for i in self.T_MAX_RECUR: # self.G.update_local(u) self.G.propagate(self.walk_list) return self.G.readout('all', pretrain=False)