def test_to_dense_batch():
    x = torch.Tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]])
    batch = torch.tensor([0, 0, 1, 2, 2, 2])

    out, mask = to_dense_batch(x, batch)
    expected = [
        [[1, 2], [3, 4], [0, 0]],
        [[5, 6], [0, 0], [0, 0]],
        [[7, 8], [9, 10], [11, 12]],
    ]
    assert out.size() == (3, 3, 2)
    assert out.tolist() == expected
    assert mask.tolist() == [[1, 1, 0], [1, 0, 0], [1, 1, 1]]

    out, mask = to_dense_batch(x, batch, max_num_nodes=5)
    assert out.size() == (3, 5, 2)
    assert out[:, :3].tolist() == expected
    assert mask.tolist() == [[1, 1, 0, 0, 0], [1, 0, 0, 0, 0], [1, 1, 1, 0, 0]]

    out, mask = to_dense_batch(x)
    assert out.size() == (1, 6, 2)
    assert out[0].tolist() == x.tolist()
    assert mask.tolist() == [[1, 1, 1, 1, 1, 1]]

    out, mask = to_dense_batch(x, max_num_nodes=10)
    assert out.size() == (1, 10, 2)
    assert out[0, :6].tolist() == x.tolist()
    assert mask.tolist() == [[1, 1, 1, 1, 1, 1, 0, 0, 0, 0]]

    out, mask = to_dense_batch(x, batch, batch_size=4)
    assert out.size() == (4, 3, 2)
Beispiel #2
0
def batch_sparse(scores, labels, batch):
    """
    method to convert "sparse" pyg vectors of scores and labels to dense ones
    """
    batch_scores, _ = to_dense_batch(scores, batch, fill_value=-10e8)
    batch_labels, _ = to_dense_batch(labels, batch, fill_value=0)
    return batch_scores, batch_labels
Beispiel #3
0
    def forward(self, inputs: MINDBatch):
        if is_precomputed(inputs['x_hist']):
            x_hist = inputs['x_hist']
        else:
            x_hist = self.encoder.forward(inputs['x_hist'])
        x_hist, mask_hist = to_dense_batch(x_hist, inputs['batch_hist'])
        x_hist = self.self_attn.forward(x_hist,
                                        attn_mask=mask_hist)[0]  # DistilBERT
        x_hist, _ = self.additive_attn(x_hist)

        if is_precomputed(inputs['x_cand']):
            x_cand = inputs['x_cand']
        else:
            x_cand = self.encoder.forward(inputs['x_cand'])
        x_cand, mask_cand = to_dense_batch(x_cand, inputs['batch_cand'])

        logits = torch.bmm(x_hist.unsqueeze(1), x_cand.permute(0, 2,
                                                               1)).squeeze(1)
        logits = logits[mask_cand]

        targets = inputs['targets']
        if targets is None:
            return logits

        if self.training:
            criterion = nn.CrossEntropyLoss()
            # criterion = LabelSmoothingCrossEntropy()
            loss = criterion(logits.reshape(targets.size(0), -1), targets)
        else:
            # In case of val, targets are multi label. It's not comparable with train.
            with torch.no_grad():
                criterion = nn.BCEWithLogitsLoss()
                loss = criterion(logits, targets.float())

        return loss, logits
Beispiel #4
0
    def forward(self, data):
        x, edge_index, batch, num_graphs = data.x, data.edge_index, data.batch, data.num_graphs

        a_1 = F.relu(self.mp_a1(x, edge_index))
        x_1 = F.relu(self.mp_x1(x, edge_index))
        if self.skip:
            a_1 = torch.cat([a_1, x], dim=1)
            x_1 = torch.cat([x_1, x], dim=1)
            a_1 = self.linear_a(a_1)
            x_1 = self.linear_x(x_1)


        a_2 = self.mp_a2(a_1, edge_index)
        x_2 = F.relu(self.mp_x2(x_1, edge_index))

        if self.skip:
            a_2 = torch.cat([a_2, a_1], dim=1)
            x_2 = torch.cat([x_2, x_1], dim=1)

        a_2 = softmax(a_2, batch)

        a_batch, _ = to_dense_batch(a_2, batch)
        a_t = a_batch.transpose(2, 1)
        x_batch, _ = to_dense_batch(x_2, batch)
        prods = torch.bmm(a_t, x_batch)
        flat = torch.flatten(prods, 1, -1)
        batch_out = self.linear2(flat)

        final = F.softmax(batch_out, dim=-1)

        return final
Beispiel #5
0
    def calculate_histogram(self, abstract_features_1, abstract_features_2, batch_1, batch_2):
        """
        Calculate histogram from similarity matrix.
        :param abstract_features_1: Feature matrix for target graphs.
        :param abstract_features_2: Feature matrix for source graphs.
        :param batch_1: Batch vector for source graphs, which assigns each node to a specific example
        :param batch_1: Batch vector for target graphs, which assigns each node to a specific example
        :return hist: Histsogram of similarity scores.
        """
        abstract_features_1, mask_1 = to_dense_batch(abstract_features_1, batch_1)
        abstract_features_2, mask_2 = to_dense_batch(abstract_features_2, batch_2)

        B1, N1, _ = abstract_features_1.size()
        B2, N2, _ = abstract_features_2.size()

        mask_1 = mask_1.view(B1, N1)
        mask_2 = mask_2.view(B2, N2)
        num_nodes = torch.max(mask_1.sum(dim=1), mask_2.sum(dim=1))

        scores = torch.matmul(abstract_features_1, abstract_features_2.permute([0,2,1])).detach()

        hist_list = []
        for i, mat in enumerate(scores):
            mat = torch.sigmoid(mat[:num_nodes[i], :num_nodes[i]]).view(-1)
            hist = torch.histc(mat, bins=self.args.bins)
            hist = hist/torch.sum(hist)
            hist = hist.view(1, -1)
            hist_list.append(hist)
        
        return torch.stack(hist_list).view(-1, self.args.bins)
Beispiel #6
0
    def forward(self, w, edge_index, batch):
        prob = torch.relu(self.bn1(self.conv1(w.unsqueeze(1), edge_index)))
        prob = torch.relu(self.bn2(self.conv2(prob, edge_index)))
        prob = torch.relu(self.bn3(self.conv3(prob, edge_index)))
        # prob = torch.relu(self.bn4(self.conv4(prob, edge_index)))
        prob = torch.relu(self.bn5(self.conv5(prob, edge_index)))

        prob = torch.sigmoid(self.conv6(prob, edge_index))

        prob_dense, prob_mask = to_dense_batch(prob, batch)
        w_dense, w_mask = to_dense_batch(w, batch)
        gammas = w_dense.sum(dim=1)

        adj = to_dense_adj(edge_index, batch)

        loss_thresholds = self.calculate_loss_thresholds(
            w_dense, prob_dense, adj, gammas)

        loss = loss_thresholds.sum() / adj.size(0)

        mis = self.conditional_expectation(w_dense.detach(),
                                           prob_dense.detach(), adj,
                                           loss_thresholds.detach(),
                                           gammas.detach(), prob_mask.detach())

        return loss, mis
    def chamfer_loss(self, x, y, batch):
        x = to_dense_batch(x, batch)[0]
        y = to_dense_batch(y, batch)[0] 

        # https://github.com/zichunhao/mnist_graph_autoencoder/blob/master/utils/loss.py
        dist = pairwise_distance(x, y, self.device)

        min_dist_xy = torch.min(dist, dim=-1)
        min_dist_yx = torch.min(dist, dim=-2)  # Equivalent to permute the last two axis

        loss = torch.sum(min_dist_xy.values + min_dist_yx.values) / len(x)

        return loss
Beispiel #8
0
 def readout(self, atoms: Tensor, edge_index: Tensor, edge_ids: Tensor,
             word_pos: Tensor, word_batch: Tensor, word_ids: Tensor,
             word_starts: Tensor) -> Tuple[Tensor, Tensor]:
     node_reprs = self.base.contextualize_nodes(atoms, edge_index,
                                                edge_ids)[word_pos]
     words, ids = to_dense_batch(word_ids,
                                 word_batch,
                                 fill_value=self.word_encoder.pad_value)
     ctx = self.dropout(self.word_encoder(words)[ids][word_starts.eq(1)])
     ctx, _ = to_dense_batch(ctx, word_batch[word_starts.eq(1)])
     node_reprs, _ = to_dense_batch(node_reprs,
                                    word_batch[word_starts.eq(1)])
     return ctx, node_reprs
Beispiel #9
0
    def forward(self, batch_protein_tokenized,batch_chem_graphs, **kwargs):
        # ---------------protein embedding ready -------------
        if self.all_config['protein_descriptor']=='DISAE':
            if self.all_config['frozen'] == 'whole':
                with torch.no_grad():
                    batch_protein_repr = self.proteinEmbedding(batch_protein_tokenized)[0]
            else:
                batch_protein_repr = self.proteinEmbedding(batch_protein_tokenized)[0]

            batch_protein_repr_resnet = self.resnet(batch_protein_repr.unsqueeze(1)).reshape(self.all_config['batch_size'],1,-1)#(batch_size,1,256)

        # ---------------ligand embedding ready -------------
        node_representation = self.ligandEmbedding(batch_chem_graphs.x, batch_chem_graphs.edge_index,
                                                   batch_chem_graphs.edge_attr)
        batch_chem_graphs_repr_masked, mask_graph = to_dense_batch(node_representation, batch_chem_graphs.batch)
        batch_chem_graphs_repr_pooled = batch_chem_graphs_repr_masked.sum(axis=1).unsqueeze(1)  # (batch_size,1,300)
        # ---------------interaction embedding ready -------------
        ((chem_vector, chem_score), (prot_vector, prot_score)) = self.attentive_interaction_pooler(  batch_chem_graphs_repr_pooled,
                                                                                                     batch_protein_repr_resnet)  # same as input dimension


        interaction_vector = self.interaction_pooler(
            torch.cat((chem_vector.squeeze(), prot_vector.squeeze()), 1))  # (batch_size,64)
        logits = self.binary_predictor(interaction_vector)  # (batch_size,2)
        return logits
Beispiel #10
0
    def forward(self, batched_data, mask=None):

        x, edge_index, edge_attr, node_depth, batch = batched_data.x, batched_data.edge_index,  batched_data.edge_attr, batched_data.node_depth, batched_data.batch

        x = self.node_encoder(x, node_depth.view(-1,))

        x, mask = to_dense_batch(x, batch=batch)
        adj = to_dense_adj(edge_index, batch=batch)

        s = self.gnn1_pool(x, adj, mask)
        x = self.gnn1_embed(x, adj, mask)
        x, adj, l1, e1 = dense_diff_pool(x, adj, s, mask)

        s = self.gnn2_pool(x, adj)
        x = self.gnn2_embed(x, adj)
        x, adj, l2, e2 = dense_diff_pool(x, adj, s)

        x = self.gnn3_embed(x, adj)

        x = x.mean(dim=1)
        x = F.relu(self.lin1(x))
        # x = self.lin2(x)
        # return self.activation(x)  #, l1 + l2, e1 + e2

        pred_list = []
        for i in range(self.max_seq_len):
            pred_list.append(self.graph_pred_linear_list[i](x))

        return pred_list
Beispiel #11
0
    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        #print('x ', x.shape)
        #print('edge_index ', edge_index.shape)
        #print('edge_attr ', edge_attr.shape)
        #print('conv1 weight ', self.conv1.weight.shape)
        x = self.conv1(x, edge_index, edge_attr)
        #print('conv1 out x ', x.shape)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)

        x = self.conv2(x, edge_index, edge_attr)
        #print('conv1 out x ', x.shape)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        #print('conv2 out x ', x.shape)
        #转为普通1D
        x, mask = to_dense_batch(x, data.batch)
        x = x.transpose(1, 2)  # [batch_size, in_channels, num_nodes]
        #print('to_dense_batch out x ', x.shape)
        #展平
        #x = x.view(x.size(0), -1)
        x = x.reshape(x.size(0), x.size(1)*x.size(2))
        #print('layer2 in x ', x.shape)
        #x = self.conv2(x, edge_index, edge_attr)
        x = F.relu(self.layer3(x))
        #print('layer2 out x ', x.shape)
        x = F.relu(self.layer4(x))

        return x
Beispiel #12
0
    def _sparse_to_dense_input(self, data):

        x, edge_index, batch = data.x, data.edge_index, data.batch
        label = data.y
        edge_index = to_dense_adj(edge_index, batch)
        x, batch_num_node = to_dense_batch(x, batch)
        return x, edge_index, batch_num_node, label
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        if self.encode_edge:
            x = self.atom_encoder(x)
            x = self.conv1(x, edge_index, data.edge_attr)

        x, mask = to_dense_batch(x, batch=batch)
        adj = to_dense_adj(edge_index, batch=batch)

        x = self.initial_embed(x, adj, mask)

        x_all, l_total, e_total = [], 0, 0

        for i in range(self.num_pooling_layers):
            if i != 0:
                mask = None

            x, adj, l, e = self.diffpool_layers[i](
                x, adj,
                mask)  # x has shape (batch, MAX_no_nodes, feature_size)

            x = self.after_pool_layers[i](x, adj)

            l_total += l
            e_total += e

        x = torch.max(x, dim=1)[0]

        x = F.relu(self.lin1(x))
        x = self.lin2(x)
        return x, l_total, e_total
Beispiel #14
0
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        edge_weight = data.edge_weight

        x, mask = to_dense_batch(x, batch=batch)
        adj = to_dense_adj(edge_index, batch=batch, edge_attr=edge_weight)

        x_all, l_total, e_total = [], 0, 0

        for i in range(self.num_diffpool_layers):
            if i != 0:
                mask = None

            x, adj, l, e = self.diffpool_layers[i](
                x, adj,
                mask)  # x has shape (batch, MAX_no_nodes, feature_size)
            x_all.append(torch.max(x, dim=1)[0])

            l_total += l
            e_total += e

        x = self.final_embed(x, adj)
        x_all.append(torch.max(x, dim=1)[0])

        x = torch.cat(x_all,
                      dim=1)  # shape (batch, feature_size x diffpool layers)
        x = F.relu(self.lin1(x))
        x = self.lin2(x)
        return x
Beispiel #15
0
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x, mask = to_dense_batch(x, batch=batch)
        adj = to_dense_adj(edge_index, batch=batch)
        # data = ToDense(data.num_nodes)(data)
        # TODO describe mask shape and how batching works

        # adj, mask, x = data.adj, data.mask, data.x
        x_all, l_total, e_total = [], 0, 0

        for i in range(self.num_diffpool_layers):
            if i != 0:
                mask = None

            x, adj, l, e = self.diffpool_layers[i](
                x, adj,
                mask)  # x has shape (batch, MAX_no_nodes, feature_size)
            x_all.append(torch.max(x, dim=1)[0])

            l_total += l
            e_total += e

        x = self.final_embed(x, adj)
        x_all.append(torch.max(x, dim=1)[0])

        x = torch.cat(x_all,
                      dim=1)  # shape (batch, feature_size x diffpool layers)
        x = F.relu(self.lin1(x))
        x = self.lin2(x)
        return x, l_total, e_total
Beispiel #16
0
def reinforce_train_batch(
    model: nn.Module,
    baseline: nn.Module,
    optimizer: optim.Optimizer,
    batch: Batch,
    epoch: int,
    batch_id: int,
    step: int,
    env: TSPEnv,
    logger,
    args,
) -> None:
    batch = batch.to(args.device)
    node_pos = to_dense_batch(batch.pos, batch.batch)[0]
    log_p_s = []
    action_s = []
    reward_s = []
    done = False
    state = env.reset(node_pos)
    embed_data = model.init_embed(batch)
    node_embeddings, graph_feat = model.encoder(embed_data)
    fixed = model.precompute_fixed(node_embeddings, graph_feat)
    while not done:
        action, log_p = model(state, fixed)
        state, reward, done, _ = env.step(action)
        log_p_s.append(log_p)
        action_s.append(action)
        reward_s.append(reward)
    log_p = torch.stack(log_p_s, 1)
    a = torch.stack(action_s, 1)
    # Calculate policy's log_likelihood and reward
    log_likelihood = _calc_log_likelihood(log_p, a)
    # reward is a negative value of tour lenth
    # let baseline to predict positive value
    cost = -(reward_s[-1])
    bl_val, bl_loss = baseline.evaluate(batch, cost)
    rl_loss = ((cost - bl_val) * log_likelihood).mean()
    loss = rl_loss + bl_loss

    optimizer.zero_grad()
    loss.backward()
    grad_norms = clip_grad_norms(optimizer.param_groups, args.max_grad_norm)
    optimizer.step()

    # Logging
    if step % int(args.log_step) == 0:
        log_values(
            cost=cost,
            grad_norms=grad_norms,
            bl_val=bl_val,
            epoch=epoch,
            batch_id=batch_id,
            step=step,
            log_likelihood=log_likelihood,
            reinforce_loss=rl_loss,
            bl_loss=bl_loss,
            log_p=log_p,
            logger=logger,
            args=args,
        )
Beispiel #17
0
 def forward(self, x, edge_index):
     z = x
     for conv in self.convs[:-1]:
         z = self.relu(conv(z, edge_index))
     # if not self.variational:
     z = self.convs[-1](z, edge_index)
     if self.use_mincut:
         z_p, mask = to_dense_batch(z, None)
         adj = to_dense_adj(edge_index, None)
         s = self.pool1(z)
         # print(s.shape)
         # print(np.bincount(s.detach().argmax(1).numpy().flatten()))
         _, adj, mc1, o1 = dense_mincut_pool(z_p, adj, s, mask)
     output = dict()
     if self.variational:
         output['mu'], output['logvar'] = self.conv_mu(
             z, edge_index), self.conv_logvar(z, edge_index)
         output['z'] = self.reparametrize(output['mu'], output['logvar'])
         # output=[self.conv_mu(z,edge_index), self.conv_logvar(z,edge_index)]
     else:
         output['z'] = z
         # output=[z]
     if self.prediction_task:
         output['y'] = self.classification_layer(z)
     if self.use_mincut:
         output['s'] = s
         output['mc1'] = mc1
         output['o1'] = o1
         # output.extend([s, mc1, o1])
     elif self.activate_kmeans:
         s = self.kmeans(z)
         output['s'] = s
         # output.extend([s])
     return output
Beispiel #18
0
    def forward(self,
                Q,
                K,
                attention_mask=None,
                graph=None,
                return_attn=False):
        Q = self.fc_q(Q)

        # Adj: Exist (graph is not None), or Identity (else)
        if graph is not None:

            (x, edge_index, batch) = graph

            K, V = self.fc_k(x, edge_index), self.fc_v(x, edge_index)

            K, _ = to_dense_batch(K, batch)
            V, _ = to_dense_batch(V, batch)

        else:

            K, V = self.fc_k(K), self.fc_v(K)

        dim_split = self.dim_V // self.num_heads
        Q_ = torch.cat(Q.split(dim_split, 2), 0)
        K_ = torch.cat(K.split(dim_split, 2), 0)
        V_ = torch.cat(V.split(dim_split, 2), 0)

        if attention_mask is not None:
            attention_mask = torch.cat(
                [attention_mask for _ in range(self.num_heads)], 0)
            attention_score = Q_.bmm(K_.transpose(1, 2)) / math.sqrt(
                self.dim_V)
            A = torch.softmax(attention_mask + attention_score,
                              self.softmax_dim)
        else:
            A = torch.softmax(
                Q_.bmm(K_.transpose(1, 2)) / math.sqrt(self.dim_V),
                self.softmax_dim)

        O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
        O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
        O = O + F.relu(self.fc_o(O))
        O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
        if return_attn:
            return O, A
        else:
            return O
def test_to_dense_batch():
    x = torch.Tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]])
    batch = torch.tensor([0, 0, 1, 2, 2, 2])

    out, mask = to_dense_batch(x, batch)
    expected = [
        [[1, 2], [3, 4], [0, 0]],
        [[5, 6], [0, 0], [0, 0]],
        [[7, 8], [9, 10], [11, 12]],
    ]
    assert out.size() == (3, 3, 2)
    assert out.tolist() == expected
    assert mask.tolist() == [1, 1, 0, 1, 0, 0, 1, 1, 1]

    out = to_dense_batch(x)[0]
    assert out.size() == (1, 6, 2)
    assert out[0].tolist() == x.tolist()
Beispiel #20
0
 def __call__(self, scores, labels, batch_vec):
     """
     * the three input tensors have shape (N, ), N being the number of nodes in the batch
     * what makes possible to split values by query (i.e. graph) is the batch_vec vector, indicating which node
     belongs to which graph
     we want to compute all the pairwise contributions in the batch, dealing with:
     1. not mixing between graphs
     2. variable number of valid pairs between graphs (using masking)
     """
     ids_pos = labels == 1
     ids_neg = labels == 0
     batch_vec_pos = batch_vec[ids_pos]
     batch_vec_neg = batch_vec[ids_neg]
     pos_scores = scores[ids_pos]
     neg_scores = scores[ids_neg]
     # densify the tensors (see: https://rusty1s.github.io/pytorch_geometric/build/html/modules/utils.html?highlight=to_dense#torch_geometric.utils.to_dense_batch)
     dense_pos_scores, pos_mask = to_dense_batch(pos_scores,
                                                 batch_vec_pos,
                                                 fill_value=0)
     # dense_pos_scores has shape (nb_graphs, padding => max number nodes for graphs in batch)
     pos_len = torch.sum(
         pos_mask,
         dim=-1)  # shape (nb_graphs, ), actual number of nodes per graph
     dense_neg_scores, neg_mask = to_dense_batch(neg_scores,
                                                 batch_vec_neg,
                                                 fill_value=0)
     neg_len = torch.sum(neg_mask, dim=-1)
     max_pos_len = pos_len.max(
     )  # == the padding value for the positive scores
     max_neg_len = neg_len.max()
     pos_mask = masking(pos_len, max_pos_len.item())
     neg_mask = masking(neg_len, max_neg_len.item())
     diff_ = dense_pos_scores.view(
         -1, 1, dense_pos_scores.size(1)) - dense_neg_scores.view(
             -1, dense_neg_scores.size(1), 1)
     # now we use the mask and some reshaping to only extract the valid pair contributions:
     pos_mask_ = pos_mask.repeat(1, neg_mask.size(1))
     neg_mask_ = neg_mask.view(-1, neg_mask.size(1), 1).repeat(
         1, 1, pos_mask.size(1)).view(-1,
                                      neg_mask.size(1) * pos_mask.size(1))
     flattened_mask = (pos_mask_ * neg_mask_).view(-1).long()
     valid_diff_ = diff_.view(-1)[flattened_mask > 0]
     loss = self.compute_loss(valid_diff_)
     return loss
 def diffpool(self, abstract_features, edge_index, batch):
     """
     Making differentiable pooling.
     :param abstract_features: Node feature matrix.
     :param batch: Batch vector, which assigns each node to a specific example
     :return pooled_features: Graph feature matrix.
     """
     x, mask = to_dense_batch(abstract_features, batch)
     adj = to_dense_adj(edge_index, batch)
     return self.attention(x, adj, mask)
Beispiel #22
0
    def forward(
        self,
        Q: Tensor,
        K: Tensor,
        graph: Optional[Tuple[Tensor, Tensor, Tensor]] = None,
        mask: Optional[Tensor] = None,
    ) -> Tensor:

        Q = self.fc_q(Q)

        if graph is not None:
            x, edge_index, batch = graph
            K, V = self.layer_k(x, edge_index), self.layer_v(x, edge_index)
            K, _ = to_dense_batch(K, batch)
            V, _ = to_dense_batch(V, batch)
        else:
            K, V = self.layer_k(K), self.layer_v(K)

        dim_split = self.dim_V // self.num_heads
        Q_ = torch.cat(Q.split(dim_split, 2), dim=0)
        K_ = torch.cat(K.split(dim_split, 2), dim=0)
        V_ = torch.cat(V.split(dim_split, 2), dim=0)

        if mask is not None:
            mask = torch.cat([mask for _ in range(self.num_heads)], 0)
            attention_score = Q_.bmm(K_.transpose(1, 2))
            attention_score = attention_score / math.sqrt(self.dim_V)
            A = torch.softmax(mask + attention_score, 1)
        else:
            A = torch.softmax(
                Q_.bmm(K_.transpose(1, 2)) / math.sqrt(self.dim_V), 1)

        out = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)

        if self.layer_norm:
            out = self.ln0(out)

        out = out + self.fc_o(out).relu()

        if self.layer_norm:
            out = self.ln1(out)

        return out
    def forward(self, x, edge_index, batch, edge_attr, perturb=None):
        q0 = self._get_q0(batch, x, edge_index, edge_attr, perturb)
        q0, mask = to_dense_batch(q0, batch=batch)
        q0 = self.bn(q0.view(-1, q0.shape[-1])).view(*q0.size())

        q, kl_total = q0, 0
        for i, mem_layer in enumerate(self.mem_layers):
            q, kl = mem_layer(q, mask if i == 0 else None)
            kl_total += kl

        return self.mlp(q.mean(dim=-2)), kl_total / len(batch)
def test_to_dense_batch():
    x = torch.Tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]])
    batch = torch.tensor([0, 0, 1, 2, 2, 2])

    x, num_nodes = to_dense_batch(x, batch)
    expected = [
        [[1, 2], [3, 4], [0, 0]],
        [[5, 6], [0, 0], [0, 0]],
        [[7, 8], [9, 10], [11, 12]],
    ]
    assert x.tolist() == expected
    assert num_nodes.tolist() == [2, 1, 3]
Beispiel #25
0
 def forward(self, data):
     x, edge_index, batch = data.x, data.edge_index, data.batch
     x, mask = to_dense_batch(x, batch)
     adj = to_dense_adj(edge_index, batch)
     x = F.relu(self.gcn_1(x, adj, mask), True)
     x = F.relu(self.gcn_2(x, adj, mask), True)
     x, adj, l_lp, l_e = self.pooling(x, adj, mask)
     x = F.relu(self.gcn_3(x, adj))
     x = F.relu(self.gcn_4(x, adj))
     x = x.mean(1)
     logits = self.classifier(x)
     return logits, l_lp, l_e
Beispiel #26
0
    def forward(self, data):

        x, edge_index, batch = data.x, data.edge_index, data.batch

        # Encoder
        for _ in range(self.args.num_convs):

            x = F.relu(self.convs[_](x, edge_index))

        # Pooling
        for _index, _model_str in enumerate(self.model_sequence):

            if _index == 0:

                batch_x, mask = to_dense_batch(x, batch)

                extended_attention_mask = mask.unsqueeze(1)
                extended_attention_mask = extended_attention_mask.to(
                    dtype=next(self.parameters()).dtype)
                extended_attention_mask = (1.0 -
                                           extended_attention_mask) * -1e9

            if _model_str == 'GMPool_G':

                batch_x, attn = self.pools[_index](
                    batch_x,
                    attention_mask=extended_attention_mask,
                    graph=(x, edge_index, batch),
                    return_attn=True)

            else:

                batch_x, attn = self.pools[_index](
                    batch_x,
                    attention_mask=extended_attention_mask,
                    return_attn=True)

            extended_attention_mask = None

        # Decoder
        x = torch.bmm(attn.transpose(1, 2), batch_x)

        x = x[mask]

        for _ in range(self.args.num_unconvs):

            x = self.unconvs[_](x, edge_index)

            if _ < (self.args.num_unconvs - 1):
                x = F.relu(x)

        return x
Beispiel #27
0
 def aggregate(self, x_j, index):
     # `to_dense_batch` requires the `index` is sorted
     # TODO: is there any way to avoid `argsort`?
     ix = torch.argsort(index)
     index = index[ix]
     x_j = x_j[ix]
     dense_x, mask = to_dense_batch(x_j, index)
     out = x_j.new_zeros(dense_x.size(0), dense_x.size(-1))
     deg = mask.sum(dim=1)
     for i in deg.unique():
         deg_mask = deg == i
         out[deg_mask] = dense_x[deg_mask, :i].median(dim=1).values
     return out
Beispiel #28
0
    def forward(self, data, negative_data):

        x, edge_index, batch = data.x, data.edge_index, data.batch
        x_n, edge_index_n, batch_n = negative_data.x, negative_data.edge_index, negative_data.batch

        pos_z = self.encoder(data)
        neg_z = self.encoder(negative_data)

        #graph
        summary = global_mean_pool(pos_z, batch)

        graph_emb = self.outgc(summary)

        pos_z, mask = to_dense_batch(pos_z, batch=batch)
        neg_z, mask_n = to_dense_batch(neg_z, batch=batch_n)

        mask = mask.contiguous().view(pos_z.size(0) * pos_z.size(1), -1)
        mask_n = mask_n.contiguous().view(neg_z.size(0) * neg_z.size(1), -1)

        loss_val = self.loss(pos_z, neg_z, mask, mask_n, self.sigm(summary))

        return graph_emb, loss_val
Beispiel #29
0
def test_model(model, args, testset, pin_memory):
    model.eval()
    pred_ = []
    truth_ = []
    loss = 0.0
    with torch.no_grad():
        cn = 0
        for data in testset:
            data = data.to(args.device, non_blocking=pin_memory)
            pred, _, _ = model(data, args.adj)
            loss += func.mse_loss(data.y, pred, reduction="mean")
            pred, _ = to_dense_batch(pred, batch=data.batch)
            data.y, _ = to_dense_batch(data.y, batch=data.batch)
            pred_.append(pred.cpu().data.numpy())
            truth_.append(data.y.cpu().data.numpy())
            cn += 1
        loss = loss / cn
        args.logger.info("[*] loss:{:.4f}".format(loss))
        pred_ = np.concatenate(pred_, 0)
        truth_ = np.concatenate(truth_, 0)
        mae = metric(truth_, pred_, args)
        return loss
Beispiel #30
0
    def forward(self, x: Tensor, batch: Tensor,
                edge_index: Optional[Tensor] = None) -> Tensor:
        """"""
        x = self.lin1(x)
        batch_x, mask = to_dense_batch(x, batch)
        mask = (~mask).unsqueeze(1).to(dtype=x.dtype) * -1e9

        for i, (name, pool) in enumerate(zip(self.pool_sequences, self.pools)):
            graph = (x, edge_index, batch) if name == 'GMPool_G' else None
            batch_x = pool(batch_x, graph, mask)
            mask = None

        return self.lin2(batch_x.squeeze(1))