示例#1
0
def test_multi_recv_0deg():
    # test recv with 0deg nodes;
    g = DGLGraph()
    def _message(edges):
        return {'m' : edges.src['h']}
    def _reduce(nodes):
        return {'h' : nodes.data['h'] + nodes.mailbox['m'].sum(1)}
    def _apply(nodes):
        return {'h' : nodes.data['h'] * 2}
    def _init2(shape, dtype, ctx, ids):
        return 2 + th.zeros(shape, dtype=dtype, device=ctx)
    g.register_message_func(_message)
    g.register_reduce_func(_reduce)
    g.register_apply_node_func(_apply)
    g.set_n_initializer(_init2)
    g.add_nodes(2)
    g.add_edge(0, 1)
    # recv both 0deg and non-0deg nodes
    old = th.randn((2, 5))
    g.ndata['h'] = old
    g.send((0, 1))
    g.recv([0, 1])
    new = g.ndata['h']
    # 0deg check: initialized with the func and got applied
    assert U.allclose(new[0], th.full((5,), 4))
    # non-0deg check
    assert U.allclose(new[1], th.sum(old, 0) * 2)

    # recv again on zero degree node
    g.recv([0])
    assert U.allclose(g.nodes[0].data['h'], th.full((5,), 8))

    # recv again on node with no incoming message
    g.recv([1])
    assert U.allclose(g.nodes[1].data['h'], th.sum(old, 0) * 4)
示例#2
0
def test_dynamic_addition():
    N = 3
    D = 1

    g = DGLGraph()

    def _init(shape, dtype, ctx, ids):
        return F.copy_to(F.astype(F.randn(shape), dtype), ctx)

    g.set_n_initializer(_init)
    g.set_e_initializer(_init)

    def _message(edges):
        return {
            'm':
            edges.src['h1'] + edges.dst['h2'] + edges.data['h1'] +
            edges.data['h2']
        }

    def _reduce(nodes):
        return {'h': F.sum(nodes.mailbox['m'], 1)}

    def _apply(nodes):
        return {'h': nodes.data['h']}

    g.register_message_func(_message)
    g.register_reduce_func(_reduce)
    g.register_apply_node_func(_apply)
    g.set_n_initializer(dgl.init.zero_initializer)
    g.set_e_initializer(dgl.init.zero_initializer)

    # add nodes and edges
    g.add_nodes(N)
    g.ndata.update({'h1': F.randn((N, D)), 'h2': F.randn((N, D))})
    g.add_nodes(3)
    g.add_edge(0, 1)
    g.add_edge(1, 0)
    g.edata.update({'h1': F.randn((2, D)), 'h2': F.randn((2, D))})
    g.send()
    expected = F.copy_to(F.ones((g.number_of_edges(), ), dtype=F.int64),
                         F.cpu())
    assert F.array_equal(g._get_msg_index().tousertensor(), expected)

    # add more edges
    g.add_edges([0, 2], [2, 0], {'h1': F.randn((2, D))})
    g.send(([0, 2], [2, 0]))
    g.recv(0)

    g.add_edge(1, 2)
    g.edges[4].data['h1'] = F.randn((1, D))
    g.send((1, 2))
    g.recv([1, 2])

    h = g.ndata.pop('h')

    # a complete round of send and recv
    g.send()
    g.recv()
    assert F.allclose(h, g.ndata['h'])
示例#3
0
def test_recv_0deg_newfld():
    # test recv with 0deg nodes; the reducer also creates a new field
    g = DGLGraph()
    g.add_nodes(2)
    g.add_edge(0, 1)

    def _message(edges):
        return {'m': edges.src['h']}

    def _reduce(nodes):
        return {'h1': nodes.data['h'] + mx.nd.sum(nodes.mailbox['m'], 1)}

    def _apply(nodes):
        return {'h1': nodes.data['h1'] * 2}

    def _init2(shape, dtype, ctx, ids):
        return 2 + mx.nd.zeros(shape=shape, dtype=dtype, ctx=ctx)

    g.register_message_func(_message)
    g.register_reduce_func(_reduce)
    g.register_apply_node_func(_apply)
    # test#1: recv both 0deg and non-0deg nodes
    old = mx.nd.random.normal(shape=(2, 5))
    g.set_n_initializer(_init2, 'h1')
    g.ndata['h'] = old
    g.send((0, 1))
    g.recv([0, 1])
    new = g.ndata.pop('h1')
    # 0deg check: initialized with the func and got applied
    assert np.allclose(new[0].asnumpy(), np.full((5, ), 4))
    # non-0deg check
    assert np.allclose(new[1].asnumpy(), mx.nd.sum(old, 0).asnumpy() * 2)

    # test#2: recv only 0deg node
    old = mx.nd.random.normal(shape=(2, 5))
    g.ndata['h'] = old
    g.ndata['h1'] = mx.nd.full((2, 5), -1)  # this is necessary
    g.send((0, 1))
    g.recv(0)
    new = g.ndata.pop('h1')
    # 0deg check: fallback to apply
    assert np.allclose(new[0].asnumpy(), np.full((5, ), -2))
    # non-0deg check: not changed
    assert np.allclose(new[1].asnumpy(), np.full((5, ), -1))
示例#4
0
def test_recv_0deg():
    # test recv with 0deg nodes;
    g = DGLGraph()
    g.add_nodes(2)
    g.add_edge(0, 1)

    def _message(edges):
        return {'m': edges.src['h']}

    def _reduce(nodes):
        return {'h': nodes.data['h'] + nodes.mailbox['m'].sum(1)}

    def _apply(nodes):
        return {'h': nodes.data['h'] * 2}

    def _init2(shape, dtype, ctx, ids):
        return 2 + th.zeros(shape, dtype=dtype, device=ctx)

    g.register_message_func(_message)
    g.register_reduce_func(_reduce)
    g.register_apply_node_func(_apply)
    g.set_n_initializer(_init2, 'h')
    # test#1: recv both 0deg and non-0deg nodes
    old = th.randn((2, 5))
    g.ndata['h'] = old
    g.send((0, 1))
    g.recv([0, 1])
    new = g.ndata.pop('h')
    # 0deg check: initialized with the func and got applied
    assert U.allclose(new[0], th.full((5, ), 4))
    # non-0deg check
    assert U.allclose(new[1], th.sum(old, 0) * 2)

    # test#2: recv only 0deg node is equal to apply
    old = th.randn((2, 5))
    g.ndata['h'] = old
    g.send((0, 1))
    g.recv(0)
    new = g.ndata.pop('h')
    # 0deg check: equal to apply_nodes
    assert U.allclose(new[0], 2 * old[0])
    # non-0deg check: untouched
    assert U.allclose(new[1], old[1])
示例#5
0
class TopDownNet(nn.Module):
    def __init__(self,
                 h_dims=128,
                 n_classes=10,
                 filters=[16, 32, 64, 128, 256],
                 kernel_size=(3, 3),
                 final_pool_size=(2, 2),
                 glimpse_type='gaussian',
                 glimpse_size=(15, 15),
                 cnn='cnn'):
        from networkx.algorithms.traversal.breadth_first_search import bfs_edges
        nn.Module.__init__(self)
        t = nx.balanced_tree(1, 2)
        self.G = DGLGraph(t)
        self.root = 0
        #self.walk_list = bfs_edges(t, self.root)
        self.walk_list = [(0, 1), (1, 2)]
        self.h_dims = h_dims
        self.n_classes = n_classes

        self.update_module = UpdateModule(
            h_dims=h_dims,
            n_classes=n_classes,
            filters=filters,
            kernel_size=kernel_size,
            final_pool_size=final_pool_size,
            glimpse_type=glimpse_type,
            glimpse_size=glimpse_size,
            cnn='cnn',
        )
        self.message_module = MessageModule(
            h_dims=h_dims, g_dims=self.update_module.glimpse.att_params)
        self.readout_module = ReadoutModule(
            h_dims=h_dims,
            n_classes=n_classes,
        )

        self.G.register_message_func(self.message_module)
        self.G.register_update_func(self.update_module)
        self.G.register_readout_func(self.readout_module)

    def forward(self, x):
        batch_size = x.shape[0]
        g_dims = self.update_module.glimpse.att_params

        self.update_module.set_image(x)
        zero_tensor_x = lambda r, c: \
            x.new(r, c).zero_()

        init_states = {
            's':
            zero_tensor_x(batch_size, self.h_dims),
            'a': (
                zero_tensor_x(batch_size, self.h_dims),
                zero_tensor_x(batch_size, g_dims),
            ),
            'g':
            None,
            'c':
            zero_tensor_x(batch_size, 1),
        }

        for n in self.G.nodes():
            self.G.node[n].update(init_states)

        self.G.recvfrom(self.root, [])  # Update root node
        self.G.propagate(self.walk_list)
        return self.G.readout()
示例#6
0
class DFSGlimpseSingleObjectClassifier(nn.Module):
    def __init__(
            self,
            h_dims=128,
            n_classes=10,
            filters=[16, 32, 64, 128, 256],
            kernel_size=(3, 3),
            final_pool_size=(2, 2),
            glimpse_type='gaussian',
            glimpse_size=(15, 15),
            cnn='cnn',
            cnn_file='cnn.pt',
    ):
        nn.Module.__init__(self)

        #self.T_MAX_RECUR = kwarg['steps']

        t = nx.balanced_tree(2, 2)
        t_uni = nx.bfs_tree(t, 0)
        self.G = DGLGraph(t)
        self.root = 0
        self.h_dims = h_dims
        self.n_classes = n_classes

        self.message_module = MessageModule()
        self.G.register_message_func(self.message_module)  # default: just copy

        cnnmodule = CNN(
            cnn=cnn,
            n_layers=6,
            h_dims=h_dims,
            n_classes=n_classes,
            final_pool_size=final_pool_size,
            filters=filters,
            kernel_size=kernel_size,
            input_size=glimpse_size,
        )
        if cnn_file is not None:
            cnnmodule.load_state_dict(T.load(cnn_file))

        #self.update_module = UpdateModule(h_dims, n_classes, glimpse_size)
        self.update_module = UpdateModule(
            glimpse_type=glimpse_type,
            glimpse_size=glimpse_size,
            cnn=cnnmodule,
            max_recur=1,  # T_MAX_RECUR
            n_classes=n_classes,
            h_dims=h_dims,
        )
        self.G.register_update_func(self.update_module)

        self.readout_module = ReadoutModule(h_dims=h_dims, n_classes=n_classes)
        self.G.register_readout_func(self.readout_module)

        #self.walk_list = [(0, 1), (1, 2), (2, 1), (1, 0)]
        self.walk_list = []
        dfs_walk(t_uni, self.root, self.walk_list)

    def forward(self, x, pretrain=False):
        batch_size = x.shape[0]

        self.update_module.set_image(x)
        init_states = {
            'h':
            x.new(batch_size, self.h_dims).zero_(),
            'b':
            x.new(batch_size, self.update_module.glimpse.att_params).zero_(),
            'b_next':
            x.new(batch_size, self.update_module.glimpse.att_params).zero_(),
            'a':
            x.new(batch_size, 1).zero_(),
            'y':
            x.new(batch_size, self.n_classes).zero_(),
            'g':
            None,
            'b_fix':
            None,
            'db':
            None,
        }
        for n in self.G.nodes():
            self.G.node[n].update(init_states)

        #TODO: the following two lines is needed for single object
        #TODO: but not useful or wrong for multi-obj
        self.G.recvfrom(self.root, [])

        if pretrain:
            return self.G.readout([self.root], pretrain=True)
        else:
            #for u, v in self.walk_list:
            #    self.G.update_by_edge((u, v))
            # update local should be inside the update module
            #for i in self.T_MAX_RECUR:
            #    self.G.update_local(u)
            self.G.propagate(self.walk_list)
            return self.G.readout('all', pretrain=False)