Example #1
0
def mergeGraph(graph_a, graph_b, train_data):
    g = DGLGraph()
    g.add_nodes(len(graph_a.id2idx) + len(graph_b.id2idx))

    g.add_edges(graph_a.edge_src, graph_a.edge_dst)
    g.add_edges(graph_a.edge_dst, graph_a.edge_src)

    #offset
    g.add_edges(list(map(lambda x: x + len(graph_a.id2idx), graph_b.edge_src)),
                list(map(lambda x: x + len(graph_a.id2idx), graph_b.edge_dst)))
    g.add_edges(list(map(lambda x: x + len(graph_a.id2idx), graph_b.edge_dst)),
                list(map(lambda x: x + len(graph_a.id2idx), graph_b.edge_src)))

    print(train_data.shape)

    print(g.number_of_edges())
    for i in range(train_data.shape[0]):
        g.add_edge(train_data[i, 0], train_data[i, 1] + len(graph_a.id2idx))
        g.add_edge(train_data[i, 1] + len(graph_a.id2idx), train_data[i, 0])

    num_edges = g.number_of_edges()
    g.ndata['features'] = torch.cat([
        torch.FloatTensor(graph_a.features),
        torch.FloatTensor(graph_b.features)
    ], 0).cuda()

    return g
Example #2
0
def mol2dgl(smiles_batch):
    n_nodes = 0
    graph_list = []
    for smiles in smiles_batch:
        atom_feature_list = []
        bond_feature_list = []
        bond_source_feature_list = []
        graph = DGLGraph()
        mol = get_mol(smiles)
        for atom in mol.GetAtoms():
            graph.add_node(atom.GetIdx())
            atom_feature_list.append(atom_features(atom))
        for bond in mol.GetBonds():
            begin_idx = bond.GetBeginAtom().GetIdx()
            end_idx = bond.GetEndAtom().GetIdx()
            features = bond_features(bond)
            graph.add_edge(begin_idx, end_idx)
            bond_feature_list.append(features)
            # set up the reverse direction
            graph.add_edge(end_idx, begin_idx)
            bond_feature_list.append(features)

        atom_x = torch.stack(atom_feature_list)
        graph.set_n_repr({'x': atom_x})
        if len(bond_feature_list) > 0:
            bond_x = torch.stack(bond_feature_list)
            graph.set_e_repr({
                'x': bond_x,
                'src_x': atom_x.new(len(bond_feature_list), ATOM_FDIM).zero_()
            })
        graph_list.append(graph)

    return graph_list
Example #3
0
def test_pull_0deg():
    g = DGLGraph()
    g.add_nodes(2)
    g.add_edge(0, 1)
    def _message(edges):
        return {'m' : edges.src['h']}
    def _reduce(nodes):
        return {'h' : nodes.data['h'] + F.sum(nodes.mailbox['m'], 1)}
    def _apply(nodes):
        return {'h' : nodes.data['h'] * 2}
    def _init2(shape, dtype, ctx, ids):
        return 2 + F.zeros(shape, dtype, ctx)
    g.register_message_func(_message)
    g.register_reduce_func(_reduce)
    g.register_apply_node_func(_apply)
    g.set_n_initializer(_init2, 'h')
    # test#1: pull both 0deg and non-0deg nodes
    old = F.randn((2, 5))
    g.ndata['h'] = old
    g.pull([0, 1])
    new = g.ndata.pop('h')
    # 0deg check: initialized with the func and got applied
    assert F.allclose(new[0], F.full_1d(5, 4, dtype=F.float32))
    # non-0deg check
    assert F.allclose(new[1], F.sum(old, 0) * 2)

    # test#2: pull only 0deg node
    old = F.randn((2, 5))
    g.ndata['h'] = old
    g.pull(0)
    new = g.ndata.pop('h')
    # 0deg check: fallback to apply
    assert F.allclose(new[0], 2*old[0])
    # non-0deg check: not touched
    assert F.allclose(new[1], old[1])
Example #4
0
class DataEntry:
    def __init__(self, datset, num_nodes, features, edges, target):
        self.dataset = datset
        self.num_nodes = num_nodes
        self.target = target
        self.graph = DGLGraph()
        self.features = torch.FloatTensor(features)
        self.graph.add_nodes(self.num_nodes, data={'features': self.features})
        for s, _type, t in edges:
            etype_number = self.dataset.get_edge_type_number(_type)
            self.graph.add_edge(
                s, t, data={'etype': torch.LongTensor([etype_number])})
Example #5
0
def batch_transform_bert_fever(inst, bert_max_len, bert_tokenizer):
    g = DGLGraph()

    g.add_nodes(len(inst['node']))

    question = inst['question']

    for i in range(len(inst['node'])):
        for j in range(len(inst['node'])):
            if i == j:
                continue
            g.add_edge(i, j)

    for i, node in enumerate(inst['node']):

        if node['label'] == 1:
            g.nodes[i].data['label'] = torch.tensor(1).unsqueeze(0).type(
                torch.FloatTensor)
        elif node['label'] == 0:
            g.nodes[i].data['label'] = torch.tensor(0).unsqueeze(0).type(
                torch.FloatTensor)
        node['name'] = str(node['name'])
        title = node['name'].replace('_', ' ')
        context = node['context']
        sent_num = node['sent_num']
        encoding_inputs, encoding_masks, encoding_ids = encode_sequence_fever(
            question, title, context, bert_max_len, bert_tokenizer, sent_num)
        g.nodes[i].data['encoding'] = encoding_inputs.unsqueeze(0)
        g.nodes[i].data['encoding_mask'] = encoding_masks.unsqueeze(0)
        g.nodes[i].data['segment_id'] = encoding_ids.unsqueeze(0)
    """if inst['label'] == 'REFUTES':
        label = 0
    elif inst['label'] == 'NOT ENOUGH INFO':
        label = 1
    elif inst['label'] == 'SUPPORTS':
        label = 2
    else:
        print('Problem!')
    """
    if inst['label'] == 'false':
        label = 0
    elif inst['label'] == 'half-true':
        label = 1
    elif inst['label'] == 'true':
        label = 2
    else:
        print('Problem!')

    return g, inst['qid'], label
Example #6
0
    def process(mol: Mol, device: torch.device, **kwargs) -> GATData:
        n = mol.GetNumAtoms() + 1

        graph = DGLGraph()
        graph.add_nodes(n)
        graph.add_edges(graph.nodes(), graph.nodes())
        graph.add_edges(range(1, n), 0)
        # graph.add_edges(0, range(1, n))
        for e in mol.GetBonds():
            u, v = e.GetBeginAtomIdx(), e.GetEndAtomIdx()
            graph.add_edge(u + 1, v + 1)
            graph.add_edge(v + 1, u + 1)

        v, m = feature.mol_feature(mol)
        vec = torch.cat([torch.zeros((1, m)), v]).to(device)

        return GATData(n, graph, vec)
Example #7
0
    def process(mol: Mol, device: torch.device, **kwargs):
        n = mol.GetNumAtoms() + 1

        graph = DGLGraph()
        graph.add_nodes(n)
        graph.add_edges(graph.nodes(), graph.nodes())
        graph.add_edges(range(1, n), 0)
        # graph.add_edges(0, range(1, n))
        for e in mol.GetBonds():
            u, v = e.GetBeginAtomIdx(), e.GetEndAtomIdx()
            graph.add_edge(u + 1, v + 1)
            graph.add_edge(v + 1, u + 1)
        adj = graph.adjacency_matrix(transpose=False).to_dense()

        v, m = feature.mol_feature(mol)
        vec = torch.cat([torch.zeros((1, m)), v]).to(device)

        return ChebNetData(n, adj, vec)
Example #8
0
def test_recv_0deg_newfld():
    # test recv with 0deg nodes; the reducer also creates a new field
    g = DGLGraph()
    g.add_nodes(2)
    g.add_edge(0, 1)

    def _message(edges):
        return {'m': edges.src['h']}

    def _reduce(nodes):
        return {'h1': nodes.data['h'] + F.sum(nodes.mailbox['m'], 1)}

    def _apply(nodes):
        return {'h1': nodes.data['h1'] * 2}

    def _init2(shape, dtype, ctx, ids):
        return 2 + F.zeros(shape, dtype=dtype, ctx=ctx)

    g.register_message_func(_message)
    g.register_reduce_func(_reduce)
    g.register_apply_node_func(_apply)
    # test#1: recv both 0deg and non-0deg nodes
    old = F.randn((2, 5))
    g.set_n_initializer(_init2, 'h1')
    g.ndata['h'] = old
    g.send((0, 1))
    g.recv([0, 1])
    new = g.ndata.pop('h1')
    # 0deg check: initialized with the func and got applied
    assert F.allclose(new[0], F.full_1d(5, 4, dtype=F.float32))
    # non-0deg check
    assert F.allclose(new[1], F.sum(old, 0) * 2)

    # test#2: recv only 0deg node
    old = F.randn((2, 5))
    g.ndata['h'] = old
    g.ndata['h1'] = F.full((2, 5), -1, F.int64)  # this is necessary
    g.send((0, 1))
    g.recv(0)
    new = g.ndata.pop('h1')
    # 0deg check: fallback to apply
    assert F.allclose(new[0], F.full_1d(5, -2, F.int64))
    # non-0deg check: not changed
    assert F.allclose(new[1], F.full_1d(5, -1, F.int64))
def create_g(file_path, use_cuda=False):
    # print(file_path)
    npz = np.load(file_path, allow_pickle=True)
    labels = npz['labels']
    fts_nodes = npz['fts_node']
    edge_type = npz['edge_type'].tolist()
    edge_norm = npz['edge_norm'].tolist()
    edges = npz['edges']
    # print(npz['nums'],len(labels),len(fts_nodes))
    num_nodes = len(labels)
    # print(np.dtype(fts_nodes))

    g = DGLGraph()
    g.add_nodes(num_nodes)

    edge_type = np.array(edge_type)
    edge_norm = np.array(edge_norm)

    for i in edges:  #BASIC EDGES
        g.add_edge(i[0].item(), i[1].item())

    edge_type = torch.from_numpy(edge_type)
    edge_norm = torch.from_numpy(edge_norm).unsqueeze(1)
    edge_type = edge_type.long()
    edge_norm = edge_norm.float()

    # fts_nodes = fts_nodes.astype(float)
    # fts_nodes = fts_nodes.astype(int)
    fts_nodes = torch.from_numpy(fts_nodes)
    fts_nodes = fts_nodes.long()

    labels = torch.from_numpy(labels)

    if (use_cuda):
        labels = labels.cuda()
        edge_type = edge_type.cuda()
        edge_norm = edge_norm.cuda()
        fts_nodes = fts_nodes.cuda()

    g.edata.update({'rel_type': edge_type, 'norm': edge_norm})
    g.ndata['id'] = fts_nodes

    return [g, labels, fts_nodes]
Example #10
0
def test_update_all_0deg():
    # test#1
    g = DGLGraph()
    g.add_nodes(5)
    g.add_edge(1, 0)
    g.add_edge(2, 0)
    g.add_edge(3, 0)
    g.add_edge(4, 0)

    def _message(edges):
        return {'m': edges.src['h']}

    def _reduce(nodes):
        return {'h': nodes.data['h'] + F.sum(nodes.mailbox['m'], 1)}

    def _apply(nodes):
        return {'h': nodes.data['h'] * 2}

    def _init2(shape, dtype, ctx, ids):
        return 2 + F.zeros(shape, dtype, ctx)

    g.set_n_initializer(_init2, 'h')
    old_repr = F.randn((5, 5))
    g.ndata['h'] = old_repr
    g.update_all(_message, _reduce, _apply)
    new_repr = g.ndata['h']
    # the first row of the new_repr should be the sum of all the node
    # features; while the 0-deg nodes should be initialized by the
    # initializer and applied with UDF.
    assert F.allclose(new_repr[1:], 2 * (2 + F.zeros((4, 5))))
    assert F.allclose(new_repr[0], 2 * F.sum(old_repr, 0))

    # test#2: graph with no edge
    g = DGLGraph()
    g.add_nodes(5)
    g.set_n_initializer(_init2, 'h')
    g.ndata['h'] = old_repr
    g.update_all(_message, _reduce, _apply)
    new_repr = g.ndata['h']
    # should fallback to apply
    assert F.allclose(new_repr, 2 * old_repr)
Example #11
0
def generate_graph(grad=False):
    g = DGLGraph()
    g.add_nodes(10)  # 10 nodes
    # create a graph where 0 is the source and 9 is the sink
    # 17 edges
    for i in range(1, 9):
        g.add_edge(0, i)
        g.add_edge(i, 9)
    # add a back flow from 9 to 0
    g.add_edge(9, 0)
    ncol = F.randn((10, D))
    ecol = F.randn((17, D))
    if grad:
        ncol = F.attach_grad(ncol)
        ecol = F.attach_grad(ecol)

    g.ndata['h'] = ncol
    g.edata['w'] = ecol
    g.set_n_initializer(dgl.init.zero_initializer)
    g.set_e_initializer(dgl.init.zero_initializer)
    return g
Example #12
0
def test_dynamic_addition():
    N = 3
    D = 1

    g = DGLGraph()
    g = g.to(F.ctx())

    # Test node addition
    g.add_nodes(N)
    g.ndata.update({'h1': F.randn((N, D)), 'h2': F.randn((N, D))})
    g.add_nodes(3)
    assert g.ndata['h1'].shape[0] == g.ndata['h2'].shape[0] == N + 3

    # Test edge addition
    g.add_edge(0, 1)
    g.add_edge(1, 0)
    g.edata.update({'h1': F.randn((2, D)), 'h2': F.randn((2, D))})
    assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 2

    g.add_edges([0, 2], [2, 0])
    g.edata['h1'] = F.randn((4, D))
    assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 4

    g.add_edge(1, 2)
    g.edges[4].data['h1'] = F.randn((1, D))
    assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 5

    # test add edge with part of the features
    g.add_edge(2, 1, {'h1': F.randn((1, D))})
    assert len(g.edata['h1']) == len(g.edata['h2'])
Example #13
0
def reduced_graph(graph:DGLGraph,center_node:int,paths:dict,node_attr_name,edge_attr_name):
	"""
	reduced graph into a simpler graph with only one center node
	:param graph: the graph need to be reduced
	:param center_node the reference node
	:param paths: the traversal path of nodes using BFS
	:return: new_graph
	"""
	new_graph = DGLGraph()
	new_graph.add_nodes(num=graph.number_of_nodes())
	new_graph.ndata[node_attr_name] = graph.ndata[node_attr_name]
	for node, path in paths.items():
		path_weight = torch.tensor([1.])
		for index,edge in enumerate(path):
			path_weight *= graph.edata[edge_attr_name][graph.edge_id(edge[0],edge[1])]*math.exp(-index)
		new_graph.add_edge(center_node,node,data={edge_attr_name:path_weight})
		new_graph.add_edge(node, center_node, data={edge_attr_name: path_weight})
	new_graph.add_edges(new_graph.nodes(), new_graph.nodes(),
	                    data={edge_attr_name: torch.ones(new_graph.number_of_nodes(), )})
	new_graph.edata[edge_attr_name] = new_graph.edata[edge_attr_name].softmax(dim=0)
	pass
	return new_graph
Example #14
0
    def process(self, mol: Mol, atom_map):
        n = mol.GetNumAtoms() + 1

        graph = DGLGraph()
        graph.add_nodes(n)
        graph.add_edges(graph.nodes(), graph.nodes())
        graph.add_edges(range(1, n), 0)
        # graph.add_edges(0, range(1, n))
        for e in mol.GetBonds():
            u, v = e.GetBeginAtomIdx(), e.GetEndAtomIdx()
            graph.add_edge(u + 1, v + 1)
            graph.add_edge(v + 1, u + 1)

        feature = torch.cat([
            torch.zeros((1, self.feature_dim), device=self.device),  # node 0
            torch.nn.functional.one_hot(torch.tensor(
                [atom_map[u.GetAtomicNum()] for u in mol.GetAtoms()],
                device=self.device),
                                        num_classes=self.feature_dim).to(
                                            torch.float)
        ])

        return GCNData(n, graph, feature)
Example #15
0
def get_graph_from_smile(molecule_smile):
    """
    Method that constructs a molecular graph with nodes being the atoms
    and bonds being the edges.
    :param molecule_smile: SMILE sequence
    :return: DGL graph object, Node features and Edge features
    """

    G = DGLGraph()
    molecule = Chem.MolFromSmiles(molecule_smile)
    features = rdDesc.GetFeatureInvariants(molecule)

    stereo = Chem.FindMolChiralCenters(molecule)
    chiral_centers = [0] * molecule.GetNumAtoms()
    for i in stereo:
        chiral_centers[i[0]] = i[1]

    G.add_nodes(molecule.GetNumAtoms())
    node_features = []
    edge_features = []
    for i in range(molecule.GetNumAtoms()):

        atom_i = molecule.GetAtomWithIdx(i)
        atom_i_features = get_atom_features(atom_i, chiral_centers[i], features[i])
        node_features.append(atom_i_features)

        for j in range(molecule.GetNumAtoms()):
            bond_ij = molecule.GetBondBetweenAtoms(i, j)
            if bond_ij is not None:
                G.add_edge(i, j)
                bond_features_ij = get_bond_features(bond_ij)
                edge_features.append(bond_features_ij)

    G.ndata['x'] = np.array(node_features)
    G.edata['w'] = np.array(edge_features)
    return G
Example #16
0
def vectorize_qanta(ex, tokenizer, device, istrain, max_seq_length=64):
    bert_model.eval()
    t_id = ex['id']
    text = ex['text']
    positive_entity = ex['pos_et']
    negative_entities = ex['neg_ets']
    ## In QANTA setting, we limit the maximum sentences as three ( for efficient training and evaluation)
    num_edges = 3
    g = DGLGraph()
    question_node_list = list()
    candidate_node_list = list()
    first_sent_tokens = list()
    first_sent_masks = list()
    question_tokens = list()
    question_masks = list()

    input_ids, input_mask = text_tokenize(text, tokenizer, max_seq_length)
    question_tokens.append(input_ids)
    question_masks.append(input_mask)
    node_sub_questions = list()

    for sup_q in ex['q_et']:
        sub_question = sup_q['text']
        input_ids, input_mask = text_tokenize(sub_question, tokenizer,
                                              max_seq_length)
        question_tokens.append(input_ids)
        question_masks.append(input_mask)

        for et in sup_q['entity']:
            topic = et['et']
            node_first_sent = et['first_sent']
            if topic is None:
                continue

            question_node_list.append(topic)
            question_idx = len(question_tokens) - 1
            node_sub_questions.append(question_idx)
            input_ids, input_mask = text_tokenize(node_first_sent, tokenizer,
                                                  max_seq_length)

            first_sent_tokens.append(input_ids)
            first_sent_masks.append(input_mask)

    candidate_node_list.append(normalize(positive_entity['et']))
    input_ids, input_mask = text_tokenize(positive_entity['first_sent'],
                                          tokenizer, max_seq_length)
    first_sent_tokens.append(input_ids)
    first_sent_masks.append(input_mask)
    node_sub_questions.append(0)

    for neg_et in negative_entities:
        candidate_node_list.append(normalize(neg_et['et']))
        input_ids, input_mask = text_tokenize(neg_et['first_sent'], tokenizer,
                                              max_seq_length)

        first_sent_tokens.append(input_ids)
        first_sent_masks.append(input_mask)
        node_sub_questions.append(0)

    num_nodes = len(question_node_list) + len(candidate_node_list)
    g.add_nodes(num_nodes)

    num_questions = len(question_tokens)

    ### combine question and first sentence
    all_tokens = question_tokens + first_sent_tokens
    all_masks = question_masks + first_sent_masks

    all_tensor = torch.LongTensor(all_tokens).to(device)
    all_masks_tensor = torch.LongTensor(all_masks).to(device)
    all_encodings = list()
    num_exs = 50
    for iii in range(int(all_tensor.size(0) / num_exs)):
        encoding, _ = bert_model(
            all_tensor[iii * num_exs:(iii + 1) * num_exs], None,
            all_masks_tensor[iii * num_exs:(iii + 1) * num_exs])
        encoding = encoding.detach().cpu()
        all_encodings.append(encoding)
    if all_tensor.size(0) % num_exs > 0:
        encoding, _ = bert_model(
            all_tensor[int(all_tensor.size(0) / num_exs) * num_exs:], None,
            all_masks_tensor[int(all_tensor.size(0) / num_exs) * num_exs:])
        encoding = encoding.detach().cpu()
        all_encodings.append(encoding)
    all_encodings = torch.cat(all_encodings, dim=0)

    all_masks_tensor = all_masks_tensor.cpu()

    g.ndata['first_sent'] = all_encodings[num_questions:].cpu()
    g.ndata['first_sent_mask'] = all_masks_tensor[num_questions:].cpu().eq(0)

    for i in range(len(question_node_list)):
        sub_q_num = node_sub_questions[i]

        g.nodes[i].data['question'] = all_encodings[sub_q_num].unsqueeze(0)
        g.nodes[i].data['question_mask'] = all_masks_tensor[
            sub_q_num].unsqueeze(0).eq(0)
        g.nodes[i].data['label'] = torch.tensor(-1).unsqueeze(0)

    g.nodes[len(
        question_node_list)].data['question'] = all_encodings[0].unsqueeze(0)
    g.nodes[len(question_node_list)].data['question_mask'] = all_masks_tensor[
        0].unsqueeze(0).eq(0)
    g.nodes[len(question_node_list)].data['label'] = torch.tensor(1).unsqueeze(
        0)
    #### for candidates, we only use the full question sentence
    for i in range(
            len(question_node_list) + 1,
            len(question_node_list) + len(candidate_node_list)):
        g.nodes[i].data['question'] = all_encodings[0].unsqueeze(0)
        g.nodes[i].data['question_mask'] = all_masks_tensor[0].unsqueeze(0).eq(
            0)
        g.nodes[i].data['label'] = torch.tensor(0).unsqueeze(0)

    for k_entity in positive_entity['evidence']:
        normalized_k_entity = normalize(k_entity)
        if normalized_k_entity in question_node_list:
            s_id = question_node_list.index(normalized_k_entity)
            g.add_edge(question_node_list.index(normalized_k_entity),
                       len(question_node_list))
            evidence_tokens = list()
            evidence_masks = list()
            evidence_ids = list()
            all_evidences = positive_entity['evidence'][k_entity]

            for evi_text in all_evidences[:num_edges]:
                input_ids, input_mask = text_tokenize(evi_text, tokenizer,
                                                      max_seq_length)

                evidence_tokens.append(input_ids)
                evidence_masks.append(input_mask)

            evidence_tensor = torch.LongTensor(evidence_tokens)
            evidence_masks_tensor = torch.LongTensor(evidence_masks)

            edge_features = torch.LongTensor(1, num_edges,
                                             max_seq_length).zero_()
            edge_feature_masks = torch.LongTensor(1, num_edges,
                                                  max_seq_length).zero_()
            egde_sent_mask = torch.ByteTensor(1, num_edges).fill_(1)
            edge_features[0, :len(evidence_tokens), :].copy_(evidence_tensor)
            edge_feature_masks[0, :len(evidence_tokens), :].copy_(
                evidence_masks_tensor)
            egde_sent_mask[0, :len(evidence_tokens)].fill_(0)

            g.edges[s_id,
                    len(question_node_list)].data['evidence'] = edge_features
            g.edges[s_id, len(question_node_list
                              )].data['evidence_mask'] = edge_feature_masks
            g.edges[s_id, len(question_node_list
                              )].data['evidence_sent_mask'] = egde_sent_mask

    for neg_et in negative_entities:
        for k_entity in neg_et['evidence']:
            normalized_k_entity = normalize(k_entity)
            if normalized_k_entity in question_node_list:
                s_id = question_node_list.index(normalized_k_entity)
                t_id = len(question_node_list) + candidate_node_list.index(
                    normalize(neg_et['et']))
                g.add_edge(s_id, t_id)

                evidence_tokens = list()
                evidence_masks = list()
                evidence_ids = list()
                all_evidences = neg_et['evidence'][normalized_k_entity]
                for evi_text in all_evidences[:num_edges]:
                    input_ids, input_mask = text_tokenize(
                        evi_text, tokenizer, max_seq_length)
                    evidence_tokens.append(input_ids)
                    evidence_masks.append(input_mask)

                evidence_tensor = torch.LongTensor(evidence_tokens)
                evidence_masks_tensor = torch.LongTensor(evidence_masks)

                edge_features = torch.LongTensor(1, num_edges,
                                                 max_seq_length).zero_()
                edge_feature_masks = torch.LongTensor(1, num_edges,
                                                      max_seq_length).zero_()
                egde_sent_mask = torch.ByteTensor(1, num_edges).fill_(1)
                edge_features[0, :len(evidence_tokens), :].copy_(
                    evidence_tensor)
                edge_feature_masks[0, :len(evidence_tokens), :].copy_(
                    evidence_masks_tensor)
                egde_sent_mask[0, :len(evidence_tokens)].fill_(0)

                g.edges[s_id, t_id].data['evidence'] = edge_features
                g.edges[s_id, t_id].data['evidence_mask'] = edge_feature_masks
                g.edges[s_id, t_id].data['evidence_sent_mask'] = egde_sent_mask

    ### Batch the sentences and get BERT embeddings
    if 'evidence' in g.edata:
        evi = g.edata['evidence'].to(device)
        evi_mask = g.edata['evidence_mask'].to(device)
        batch_size, sent_max_len, word_max_len = evi.size(0), evi.size(
            1), evi.size(2)
        evi = evi.view(batch_size * sent_max_len, word_max_len)
        evi_mask = evi_mask.view(batch_size * sent_max_len, word_max_len)
        all_encodings = list()
        num_exs = 50
        for iii in range(int(evi.size(0) / num_exs)):
            encoding, _ = bert_model(
                evi[iii * num_exs:(iii + 1) * num_exs], None,
                evi_mask[iii * num_exs:(iii + 1) * num_exs])
            encoding = encoding.detach().cpu()
            all_encodings.append(encoding)
        if evi.size(0) % num_exs > 0:
            encoding, _ = bert_model(
                evi[int(evi.size(0) / num_exs) * num_exs:], None,
                evi_mask[int(evi.size(0) / num_exs) * num_exs:])
            encoding = encoding.detach().cpu()
            all_encodings.append(encoding)

        g.edata['evidence'] = torch.cat(all_encodings,
                                        dim=0).view(batch_size, sent_max_len,
                                                    word_max_len, -1)
        g.edata['evidence_mask'] = g.edata['evidence_mask'].eq(0)

    return g
Example #17
0
def test_send_multigraph():
    g = DGLGraph(multigraph=True)
    g.add_nodes(3)
    g.add_edge(0, 1)
    g.add_edge(0, 1)
    g.add_edge(0, 1)
    g.add_edge(2, 1)

    def _message_a(edges):
        return {'a': edges.data['a']}

    def _message_b(edges):
        return {'a': edges.data['a'] * 3}

    def _reduce(nodes):
        return {'a': F.max(nodes.mailbox['a'], 1)}

    def answer(*args):
        return F.max(F.stack(args, 0), 0)

    # send by eid
    old_repr = F.randn((4, 5))
    g.ndata['a'] = F.zeros((3, 5))
    g.edata['a'] = old_repr
    g.send([0, 2], message_func=_message_a)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
    assert F.allclose(new_repr[1], answer(old_repr[0], old_repr[2]))

    g.ndata['a'] = F.zeros((3, 5))
    g.edata['a'] = old_repr
    g.send([0, 2, 3], message_func=_message_a)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
    assert F.allclose(new_repr[1], answer(old_repr[0], old_repr[2],
                                          old_repr[3]))

    # send on multigraph
    g.ndata['a'] = F.zeros((3, 5))
    g.edata['a'] = old_repr
    g.send(([0, 2], [1, 1]), _message_a)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
    assert F.allclose(new_repr[1], F.max(old_repr, 0))

    # consecutive send and send_on
    g.ndata['a'] = F.zeros((3, 5))
    g.edata['a'] = old_repr
    g.send((2, 1), _message_a)
    g.send([0, 1], message_func=_message_b)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
    assert F.allclose(new_repr[1],
                      answer(old_repr[0] * 3, old_repr[1] * 3, old_repr[3]))

    # consecutive send_on
    g.ndata['a'] = F.zeros((3, 5))
    g.edata['a'] = old_repr
    g.send(0, message_func=_message_a)
    g.send(1, message_func=_message_b)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
    assert F.allclose(new_repr[1], answer(old_repr[0], old_repr[1] * 3))

    # send_and_recv_on
    g.ndata['a'] = F.zeros((3, 5))
    g.edata['a'] = old_repr
    g.send_and_recv([0, 2, 3], message_func=_message_a, reduce_func=_reduce)
    new_repr = g.ndata['a']
    assert F.allclose(new_repr[1], answer(old_repr[0], old_repr[2],
                                          old_repr[3]))
    assert F.allclose(new_repr[[0, 2]], F.zeros((2, 5)))
Example #18
0
def get_batch_dist(filename, batch_size, gpu_rank):

    doc_dict = read_bert_features(
        '../../Recommenders/data/MINDlarge_train/news_token_bert_features.txt')
    f = open(filename, 'r').readlines()
    length = int(len(f) / 4)
    f = f[gpu_rank * length:(gpu_rank + 1) * length]
    batch_data = []
    all_nodes = 0
    i = 0
    batch_graph = []
    bert_max_len = 20
    imp_index = []
    batch_t = 0
    while i < len(f):
        data = json.loads(f[i])
        neg_list = data['neg_list']
        neg_sample = np.random.choice(len(neg_list), 1, replace=False)[0]
        #ex['neg_list']=doc_dict[neg_list[neg_sample]]
        data['node'][-1]['context'] = doc_dict[neg_list[neg_sample]]
        # data['node'][-3]['label']=-2

        all_nodes += len(data['node'])
        g = DGLGraph()
        g.add_nodes(len(data['node']))
        # edge_dict = dict()
        # for edge in data['edges']:
        #     e_start = edge['start']
        #     e_end = edge['end']

        #     idx = (e_start, e_end)
        #     if idx not in edge_dict:
        #         edge_dict[idx] =list()
        #     # if edge['sent'] not in edge_dict[idx]:
        #     #     edge_dict[idx].append(edge['sent'])

        # for idx, context in edge_dict.items():
        #     start, end = idx
        #     g.add_edge(start, end)

        for k in range(len(data['node']) - 2):
            for j in range(k + 1, len(data['node'])):
                #if i!=len(node)-2 and j!=len(node)_1:
                # data['edges'].append({'start': i, 'end': j})
                # data['edges'].append({'start': j, 'end': i})
                g.add_edge(k, j)
                g.add_edge(j, k)

        for k, node in enumerate(data['node']):

            context = node['context']
            context = [int(x) for x in context]
            #evidence = list(set(node['evidence']))
            encoding_inputs, encoding_masks, encoding_ids = encode_sequence_hotpot(
                context, bert_max_len)
            g.nodes[k].data['encoding'] = encoding_inputs.unsqueeze(0)
            g.nodes[k].data['encoding_mask'] = encoding_masks.unsqueeze(0)
            g.nodes[k].data['segment_id'] = encoding_ids.unsqueeze(0)

            #print('????',node)

            if node['label'] == 1:
                g.nodes[k].data['label'] = torch.tensor(1).unsqueeze(0).type(
                    torch.FloatTensor)
            elif node['label'] == 0:
                g.nodes[k].data['label'] = torch.tensor(0).unsqueeze(0).type(
                    torch.FloatTensor)
            elif k == len(data['node']) - 3:
                g.nodes[k].data['label'] = torch.tensor(-2).unsqueeze(0).type(
                    torch.FloatTensor)
            else:
                g.nodes[k].data['label'] = torch.tensor(-1).unsqueeze(0).type(
                    torch.FloatTensor)

        if batch_t >= batch_size:
            batch_t = 0
            g = dgl.batch(batch_graph)
            batch_graph = []
            all_nodes = 0
            yield g, imp_index, torch.tensor([0] * len(imp_index))
            imp_index = []
        else:
            batch_t += 1
            batch_graph.append(g)
            imp_index.append(data['imp_id'])
            #print('???',data['imp_id'],imp_index)
            i += 1
Example #19
0
def batch_transform_bert_hotpot(inst, bert_max_len, bert_tokenizer):
    g = DGLGraph()

    g.add_nodes(len(inst['node']))

    question = inst['question']

    for i, node in enumerate(inst['node']):
        inst['node'][i]['evidence'] = list()

    #### concatenate all edge sentences

    edge_dict = dict()
    for edge in inst['edge']:
        e_start = edge['start']
        e_end = edge['end']

        idx = (e_start, e_end)
        if idx not in edge_dict:
            edge_dict[idx] = list()
        if edge['sent'] not in edge_dict[idx]:
            edge_dict[idx].append(edge['sent'])

    for idx, context in edge_dict.items():
        start, end = idx
        g.add_edge(start, end)
        for sent in context:
            inst['node'][end]['evidence'].append(sent)

    for i, node in enumerate(inst['node']):

        context = node['context']
        evidence = list(set(node['evidence']))

        encoding_inputs, encoding_masks, encoding_ids, B_start = encode_sequence_hotpot(
            question, context, evidence, bert_max_len, bert_tokenizer)
        g.nodes[i].data['encoding'] = encoding_inputs.unsqueeze(0)
        g.nodes[i].data['encoding_mask'] = encoding_masks.unsqueeze(0)
        g.nodes[i].data['segment_id'] = encoding_ids.unsqueeze(0)
        g.nodes[i].data['B_start'] = B_start

        if node['is_ans'] == 1:
            g.nodes[i].data['label'] = torch.tensor(1).unsqueeze(0).type(
                torch.FloatTensor)
        elif node['is_ans'] == 0:
            g.nodes[i].data['label'] = torch.tensor(0).unsqueeze(0).type(
                torch.FloatTensor)
        else:
            g.nodes[i].data['label'] = torch.tensor(-1).unsqueeze(0).type(
                torch.FloatTensor)

        spans = node['spans']
        if node['is_ans'] != 1:
            g.nodes[i].data['span_label'] = torch.tensor(-1).unsqueeze(0).type(
                torch.FloatTensor)
        else:
            g.nodes[i].data['span_label'] = torch.tensor(0).unsqueeze(0).type(
                torch.FloatTensor)

        if len(spans) == 0:
            spans = [(-1, -1)]
        start_spans = torch.LongTensor([p[0] for p in spans])
        end_spans = torch.LongTensor([p[1] for p in spans])
        start_spans = start_spans + B_start
        end_spans = end_spans + B_start

        g.nodes[i].data['label_start'] = start_spans
        g.nodes[i].data['label_end'] = end_spans

    return g, inst['qid']
Example #20
0
def vectorize_qanta(ex, model, istrain, max_seq_length=128):
    q_id = ex['id']
    text = ex['text']
    positive_entity = ex['pos_et']
    negative_entities = ex['neg_ets']
    ### Maximum 3 sentences per edge
    num_edges = 3

    g = DGLGraph()

    question_node_list = list()
    candidate_node_list = list()
    first_sent_tokens = list()
    first_sent_masks = list()
    question_tokens = list()
    question_masks = list()

    input_ids, input_mask = text_tokenize(text, model.word_dict, max_seq_length)
    question_tokens.append(input_ids)
    question_masks.append(input_mask)
    node_sub_questions = list()
    

    for sup_q in ex['q_et']:
        
        sub_question = sup_q['text']
        input_ids, input_mask = text_tokenize(word_tokenize(sub_question), model.word_dict, max_seq_length)
        question_tokens.append(input_ids)
        question_masks.append(input_mask)

        for et in sup_q['entity']:
            topic = et['et']
            node_first_sent = et['first_sent']
            if topic is None:
                continue

            question_node_list.append(topic)
            question_idx = len(question_tokens) - 1
            node_sub_questions.append(question_idx)
            input_ids, input_mask = text_tokenize(node_first_sent, model.word_dict, max_seq_length)
            
            first_sent_tokens.append(input_ids)
            first_sent_masks.append(input_mask)


    candidate_node_list.append(normalize(positive_entity['et']))
    input_ids, input_mask = text_tokenize(positive_entity['first_sent'], model.word_dict, max_seq_length)
    first_sent_tokens.append(input_ids)
    first_sent_masks.append(input_mask)
    node_sub_questions.append(0)
    for neg_et in negative_entities:
        input_ids, input_mask = text_tokenize(neg_et['first_sent'], model.word_dict, max_seq_length)
        
        candidate_node_list.append(normalize(neg_et['et']))

        first_sent_tokens.append(input_ids)
        first_sent_masks.append(input_mask)
        node_sub_questions.append(0)

    num_nodes = len(question_node_list) + len(candidate_node_list)
    g.add_nodes(num_nodes)

    num_questions = len(question_tokens)

    ### combine question and first sentence
    all_tokens = question_tokens + first_sent_tokens
    all_masks = question_masks + first_sent_masks

    all_tensor = torch.LongTensor(all_tokens)
    all_masks_tensor = torch.LongTensor(all_masks)

    #### add node features
    g.ndata['first_sent'] = all_tensor[num_questions:].cpu()
    g.ndata['first_sent_mask'] = all_masks_tensor[num_questions:].eq(0)
  

    for i in range(len(question_node_list)):
        sub_q_num = node_sub_questions[i]
        
        g.nodes[i].data['question'] = all_tensor[sub_q_num].unsqueeze(0)
        g.nodes[i].data['question_mask'] = all_masks_tensor[sub_q_num].unsqueeze(0).eq(0)
        g.nodes[i].data['label'] = torch.tensor(-1).unsqueeze(0)
        
    g.nodes[len(question_node_list)].data['question'] = all_tensor[0].unsqueeze(0)
    g.nodes[len(question_node_list)].data['question_mask'] = all_masks_tensor[0].unsqueeze(0).eq(0)
    g.nodes[len(question_node_list)].data['label'] = torch.tensor(1).unsqueeze(0)
    #### for candidates, we only use the full question sentence
    for i in range(len(question_node_list) + 1, len(question_node_list) + len(candidate_node_list)):
        g.nodes[i].data['question'] = all_tensor[0].unsqueeze(0)
        g.nodes[i].data['question_mask'] = all_masks_tensor[0].unsqueeze(0).eq(0)
        g.nodes[i].data['label'] = torch.tensor(0).unsqueeze(0)
        

    #### add postive edges
        
    for k_entity in positive_entity['evidence']:
        normalized_k_entity = normalize(k_entity)
        if normalized_k_entity in question_node_list:
            s_id = question_node_list.index(normalized_k_entity)
            g.add_edge(question_node_list.index(normalized_k_entity), len(question_node_list))
            evidence_tokens = list()
            evidence_masks = list()
            all_evidences = positive_entity['evidence'][k_entity]
            
            for evi_text in all_evidences[:num_edges]:
                input_ids, input_mask = text_tokenize(evi_text, model.word_dict, max_seq_length)
                
                evidence_tokens.append(input_ids)
                evidence_masks.append(input_mask)

            evidence_tensor = torch.LongTensor(evidence_tokens)
            evidence_masks_tensor = torch.LongTensor(evidence_masks)

            edge_features = torch.LongTensor(1, num_edges, max_seq_length).zero_()
            edge_feature_masks = torch.LongTensor(1, num_edges, max_seq_length).zero_()
            egde_sent_mask = torch.ByteTensor(1, num_edges).fill_(1)
            edge_features[0, :len(evidence_tokens), :].copy_(evidence_tensor)
            edge_feature_masks[0, :len(evidence_tokens), :].copy_(evidence_masks_tensor)
            egde_sent_mask[0, :len(evidence_tokens)].fill_(0)
     
            g.edges[s_id, len(question_node_list)].data['evidence'] = edge_features
            g.edges[s_id, len(question_node_list)].data['evidence_mask'] = edge_feature_masks.eq(0)
            g.edges[s_id, len(question_node_list)].data['evidence_sent_mask'] = egde_sent_mask


        
    for neg_et in negative_entities:
        #### 
        for k_entity in neg_et['evidence']:
            normalized_k_entity = normalize(k_entity)
            if normalized_k_entity in question_node_list:
                s_id = question_node_list.index(normalized_k_entity)
                t_id = len(question_node_list) + candidate_node_list.index(normalize(neg_et['et']))
                g.add_edge(s_id, t_id)
                    
                evidence_tokens = list()
                evidence_masks = list()
                all_evidences = neg_et['evidence'][normalized_k_entity]
                for evi_text in all_evidences[:num_edges]:
                    input_ids, input_mask = text_tokenize(evi_text, model.word_dict, max_seq_length)
                    evidence_tokens.append(input_ids)
                    evidence_masks.append(input_mask)
                evidence_tensor = torch.LongTensor(evidence_tokens)
                evidence_masks_tensor = torch.LongTensor(evidence_masks)

                edge_features = torch.LongTensor(1, num_edges, max_seq_length).zero_()
                edge_feature_masks = torch.LongTensor(1, num_edges, max_seq_length).zero_()
                egde_sent_mask = torch.ByteTensor(1, num_edges).fill_(1)
                edge_features[0, :len(evidence_tokens), :].copy_(evidence_tensor)
                edge_feature_masks[0, :len(evidence_tokens), :].copy_(evidence_masks_tensor)
                egde_sent_mask[0, :len(evidence_tokens)].fill_(0)
     
                g.edges[s_id, t_id].data['evidence'] = edge_features
                g.edges[s_id, t_id].data['evidence_mask'] = edge_feature_masks.eq(0)
                g.edges[s_id, t_id].data['evidence_sent_mask'] = egde_sent_mask

    return g
Example #21
0
class MoleculeEnv(object):
    """MDP environment for generating molecules.

    Parameters
    ----------
    atom_types : list
        E.g. ['C', 'N']
    bond_types : list
        E.g. [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
        Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
    """
    def __init__(self, atom_types, bond_types):
        super(MoleculeEnv, self).__init__()

        self.atom_types = atom_types
        self.bond_types = bond_types

        self.atom_type_to_id = dict()
        self.bond_type_to_id = dict()

        for id, a_type in enumerate(atom_types):
            self.atom_type_to_id[a_type] = id

        for id, b_type in enumerate(bond_types):
            self.bond_type_to_id[b_type] = id

    def get_decision_sequence(self, mol, atom_order):
        """Extract a decision sequence with which DGMG can generate the
        molecule with a specified atom order.

        Parameters
        ----------
        mol : Chem.rdchem.Mol
        atom_order : list
            Specifies a mapping between the original atom
            indices and the new atom indices. In particular,
            atom_order[i] is re-labeled as i.

        Returns
        -------
        decisions : list
            decisions[i] is a 2-tuple (i, j)
            - If i = 0, j specifies either the type of the atom to add
              self.atom_types[j] or termination with j = len(self.atom_types)
            - If i = 1, j specifies either the type of the bond to add
              self.bond_types[j] or termination with j = len(self.bond_types)
            - If i = 2, j specifies the destination atom id for the bond to add.
              With the formulation of DGMG, j must be created before the decision.
        """
        decisions = []
        old2new = dict()

        for new_id, old_id in enumerate(atom_order):
            atom = mol.GetAtomWithIdx(old_id)
            a_type = atom.GetSymbol()
            decisions.append((0, self.atom_type_to_id[a_type]))
            for bond in atom.GetBonds():
                u = bond.GetBeginAtomIdx()
                v = bond.GetEndAtomIdx()
                if v == old_id:
                    u, v = v, u
                if v in old2new:
                    decisions.append(
                        (1, self.bond_type_to_id[bond.GetBondType()]))
                    decisions.append((2, old2new[v]))
            decisions.append((1, len(self.bond_types)))
            old2new[old_id] = new_id
        decisions.append((0, len(self.atom_types)))
        return decisions

    def reset(self, rdkit_mol=False):
        """Setup for generating a new molecule

        Parameters
        ----------
        rdkit_mol : bool
            Whether to keep a Chem.rdchem.Mol object so
            that we know what molecule is being generated
        """
        self.dgl_graph = DGLGraph()
        # If there are some features for nodes and edges,
        # zero tensors will be set for those of new nodes and edges.
        self.dgl_graph.set_n_initializer(dgl.frame.zero_initializer)
        self.dgl_graph.set_e_initializer(dgl.frame.zero_initializer)

        self.mol = None
        if rdkit_mol:
            # RWMol is a molecule class that is intended to be edited.
            self.mol = Chem.RWMol(Chem.MolFromSmiles(''))

    def num_atoms(self):
        """Get the number of atoms for the current molecule.

        Returns
        -------
        int
        """
        return self.dgl_graph.number_of_nodes()

    def add_atom(self, type):
        """Add an atom of the specified type.

        Parameters
        ----------
        type : int
            Should be in the range of [0, len(self.atom_types) - 1]
        """
        self.dgl_graph.add_nodes(1)
        if self.mol is not None:
            self.mol.AddAtom(Chem.Atom(self.atom_types[type]))

    def add_bond(self, u, v, type, bi_direction=True):
        """Add a bond of the specified type between atom u and v.

        Parameters
        ----------
        u : int
            Index for the first atom
        v : int
            Index for the second atom
        type : int
            Index for the bond type
        bi_direction : bool
            Whether to add edges for both directions in the DGLGraph.
            If not, we will only add the edge (u, v).
        """
        if bi_direction:
            self.dgl_graph.add_edges([u, v], [v, u])
        else:
            self.dgl_graph.add_edge(u, v)

        if self.mol is not None:
            self.mol.AddBond(u, v, self.bond_types[type])

    def get_current_smiles(self):
        """Get the generated molecule in SMILES

        Returns
        -------
        s : str
            SMILES
        """
        assert self.mol is not None, 'Expect a Chem.rdchem.Mol object initialized.'
        s = Chem.MolToSmiles(self.mol)
        return s
Example #22
0
def vectorize_trivia(ex, tokenizer, device, istrain, max_seq_length=64):
    bert_model.eval()
    t_id = ex['id']
    text = ex['text']
    positive_entity = ex['pos_et']
    negative_entities = ex['neg_ets']

    num_edges = 5  #?
    g = DGLGraph()  #instance

    question_node_list = list()
    candidate_node_list = list()
    first_sent_tokens = list()
    first_sent_masks = list()
    question_tokens = list()
    question_masks = list()

    input_ids, input_mask = text_tokenize(text, tokenizer, max_seq_length)
    question_tokens.append(input_ids)
    question_masks.append(input_mask)
    node_sub_questions = list()

    ### Since Trivia QA question has much fewer entities, we also incorporate IR retrieved pages
    ### as additional entities linked to candidate nodes (So that we get some more edge evidence sentences)
    for sup_q in ex['q_et']:  #each question entity

        sub_question = sup_q['text']
        input_ids, input_mask = text_tokenize(
            sub_question, tokenizer,
            max_seq_length)  #utils text tokenize?id and mask
        question_tokens.append(input_ids)
        question_masks.append(input_mask)

        for et in sup_q['entity']:
            topic = et['et']
            node_first_sent = et['first_sent']
            if topic is None:
                continue

            question_node_list.append(topic)
            question_idx = len(question_tokens) - 1
            node_sub_questions.append(question_idx)
            input_ids, input_mask = text_tokenize(
                node_first_sent, tokenizer,
                max_seq_length)  #utils text tokenize?id and mask

            first_sent_tokens.append(input_ids)
            first_sent_masks.append(input_mask)

    candidate_node_list.append(normalize(positive_entity['et']))  #normalized
    input_ids, input_mask = text_tokenize(positive_entity['first_sent'],
                                          tokenizer, max_seq_length)
    first_sent_tokens.append(input_ids)
    first_sent_masks.append(input_mask)
    node_sub_questions.append(0)

    for neg_et in negative_entities:
        candidate_node_list.append(normalize(neg_et['et']))
        input_ids, input_mask = text_tokenize(neg_et['first_sent'], tokenizer,
                                              max_seq_length)

        first_sent_tokens.append(input_ids)
        first_sent_masks.append(input_mask)
        node_sub_questions.append(0)
    #nodes
    num_nodes = len(question_node_list) + len(candidate_node_list)
    g.add_nodes(num_nodes)

    num_questions = len(question_tokens)

    ### combine question and first sentence
    all_tokens = question_tokens + first_sent_tokens
    all_masks = question_masks + first_sent_masks

    all_tensor = torch.LongTensor(all_tokens).to(device)  #token as tensor
    all_masks_tensor = torch.LongTensor(all_masks).to(
        device)  #masked token as tensor
    all_encodings = list()
    num_exs = 50
    #Generate encoding? what is this
    for iii in range(int(all_tensor.size(0) / num_exs)):
        encoding, _ = bert_model(all_tensor[iii * num_exs:(iii + 1) * num_exs],
                                 None,
                                 all_masks_tensor[iii * num_exs:(iii + 1) *
                                                  num_exs])  #bert_model()?
        encoding = encoding.detach().cpu()
        all_encodings.append(encoding)
    if all_tensor.size(0) % num_exs > 0:
        encoding, _ = bert_model(
            all_tensor[int(all_tensor.size(0) / num_exs) * num_exs:], None,
            all_masks_tensor[int(all_tensor.size(0) / num_exs) * num_exs:])
        encoding = encoding.detach().cpu()
        all_encodings.append(encoding)
    all_encodings = torch.cat(all_encodings, dim=0)

    all_masks_tensor = all_masks_tensor.cpu()

    #saved for graph
    g.ndata['first_sent'] = all_encodings[num_questions:].cpu()
    g.ndata['first_sent_mask'] = all_masks_tensor[num_questions:].cpu().eq(0)

    for i in range(len(question_node_list)):
        sub_q_num = node_sub_questions[i]
        #distribute each node with its feature
        g.nodes[i].data['question'] = all_encodings[sub_q_num].unsqueeze(0)
        g.nodes[i].data['question_mask'] = all_masks_tensor[
            sub_q_num].unsqueeze(0).eq(0)
        g.nodes[i].data['label'] = torch.tensor(-1).unsqueeze(0)

    g.nodes[len(
        question_node_list)].data['question'] = all_encodings[0].unsqueeze(0)
    g.nodes[len(question_node_list)].data['question_mask'] = all_masks_tensor[
        0].unsqueeze(0).eq(0)
    g.nodes[len(question_node_list)].data['label'] = torch.tensor(1).unsqueeze(
        0)
    #### for candidates, we only use the full question sentence
    for i in range(
            len(question_node_list) + 1,
            len(question_node_list) + len(candidate_node_list)):
        g.nodes[i].data['question'] = all_encodings[0].unsqueeze(0)
        g.nodes[i].data['question_mask'] = all_masks_tensor[0].unsqueeze(0).eq(
            0)
        g.nodes[i].data['label'] = torch.tensor(0).unsqueeze(0)

    #positive entities
    for k_entity in positive_entity['evidence']:
        normalized_k_entity = normalize(k_entity)  #normalize
        if normalized_k_entity in question_node_list:
            s_id = question_node_list.index(normalized_k_entity)  #list.index()
            g.add_edge(question_node_list.index(normalized_k_entity),
                       len(question_node_list))  #edge between
            evidence_tokens = list()
            evidence_masks = list()
            evidence_ids = list()
            all_evidences = positive_entity['evidence'][k_entity]

            for evi_text in all_evidences[:num_edges]:
                input_ids, input_mask = text_tokenize(evi_text, tokenizer,
                                                      max_seq_length)

                evidence_tokens.append(input_ids)
                evidence_masks.append(input_mask)

            evidence_tensor = torch.LongTensor(evidence_tokens)
            evidence_masks_tensor = torch.LongTensor(evidence_masks)

            edge_features = torch.LongTensor(1, num_edges,
                                             max_seq_length).zero_()
            edge_feature_masks = torch.LongTensor(1, num_edges,
                                                  max_seq_length).zero_()
            egde_sent_mask = torch.ByteTensor(1, num_edges).fill_(1)
            edge_features[0, :len(evidence_tokens), :].copy_(evidence_tensor)
            edge_feature_masks[0, :len(evidence_tokens), :].copy_(
                evidence_masks_tensor)
            egde_sent_mask[0, :len(evidence_tokens)].fill_(0)

            g.edges[s_id,
                    len(question_node_list)].data['evidence'] = edge_features
            g.edges[s_id, len(question_node_list)].data[
                'evidence_mask'] = edge_feature_masks  #edge feature mask
            g.edges[s_id, len(question_node_list)].data[
                'evidence_sent_mask'] = egde_sent_mask  #edge sentence mask

    #negative entities
    for neg_et in negative_entities:
        for k_entity in neg_et['evidence']:
            normalized_k_entity = normalize(k_entity)
            if normalized_k_entity in question_node_list:
                s_id = question_node_list.index(normalized_k_entity)
                t_id = len(question_node_list) + candidate_node_list.index(
                    normalize(neg_et['et']))
                g.add_edge(s_id, t_id)

                evidence_tokens = list()
                evidence_masks = list()
                evidence_ids = list()
                all_evidences = neg_et['evidence'][normalized_k_entity]
                for evi_text in all_evidences[:num_edges]:
                    input_ids, input_mask = text_tokenize(
                        evi_text, tokenizer, max_seq_length)
                    evidence_tokens.append(input_ids)
                    evidence_masks.append(input_mask)

                evidence_tensor = torch.LongTensor(evidence_tokens)
                evidence_masks_tensor = torch.LongTensor(evidence_masks)

                edge_features = torch.LongTensor(1, num_edges,
                                                 max_seq_length).zero_()
                edge_feature_masks = torch.LongTensor(1, num_edges,
                                                      max_seq_length).zero_()
                egde_sent_mask = torch.ByteTensor(1, num_edges).fill_(1)
                edge_features[0, :len(evidence_tokens), :].copy_(
                    evidence_tensor)
                edge_feature_masks[0, :len(evidence_tokens), :].copy_(
                    evidence_masks_tensor)
                egde_sent_mask[0, :len(evidence_tokens)].fill_(0)

                g.edges[s_id, t_id].data['evidence'] = edge_features
                g.edges[s_id, t_id].data['evidence_mask'] = edge_feature_masks
                g.edges[s_id, t_id].data['evidence_sent_mask'] = egde_sent_mask

    ### Batch the sentences and get BERT embeddings
    #get embeddings
    if 'evidence' in g.edata:
        evi = g.edata['evidence'].to(device)
        evi_mask = g.edata['evidence_mask'].to(device)
        batch_size, sent_max_len, word_max_len = evi.size(0), evi.size(
            1), evi.size(2)
        evi = evi.view(batch_size * sent_max_len, word_max_len)
        evi_mask = evi_mask.view(batch_size * sent_max_len, word_max_len)
        all_encodings = list()
        num_exs = 50
        for iii in range(int(evi.size(0) / num_exs)):
            encoding, _ = bert_model(
                evi[iii * num_exs:(iii + 1) * num_exs], None,
                evi_mask[iii * num_exs:(iii + 1) * num_exs])
            encoding = encoding.detach().cpu()
            all_encodings.append(encoding)
        if evi.size(0) % num_exs > 0:
            encoding, _ = bert_model(
                evi[int(evi.size(0) / num_exs) * num_exs:], None,
                evi_mask[int(evi.size(0) / num_exs) * num_exs:])
            encoding = encoding.detach().cpu()
            all_encodings.append(encoding)

        g.edata['evidence'] = torch.cat(all_encodings,
                                        dim=0).view(batch_size, sent_max_len,
                                                    word_max_len, -1)
        g.edata['evidence_mask'] = g.edata['evidence_mask'].eq(0)

    #return graph
    return g
Example #23
0
def mol2dgl(cand_batch, mol_tree_batch):
    cand_graphs = []
    tree_mess_source_edges = []  # map these edges from trees to...
    tree_mess_target_edges = []  # these edges on candidate graphs
    tree_mess_target_nodes = []
    n_nodes = 0

    for mol, mol_tree, ctr_node_id in cand_batch:
        atom_feature_list = []
        bond_feature_list = []
        ctr_node = mol_tree.nodes[ctr_node_id]
        ctr_bid = ctr_node['idx']
        g = DGLGraph()

        for atom in mol.GetAtoms():
            atom_feature_list.append(atom_features(atom))
            g.add_node(atom.GetIdx())

        for bond in mol.GetBonds():
            a1 = bond.GetBeginAtom()
            a2 = bond.GetEndAtom()
            begin_idx = a1.GetIdx()
            end_idx = a2.GetIdx()
            features = bond_features(bond)

            g.add_edge(begin_idx, end_idx)
            bond_feature_list.append(features)
            g.add_edge(end_idx, begin_idx)
            bond_feature_list.append(features)

            x_nid, y_nid = a1.GetAtomMapNum(), a2.GetAtomMapNum()
            # Tree node ID in the batch
            x_bid = mol_tree.nodes[x_nid - 1]['idx'] if x_nid > 0 else -1
            y_bid = mol_tree.nodes[y_nid - 1]['idx'] if y_nid > 0 else -1
            if x_bid >= 0 and y_bid >= 0 and x_bid != y_bid:
                if (x_bid, y_bid) in mol_tree_batch.edge_list:
                    tree_mess_target_edges.append(
                        (begin_idx + n_nodes, end_idx + n_nodes))
                    tree_mess_source_edges.append((x_bid, y_bid))
                    tree_mess_target_nodes.append(end_idx + n_nodes)
                if (y_bid, x_bid) in mol_tree_batch.edge_list:
                    tree_mess_target_edges.append(
                        (end_idx + n_nodes, begin_idx + n_nodes))
                    tree_mess_source_edges.append((y_bid, x_bid))
                    tree_mess_target_nodes.append(begin_idx + n_nodes)

        n_nodes += len(g.nodes)

        atom_x = torch.stack(atom_feature_list)
        g.set_n_repr({'x': atom_x})
        if len(bond_feature_list) > 0:
            bond_x = torch.stack(bond_feature_list)
            g.set_e_repr({
                'x':
                bond_x,
                'src_x':
                atom_x.new(len(bond_feature_list), ATOM_FDIM).zero_()
            })
        cand_graphs.append(g)

    return cand_graphs, tree_mess_source_edges, tree_mess_target_edges, \
           tree_mess_target_nodes
Example #24
0
class D2GCN(nn.Module):
    def __init__(self, in_feat_dim, out_feat_dim):
        super(D2GCN, self).__init__()
        self.fedge = nn.Sequential(
            nn.Linear(in_feat_dim * 2, in_feat_dim // 64),
            nn.BatchNorm1d(in_feat_dim // 64), nn.Dropout(dropout),
            nn.LeakyReLU(), nn.Linear(in_feat_dim // 64, out_feat_dim),
            nn.BatchNorm1d(out_feat_dim), nn.Dropout(dropout), nn.ReLU())
        if feature_drop:
            self.feat_drop = nn.Dropout(feature_drop)
        else:
            self.feat_drop = lambda x: x
        if att_drop:
            self.att_drop = nn.Dropout(att_drop)
        else:
            self.att_drop = lambda x: x
        self.attn_l = nn.Parameter(torch.Tensor(size=(1, out_feat_dim)))
        self.attn_r = nn.Parameter(torch.Tensor(size=(1, out_feat_dim)))

        self.relu = nn.LeakyReLU(alpha)
        self.softmax = edge_softmax
        self.fnode = nn.Sequential(
            nn.Linear(in_feat_dim + out_feat_dim, out_feat_dim // 64),
            nn.BatchNorm1d(out_feat_dim // 64), nn.Dropout(dropout),
            nn.LeakyReLU(), nn.Linear(out_feat_dim // 64, out_feat_dim),
            nn.BatchNorm1d(out_feat_dim), nn.Dropout(dropout), nn.ReLU())

        nn.init.xavier_normal_(self.attn_l.data, gain=1.414)
        nn.init.xavier_normal_(self.attn_r.data, gain=1.414)

    def build_graph(self, num_nodes, device):
        self.g = DGLGraph()
        self.g.add_nodes(num_nodes)
        for i in range(0, num_nodes):
            for j in range(0, num_nodes):
                if i != j:
                    self.g.add_edge(i, j)
                    self.g.add_edge(j, i)

        self.g.to(device)
        self.g.register_message_func(self.send_source)
        self.g.register_reduce_func(self.simple_reduce)

    def send_source(self, edges):
        edge_feature = self.fedge.forward(
            torch.cat((edges.src["h"], edges.dst["h"]), dim=1))
        msg = self.fnode.forward(
            torch.cat((edges.src["h"], edge_feature), dim=1))
        m = torch.mul(msg, edges.data['a_drop'])
        return {"m": m}

    def simple_reduce(self, nodes):
        return {"h": torch.sum(nodes.mailbox['m'], dim=1) + nodes.data["h"]}

    def edge_attention(self, edges):
        a = self.relu(edges.src['a1'] + edges.dst['a2'])
        return {'a': a}

    def edge_softmax(self):
        att = self.softmax(self.g, self.g.edata.pop('a'))
        self.g.edata['a_drop'] = self.att_drop(att)

    def forward(self, n_feature):
        a1 = (n_feature * self.attn_l).sum(dim=-1).unsqueeze(-1)
        a2 = (n_feature * self.attn_r).sum(dim=-1).unsqueeze(-1)
        self.g.ndata.update({'h': n_feature, 'a1': a1, 'a2': a2})
        self.g.apply_edges(self.edge_attention)
        self.edge_softmax()
        self.g.send(self.g.edges())
        self.g.recv(self.g.nodes())
        return self.g.ndata.pop('h')