예제 #1
0
def test_copy_src():
    # copy_src with both fields
    g = generate_graph()
    g.register_message_func(fn.copy_src(src='h', out='m'))
    g.register_reduce_func(reducer_both)
    g.update_all()
    assert F.allclose(g.ndata['h'],
                      F.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
예제 #2
0
    def _test(apply_func):
        g = generate_graph()
        f = g.ndata['f']

        # an out place run to get result
        g.pull(nodes, fn.copy_src(src='f', out='m'), fn.sum(msg='m', out='f'),
               apply_func)
        result = g.ndata['f']

        # inplace deg bucket
        v1 = F.clone(f)
        g.ndata['f'] = v1
        g.pull(nodes, message_func, reduce_func, apply_func, inplace=True)
        r1 = g.ndata['f']
        # check result
        assert F.allclose(r1, result)
        # check inplace
        assert F.allclose(v1, r1)

        # inplace v2v spmv
        v1 = F.clone(f)
        g.ndata['f'] = v1
        g.pull(nodes,
               fn.copy_src(src='f', out='m'),
               fn.sum(msg='m', out='f'),
               apply_func,
               inplace=True)
        r1 = g.ndata['f']
        # check result
        assert F.allclose(r1, result)
        # check inplace
        assert F.allclose(v1, r1)

        # inplace e2v spmv
        v1 = F.clone(f)
        g.ndata['f'] = v1
        g.pull(nodes,
               message_func,
               fn.sum(msg='m', out='f'),
               apply_func,
               inplace=True)
        r1 = g.ndata['f']
        # check result
        assert F.allclose(r1, result)
        # check inplace
        assert F.allclose(v1, r1)
예제 #3
0
    def forward(self, G):
        assert G.number_of_nodes() == self.G.number_of_nodes()
        G.ndata['deg'] = self.deg

        G.update_all(FN.copy_src('h', 'h'), FN.sum('h', 'h_agg'))  # mean, max, sum

        G.apply_nodes(self.disease_update, self.disease_nodes)
        G.apply_nodes(self.miran_update, self.mirna_nodes)
    def forward(self, graph, feat):
        r"""Compute GraphSAGE layer.

        Parameters
        ----------
        graph : DGLGraph
            The graph.
        feat : torch.Tensor or pair of torch.Tensor
            If a torch.Tensor is given, the input feature of shape :math:`(N, D_{in})` where
            :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
            If a pair of torch.Tensor is given, the pair must contain two tensors of shape
            :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.

        Returns
        -------
        torch.Tensor
            The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
            is size of output feature.
        """
        graph = graph.local_var()

        if isinstance(feat, tuple):
            feat_src = self.feat_drop(feat[0])
            feat_dst = self.feat_drop(feat[1])
        else:
            feat_src = feat_dst = self.feat_drop(feat)

        h_self = feat_dst

        graph.srcdata['h'] = feat_src
        if self._aggr == 'sum':
            graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh'))
        elif self._aggr == 'mean':
            graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh'))
        else:
            return ValueError(
                "Expect aggregation to be 'sum' or 'mean', got {}".format(
                    self._aggr))
        h_neigh = graph.dstdata['neigh']
        rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)

        # activation
        if self.activation is not None:
            rst = self.activation(rst)
        return rst
예제 #5
0
def compute_pagerank(g):
    g.ndata['pv'] = torch.ones(N) / N
    degrees = g.out_degrees(g.nodes()).type(torch.float32)
    for k in range(K):
        g.ndata['pv'] = g.ndata['pv'] / degrees
        g.update_all(message_func=fn.copy_src(src='pv', out='m'),
                     reduce_func=fn.sum(msg='m', out='pv'))
        g.ndata['pv'] = (1 - DAMP) / N + DAMP * g.ndata['pv']
    return g.ndata['pv']
예제 #6
0
    def forward(self, graph, features):
        features = self.transform(features)

        graph.ndata['h'] = features
        graph.update_all(fn.copy_src(src='h', out='m'), self.agg_fn)

        features = torch.cat([features, graph.ndata['h']], -1)

        return features
예제 #7
0
    def forward(self, fact_graph, img_graph, sem_graph):
        self.img_graph = img_graph
        self.fact_graph = fact_graph
        self.sem_graph = sem_graph

        fact_graph.apply_nodes(func=self.apply_node)
        fact_graph.update_all(message_func=fn.copy_src(src='h', out='m'),
                              reduce_func=self.reduce)
        return fact_graph
예제 #8
0
    def forward(self, graph, fea):
        with graph.local_scope():
           feat_src = torch.mm(fea, self.weight)
           graph.srcdata['h'] = feat_src
           graph.update_all(fn.copy_src('h', 'm'), fn.sum(msg='m', out='h'))
           rst = graph.dstdata['h']
           rst = self.activation(rst)

           return rst
예제 #9
0
def check_prop_flows(create_node_flow):
    num_layers = 2
    g = generate_rand_graph(100)
    g.ndata['h'] = g.ndata['h1']
    nf2 = create_node_flow(g, num_layers)
    nf2.copy_from_parent()
    # Test the computation on a layer at a time.
    for i in range(num_layers):
        g.update_all(fn.copy_src(src='h', out='m'), fn.sum(msg='m', out='t'),
                     lambda nodes: {'h': nodes.data['t'] + 1})

    # Test the computation on all layers.
    nf2.prop_flow(fn.copy_src(src='h', out='m'), fn.sum(msg='m', out='t'),
                  lambda nodes: {'h': nodes.data['t'] + 1})
    assert_allclose(F.asnumpy(nf2.layers[-1].data['h']),
                    F.asnumpy(g.nodes[nf2.layer_parent_nid(-1)].data['h']),
                    rtol=1e-4,
                    atol=1e-4)
예제 #10
0
def pagerank_builtin(g):
    g.ndata['pv'] = g.ndata['pv'] / g.ndata['deg']
    g.update_all(
        message_func=fn.copy_src(
            src='pv',
            out='m'),  # compute the output using the source node feature data
        reduce_func=fn.sum(
            msg='m', out='m_sum'))  # sum the messages in the node’s mailbox
    g.ndata['pv'] = (1 - DAMP) / N + DAMP * g.ndata['m_sum']  # update
예제 #11
0
    def forward(self, nf):
        nf.layers[0].data['activation'] = nf.layers[0].data['features']

        for i, layer in enumerate(self.layers):
            h = nf.layers[i].data.pop('activation')

            if self.dropout:
                h = self.dropout(h)

            if i == 0:
                nf.layers[i].data['h'] = h

                sum_nf = []  # 替换
                self_nodes_nf = []  # 自己nf上的ID
                sum_nf1 = nf.layer_parent_nid(0).numpy().tolist()
                sum_nf2 = nf.layer_parent_nid(1).numpy().tolist()
                sum_nf3 = nf.layer_parent_nid(2).numpy().tolist()
                sum_nf = sum_nf + sum_nf1 + sum_nf2 + sum_nf3  # 总共用到的节点的ID, 是在new_g中的ID
                for index8 in range(len(sum_nf1)):
                    if sum_nf1[index8] in (sub_node_id[0]):  # 如果第一层的点是自己家的
                        self_nodes_nf.append(nf.map_from_parent_nid(0, sum_nf1[index8]).item())
                nf.copy_to_parent()  # 赋值到new_g中


                for index8 in range(len(sum_nf1)):
                    if sum_nf1[index8] not in (sub_node_id[0]):                         # 不是自己方
                        for n2 in range(len(send_index1)):
                            map_g = new_g.parent_nid[sum_nf1[index8]]  # 在g中的ID
                            if map_g in sub_node_list[send_index1[n2]]:  # send_index1[n2] 方
                                fc.weight = nn.Parameter(Model[send_index1[in2]]['layers.0.linear.weight'])
                                fc.bias = nn.Parameter(Model[send_index1[in2]]['layers.0.linear.bias'])
                                new_g.nodes[sum_nf1[index8]].data['activation'] = \
                                    fc(g.nodes[map_g].data['features'])

                nf.copy_from_parent()

                nf.apply_layer(0, layer, v=self_nodes_nf)  # layer0 降维

                print(i)
                print(layer)

            else:
                nf.layers[i].data['h'] = h
                nf.block_compute(i,
                                 fn.copy_src(src='h', out='m'),
                                 lambda node: {'h': node.mailbox['m'].mean(dim=1)},
                                 layer)
                print(i)
                print(layer)
            # print(i)
            # print(layer)
            # print(len(nf.layers[i].data['h'][0]))  # 该层的h
            # print(len(nf.layers[-1].data['activation'][0]))

        h = nf.layers[-1].data.pop('activation')

        return h
예제 #12
0
 def forward(self, nf):
     nf.layers[0].data['activation'] = nf.layers[0].data['features']
     for i, layer in enumerate(self.layers):
         h = nf.layers[i].data.pop('activation')
         if self.dropout:
             h = mx.nd.Dropout(h, p=self.dropout)
         nf.layers[i].data['h'] = h
         self.layers[i].set_old_features(nf.layers[i + 1].data['features'])
         if (self.msg_fn == "mean"):
             nf.block_compute(
                 i, fn.copy_src(src='h', out='m'),
                 lambda node: {'h': node.mailbox['m'].sum(axis=1)}, layer)
         else:
             nf.block_compute(
                 i, fn.copy_src(src='h', out='m'),
                 lambda node: {'h': node.mailbox['m'].sum(axis=1)}, layer)
     h = nf.layers[-1].data.pop('activation')
     return h
예제 #13
0
    def forward(self, graph, features):
        graph.ndata['h'] = features
        graph.update_all(fn.copy_src(src='h', out='m'),
                         fn.mean(msg='m', out='h'))

        if self.residual:
            graph.ndata['h'] += features

        return graph, graph.ndata['h']
예제 #14
0
    def forward(self, g, node_mask):
        # collect features from source nodes and aggregate them in destination nodes
        g.update_all(fn.copy_src('nodes', 'message'), fn.sum('message', 'message_sum'))
        msg = g.ndata.pop('message_sum')
        nodes = self.update_GRU(msg, g.ndata['nodes'])

        g.apply_edges(fn.u_mul_v('nodes', 'nodes', 'edge_message'))
        edges = g.edata.pop('edge_spans') * g.edata.pop('edge_message').unsqueeze(-1)
        return nodes, edges
예제 #15
0
파일: wln.py 프로젝트: zwvews/dgl-lifesci
    def forward(self, g, node_feats, edge_feats):
        """Performs message passing and updates node representations.

        Parameters
        ----------
        g : DGLGraph
            DGLGraph for a batch of graphs
        node_feats : float32 tensor of shape (V, node_in_feats)
            Input node features. V for the number of nodes.
        edge_feats : float32 tensor of shape (E, edge_in_feats)
            Input edge features. E for the number of edges.

        Returns
        -------
        float32 tensor of shape (V, node_out_feats)
            Updated node representations.
        """
        if self.project_in_feats:
            node_feats = self.project_node_in_feats(node_feats)
        for _ in range(self.n_layers):
            g = g.local_var()
            if g.num_edges() > 0:
                # The following lines do not work for a graph without edges.
                g.ndata['hv'] = node_feats
                g.apply_edges(fn.copy_src('hv', 'he_src'))
                concat_edge_feats = torch.cat([g.edata['he_src'], edge_feats], dim=1)
                g.edata['he'] = self.project_concatenated_messages(concat_edge_feats)
                g.update_all(fn.copy_edge('he', 'm'), fn.sum('m', 'hv_new'))
                node_feats = self.get_new_node_feats(
                    torch.cat([node_feats, g.ndata['hv_new']], dim=1))
            else:
                # If we don't have edges, above formula becomes very simple.
                # The sum over the neighbors is zero then.
                # Refer to equations in section S2.2 of
                # http://www.rsc.org/suppdata/c8/sc/c8sc04228d/c8sc04228d2.pdf
                node_feats = self.get_new_node_feats(
                    torch.cat([node_feats, node_feats*0], dim=1))

        if not self.set_comparison:
            return node_feats
        else:
            if g.num_edges() > 0:
                # The following lines don't work for a graph without edges
                g = g.local_var()
                g.ndata['hv'] = self.project_node_messages(node_feats)
                g.edata['he'] = self.project_edge_messages(edge_feats)
                g.update_all(fn.u_mul_e('hv', 'he', 'm'), fn.sum('m', 'h_nbr'))
                h_self = self.project_self(node_feats)  # (V, node_out_feats)
                return g.ndata['h_nbr'] * h_self
            else:
                # If the graph has no edges, the formula becomes very simple.
                # The sum over the neighbors is zero then.
                # Refer to equations in section S2.5 of
                # http://www.rsc.org/suppdata/c8/sc/c8sc04228d/c8sc04228d2.pdf
                return torch.zeros((g.num_nodes(), self.project_self.out_feats),
                                   device=node_feats.device)
예제 #16
0
    def forward(self, feat, graph, mask=None):
        if self._jump:
            _feat = feat

        if self._norm:
            if mask is None:
                norm = torch.pow(graph.in_degrees().float(), -0.5)
                norm.masked_fill_(graph.in_degrees() == 0, 1.0)
                shp = norm.shape + (1, ) * (feat.dim() - 1)
                norm = torch.reshape(norm, shp).to(feat.device)
                feat = feat * norm.unsqueeze(1)
            else:
                graph.ndata['h'] = mask.float()
                graph.update_all(fn.copy_src(src='h', out='m'),
                                 fn.sum(msg='m', out='h'))
                masked_deg = graph.ndata.pop('h')
                norm = torch.pow(masked_deg, -0.5)
                norm.masked_fill_(masked_deg == 0, 1.0)
                feat = feat * norm.unsqueeze(-1)

        if mask is not None:
            feat = mask.float().unsqueeze(-1) * feat

        graph.ndata['h'] = feat
        graph.update_all(fn.copy_src(src='h', out='m'), fn.sum(msg='m',
                                                               out='h'))
        rst = graph.ndata.pop('h')

        if self._norm:
            rst = rst * norm.unsqueeze(-1)

        if self._jump:
            rst = torch.cat([rst, _feat], dim=-1)

        rst = torch.matmul(rst, self.weight)

        if self.bias is not None:
            rst = rst + self.bias

        if self._activation is not None:
            rst = self._activation(rst)

        return rst
예제 #17
0
    def forward(self, nf):
        nf.layers[0].data['activation'] = nf.layers[0].data['features']

        for i, layer in enumerate(self.layers):
            h = nf.layers[i].data.pop('activation')
            nf.layers[i].data['h'] = h
            nf.block_compute(i, fn.copy_src(src='h', out='m'),
                             fn.sum(msg='m', out='h'), layer)
        h = nf.layers[-1].data.pop('activation')
        return h
예제 #18
0
def check_compute_func(worker_id, graph_name):
    time.sleep(3)
    print("worker starts")
    g = dgl.contrib.graph_store.create_graph_from_store(graph_name,
                                                        "shared_mem",
                                                        port=rand_port)
    g._sync_barrier()
    in_feats = g.nodes[0].data['feat'].shape[1]

    # Test update all.
    g.update_all(fn.copy_src(src='feat', out='m'),
                 fn.sum(msg='m', out='preprocess'))
    adj = g.adjacency_matrix()
    tmp = mx.nd.dot(adj, g.nodes[:].data['feat'])
    assert np.all((g.nodes[:].data['preprocess'] == tmp).asnumpy())
    g._sync_barrier()
    check_array_shared_memory(g, worker_id, [g.nodes[:].data['preprocess']])

    # Test apply nodes.
    data = g.nodes[:].data['feat']
    g.apply_nodes(func=lambda nodes: {'feat': mx.nd.ones((1, in_feats)) * 10},
                  v=0)
    assert np.all(data[0].asnumpy() == g.nodes[0].data['feat'].asnumpy())

    # Test apply edges.
    data = g.edges[:].data['feat']
    g.apply_edges(func=lambda edges: {'feat': mx.nd.ones((1, in_feats)) * 10},
                  edges=0)
    assert np.all(data[0].asnumpy() == g.edges[0].data['feat'].asnumpy())

    g.init_ndata('tmp', (g.number_of_nodes(), 10), 'float32')
    data = g.nodes[:].data['tmp']
    # Test pull
    g.pull(1, fn.copy_src(src='feat', out='m'), fn.sum(msg='m', out='tmp'))
    assert np.all(data[1].asnumpy() == g.nodes[1].data['preprocess'].asnumpy())

    # Test send_and_recv
    in_edges = g.in_edges(v=2)
    g.send_and_recv(in_edges, fn.copy_src(src='feat', out='m'),
                    fn.sum(msg='m', out='tmp'))
    assert np.all(data[2].asnumpy() == g.nodes[2].data['preprocess'].asnumpy())

    g.destroy()
예제 #19
0
    def forward(self, graph, features):
        graph.ndata['h'] = features
        graph.update_all(fn.copy_src(src='h', out='m'),
                         fn.mean(msg='m', out='h'))

        # self-addition
        graph.ndata['h'] += features
        graph.ndata['h'] *= 0.5

        return graph, graph.ndata['h']
예제 #20
0
    def forward(self, g, feature):
        g.ndata["h"] = feature
        g.update_all(message_func=fn.copy_src(src="h", out="m"),
                     reduce_func=fn.sum(msg="m", out="h"))
        g.apply_nodes(func=self.apply_mod)

        res = g.ndata.pop("h")
        if self.batchnorm:
            res = self.bn(res)
        return res
예제 #21
0
    def pool_agg(self, g):
        x = g.ndata['x']
        x = self.dropout(x)
        h = torch.matmul(x, self.weight_pool_in)
        if self.bias:
            h = h + self.bias_in
        g.srcdata['h'] = h
        for i in range(self.K):
            if i == 0:
                g.update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh'))
                h_neigh = g.dstdata['neigh']
                h = torch.matmul(torch.cat([g.srcdata['h'], h_neigh], dim=1), self.weight_in)
                if self.activation:
                    h = self.activation(h, inplace=False)
                norm = torch.norm(h, dim=1)
                h = h / (norm.unsqueeze(-1) + 0.05)
                g.srcdata['h'] = h
            elif i == self.K - 1:
                h = torch.matmul(g.srcdata['h'], self.weight_pool_hid)
                if self.bias:
                    h = h + self.bias_hid
                g.srcdata['h'] = h
                g.update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh'))
                h_neigh = g.dstdata['neigh']
                h = torch.matmul(torch.cat([g.srcdata['h'], h_neigh], dim=1), self.weight_out)

                norm = torch.norm(h, dim=1)
                h = h / (norm.unsqueeze(-1) + 0.05)
                g.ndata['z'] = h
            else:
                h = torch.matmul(g.srcdata['h'], self.weight_pool_hid)
                if self.bias:
                    h = h + self.bias_hid
                g.srcdata['h'] = h
                g.update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh'))
                h_neigh = g.dstdata['neigh']
                h = torch.matmul(torch.cat([g.srcdata['h'], h_neigh], dim=1), self.weight_hid[i-1, :, :])
                if self.activation:
                    h = self.activation(h, inplace=False)
                norm = torch.norm(h, dim=1)
                h = h / (norm.unsqueeze(-1) + 0.05)
                g.srcdata['h'] = h
        return g
예제 #22
0
파일: gcn_v3.py 프로젝트: xcgoner/dgl
 def hybrid_forward(self, F, h):
     self.g.ndata['h'] = h * self.g.ndata['out_norm']
     self.g.update_all(fn.copy_src(src='h', out='m'),
                       fn.sum(msg='m', out='accum'))
     accum = self.g.ndata.pop('accum')
     accum = self.dense(accum * self.g.ndata['in_norm'])
     if self.dropout:
         accum = F.Dropout(accum, p=self.dropout)
     h = accum
     return h
예제 #23
0
def check_flow_compute1(create_node_flow, use_negative_block_id=False):
    num_layers = 2
    g = generate_rand_graph(100)

    # test the case that we register UDFs per block.
    nf = create_node_flow(g, num_layers)
    nf.copy_from_parent()
    g.ndata['h'] = g.ndata['h1']
    nf.layers[0].data['h'] = nf.layers[0].data['h1']
    for i in range(num_layers):
        l = -num_layers + i if use_negative_block_id else i
        nf.register_message_func(fn.copy_src(src='h', out='m'), l)
        nf.register_reduce_func(fn.sum(msg='m', out='t'), l)
        nf.register_apply_node_func(lambda nodes: {'h': nodes.data['t'] + 1},
                                    l)
        nf.block_compute(l)
        g.update_all(fn.copy_src(src='h', out='m'), fn.sum(msg='m', out='t'),
                     lambda nodes: {'h': nodes.data['t'] + 1})
        assert_allclose(F.asnumpy(nf.layers[i + 1].data['h']),
                        F.asnumpy(g.nodes[nf.layer_parent_nid(i +
                                                              1)].data['h']),
                        rtol=1e-4,
                        atol=1e-4)

    # test the case that we register UDFs in all blocks.
    nf = create_node_flow(g, num_layers)
    nf.copy_from_parent()
    g.ndata['h'] = g.ndata['h1']
    nf.layers[0].data['h'] = nf.layers[0].data['h1']
    nf.register_message_func(fn.copy_src(src='h', out='m'))
    nf.register_reduce_func(fn.sum(msg='m', out='t'))
    nf.register_apply_node_func(lambda nodes: {'h': nodes.data['t'] + 1})
    for i in range(num_layers):
        l = -num_layers + i if use_negative_block_id else i
        nf.block_compute(l)
        g.update_all(fn.copy_src(src='h', out='m'), fn.sum(msg='m', out='t'),
                     lambda nodes: {'h': nodes.data['t'] + 1})
        assert_allclose(F.asnumpy(nf.layers[i + 1].data['h']),
                        F.asnumpy(g.nodes[nf.layer_parent_nid(i +
                                                              1)].data['h']),
                        rtol=1e-4,
                        atol=1e-4)
예제 #24
0
 def forward(self, x):
     x = torch.mm(x, self.weight)
     x = x * self.g.ndata['norm']
     self.g.ndata['x'] = x
     self.g.update_all(fn.copy_src(src='x', out='m'),
                       fn.sum(msg='m', out='x'))
     x = self.g.ndata.pop('x')
     x = x * self.g.ndata['norm']
     if self.bias is not None:
         x = x + self.bias
     return x
예제 #25
0
 def precalc(self, g):
     norm = self.get_norm(g)
     g.ndata['norm'] = norm
     features = g.ndata['feat']
     print("features shape, ", features.shape)
     with torch.no_grad():
         g.update_all(fn.copy_src(src='feat', out='m'),
                      fn.sum(msg='m', out='feat'), None)
         pre_feats = g.ndata['feat'] * norm
         # use graphsage embedding aggregation style
         g.ndata['feat'] = torch.cat([features, pre_feats], dim=1)
예제 #26
0
 def precalc(self, g):
     norm = self.get_norm(g)
     g.ndata['norm'] = norm
     features = g.ndata['features']
     print("features shape, ", features.shape)
     with torch.no_grad():
         g.update_all(fn.copy_src(src='features', out='m'),
                         fn.sum(msg='m', out='features'),
                         None)
         pre_feats = g.ndata['features'] * norm
         g.ndata['features'] = torch.cat([features, pre_feats], dim=1)
예제 #27
0
파일: gcn_concat.py 프로젝트: xcgoner/dgl
 def forward(self, h):
     self.g.ndata['h'] = h * self.g.ndata['out_norm']
     self.g.update_all(fn.copy_src(src='h', out='m'),
                       fn.sum(msg='m', out='accum'))
     accum = self.g.ndata.pop('accum')
     accum = self.dense(accum * self.g.ndata['in_norm'])
     if self.dropout:
         accum = mx.nd.Dropout(accum, p=self.dropout)
     h = self.g.ndata.pop('h')
     h = mx.nd.concat(h / self.g.ndata['out_norm'], accum, dim=1)
     return h
예제 #28
0
    def _take_action(self, action):
        undecided = self.x == 2
        self.x[undecided] = action[undecided]
        self.t += 1

        x1 = (self.x == 1)
        self.g = self.g.to(self.device)
        self.g.ndata['h'] = x1.float()
        self.g.update_all(fn.copy_src(src='h', out='m'),
                          fn.sum(msg='m', out='h'))
        x1_deg = self.g.ndata.pop('h')

        ## forgive clashing
        clashed = x1 & (x1_deg > 0)
        self.x[clashed] = 2
        x1_deg[clashed] = 0

        # graph clean up
        still_undecided = (self.x == 2)
        self.x[still_undecided & (x1_deg > 0)] = 0

        # fill timeout with zeros
        still_undecided = (self.x == 2)
        timeout = (self.t == self.max_epi_t)
        self.x[still_undecided & timeout] = 0

        done = self._check_done()
        self.epi_t[~done] += 1

        # compute reward and solution
        x1 = (self.x == 1).float()
        node_sol = x1

        h = node_sol
        self.g.ndata['h'] = h
        next_sol = dgl.sum_nodes(self.g, 'h')
        self.g.ndata.pop('h')

        reward = (next_sol - self.sol)

        if self.hamming_reward_coef > 0.0 and self.num_samples == 2:
            xl, xr = self.x.split(1, dim=1)
            undecidedl, undecidedr = undecided.split(1, dim=1)
            hamming_d = torch.abs(xl.float() - xr.float())
            hamming_d[(xl == 2) | (xr == 2)] = 0.0
            hamming_d[~undecidedl & ~undecidedr] = 0.0
            self.g.ndata['h'] = hamming_d
            hamming_reward = dgl.sum_nodes(self.g, 'h').expand_as(reward)
            self.g.ndata.pop('h')
            reward += self.hamming_reward_coef * hamming_reward

        reward /= self.max_num_nodes

        return reward, next_sol, done
예제 #29
0
 def call(self, h):
     if self.dropout:
         h = self.dropout(h)
     self.g.ndata['h'] = tf.matmul(h, self.weight)
     self.g.ndata['norm_h'] = self.g.ndata['h'] * self.g.ndata['norm']
     self.g.update_all(fn.copy_src('norm_h', 'm'), fn.sum('m', 'h'))
     h = self.g.ndata['h']
     if self.bias is not None:
         h = h + self.bias
     if self.activation:
         h = self.activation(h)
     return h
예제 #30
0
 def forward(self, g, h):
     h = self.dropout(h)
     g.ndata['h'] = h
     if self.use_bn and not hasattr(self, 'bn'):
         device = h.device
         self.bn = nn.BatchNorm1d(h.size()[1]).to(device)
     g.update_all(fn.copy_src(src='h', out='m'), self.aggregator,
                  self.bundler)
     if self.use_bn:
         h = self.bn(h)
     h = g.ndata.pop('h')
     return h