Пример #1
0
        e_feat = None
        if self.gnn_model == "gin":
            x, all_outputs = self.gnn(g, n_feat, e_feat)
        else:
            x, all_outputs = self.gnn(g, n_feat, e_feat), None
            x = self.set2set(g, x)
            x = self.lin_readout(x)
        if self.norm:
            x = F.normalize(x, p=2, dim=-1, eps=1e-5)
        if return_all_outputs:
            return x, all_outputs
        else:
            return x


if __name__ == "__main__":
    model = GraphEncoder(gnn_model="gin")
    print(model)
    g = dgl.DGLGraph()
    g.add_nodes(3)
    g.add_edges([0, 0, 1, 2], [1, 2, 2, 1])
    g.ndata["pos_directed"] = torch.rand(3, 16)
    g.ndata["pos_undirected"] = torch.rand(3, 16)
    g.ndata["seed"] = torch.zeros(3, dtype=torch.long)
    g.ndata["nfreq"] = torch.ones(3, dtype=torch.long)
    g.edata["efreq"] = torch.ones(4, dtype=torch.long)
    g = dgl.batch([g, g, g])
    y = model(g)
    print(y.shape)
    print(y)
Пример #2
0
 def batcher_dev(batch):
     batch_trees = dgl.batch(batch).to(device)
     return GCNBatch(graph=batch_trees,
                     labels=batch_trees.ndata[LABEL_NODE_NAME])
Пример #3
0
 def batcher_dev(batch):
     graph_q, graph_k = zip(*batch)
     graph_q, graph_k = dgl.batch(graph_q), dgl.batch(graph_k)
     return graph_q, graph_k
Пример #4
0
def test_moleculenet():
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    for dataset in ['BACE', 'BBBP', 'ClinTox', 'FreeSolv', 'HIV', 'MUV', 'SIDER', 'ToxCast',
                    'PCBA', 'ESOL', 'Lipophilicity', 'Tox21']:
        for featurizer_type in ['canonical', 'attentivefp']:
            if featurizer_type == 'canonical':
                node_featurizer = CanonicalAtomFeaturizer(atom_data_field='hv')
                edge_featurizer = CanonicalBondFeaturizer(bond_data_field='he', self_loop=True)
            else:
                node_featurizer = AttentiveFPAtomFeaturizer(atom_data_field='hv')
                edge_featurizer = AttentiveFPBondFeaturizer(bond_data_field='he', self_loop=True)

            for model_type in ['GCN', 'GAT']:
                g1 = smiles_to_bigraph('CO', node_featurizer=node_featurizer)
                g2 = smiles_to_bigraph('CCO', node_featurizer=node_featurizer)
                bg = dgl.batch([g1, g2])

                model = load_pretrained('{}_{}_{}'.format(
                    model_type, featurizer_type, dataset)).to(device)
                with torch.no_grad():
                    model(bg.to(device), bg.ndata.pop('hv').to(device))
                    model.eval()
                    model(g1.to(device), g1.ndata.pop('hv').to(device))
                remove_file('{}_{}_{}_pre_trained.pth'.format(
                    model_type.lower(), featurizer_type, dataset))

            for model_type in ['Weave', 'MPNN', 'AttentiveFP']:
                g1 = smiles_to_bigraph('CO', add_self_loop=True, node_featurizer=node_featurizer,
                                       edge_featurizer=edge_featurizer)
                g2 = smiles_to_bigraph('CCO', add_self_loop=True, node_featurizer=node_featurizer,
                                       edge_featurizer=edge_featurizer)
                bg = dgl.batch([g1, g2])

                model = load_pretrained('{}_{}_{}'.format(
                    model_type, featurizer_type, dataset)).to(device)
                with torch.no_grad():
                    model(bg.to(device), bg.ndata.pop('hv').to(device), bg.edata.pop('he').to(device))
                    model.eval()
                    model(g1.to(device), g1.ndata.pop('hv').to(device), g1.edata.pop('he').to(device))
                remove_file('{}_{}_{}_pre_trained.pth'.format(
                    model_type.lower(), featurizer_type, dataset))

        if dataset == 'ClinTox':
            continue

        node_featurizer = PretrainAtomFeaturizer()
        edge_featurizer = PretrainBondFeaturizer()
        for model_type in ['gin_supervised_contextpred', 'gin_supervised_infomax',
                           'gin_supervised_edgepred', 'gin_supervised_masking']:
            g1 = smiles_to_bigraph('CO', add_self_loop=True, node_featurizer=node_featurizer,
                                   edge_featurizer=edge_featurizer)
            g2 = smiles_to_bigraph('CCO', add_self_loop=True, node_featurizer=node_featurizer,
                                   edge_featurizer=edge_featurizer)
            bg = dgl.batch([g1, g2])

            model = load_pretrained('{}_{}'.format(model_type, dataset)).to(device)
            with torch.no_grad():
                node_feats = [
                    bg.ndata.pop('atomic_number').to(device),
                    bg.ndata.pop('chirality_type').to(device)
                ]
                edge_feats = [
                    bg.edata.pop('bond_type').to(device),
                    bg.edata.pop('bond_direction_type').to(device)
                ]
                model(bg.to(device), node_feats, edge_feats)
                model.eval()
                node_feats = [
                    g1.ndata.pop('atomic_number').to(device),
                    g1.ndata.pop('chirality_type').to(device)
                ]
                edge_feats = [
                    g1.edata.pop('bond_type').to(device),
                    g1.edata.pop('bond_direction_type').to(device)
                ]
                model(g1.to(device), node_feats, edge_feats)
            remove_file('{}_{}_pre_trained.pth'.format(model_type.lower(), dataset))
Пример #5
0
    def forward(self,
                node_num=None,
                feat=None,
                spatial_feat=None,
                word2vec=None,
                roi_label=None,
                validation=False,
                choose_nodes=None,
                remove_nodes=None):
        # set up graph
        batch_graph, batch_h_node_list, batch_obj_node_list, batch_h_h_e_list, batch_o_o_e_list, batch_h_o_e_list, batch_readout_edge_list, batch_readout_h_h_e_list, batch_readout_h_o_e_list = [], [], [], [], [], [], [], [], []
        node_num_cum = np.cumsum(node_num)  # !IMPORTANT
        for i in range(len(node_num)):
            # set node space
            node_space = 0
            if i != 0:
                node_space = node_num_cum[i - 1]
            graph, h_node_list, obj_node_list, h_h_e_list, o_o_e_list, h_o_e_list, readout_edge_list, readout_h_h_e_list, readout_h_o_e_list = self._build_graph(
                node_num[i],
                roi_label[i],
                node_space,
                diff_edge=self.diff_edge)
            # updata batch graph,
            batch_graph.append(graph)
            batch_h_node_list += h_node_list
            batch_obj_node_list += obj_node_list
            batch_h_h_e_list += h_h_e_list
            batch_o_o_e_list += o_o_e_list
            batch_h_o_e_list += h_o_e_list
            batch_readout_edge_list += readout_edge_list
            batch_readout_h_h_e_list += readout_h_h_e_list
            batch_readout_h_o_e_list += readout_h_o_e_list
        batch_graph = dgl.batch(batch_graph)

        # ipdb.set_trace()
        if not self.CONFIG1.feat_type == 'fc7':
            feat = self.graph_head(feat)

        # pass throuh gnn/gcn
        if self.layer == 1:
            self.grnn1(batch_graph,
                       batch_h_node_list,
                       batch_obj_node_list,
                       batch_h_h_e_list,
                       batch_o_o_e_list,
                       batch_h_o_e_list,
                       feat,
                       spatial_feat,
                       word2vec,
                       validation,
                       initial_feat=True)
            # batch_graph.apply_edges(self.edge_readout, tuple(zip(*(batch_readout_h_o_e_list+batch_readout_h_h_e_list))))
            batch_graph.apply_edges(self.edge_readout,
                                    tuple(zip(*batch_readout_edge_list)))

        elif self.layer == 2:
            feat, feat_lang = self.grnn1(batch_graph,
                                         batch_h_node_list,
                                         batch_obj_node_list,
                                         batch_h_h_e_list,
                                         batch_o_o_e_list,
                                         batch_h_o_e_list,
                                         feat,
                                         spatial_feat,
                                         word2vec,
                                         validation,
                                         pop_feat=True,
                                         initial_feat=True)
            self.grnn2(batch_graph, batch_h_node_list, batch_obj_node_list,
                       batch_h_h_e_list, batch_o_o_e_list, batch_h_o_e_list,
                       feat, spatial_feat, feat_lang, validation)
            if self.diff_edge:
                # update node feature at the last layer
                if not len(batch_h_node_list) == 0:
                    batch_graph.apply_nodes(self.h_node_update,
                                            batch_h_node_list)
                if not len(batch_obj_node_list) == 0:
                    batch_graph.apply_nodes(self.o_node_update,
                                            batch_obj_node_list)
            else:
                batch_graph.apply_nodes(
                    self.h_node_update,
                    batch_h_node_list + batch_obj_node_list)
            batch_graph.apply_edges(
                self.edge_readout,
                tuple(
                    zip(*(batch_readout_h_o_e_list +
                          batch_readout_h_h_e_list))))

        else:
            feat, feat_lang = self.grnn1(batch_graph,
                                         batch_h_node_list,
                                         batch_obj_node_list,
                                         batch_h_h_e_list,
                                         batch_o_o_e_list,
                                         batch_h_o_e_list,
                                         feat,
                                         spatial_feat,
                                         word2vec,
                                         validation,
                                         pop_feat=True,
                                         initial_feat=True)
            feat, feat_lang = self.grnn2(batch_graph,
                                         batch_h_node_list,
                                         batch_obj_node_list,
                                         batch_h_h_e_list,
                                         batch_o_o_e_list,
                                         batch_h_o_e_list,
                                         feat,
                                         spatial_feat,
                                         feat_lang,
                                         validation,
                                         pop_feat=True)
            self.grnn3(batch_graph, batch_h_node_list, batch_obj_node_list,
                       batch_h_h_e_list, batch_o_o_e_list, batch_h_o_e_list,
                       feat, spatial_feat, feat_lang, validation)
            if self.diff_edge:
                # update node feature at the last layer
                if not len(batch_h_node_list) == 0:
                    batch_graph.apply_nodes(self.h_node_update,
                                            batch_h_node_list)
                if not len(batch_obj_node_list) == 0:
                    batch_graph.apply_nodes(self.o_node_update,
                                            batch_obj_node_list)
            else:
                batch_graph.apply_nodes(
                    self.h_node_update,
                    batch_h_node_list + batch_obj_node_list)
            batch_graph.apply_edges(
                self.edge_readout,
                tuple(
                    zip(*(batch_readout_h_o_e_list +
                          batch_readout_h_h_e_list))))

        # import ipdb; ipdb.set_trace()
        if self.training or validation:
            # return batch_graph.edges[tuple(zip(*(batch_readout_h_o_e_list+batch_readout_h_h_e_list)))].data['pred']
            # !NOTE: cannot use "batch_readout_h_o_e_list+batch_readout_h_h_e_list" because of the wrong order
            return batch_graph.edges[tuple(
                zip(*batch_readout_edge_list))].data['pred']
        else:
            return batch_graph.edges[tuple(zip(*batch_readout_edge_list))].data['pred'], \
                   batch_graph.nodes[batch_h_node_list].data['alpha'], \
                   batch_graph.nodes[batch_h_node_list].data['alpha_lang']
Пример #6
0
###############################################################################
# The learning curve of a run is presented below.

plt.title('cross entropy averaged over minibatches')
plt.plot(epoch_losses)
plt.show()

###############################################################################
# The trained model is evaluated on the test set created. To deploy
# the tutorial, restrict the running time to get a higher
# accuracy (:math:`80` % ~ :math:`90` %) than the ones printed below.

model.eval()
# Convert a list of tuples to two lists
test_X, test_Y = map(list, zip(*testset))
test_bg = dgl.batch(test_X)
test_Y = torch.tensor(test_Y).float().view(-1, 1)
probs_Y = torch.softmax(model(test_bg), 1)
sampled_Y = torch.multinomial(probs_Y, 1)
argmax_Y = torch.max(probs_Y, 1)[1].view(-1, 1)
print('Accuracy of sampled predictions on the test set: {:.4f}%'.format(
    (test_Y == sampled_Y.float()).sum().item() / len(test_Y) * 100))
print('Accuracy of argmax predictions on the test set: {:4f}%'.format(
    (test_Y == argmax_Y.float()).sum().item() / len(test_Y) * 100))

###############################################################################
# The animation here plots the probability that a trained model predicts the correct graph type.
#
# .. image:: https://data.dgl.ai/tutorial/batch/test_eval4.gif
#
# To understand the node and graph representations that a trained model learned,
Пример #7
0
def test_graph6():
    """Batched graph with node types and edge distances."""
    g1 = dgl.graph(([0, 0, 1], [1, 2, 2]), idtype=torch.int32)
    g2 = dgl.graph(([0, 1, 1, 1], [1, 2, 3, 4]), idtype=torch.int32)
    bg = dgl.batch([g1, g2])
    return bg, torch.LongTensor([0, 1, 0, 2, 0, 3, 4, 4]), torch.randn(7, 1)
Пример #8
0
def test_empty_relation(idtype):
    """Test the features of batched DGLHeteroGraphs"""
    g1 = dgl.heterograph(
        {
            ('user', 'follows', 'user'): ([0, 1], [1, 2]),
            ('user', 'plays', 'game'): ([], [])
        },
        idtype=idtype,
        device=F.ctx())
    g1.nodes['user'].data['h1'] = F.tensor([[0.], [1.], [2.]])
    g1.nodes['user'].data['h2'] = F.tensor([[3.], [4.], [5.]])
    g1.edges['follows'].data['h1'] = F.tensor([[0.], [1.]])
    g1.edges['follows'].data['h2'] = F.tensor([[2.], [3.]])

    g2 = dgl.heterograph(
        {
            ('user', 'follows', 'user'): ([0, 1], [1, 2]),
            ('user', 'plays', 'game'): ([0, 1], [0, 0])
        },
        idtype=idtype,
        device=F.ctx())
    g2.nodes['user'].data['h1'] = F.tensor([[0.], [1.], [2.]])
    g2.nodes['user'].data['h2'] = F.tensor([[3.], [4.], [5.]])
    g2.nodes['game'].data['h1'] = F.tensor([[0.]])
    g2.nodes['game'].data['h2'] = F.tensor([[1.]])
    g2.edges['follows'].data['h1'] = F.tensor([[0.], [1.]])
    g2.edges['follows'].data['h2'] = F.tensor([[2.], [3.]])
    g2.edges['plays'].data['h1'] = F.tensor([[0.], [1.]])

    bg = dgl.batch([g1, g2])

    # Test number of nodes
    for ntype in bg.ntypes:
        assert F.asnumpy(bg.batch_num_nodes(ntype)).tolist() == [
            g1.number_of_nodes(ntype),
            g2.number_of_nodes(ntype)
        ]

    # Test number of edges
    for etype in bg.canonical_etypes:
        assert F.asnumpy(bg.batch_num_edges(etype)).tolist() == [
            g1.number_of_edges(etype),
            g2.number_of_edges(etype)
        ]

    # Test features
    assert F.allclose(
        bg.nodes['user'].data['h1'],
        F.cat([g1.nodes['user'].data['h1'], g2.nodes['user'].data['h1']],
              dim=0))
    assert F.allclose(
        bg.nodes['user'].data['h2'],
        F.cat([g1.nodes['user'].data['h2'], g2.nodes['user'].data['h2']],
              dim=0))
    assert F.allclose(bg.nodes['game'].data['h1'], g2.nodes['game'].data['h1'])
    assert F.allclose(bg.nodes['game'].data['h2'], g2.nodes['game'].data['h2'])
    assert F.allclose(
        bg.edges['follows'].data['h1'],
        F.cat([g1.edges['follows'].data['h1'], g2.edges['follows'].data['h1']],
              dim=0))
    assert F.allclose(bg.edges['plays'].data['h1'],
                      g2.edges['plays'].data['h1'])

    # Test unbatching graphs
    g3, g4 = dgl.unbatch(bg)
    check_equivalence_between_heterographs(g1,
                                           g3,
                                           node_attrs={
                                               'user': ['h1', 'h2'],
                                               'game': ['h1', 'h2']
                                           },
                                           edge_attrs={
                                               ('user', 'follows', 'user'):
                                               ['h1']
                                           })
    check_equivalence_between_heterographs(g2,
                                           g4,
                                           node_attrs={
                                               'user': ['h1', 'h2'],
                                               'game': ['h1', 'h2']
                                           },
                                           edge_attrs={
                                               ('user', 'follows', 'user'):
                                               ['h1']
                                           })

    # Test graphs without edges
    g1 = dgl.heterograph({('u', 'r', 'v'): ([], [])}, {'u': 0, 'v': 4})
    g2 = dgl.heterograph({('u', 'r', 'v'): ([], [])}, {'u': 1, 'v': 5})
    dgl.batch([g1, g2])
Пример #9
0
def test_topology(gs, idtype):
    """Test batching two DGLHeteroGraphs where some nodes are isolated in some relations"""
    g1, g2 = gs
    g1 = g1.astype(idtype).to(F.ctx())
    g2 = g2.astype(idtype).to(F.ctx())
    bg = dgl.batch([g1, g2])

    assert bg.idtype == idtype
    assert bg.device == F.ctx()
    assert bg.ntypes == g2.ntypes
    assert bg.etypes == g2.etypes
    assert bg.canonical_etypes == g2.canonical_etypes
    assert bg.batch_size == 2

    # Test number of nodes
    for ntype in bg.ntypes:
        print(ntype)
        assert F.asnumpy(bg.batch_num_nodes(ntype)).tolist() == [
            g1.number_of_nodes(ntype),
            g2.number_of_nodes(ntype)
        ]
        assert bg.number_of_nodes(ntype) == (g1.number_of_nodes(ntype) +
                                             g2.number_of_nodes(ntype))

    # Test number of edges
    for etype in bg.canonical_etypes:
        assert F.asnumpy(bg.batch_num_edges(etype)).tolist() == [
            g1.number_of_edges(etype),
            g2.number_of_edges(etype)
        ]
        assert bg.number_of_edges(etype) == (g1.number_of_edges(etype) +
                                             g2.number_of_edges(etype))

    # Test relabeled nodes
    for ntype in bg.ntypes:
        assert list(F.asnumpy(bg.nodes(ntype))) == list(
            range(bg.number_of_nodes(ntype)))

    # Test relabeled edges
    src, dst = bg.edges(etype=('user', 'follows', 'user'))
    assert list(F.asnumpy(src)) == [0, 1, 4, 5]
    assert list(F.asnumpy(dst)) == [1, 2, 5, 6]
    src, dst = bg.edges(etype=('user', 'follows', 'developer'))
    assert list(F.asnumpy(src)) == [0, 1, 4, 5]
    assert list(F.asnumpy(dst)) == [1, 2, 4, 5]
    src, dst, eid = bg.edges(etype='plays', form='all')
    assert list(F.asnumpy(src)) == [0, 1, 2, 3, 4, 5, 6]
    assert list(F.asnumpy(dst)) == [0, 0, 1, 1, 2, 2, 3]
    assert list(F.asnumpy(eid)) == [0, 1, 2, 3, 4, 5, 6]

    # Test unbatching graphs
    g3, g4 = dgl.unbatch(bg)
    check_equivalence_between_heterographs(g1, g3)
    check_equivalence_between_heterographs(g2, g4)

    # Test dtype cast
    if idtype == "int32":
        bg_cast = bg.long()
    else:
        bg_cast = bg.int()
    assert bg.batch_size == bg_cast.batch_size

    # Test local var
    bg_local = bg.local_var()
    assert bg.batch_size == bg_local.batch_size
Пример #10
0
def test_batching_batched(idtype):
    """Test batching a DGLHeteroGraph and a BatchedDGLHeteroGraph."""
    g1 = dgl.heterograph(
        {
            ('user', 'follows', 'user'): ([0, 1], [1, 2]),
            ('user', 'plays', 'game'): ([0, 1], [0, 0])
        },
        idtype=idtype,
        device=F.ctx())
    g2 = dgl.heterograph(
        {
            ('user', 'follows', 'user'): ([0, 1], [1, 2]),
            ('user', 'plays', 'game'): ([0, 1], [0, 0])
        },
        idtype=idtype,
        device=F.ctx())
    bg1 = dgl.batch([g1, g2])
    g3 = dgl.heterograph(
        {
            ('user', 'follows', 'user'): ([0], [1]),
            ('user', 'plays', 'game'): ([1], [0])
        },
        idtype=idtype,
        device=F.ctx())
    bg2 = dgl.batch([bg1, g3])
    assert bg2.idtype == idtype
    assert bg2.device == F.ctx()
    assert bg2.ntypes == g3.ntypes
    assert bg2.etypes == g3.etypes
    assert bg2.canonical_etypes == g3.canonical_etypes
    assert bg2.batch_size == 3

    # Test number of nodes
    for ntype in bg2.ntypes:
        assert F.asnumpy(bg2.batch_num_nodes(ntype)).tolist() == [
            g1.number_of_nodes(ntype),
            g2.number_of_nodes(ntype),
            g3.number_of_nodes(ntype)
        ]
        assert bg2.number_of_nodes(ntype) == (g1.number_of_nodes(ntype) +
                                              g2.number_of_nodes(ntype) +
                                              g3.number_of_nodes(ntype))

    # Test number of edges
    for etype in bg2.canonical_etypes:
        assert F.asnumpy(bg2.batch_num_edges(etype)).tolist() == [
            g1.number_of_edges(etype),
            g2.number_of_edges(etype),
            g3.number_of_edges(etype)
        ]
        assert bg2.number_of_edges(etype) == (g1.number_of_edges(etype) +
                                              g2.number_of_edges(etype) +
                                              g3.number_of_edges(etype))

    # Test relabeled nodes
    for ntype in bg2.ntypes:
        assert list(F.asnumpy(bg2.nodes(ntype))) == list(
            range(bg2.number_of_nodes(ntype)))

    # Test relabeled edges
    src, dst = bg2.edges(etype='follows')
    assert list(F.asnumpy(src)) == [0, 1, 3, 4, 6]
    assert list(F.asnumpy(dst)) == [1, 2, 4, 5, 7]
    src, dst = bg2.edges(etype='plays')
    assert list(F.asnumpy(src)) == [0, 1, 3, 4, 7]
    assert list(F.asnumpy(dst)) == [0, 0, 1, 1, 2]

    # Test unbatching graphs
    g4, g5, g6 = dgl.unbatch(bg2)
    check_equivalence_between_heterographs(g1, g4)
    check_equivalence_between_heterographs(g2, g5)
    check_equivalence_between_heterographs(g3, g6)
Пример #11
0
def test_features(idtype):
    """Test the features of batched DGLHeteroGraphs"""
    g1 = dgl.heterograph(
        {
            ('user', 'follows', 'user'): ([0, 1], [1, 2]),
            ('user', 'plays', 'game'): ([0, 1], [0, 0])
        },
        idtype=idtype,
        device=F.ctx())
    g1.nodes['user'].data['h1'] = F.tensor([[0.], [1.], [2.]])
    g1.nodes['user'].data['h2'] = F.tensor([[3.], [4.], [5.]])
    g1.nodes['game'].data['h1'] = F.tensor([[0.]])
    g1.nodes['game'].data['h2'] = F.tensor([[1.]])
    g1.edges['follows'].data['h1'] = F.tensor([[0.], [1.]])
    g1.edges['follows'].data['h2'] = F.tensor([[2.], [3.]])
    g1.edges['plays'].data['h1'] = F.tensor([[0.], [1.]])

    g2 = dgl.heterograph(
        {
            ('user', 'follows', 'user'): ([0, 1], [1, 2]),
            ('user', 'plays', 'game'): ([0, 1], [0, 0])
        },
        idtype=idtype,
        device=F.ctx())
    g2.nodes['user'].data['h1'] = F.tensor([[0.], [1.], [2.]])
    g2.nodes['user'].data['h2'] = F.tensor([[3.], [4.], [5.]])
    g2.nodes['game'].data['h1'] = F.tensor([[0.]])
    g2.nodes['game'].data['h2'] = F.tensor([[1.]])
    g2.edges['follows'].data['h1'] = F.tensor([[0.], [1.]])
    g2.edges['follows'].data['h2'] = F.tensor([[2.], [3.]])
    g2.edges['plays'].data['h1'] = F.tensor([[0.], [1.]])

    # test default setting
    bg = dgl.batch([g1, g2])
    assert F.allclose(
        bg.nodes['user'].data['h1'],
        F.cat([g1.nodes['user'].data['h1'], g2.nodes['user'].data['h1']],
              dim=0))
    assert F.allclose(
        bg.nodes['user'].data['h2'],
        F.cat([g1.nodes['user'].data['h2'], g2.nodes['user'].data['h2']],
              dim=0))
    assert F.allclose(
        bg.nodes['game'].data['h1'],
        F.cat([g1.nodes['game'].data['h1'], g2.nodes['game'].data['h1']],
              dim=0))
    assert F.allclose(
        bg.nodes['game'].data['h2'],
        F.cat([g1.nodes['game'].data['h2'], g2.nodes['game'].data['h2']],
              dim=0))
    assert F.allclose(
        bg.edges['follows'].data['h1'],
        F.cat([g1.edges['follows'].data['h1'], g2.edges['follows'].data['h1']],
              dim=0))
    assert F.allclose(
        bg.edges['follows'].data['h2'],
        F.cat([g1.edges['follows'].data['h2'], g2.edges['follows'].data['h2']],
              dim=0))
    assert F.allclose(
        bg.edges['plays'].data['h1'],
        F.cat([g1.edges['plays'].data['h1'], g2.edges['plays'].data['h1']],
              dim=0))

    # test specifying ndata/edata
    bg = dgl.batch([g1, g2], ndata=['h2'], edata=['h1'])
    assert F.allclose(
        bg.nodes['user'].data['h2'],
        F.cat([g1.nodes['user'].data['h2'], g2.nodes['user'].data['h2']],
              dim=0))
    assert F.allclose(
        bg.nodes['game'].data['h2'],
        F.cat([g1.nodes['game'].data['h2'], g2.nodes['game'].data['h2']],
              dim=0))
    assert F.allclose(
        bg.edges['follows'].data['h1'],
        F.cat([g1.edges['follows'].data['h1'], g2.edges['follows'].data['h1']],
              dim=0))
    assert F.allclose(
        bg.edges['plays'].data['h1'],
        F.cat([g1.edges['plays'].data['h1'], g2.edges['plays'].data['h1']],
              dim=0))
    assert 'h1' not in bg.nodes['user'].data
    assert 'h1' not in bg.nodes['game'].data
    assert 'h2' not in bg.edges['follows'].data

    # Test unbatching graphs
    g3, g4 = dgl.unbatch(bg)
    check_equivalence_between_heterographs(g1,
                                           g3,
                                           node_attrs={
                                               'user': ['h2'],
                                               'game': ['h2']
                                           },
                                           edge_attrs={
                                               ('user', 'follows', 'user'):
                                               ['h1']
                                           })
    check_equivalence_between_heterographs(g2,
                                           g4,
                                           node_attrs={
                                               'user': ['h2'],
                                               'game': ['h2']
                                           },
                                           edge_attrs={
                                               ('user', 'follows', 'user'):
                                               ['h1']
                                           })

    # test legacy
    bg = dgl.batch([g1, g2], edge_attrs=['h1'])
    assert 'h2' not in bg.edges['follows'].data.keys()
Пример #12
0
    def forward(self, batch):

        # ======================================================================================
        #                                    数据处理
        # ======================================================================================
        batch_size = len(batch['facts_num_nodes_list'])

        images = batch['features_list']  # [(36,2048)]
        images = torch.stack(images).to(self.device)  # [batch,36,2048]

        img_relations = batch['img_relations_list']
        img_relations = torch.stack(img_relations).to(self.device)  # shape (batch,36,36,7) 暂定7维


        questions = batch['question_list']  # list((max_length,))
        questions = torch.stack(questions).long().to(self.device)  # [batch,max_length]
        questions_len_list = batch['question_length_list']
        questions_len_list = torch.tensor(batch['question_length_list']).long().to(self.device)





        fact_num_nodes_list = torch.Tensor(batch['facts_num_nodes_list']).long().to(self.device)
        facts_features_list = batch['facts_features_list']
        facts_features_list = [torch.Tensor(features).to(self.device)
                               for features in facts_features_list
                               ]
        facts_e1ids_list = batch['facts_e1ids_list']
        facts_e1ids_list = [
            torch.Tensor(e1ids).long().to(self.device)
            for e1ids in facts_e1ids_list
        ]
        facts_e2ids_list = batch['facts_e2ids_list']
        facts_e2ids_list = [
            torch.tensor(e2ids).long().to(self.device)
            for e2ids in facts_e2ids_list
        ]
        facts_answer_list = batch['facts_answer_list']
        facts_answer_list = [
            torch.tensor(answer).long().to(self.device)
            for answer in facts_answer_list
        ]

        ques_embed = self.que_glove_embed(questions).float()  
     
        _, (ques_embed, _)=self.ques_rnn(ques_embed,questions_len_list)
   
        node_att_proj_ques_embed = self.node_att_proj_ques(ques_embed)  
        node_att_proj_img_embed = self.node_att_proj_img(images)  

        node_att_proj_ques_embed = node_att_proj_ques_embed.unsqueeze(1).repeat(1, images.shape[1],
                                                                                1)  
        node_att_proj_img_sum_ques = torch.tanh(node_att_proj_ques_embed + node_att_proj_img_embed)
        node_att_values = self.node_att_value(node_att_proj_img_sum_ques).squeeze()  
        node_att_values = F.softmax(node_att_values, dim=-1)  

        node_att_values = node_att_values.unsqueeze(-1).repeat(
            1, 1, self.config['model']['img_feature_size'])


        images = node_att_values * images  # (b,36)*(b,36,2048)

       
        rel_att_proj_ques_embed = self.rel_att_proj_ques(
            ques_embed)  # shape(batch,128)
        rel_att_proj_rel_embed = self.rel_att_proj_rel(
            img_relations)  # shape(batch,36,36,128)

        rel_att_proj_ques_embed = rel_att_proj_ques_embed.repeat(
            1, 36 * 36).view(
            batch_size, 36, 36, self.config['model']
            ['rel_att_ques_rel_proj_dims'])  # shape(batch,36,36,128)
        rel_att_proj_rel_sum_ques = torch.tanh(rel_att_proj_ques_embed +
                                               rel_att_proj_rel_embed)
        rel_att_values = self.rel_att_value(
            rel_att_proj_rel_sum_ques).squeeze()  # shape(batch,36,36)
        rel_att_values_2 = rel_att_values.unsqueeze(-1).repeat(
            1, 1, 1, self.config['model']['relation_dims'])


        img_relations = rel_att_values_2 * img_relations  # (batch,36,36,7)

  
        img_graphs = []
        for i in range(batch_size):
            g = dgl.DGLGraph()
            # add nodes
            g.add_nodes(36)
            # add node features
            g.ndata['h'] = images[i]
            g.ndata['batch'] = torch.full([36, 1], i)
            # add edges
            for s in range(36):
                for d in range(36):
                    g.add_edge(s, d)
            # add edge features
            g.edata['rel'] = img_relations[i].view(
                36 * 36,
                self.config['model']['relation_dims'])  # shape(36*36,7)
            g.edata['att'] = rel_att_values[i].view(36 * 36,
                                                    1)  # shape(36*36,1)
            img_graphs.append(g)
        image_batch_graph = dgl.batch(img_graphs)

       
        fact_graphs = []
        for i in range(batch_size):
            graph = dgl.DGLGraph()
            graph.add_nodes(fact_num_nodes_list[i])
            graph.add_edges(facts_e1ids_list[i], facts_e2ids_list[i])
            graph.ndata['h'] = facts_features_list[i]
            graph.ndata['batch'] = torch.full([fact_num_nodes_list[i], 1], i)
            graph.ndata['answer'] = facts_answer_list[i]

            fact_graphs.append(graph)
        fact_batch_graph = dgl.batch(fact_graphs)


        image_batch_graph = self.img_gcn1(image_batch_graph)
        
        fact_batch_graph = self.new_fact_gcn1(fact_batch_graph,
                                              image_batch_graph,
                                              ques_embed=ques_embed)
        fact_batch_graph.ndata['h'] = F.relu(fact_batch_graph.ndata['h'])
      
        image_batch_graph = self.img_gcn2(image_batch_graph)

        fact_batch_graph = self.new_fact_gcn2(
            fact_batch_graph, image_batch_graph,
            ques_embed=ques_embed)  # 每个节点1 个特征
        fact_batch_graph.ndata['h'] = torch.sigmoid(fact_batch_graph.ndata['h'])
        fact_batch_graph.ndata['h'] = torch.softmax(
            fact_batch_graph.ndata['h'], dim=0)
        return fact_batch_graph
Пример #13
0
def create_batch(batch_all):

    #print('Batch size : ', len(batch_all))

    #X, Y = batch_all[0]

    batch = [ProcessImage(batch_itr) for batch_itr in batch_all]

    lengths = [sample['seq_length'] for sample in batch]

    Input = [torch.FloatTensor(sample['Input']) for sample in batch]
    target = [torch.FloatTensor(sample['target']) for sample in batch]
    #position = [ torch.FloatTensor(sample['position']) for sample in batch  ]
    graph = [sample['gr'] for sample in batch]

    energy = [torch.FloatTensor(sample['energy']) for sample in batch]

    position = [torch.Tensor(sample['point_xyz']) for sample in batch]
    position_idx = [
        torch.LongTensor(sample['point_idx_zxy']) for sample in batch
    ]

    max_length = np.max(lengths)

    n_sequences = len(target)

    Input_tensor = torch.ones((n_sequences, 7, 64, 64)).float() * -5
    targets_tensor = torch.ones((n_sequences, 6, 64, 64)).float() * -5
    position_idx_tensor = torch.ones((n_sequences, max_length, 3)).int() * -1
    position_idx_tensor = position_idx_tensor.long()

    position_tensor = torch.ones((n_sequences, max_length, 3)).int() * -999.
    energy_tensor = torch.zeros((n_sequences, max_length)).float()

    for i in range(n_sequences):
        seq_len = lengths[i]

        Input_tensor[i] = Input[i]
        targets_tensor[i] = target[i]
        position_tensor[i, :seq_len] = position[i]
        position_idx_tensor[i, :seq_len] = position_idx[i]
        energy_tensor[i, :seq_len] = energy[i]

    sequence_lengths = torch.LongTensor(lengths)

    # sequence_lengths, idx = sequence_lengths.sort(dim=0, descending=True)

    # targets_tensor = targets_tensor[idx]
    # energy_tensor =  energy_tensor[idx]
    # position_tensor = position_tensor[idx]
    # position_idx_tensor = position_idx_tensor[idx]
    # Input_tensor = Input_tensor[idx]
    # graph = sorted(graph, key=lambda g: g.number_of_nodes(), reverse=True)

    # pos = torch.where(sequence_lengths > graph_size)[0]

    # sequence_lengths = sequence_lengths[pos]
    # targets_tensor = targets_tensor[pos]
    # energy_tensor =  energy_tensor[pos]
    # position_tensor = position_tensor[pos]
    # position_idx_tensor = position_idx_tensor[pos]
    # Input_tensor = Input_tensor[pos]
    # graph = graph[0: len(pos) ]

    return dgl.batch(graph),  targets_tensor,  sequence_lengths,  energy_tensor, position_tensor, position_idx_tensor,\
           Input_tensor
Пример #14
0
 def collate_fn(batch):
     graphs, pmpds, labels = zip(*batch)
     batched_graphs = dgl.batch(graphs)
     batched_pmpds = sp.block_diag(pmpds)
     batched_labels = np.concatenate(labels, axis=0)
     return batched_graphs, batched_pmpds, batched_labels
Пример #15
0
def test_pickling_graph():
    # graph structures and frames are pickled
    g = dgl.DGLGraph()
    g.add_nodes(3)
    src = F.tensor([0, 0])
    dst = F.tensor([1, 2])
    g.add_edges(src, dst)

    x = F.randn((3, 7))
    y = F.randn((3, 5))
    a = F.randn((2, 6))
    b = F.randn((2, 4))

    g.ndata['x'] = x
    g.ndata['y'] = y
    g.edata['a'] = a
    g.edata['b'] = b

    # registered functions are pickled
    g.register_message_func(_global_message_func)
    reduce_func = fn.sum('x', 'x')
    g.register_reduce_func(reduce_func)

    # custom attributes should be pickled
    g.foo = 2

    new_g = _reconstruct_pickle(g)

    _assert_is_identical(g, new_g)
    assert new_g.foo == 2
    assert new_g._message_func == _global_message_func
    assert isinstance(new_g._reduce_func, type(reduce_func))
    assert new_g._reduce_func._name == 'sum'
    assert new_g._reduce_func.msg_field == 'x'
    assert new_g._reduce_func.out_field == 'x'

    # test batched graph with partial set case
    g2 = dgl.DGLGraph()
    g2.add_nodes(4)
    src2 = F.tensor([0, 1])
    dst2 = F.tensor([2, 3])
    g2.add_edges(src2, dst2)

    x2 = F.randn((4, 7))
    y2 = F.randn((3, 5))
    a2 = F.randn((2, 6))
    b2 = F.randn((2, 4))

    g2.ndata['x'] = x2
    g2.nodes[[0, 1, 3]].data['y'] = y2
    g2.edata['a'] = a2
    g2.edata['b'] = b2

    bg = dgl.batch([g, g2])

    bg2 = _reconstruct_pickle(bg)

    _assert_is_identical(bg, bg2)
    new_g, new_g2 = dgl.unbatch(bg2)
    _assert_is_identical(g, new_g)
    _assert_is_identical(g2, new_g2)

    # readonly graph
    g = dgl.DGLGraph([(0, 1), (1, 2)], readonly=True)
    new_g = _reconstruct_pickle(g)
    _assert_is_identical(g, new_g)

    # multigraph
    g = dgl.DGLGraph([(0, 1), (0, 1), (1, 2)])
    new_g = _reconstruct_pickle(g)
    _assert_is_identical(g, new_g)

    # readonly multigraph
    g = dgl.DGLGraph([(0, 1), (0, 1), (1, 2)], readonly=True)
    new_g = _reconstruct_pickle(g)
    _assert_is_identical(g, new_g)
Пример #16
0
def main(training_file, dev_file, test_file, task, previous_model=None):
    global stop_training

    graph_type = task['graph_type']
    net = task['gnn_network']
    epochs = task['epochs']
    patience = task['patience']
    grid_width = task['grid_width']
    image_width = task['image_width']
    batch_size = task['batch_size']
    num_hidden = task['num_gnn_units']
    heads = task['num_gnn_heads']
    residual = False
    lr = task['lr']
    weight_decay = task['weight_decay']
    gnn_layers = task['num_gnn_layers']
    cnn_layers = task['num_cnn_layers']
    in_drop = task['in_drop']
    alpha = task['alpha']
    attn_drop = task['attn_drop']
    num_rels = task['num_rels']

    fw = task['fw']
    identifier = task['identifier']
    nonlinearity = task['non-linearity']
    min_train_loss = float("inf")
    min_dev_loss = float("inf")

    if nonlinearity == 'relu':
        nonlinearity = F.relu
    elif nonlinearity == 'elu':
        nonlinearity = F.elu

    output_list_records_train_loss = []
    output_list_records_dev_loss = []

    loss_fcn = torch.nn.MSELoss()  #(reduction='sum')

    print('=========================')
    print('HEADS', heads)
    print('GNN LAYERS', gnn_layers)
    print('CNN LAYERS', cnn_layers)
    print('HIDDEN', num_hidden)
    print('RESIDUAL', residual)
    print('inDROP', in_drop)
    print('atDROP', attn_drop)
    print('LR', lr)
    print('DECAY', weight_decay)
    print('ALPHA', alpha)
    print('BATCH', batch_size)
    print('GRAPH_ALT', graph_type)
    print('ARCHITECTURE', net)
    print('=========================')

    # create the dataset
    print('Loading training set...')
    train_dataset = socnavImg.SocNavDataset(training_file, mode='train')
    print('Loading dev set...')
    valid_dataset = socnavImg.SocNavDataset(dev_file, mode='valid')
    print('Loading test set...')
    test_dataset = socnavImg.SocNavDataset(test_file, mode='test')
    print('Done loading files')
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  collate_fn=collate)
    valid_dataloader = DataLoader(valid_dataset,
                                  batch_size=batch_size,
                                  collate_fn=collate)
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=batch_size,
                                 collate_fn=collate)

    if num_rels < 0:
        num_rels = len(socnavImg.get_relations())

    cur_step = 0
    best_loss = -1
    num_feats = train_dataset.graphs[0].ndata['h'].shape[1]
    print('Number of features: {}'.format(num_feats))
    g = dgl.batch(train_dataset.graphs)
    # define the model
    if fw == 'dgl':
        if net in ['gat']:
            model = GAT(
                g,  # graph
                gnn_layers,  # gnn_layers
                num_feats,  # in_dimension
                num_hidden,  # num_hidden
                1,
                grid_width,  # grid_width
                heads,  # head
                nonlinearity,  # activation
                in_drop,  # feat_drop
                attn_drop,  # attn_drop
                alpha,  # negative_slope
                residual,  # residual
                cnn_layers  # cnn_layers
            )
        elif net in ['gatmc']:
            model = GATMC(
                g,  # graph
                gnn_layers,  # gnn_layers
                num_feats,  # in_dimension
                num_hidden,  # num_hidden
                grid_width,  # grid_width
                image_width,  # image_width
                heads,  # head
                nonlinearity,  # activation
                in_drop,  # feat_drop
                attn_drop,  # attn_drop
                alpha,  # negative_slope
                residual,  # residual
                cnn_layers  # cnn_layers
            )
        elif net in ['rgcn']:
            print(
                f'CREATING RGCN(GRAPH, gnn_layers:{gnn_layers}, cnn_layers:{cnn_layers}, num_feats:{num_feats}, num_hidden:{num_hidden}, grid_with:{grid_width}, image_width:{image_width}, num_rels:{num_rels}, non-linearity:{nonlinearity}, drop:{in_drop}, num_bases:{num_rels})'
            )
            model = RGCN(g,
                         gnn_layers,
                         cnn_layers,
                         num_feats,
                         num_hidden,
                         grid_width,
                         image_width,
                         num_rels,
                         nonlinearity,
                         in_drop,
                         num_bases=num_rels)
        else:
            print('No valid GNN model specified')
            sys.exit(0)

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
    if previous_model is not None:
        model.load_state_dict(torch.load(previous_model, map_location=device))

    model = model.to(device)

    for epoch in range(epochs):
        if stop_training:
            print("Stopping training. Please wait.")
            break
        model.train()
        loss_list = []
        for batch, data in enumerate(train_dataloader):
            subgraph, labels = data
            subgraph.set_n_initializer(dgl.init.zero_initializer)
            subgraph.set_e_initializer(dgl.init.zero_initializer)
            feats = subgraph.ndata['h'].to(device)
            labels = labels.to(device)
            if fw == 'dgl':
                model.g = subgraph
                for layer in model.layers:
                    layer.g = subgraph
                if net in ['rgcn']:
                    logits = model(
                        feats.float(),
                        subgraph.edata['rel_type'].squeeze().to(device))
                else:
                    logits = model(feats.float())
            else:
                print('Only DGL is supported at the moment here.')
                sys.exit(1)
                if net in ['pgat', 'pgcn']:
                    data = Data(x=feats.float(),
                                edge_index=torch.stack(
                                    subgraph.edges()).to(device))
                else:
                    data = Data(
                        x=feats.float(),
                        edge_index=torch.stack(subgraph.edges()).to(device),
                        edge_type=subgraph.edata['rel_type'].squeeze().to(
                            device))
                logits = model(data, subgraph)
            a = logits.flatten()
            b = labels.float().flatten()
            ad = a.to(device)
            bd = b.to(device)
            loss = loss_fcn(ad, bd)
            optimizer.zero_grad()
            a = list(model.parameters())[0].clone()
            loss.backward()
            optimizer.step()
            b = list(model.parameters())[0].clone()
            not_learning = torch.equal(a.data, b.data)
            if not_learning:
                print('Not learning')
                sys.exit(1)
            else:
                pass
            loss_list.append(loss.item())
        loss_data = np.array(loss_list).mean()
        if loss_data < min_train_loss:
            min_train_loss = loss_data
        print('Loss: {}'.format(loss_data))
        output_list_records_train_loss.append(float(loss_data))
        if epoch % 5 == 0:
            print("Epoch {:05d} | Loss: {:.6f} | Patience: {} | ".format(
                epoch, loss_data, cur_step),
                  end='')
            score_list = []
            val_loss_list = []
            for batch, valid_data in enumerate(valid_dataloader):
                subgraph, labels = valid_data
                subgraph.set_n_initializer(dgl.init.zero_initializer)
                subgraph.set_e_initializer(dgl.init.zero_initializer)
                feats = subgraph.ndata['h'].to(device)
                labels = labels.to(device)
                score, val_loss = evaluate(feats.float(), model, subgraph,
                                           labels.float(), loss_fcn, fw, net)
                score_list.append(score)
                val_loss_list.append(val_loss)
            mean_score = np.array(score_list).mean()
            mean_val_loss = np.array(val_loss_list).mean()
            print("Score: {:.6f} MEAN: {:.6f} BEST: {:.6f}".format(
                mean_score, mean_val_loss, best_loss))
            output_list_records_dev_loss.append(mean_val_loss)

            # early stop
            if best_loss > mean_val_loss or best_loss < 0:
                print('Saving...')
                directory = str(identifier).zfill(5)

                try:
                    os.mkdir(directory)
                except:
                    print('Exception creating directory', directory)

                best_loss = mean_val_loss
                if best_loss < min_dev_loss:
                    min_dev_loss = best_loss

                # Save the model
                model.eval()
                torch.save(model.state_dict(), directory + '/SNGNN2D.tch')
                params = {
                    'train_loss': min_train_loss,
                    'dev_loss': min_dev_loss,
                    'net': net,
                    'fw': fw,
                    'gnn_layers': gnn_layers,
                    'cnn_layers': cnn_layers,
                    'num_feats': num_feats,
                    'num_hidden': num_hidden,
                    'graph_type': graph_type,
                    'heads': heads,
                    'grid_width': grid_width,
                    'image_width': image_width,
                    'F': nonlinearity,
                    'in_drop': in_drop,
                    'attn_drop': attn_drop,
                    'alpha': alpha,
                    'residual': residual,
                    'num_rels': num_rels,
                    'train_scores': output_list_records_train_loss,
                    'dev_scores': output_list_records_dev_loss
                }
                pickle.dump(params, open(directory + '/SNGNN2D.prms', 'wb'))
                cur_step = 0
            else:
                cur_step += 1
                if cur_step >= patience:
                    break
    time_a = time.time()
    test_score_list = []
    for batch, test_data in enumerate(test_dataloader):
        subgraph, labels = test_data
        subgraph.set_n_initializer(dgl.init.zero_initializer)
        subgraph.set_e_initializer(dgl.init.zero_initializer)
        feats = subgraph.ndata['h'].to(device)
        labels = labels.to(device)
        test_score_list.append(
            evaluate(feats, model, subgraph, labels.float(), loss_fcn, fw,
                     net)[1])
    time_b = time.time()
    time_delta = float(time_b - time_a)
    test_loss = np.array(test_score_list).mean()
    print("MSE for the test set {}".format(test_loss))
    return min_train_loss, min_dev_loss, test_loss, time_delta, num_of_params(
        model
    ), epoch, output_list_records_train_loss, output_list_records_dev_loss
Пример #17
0
def collate(sample):
    graphs, feats, labels = map(list, zip(*sample))
    graph = dgl.batch(graphs)
    feats = torch.from_numpy(np.concatenate(feats))
    labels = torch.from_numpy(np.concatenate(labels))
    return graph, feats, labels
Пример #18
0
 def batcher_dev(batch):
     graphs, labels = zip(*batch)
     batch_graphs = dgl.batch(graphs)
     labels = torch.stack(labels, 0)
     return AlchemyBatcher(graph=batch_graphs, label=labels)
Пример #19
0
def collate(samples):
    # The input `samples` is a list of pairs
    #  (graph, label).
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(labels)
 def __init__(self, tree_list):
     graph_list = []
     for tree in tree_list:
         graph_list.append(tree.dgl_graph)
     self.batch_dgl_graph = dgl.batch(graph_list)
Пример #21
0
def collate(samples):
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(labels)
Пример #22
0
    def dfs_assemble(self, mol_tree_msg, mol_vec, cur_mol, global_amap,
                     fa_amap, cur_node_id, fa_node_id):
        nodes_dict = mol_tree_msg.nodes_dict
        fa_node = nodes_dict[fa_node_id] if fa_node_id is not None else None
        cur_node = nodes_dict[cur_node_id]

        fa_nid = fa_node['nid'] if fa_node is not None else -1
        prev_nodes = [fa_node] if fa_node is not None else []

        children_node_id = [
            v for v in mol_tree_msg.successors(cur_node_id).tolist()
            if nodes_dict[v]['nid'] != fa_nid
        ]
        children = [nodes_dict[v] for v in children_node_id]
        neighbors = [nei for nei in children if nei['mol'].GetNumAtoms() > 1]
        neighbors = sorted(neighbors,
                           key=lambda x: x['mol'].GetNumAtoms(),
                           reverse=True)
        singletons = [nei for nei in children if nei['mol'].GetNumAtoms() == 1]
        neighbors = singletons + neighbors

        cur_amap = [(fa_nid, a2, a1) for nid, a1, a2 in fa_amap
                    if nid == cur_node['nid']]
        cands = enum_assemble_nx(cur_node, neighbors, prev_nodes, cur_amap)
        if len(cands) == 0:
            return None
        cand_smiles, cand_mols, cand_amap = list(zip(*cands))

        cands = [(candmol, mol_tree_msg, cur_node_id) for candmol in cand_mols]
        cand_graphs, atom_x, bond_x, tree_mess_src_edges, \
            tree_mess_tgt_edges, tree_mess_tgt_nodes = mol2dgl_dec(
                cands)
        cand_graphs = batch(cand_graphs)
        atom_x = cuda(atom_x)
        bond_x = cuda(bond_x)
        cand_graphs.ndata['x'] = atom_x
        cand_graphs.edata['x'] = bond_x
        cand_graphs.edata['src_x'] = atom_x.new(bond_x.shape[0],
                                                atom_x.shape[1]).zero_()

        cand_vecs = self.jtmpn(
            (cand_graphs, tree_mess_src_edges, tree_mess_tgt_edges,
             tree_mess_tgt_nodes),
            mol_tree_msg,
        )
        cand_vecs = self.G_mean(cand_vecs)
        mol_vec = mol_vec.squeeze()
        scores = cand_vecs @ mol_vec

        _, cand_idx = torch.sort(scores, descending=True)

        backup_mol = Chem.RWMol(cur_mol)
        for i in range(len(cand_idx)):
            cur_mol = Chem.RWMol(backup_mol)
            pred_amap = cand_amap[cand_idx[i].item()]
            new_global_amap = copy.deepcopy(global_amap)

            for nei_id, ctr_atom, nei_atom in pred_amap:
                if nei_id == fa_nid:
                    continue
                new_global_amap[nei_id][nei_atom] = new_global_amap[
                    cur_node['nid']][ctr_atom]

            cur_mol = attach_mols_nx(cur_mol, children, [], new_global_amap)
            new_mol = cur_mol.GetMol()
            new_mol = Chem.MolFromSmiles(Chem.MolToSmiles(new_mol))

            if new_mol is None:
                continue

            result = True
            for nei_node_id, nei_node in zip(children_node_id, children):
                if nei_node['is_leaf']:
                    continue
                cur_mol = self.dfs_assemble(mol_tree_msg, mol_vec, cur_mol,
                                            new_global_amap, pred_amap,
                                            nei_node_id, cur_node_id)
                if cur_mol is None:
                    result = False
                    break

            if result:
                return cur_mol

        return None
Пример #23
0
def pack_graph(graphs: Union[Tuple, List]) -> Tuple:
    # merge many dgl graphs into a huge one
    root_indices, node_nums, = get_root_node_info(graphs)
    packed_graph = dgl.batch(graphs)
    return packed_graph, root_indices, node_nums,
Пример #24
0
    def __call__(self, src_buf, tgt_buf, device='cpu', src_deps=None):
        '''
        Return a batched graph for the training phase of Transformer.
        args:
            src_buf: a set of input sequence arrays.
            tgt_buf: a set of output sequence arrays.
            device: 'cpu' or 'cuda:*'
            src_deps: list, optional
                Dependency parses of the source in the form of src_node_id -> dst_node_id.
                where src is the child and dst is the parent. i.e a child node attends on its
                syntactic parent in a dependency parse
        '''

        if src_deps is None:
            src_deps = list()
        g_list = []
        src_lens = [len(_) for _ in src_buf]
        tgt_lens = [len(_) - 1 for _ in tgt_buf]

        num_edges = {'ee': [], 'ed': [], 'dd': []}

        # We are running over source and target pairs here
        for src_len, tgt_len in zip(src_lens, tgt_lens):
            i, j = src_len - 1, tgt_len - 1
            g_list.append(self.g_pool[i][j])
            for key in ['ee', 'ed', 'dd']:
                num_edges[key].append(int(self.num_edges[key][i][j]))

        g = dgl.batch(g_list)
        src, tgt, tgt_y = [], [], []
        src_pos, tgt_pos = [], []
        enc_ids, dec_ids = [], []
        e2e_eids, d2d_eids, e2d_eids = [], [], []
        layer_eids = {'dep': [[], []]}
        n_nodes, n_edges, n_tokens = 0, 0, 0
        for src_sample, tgt_sample, src_dep, n, m, n_ee, n_ed, n_dd in zip(
                src_buf, tgt_buf, src_deps, src_lens, tgt_lens,
                num_edges['ee'], num_edges['ed'], num_edges['dd']):
            src.append(th.tensor(src_sample, dtype=th.long, device=device))
            tgt.append(th.tensor(tgt_sample[:-1], dtype=th.long,
                                 device=device))
            tgt_y.append(
                th.tensor(tgt_sample[1:], dtype=th.long, device=device))
            src_pos.append(th.arange(n, dtype=th.long, device=device))
            tgt_pos.append(th.arange(m, dtype=th.long, device=device))
            enc_ids.append(
                th.arange(n_nodes, n_nodes + n, dtype=th.long, device=device))
            n_nodes += n
            dec_ids.append(
                th.arange(n_nodes, n_nodes + m, dtype=th.long, device=device))
            n_nodes += m

            e2e_eids.append(
                th.arange(n_edges,
                          n_edges + n_ee,
                          dtype=th.long,
                          device=device))

            # Copy the ids of edges that correspond to a given node and its previous N nodes
            # We are using arange here. This will not work. Instead we need to select edges that
            # correspond to previous positions. This information is present in graph pool
            # For each edge, we need to figure out source_node_id and target_node_id.
            if src_dep:
                for i in range(0, 2):
                    for src_node_id, dst_node_id in dedupe_tuples(
                            get_src_dst_deps(src_dep, i + 1)):
                        layer_eids['dep'][i].append(n_edges + src_node_id * n +
                                                    dst_node_id)
                        layer_eids['dep'][i].append(n_edges + dst_node_id * n +
                                                    src_node_id)

            n_edges += n_ee
            e2d_eids.append(
                th.arange(n_edges,
                          n_edges + n_ed,
                          dtype=th.long,
                          device=device))
            n_edges += n_ed
            d2d_eids.append(
                th.arange(n_edges,
                          n_edges + n_dd,
                          dtype=th.long,
                          device=device))
            n_edges += n_dd
            n_tokens += m

        g.set_n_initializer(dgl.init.zero_initializer)
        g.set_e_initializer(dgl.init.zero_initializer)

        return Graph(g=g,
                     src=(th.cat(src), th.cat(src_pos)),
                     tgt=(th.cat(tgt), th.cat(tgt_pos)),
                     tgt_y=th.cat(tgt_y),
                     nids={
                         'enc': th.cat(enc_ids),
                         'dec': th.cat(dec_ids)
                     },
                     eids={
                         'ee': th.cat(e2e_eids),
                         'ed': th.cat(e2d_eids),
                         'dd': th.cat(d2d_eids)
                     },
                     nid_arr={
                         'enc': enc_ids,
                         'dec': dec_ids
                     },
                     n_nodes=n_nodes,
                     layer_eids={
                         'dep': [
                             th.tensor(layer_eids['dep'][i])
                             for i in range(0, len(layer_eids['dep']))
                         ]
                     },
                     n_edges=n_edges,
                     n_tokens=n_tokens)
Пример #25
0
    axs[0, 0].set_yticks([])
    for row, mode in enumerate(['spmm', 'sddmm']):
        filename = 'reddit_{}.csv'.format(mode)
        Xs, coo_gflops, csr_gflops = get_data(filename)
        axs[row + 1, 0].plot(Xs, coo_gflops, linewidth=1, label='coo')
        axs[row + 1, 0].plot(Xs, csr_gflops, linewidth=1, label='csr')
        axs[row + 1, 0].set_yticks([0, 200, 400])
        axs[row + 1, 0].set_xscale('log', basex=2)
        axs[row + 1, 0].set_xticks(Xs)
        axs[2, 0].set_xlabel('reddit')
    
    # mesh
    for i, k in enumerate([32]): #enumerate([8, 16, 32, 64]):
        f = open('modelnet40_{}.g'.format(k), 'rb')
        gs = pickle.load(f)
        g = dgl.batch(gs)
        axs[0, i + 1].hist(g.in_degrees(), range=(0, 100))
        axs[0, i + 1].set_yticks([])
        for row, mode in enumerate(['spmm', 'sddmm']):
            filename = 'mesh_{}_{}.csv'.format(k, mode)
            Xs, coo_gflops, csr_gflops = get_data(filename)
            axs[row + 1, i + 1].plot(Xs, coo_gflops, linewidth=1, label='coo')
            axs[row + 1, i + 1].plot(Xs, csr_gflops, linewidth=1, label='csr')
            axs[row + 1, i + 1].set_yticks([0, 200, 400])
            axs[row + 1, i + 1].set_xscale('log', basex=2)
            axs[row + 1, i + 1].set_xticks(Xs)
            axs[2, i + 1].set_xlabel('mesh {}'.format(k))

    """
    # ppi
    ppi = PPIDataset('train')
Пример #26
0
    def beam(self,
             src_buf,
             start_sym,
             max_len,
             k,
             device='cpu',
             src_deps=None):
        '''
        Return a batched graph for beam search during inference of Transformer.
        args:
            src_buf: a list of input sequence
            start_sym: the index of start-of-sequence symbol
            max_len: maximum length for decoding
            k: beam size
            device: 'cpu' or 'cuda:*' 
        '''
        if src_deps is None:
            src_deps = list()
        g_list = []
        src_lens = [len(_) for _ in src_buf]
        tgt_lens = [max_len] * len(src_buf)
        num_edges = {'ee': [], 'ed': [], 'dd': []}
        for src_len, tgt_len in zip(src_lens, tgt_lens):
            i, j = src_len - 1, tgt_len - 1
            for _ in range(k):
                g_list.append(self.g_pool[i][j])
            for key in ['ee', 'ed', 'dd']:
                num_edges[key].append(int(self.num_edges[key][i][j]))

        g = dgl.batch(g_list)
        src, tgt = [], []
        src_pos, tgt_pos = [], []
        enc_ids, dec_ids = [], []
        layer_eids = {'dep': [[], []]}
        e2e_eids, e2d_eids, d2d_eids = [], [], []
        n_nodes, n_edges, n_tokens = 0, 0, 0
        for src_sample, src_dep, n, n_ee, n_ed, n_dd in zip(
                src_buf, src_deps, src_lens, num_edges['ee'], num_edges['ed'],
                num_edges['dd']):
            for _ in range(k):
                src.append(th.tensor(src_sample, dtype=th.long, device=device))
                src_pos.append(th.arange(n, dtype=th.long, device=device))
                enc_ids.append(
                    th.arange(n_nodes,
                              n_nodes + n,
                              dtype=th.long,
                              device=device))
                n_nodes += n
                e2e_eids.append(
                    th.arange(n_edges,
                              n_edges + n_ee,
                              dtype=th.long,
                              device=device))

                # Copy the ids of edges that correspond to a given node and its previous N nodes
                # We are using arange here. This will not work. Instead we need to select edges that
                # correspond to previous positions. This information is present in graph pool
                # For each edge, we need to figure out source_node_id and target_node_id.
                if src_dep:
                    for i in range(0, 2):
                        for src_node_id, dst_node_id in dedupe_tuples(
                                get_src_dst_deps(src_dep, i + 1)):
                            layer_eids['dep'][i].append(n_edges +
                                                        src_node_id * n +
                                                        dst_node_id)
                            layer_eids['dep'][i].append(n_edges +
                                                        dst_node_id * n +
                                                        src_node_id)

                n_edges += n_ee
                tgt_seq = th.zeros(max_len, dtype=th.long, device=device)
                tgt_seq[0] = start_sym
                tgt.append(tgt_seq)
                tgt_pos.append(th.arange(max_len, dtype=th.long,
                                         device=device))

                dec_ids.append(
                    th.arange(n_nodes,
                              n_nodes + max_len,
                              dtype=th.long,
                              device=device))
                n_nodes += max_len

                e2d_eids.append(
                    th.arange(n_edges,
                              n_edges + n_ed,
                              dtype=th.long,
                              device=device))
                n_edges += n_ed
                d2d_eids.append(
                    th.arange(n_edges,
                              n_edges + n_dd,
                              dtype=th.long,
                              device=device))
                n_edges += n_dd

        g.set_n_initializer(dgl.init.zero_initializer)
        g.set_e_initializer(dgl.init.zero_initializer)

        return Graph(g=g,
                     src=(th.cat(src), th.cat(src_pos)),
                     tgt=(th.cat(tgt), th.cat(tgt_pos)),
                     tgt_y=None,
                     nids={
                         'enc': th.cat(enc_ids),
                         'dec': th.cat(dec_ids)
                     },
                     eids={
                         'ee': th.cat(e2e_eids),
                         'ed': th.cat(e2d_eids),
                         'dd': th.cat(d2d_eids)
                     },
                     nid_arr={
                         'enc': enc_ids,
                         'dec': dec_ids
                     },
                     n_nodes=n_nodes,
                     n_edges=n_edges,
                     layer_eids={
                         'dep': [
                             th.tensor(layer_eids['dep'][i])
                             for i in range(0, len(layer_eids['dep']))
                         ]
                     },
                     n_tokens=n_tokens)
Пример #27
0
    split = int(n * .8)
    index = np.arange(n)
    np.random.seed(32)
    np.random.shuffle(index)
    train_index, test_index = index[:split], index[split:]

    # prep labels
    train_labels, test_labels = Variable(
        torch.LongTensor((df['label'].values + 1)[train_index])), Variable(
            torch.LongTensor((df['label'].values + 1)[test_index]))

    # prep temporal graph data
    k = args.period
    trainGs, testGs = [dgls[i]
                       for i in train_index], [dgls[i] for i in test_index]
    trainGs, testGs = [dgl.batch([u[i] for u in trainGs]) for i in range(k)], \
                      [dgl.batch([u[i] for u in testGs]) for i in range(k)]
    train_inputs, test_inputs = [inputs[i] for i in train_index
                                 ], [inputs[i] for i in test_index]
    train_inputs, test_inputs = [torch.FloatTensor(np.concatenate([inp[i] for inp in train_inputs])) for i in range(k)],\
                                [torch.FloatTensor(np.concatenate([inp[i] for inp in test_inputs])) for i in range(k)]
    train_xav, test_xav = [xav[i]
                           for i in train_index], [xav[i] for i in test_index]
    train_xav, test_xav = [
        torch.FloatTensor(np.concatenate([inp[i] for inp in train_xav]))
        for i in range(k)
    ], [
        torch.FloatTensor(np.concatenate([inp[i] for inp in test_xav]))
        for i in range(k)
    ]
Пример #28
0
def test_simple_pool():
    ctx = F.ctx()
    g = dgl.DGLGraph(nx.path_graph(15))

    sum_pool = nn.SumPooling()
    avg_pool = nn.AvgPooling()
    max_pool = nn.MaxPooling()
    sort_pool = nn.SortPooling(10)  # k = 10
    print(sum_pool, avg_pool, max_pool, sort_pool)

    # test#1: basic
    h0 = F.randn((g.number_of_nodes(), 5))
    sum_pool = sum_pool.to(ctx)
    avg_pool = avg_pool.to(ctx)
    max_pool = max_pool.to(ctx)
    sort_pool = sort_pool.to(ctx)
    h1 = sum_pool(g, h0)
    assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0))
    h1 = avg_pool(g, h0)
    assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0))
    h1 = max_pool(g, h0)
    assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0))
    h1 = sort_pool(g, h0)
    assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.dim() == 2

    # test#2: batched graph
    g_ = dgl.DGLGraph(nx.path_graph(5))
    bg = dgl.batch([g, g_, g, g_, g])
    h0 = F.randn((bg.number_of_nodes(), 5))
    h1 = sum_pool(bg, h0)
    truth = th.stack([
        F.sum(h0[:15], 0),
        F.sum(h0[15:20], 0),
        F.sum(h0[20:35], 0),
        F.sum(h0[35:40], 0),
        F.sum(h0[40:55], 0)
    ], 0)
    assert F.allclose(h1, truth)

    h1 = avg_pool(bg, h0)
    truth = th.stack([
        F.mean(h0[:15], 0),
        F.mean(h0[15:20], 0),
        F.mean(h0[20:35], 0),
        F.mean(h0[35:40], 0),
        F.mean(h0[40:55], 0)
    ], 0)
    assert F.allclose(h1, truth)

    h1 = max_pool(bg, h0)
    truth = th.stack([
        F.max(h0[:15], 0),
        F.max(h0[15:20], 0),
        F.max(h0[20:35], 0),
        F.max(h0[35:40], 0),
        F.max(h0[40:55], 0)
    ], 0)
    assert F.allclose(h1, truth)

    h1 = sort_pool(bg, h0)
    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2
Пример #29
0
 def batcher_dev(batch):
     graph_q, label = zip(*batch)
     graph_q = dgl.batch(graph_q)
     return graph_q, torch.LongTensor(label)
Пример #30
0
def main():
    logger.info('Reading data and extra features...')
    fp_files = [] if opt.fp is None else opt.fp.split(',')
    fp_extra, y_array, name_array = dataloader.load(opt.input, opt.target, fp_files)
    smiles_list = [name.split()[0] for name in name_array]

    logger.info('Generating molecular graphs with %s...' % opt.graph)
    if opt.graph == 'rdk':
        graph_list, feats_list = smi2dgl(smiles_list)
    elif opt.graph == 'msd':
        msd_list = ['%s.msd' % base64.b64encode(smiles.encode()).decode() for smiles in smiles_list]
        graph_list, feats_list = msd2dgl(msd_list, '../data/msdfiles.zip')
    else:
        raise

    logger.info('Node feature example: (size=%d) %s' % (len(feats_list[0][0]), ','.join(map(str, feats_list[0][0]))))
    logger.info('Extra graph feature example: (size=%d) %s' % (len(fp_extra[0]), ','.join(map(str, fp_extra[0]))))
    logger.info('Output example: (size=%d) %s' % (len(y_array[0]), ','.join(map(str, y_array[0]))))

    if fp_extra.shape[-1] > 0:
        logger.info('Normalizing extra graph features...')
        scaler = preprocessing.Scaler()
        scaler.fit(fp_extra)
        scaler.save(opt.output + '/scale.txt')
        fp_extra = scaler.transform(fp_extra)

    logger.info('Selecting data...')
    selector = preprocessing.Selector(smiles_list)
    if opt.part is not None:
        logger.info('Loading partition file %s' % opt.part)
        selector.load(opt.part)
    else:
        logger.warning('Partition file not provided. Using auto-partition instead')
        selector.partition(0.8, 0.2)

    device = torch.device('cuda:0')
    # batched data for training set
    data_list = [[data[i] for i in np.where(selector.train_index)[0]]
                 for data in (graph_list, y_array, feats_list, fp_extra, name_array, smiles_list)]
    n_batch, (graphs_batch, y_batch, feats_node_batch, feats_extra_batch, names_batch) = \
        preprocessing.separate_batches(data_list[:-1], opt.batch, data_list[-1])
    bg_batch_train = [dgl.batch(graphs).to(device) for graphs in graphs_batch]
    y_batch_train = [torch.tensor(y, dtype=torch.float32, device=device) for y in y_batch]
    feats_node_batch_train = [torch.tensor(np.concatenate(feats_node), dtype=torch.float32, device=device)
                              for feats_node in feats_node_batch]
    feats_extra_batch_train = [torch.tensor(feats_extra, dtype=torch.float32, device=device)
                               for feats_extra in feats_extra_batch]
    # for plot
    y_train_array = np.concatenate(y_batch)
    names_train = np.concatenate(names_batch)

    # data for validation set
    graphs, y, feats_node, feats_extra, names_valid = \
        [[data[i] for i in np.where(selector.valid_index)[0]]
         for data in (graph_list, y_array, feats_list, fp_extra, name_array)]
    bg_valid, y_valid, feats_node_valid, feats_extra_valid = (
        dgl.batch(graphs).to(device),
        torch.tensor(y, dtype=torch.float32, device=device),
        torch.tensor(np.concatenate(feats_node), dtype=torch.float32, device=device),
        torch.tensor(feats_extra, dtype=torch.float32, device=device),
    )
    # for plot
    y_valid_array = y_array[selector.valid_index]

    logger.info('Training size = %d, Validation size = %d' % (len(y_train_array), len(y_valid_array)))
    logger.info('Batches = %d, Batch size ~= %d' % (n_batch, opt.batch))

    in_feats_node = feats_list[0].shape[-1]
    in_feats_extra = fp_extra[0].shape[-1]
    n_heads = list(map(int, opt.head.split(',')))

    logger.info('Building network...')
    logger.info('Conv layers = %s' % n_heads)
    logger.info('Learning rate = %s' % opt.lr)
    logger.info('L2 penalty = %f' % opt.l2)

    model = GATModel(in_feats_node, opt.embed, n_head_list=n_heads, extra_feats=in_feats_extra)
    model.cuda()
    print(model)
    for name, param in model.named_parameters():
        print(name, param.data.shape)

    header = 'Step MaxRE(t) Loss MeaSquE MeaSigE MeaUnsE MaxRelE Acc2% Acc5% Acc10%'.split()
    logger.info('%-8s %8s %8s %8s %8s %8s %8s %8s %8s %8s' % tuple(header))

    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, weight_decay=opt.l2)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opt.lrsteps, gamma=opt.lrgamma)
    for epoch in range(opt.epoch):
        model.train()
        if (epoch + 1) % opt.check == 0:
            pred_train = [None] * n_batch
        for ib in np.random.permutation(n_batch):
            optimizer.zero_grad()
            pred = model(bg_batch_train[ib], feats_node_batch_train[ib], feats_extra_batch_train[ib])
            loss = F.mse_loss(pred, y_batch_train[ib])
            loss.backward()
            optimizer.step()
            if (epoch + 1) % opt.check == 0:
                pred_train[ib] = pred.detach().cpu().numpy()
        scheduler.step()

        if (epoch + 1) % opt.check == 0:
            model.eval()
            pred_train = np.concatenate(pred_train)
            pred_valid = model(bg_valid, feats_node_valid, feats_extra_valid).detach().cpu().numpy()
            err_line = '%-8i %8.1f %8.2e %8.2e %8.1f %8.1f %8.1f %8.1f %8.1f %8.1f' % (
                epoch + 1,
                metrics.max_relative_error(y_train_array, pred_train) * 100,
                metrics.mean_squared_error(y_train_array, pred_train),
                metrics.mean_squared_error(y_valid_array, pred_valid),
                metrics.mean_signed_error(y_valid_array, pred_valid) * 100,
                metrics.mean_unsigned_error(y_valid_array, pred_valid) * 100,
                metrics.max_relative_error(y_valid_array, pred_valid) * 100,
                metrics.accuracy(y_valid_array, pred_valid, 0.02) * 100,
                metrics.accuracy(y_valid_array, pred_valid, 0.05) * 100,
                metrics.accuracy(y_valid_array, pred_valid, 0.10) * 100)

            logger.info(err_line)
    torch.save(model, opt.output + '/model.pt')

    visualizer = visualize.LinearVisualizer(y_train_array.reshape(-1), pred_train.reshape(-1), names_train, 'train')
    visualizer.append(y_valid_array.reshape(-1), pred_valid.reshape(-1), names_valid, 'valid')
    visualizer.dump(opt.output + '/fit.txt')
    visualizer.dump_bad_molecules(opt.output + '/error-0.10.txt', 'valid', threshold=0.1)
    visualizer.dump_bad_molecules(opt.output + '/error-0.20.txt', 'valid', threshold=0.2)
    visualizer.scatter_yy(savefig=opt.output + '/error-train.png', annotate_threshold=0.1, marker='x', lw=0.2, s=5)
    visualizer.hist_error(savefig=opt.output + '/error-hist.png', label='valid', histtype='step', bins=50)
    plt.show()