def test_to_dense_batch(): x = torch.Tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]]) batch = torch.tensor([0, 0, 1, 2, 2, 2]) out, mask = to_dense_batch(x, batch) expected = [ [[1, 2], [3, 4], [0, 0]], [[5, 6], [0, 0], [0, 0]], [[7, 8], [9, 10], [11, 12]], ] assert out.size() == (3, 3, 2) assert out.tolist() == expected assert mask.tolist() == [[1, 1, 0], [1, 0, 0], [1, 1, 1]] out, mask = to_dense_batch(x, batch, max_num_nodes=5) assert out.size() == (3, 5, 2) assert out[:, :3].tolist() == expected assert mask.tolist() == [[1, 1, 0, 0, 0], [1, 0, 0, 0, 0], [1, 1, 1, 0, 0]] out, mask = to_dense_batch(x) assert out.size() == (1, 6, 2) assert out[0].tolist() == x.tolist() assert mask.tolist() == [[1, 1, 1, 1, 1, 1]] out, mask = to_dense_batch(x, max_num_nodes=10) assert out.size() == (1, 10, 2) assert out[0, :6].tolist() == x.tolist() assert mask.tolist() == [[1, 1, 1, 1, 1, 1, 0, 0, 0, 0]] out, mask = to_dense_batch(x, batch, batch_size=4) assert out.size() == (4, 3, 2)
def batch_sparse(scores, labels, batch): """ method to convert "sparse" pyg vectors of scores and labels to dense ones """ batch_scores, _ = to_dense_batch(scores, batch, fill_value=-10e8) batch_labels, _ = to_dense_batch(labels, batch, fill_value=0) return batch_scores, batch_labels
def forward(self, inputs: MINDBatch): if is_precomputed(inputs['x_hist']): x_hist = inputs['x_hist'] else: x_hist = self.encoder.forward(inputs['x_hist']) x_hist, mask_hist = to_dense_batch(x_hist, inputs['batch_hist']) x_hist = self.self_attn.forward(x_hist, attn_mask=mask_hist)[0] # DistilBERT x_hist, _ = self.additive_attn(x_hist) if is_precomputed(inputs['x_cand']): x_cand = inputs['x_cand'] else: x_cand = self.encoder.forward(inputs['x_cand']) x_cand, mask_cand = to_dense_batch(x_cand, inputs['batch_cand']) logits = torch.bmm(x_hist.unsqueeze(1), x_cand.permute(0, 2, 1)).squeeze(1) logits = logits[mask_cand] targets = inputs['targets'] if targets is None: return logits if self.training: criterion = nn.CrossEntropyLoss() # criterion = LabelSmoothingCrossEntropy() loss = criterion(logits.reshape(targets.size(0), -1), targets) else: # In case of val, targets are multi label. It's not comparable with train. with torch.no_grad(): criterion = nn.BCEWithLogitsLoss() loss = criterion(logits, targets.float()) return loss, logits
def forward(self, data): x, edge_index, batch, num_graphs = data.x, data.edge_index, data.batch, data.num_graphs a_1 = F.relu(self.mp_a1(x, edge_index)) x_1 = F.relu(self.mp_x1(x, edge_index)) if self.skip: a_1 = torch.cat([a_1, x], dim=1) x_1 = torch.cat([x_1, x], dim=1) a_1 = self.linear_a(a_1) x_1 = self.linear_x(x_1) a_2 = self.mp_a2(a_1, edge_index) x_2 = F.relu(self.mp_x2(x_1, edge_index)) if self.skip: a_2 = torch.cat([a_2, a_1], dim=1) x_2 = torch.cat([x_2, x_1], dim=1) a_2 = softmax(a_2, batch) a_batch, _ = to_dense_batch(a_2, batch) a_t = a_batch.transpose(2, 1) x_batch, _ = to_dense_batch(x_2, batch) prods = torch.bmm(a_t, x_batch) flat = torch.flatten(prods, 1, -1) batch_out = self.linear2(flat) final = F.softmax(batch_out, dim=-1) return final
def calculate_histogram(self, abstract_features_1, abstract_features_2, batch_1, batch_2): """ Calculate histogram from similarity matrix. :param abstract_features_1: Feature matrix for target graphs. :param abstract_features_2: Feature matrix for source graphs. :param batch_1: Batch vector for source graphs, which assigns each node to a specific example :param batch_1: Batch vector for target graphs, which assigns each node to a specific example :return hist: Histsogram of similarity scores. """ abstract_features_1, mask_1 = to_dense_batch(abstract_features_1, batch_1) abstract_features_2, mask_2 = to_dense_batch(abstract_features_2, batch_2) B1, N1, _ = abstract_features_1.size() B2, N2, _ = abstract_features_2.size() mask_1 = mask_1.view(B1, N1) mask_2 = mask_2.view(B2, N2) num_nodes = torch.max(mask_1.sum(dim=1), mask_2.sum(dim=1)) scores = torch.matmul(abstract_features_1, abstract_features_2.permute([0,2,1])).detach() hist_list = [] for i, mat in enumerate(scores): mat = torch.sigmoid(mat[:num_nodes[i], :num_nodes[i]]).view(-1) hist = torch.histc(mat, bins=self.args.bins) hist = hist/torch.sum(hist) hist = hist.view(1, -1) hist_list.append(hist) return torch.stack(hist_list).view(-1, self.args.bins)
def forward(self, w, edge_index, batch): prob = torch.relu(self.bn1(self.conv1(w.unsqueeze(1), edge_index))) prob = torch.relu(self.bn2(self.conv2(prob, edge_index))) prob = torch.relu(self.bn3(self.conv3(prob, edge_index))) # prob = torch.relu(self.bn4(self.conv4(prob, edge_index))) prob = torch.relu(self.bn5(self.conv5(prob, edge_index))) prob = torch.sigmoid(self.conv6(prob, edge_index)) prob_dense, prob_mask = to_dense_batch(prob, batch) w_dense, w_mask = to_dense_batch(w, batch) gammas = w_dense.sum(dim=1) adj = to_dense_adj(edge_index, batch) loss_thresholds = self.calculate_loss_thresholds( w_dense, prob_dense, adj, gammas) loss = loss_thresholds.sum() / adj.size(0) mis = self.conditional_expectation(w_dense.detach(), prob_dense.detach(), adj, loss_thresholds.detach(), gammas.detach(), prob_mask.detach()) return loss, mis
def chamfer_loss(self, x, y, batch): x = to_dense_batch(x, batch)[0] y = to_dense_batch(y, batch)[0] # https://github.com/zichunhao/mnist_graph_autoencoder/blob/master/utils/loss.py dist = pairwise_distance(x, y, self.device) min_dist_xy = torch.min(dist, dim=-1) min_dist_yx = torch.min(dist, dim=-2) # Equivalent to permute the last two axis loss = torch.sum(min_dist_xy.values + min_dist_yx.values) / len(x) return loss
def readout(self, atoms: Tensor, edge_index: Tensor, edge_ids: Tensor, word_pos: Tensor, word_batch: Tensor, word_ids: Tensor, word_starts: Tensor) -> Tuple[Tensor, Tensor]: node_reprs = self.base.contextualize_nodes(atoms, edge_index, edge_ids)[word_pos] words, ids = to_dense_batch(word_ids, word_batch, fill_value=self.word_encoder.pad_value) ctx = self.dropout(self.word_encoder(words)[ids][word_starts.eq(1)]) ctx, _ = to_dense_batch(ctx, word_batch[word_starts.eq(1)]) node_reprs, _ = to_dense_batch(node_reprs, word_batch[word_starts.eq(1)]) return ctx, node_reprs
def forward(self, batch_protein_tokenized,batch_chem_graphs, **kwargs): # ---------------protein embedding ready ------------- if self.all_config['protein_descriptor']=='DISAE': if self.all_config['frozen'] == 'whole': with torch.no_grad(): batch_protein_repr = self.proteinEmbedding(batch_protein_tokenized)[0] else: batch_protein_repr = self.proteinEmbedding(batch_protein_tokenized)[0] batch_protein_repr_resnet = self.resnet(batch_protein_repr.unsqueeze(1)).reshape(self.all_config['batch_size'],1,-1)#(batch_size,1,256) # ---------------ligand embedding ready ------------- node_representation = self.ligandEmbedding(batch_chem_graphs.x, batch_chem_graphs.edge_index, batch_chem_graphs.edge_attr) batch_chem_graphs_repr_masked, mask_graph = to_dense_batch(node_representation, batch_chem_graphs.batch) batch_chem_graphs_repr_pooled = batch_chem_graphs_repr_masked.sum(axis=1).unsqueeze(1) # (batch_size,1,300) # ---------------interaction embedding ready ------------- ((chem_vector, chem_score), (prot_vector, prot_score)) = self.attentive_interaction_pooler( batch_chem_graphs_repr_pooled, batch_protein_repr_resnet) # same as input dimension interaction_vector = self.interaction_pooler( torch.cat((chem_vector.squeeze(), prot_vector.squeeze()), 1)) # (batch_size,64) logits = self.binary_predictor(interaction_vector) # (batch_size,2) return logits
def forward(self, batched_data, mask=None): x, edge_index, edge_attr, node_depth, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.node_depth, batched_data.batch x = self.node_encoder(x, node_depth.view(-1,)) x, mask = to_dense_batch(x, batch=batch) adj = to_dense_adj(edge_index, batch=batch) s = self.gnn1_pool(x, adj, mask) x = self.gnn1_embed(x, adj, mask) x, adj, l1, e1 = dense_diff_pool(x, adj, s, mask) s = self.gnn2_pool(x, adj) x = self.gnn2_embed(x, adj) x, adj, l2, e2 = dense_diff_pool(x, adj, s) x = self.gnn3_embed(x, adj) x = x.mean(dim=1) x = F.relu(self.lin1(x)) # x = self.lin2(x) # return self.activation(x) #, l1 + l2, e1 + e2 pred_list = [] for i in range(self.max_seq_len): pred_list.append(self.graph_pred_linear_list[i](x)) return pred_list
def forward(self, data): x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr #print('x ', x.shape) #print('edge_index ', edge_index.shape) #print('edge_attr ', edge_attr.shape) #print('conv1 weight ', self.conv1.weight.shape) x = self.conv1(x, edge_index, edge_attr) #print('conv1 out x ', x.shape) x = F.relu(x) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index, edge_attr) #print('conv1 out x ', x.shape) x = F.relu(x) x = F.dropout(x, training=self.training) #print('conv2 out x ', x.shape) #转为普通1D x, mask = to_dense_batch(x, data.batch) x = x.transpose(1, 2) # [batch_size, in_channels, num_nodes] #print('to_dense_batch out x ', x.shape) #展平 #x = x.view(x.size(0), -1) x = x.reshape(x.size(0), x.size(1)*x.size(2)) #print('layer2 in x ', x.shape) #x = self.conv2(x, edge_index, edge_attr) x = F.relu(self.layer3(x)) #print('layer2 out x ', x.shape) x = F.relu(self.layer4(x)) return x
def _sparse_to_dense_input(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch label = data.y edge_index = to_dense_adj(edge_index, batch) x, batch_num_node = to_dense_batch(x, batch) return x, edge_index, batch_num_node, label
def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch if self.encode_edge: x = self.atom_encoder(x) x = self.conv1(x, edge_index, data.edge_attr) x, mask = to_dense_batch(x, batch=batch) adj = to_dense_adj(edge_index, batch=batch) x = self.initial_embed(x, adj, mask) x_all, l_total, e_total = [], 0, 0 for i in range(self.num_pooling_layers): if i != 0: mask = None x, adj, l, e = self.diffpool_layers[i]( x, adj, mask) # x has shape (batch, MAX_no_nodes, feature_size) x = self.after_pool_layers[i](x, adj) l_total += l e_total += e x = torch.max(x, dim=1)[0] x = F.relu(self.lin1(x)) x = self.lin2(x) return x, l_total, e_total
def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch edge_weight = data.edge_weight x, mask = to_dense_batch(x, batch=batch) adj = to_dense_adj(edge_index, batch=batch, edge_attr=edge_weight) x_all, l_total, e_total = [], 0, 0 for i in range(self.num_diffpool_layers): if i != 0: mask = None x, adj, l, e = self.diffpool_layers[i]( x, adj, mask) # x has shape (batch, MAX_no_nodes, feature_size) x_all.append(torch.max(x, dim=1)[0]) l_total += l e_total += e x = self.final_embed(x, adj) x_all.append(torch.max(x, dim=1)[0]) x = torch.cat(x_all, dim=1) # shape (batch, feature_size x diffpool layers) x = F.relu(self.lin1(x)) x = self.lin2(x) return x
def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x, mask = to_dense_batch(x, batch=batch) adj = to_dense_adj(edge_index, batch=batch) # data = ToDense(data.num_nodes)(data) # TODO describe mask shape and how batching works # adj, mask, x = data.adj, data.mask, data.x x_all, l_total, e_total = [], 0, 0 for i in range(self.num_diffpool_layers): if i != 0: mask = None x, adj, l, e = self.diffpool_layers[i]( x, adj, mask) # x has shape (batch, MAX_no_nodes, feature_size) x_all.append(torch.max(x, dim=1)[0]) l_total += l e_total += e x = self.final_embed(x, adj) x_all.append(torch.max(x, dim=1)[0]) x = torch.cat(x_all, dim=1) # shape (batch, feature_size x diffpool layers) x = F.relu(self.lin1(x)) x = self.lin2(x) return x, l_total, e_total
def reinforce_train_batch( model: nn.Module, baseline: nn.Module, optimizer: optim.Optimizer, batch: Batch, epoch: int, batch_id: int, step: int, env: TSPEnv, logger, args, ) -> None: batch = batch.to(args.device) node_pos = to_dense_batch(batch.pos, batch.batch)[0] log_p_s = [] action_s = [] reward_s = [] done = False state = env.reset(node_pos) embed_data = model.init_embed(batch) node_embeddings, graph_feat = model.encoder(embed_data) fixed = model.precompute_fixed(node_embeddings, graph_feat) while not done: action, log_p = model(state, fixed) state, reward, done, _ = env.step(action) log_p_s.append(log_p) action_s.append(action) reward_s.append(reward) log_p = torch.stack(log_p_s, 1) a = torch.stack(action_s, 1) # Calculate policy's log_likelihood and reward log_likelihood = _calc_log_likelihood(log_p, a) # reward is a negative value of tour lenth # let baseline to predict positive value cost = -(reward_s[-1]) bl_val, bl_loss = baseline.evaluate(batch, cost) rl_loss = ((cost - bl_val) * log_likelihood).mean() loss = rl_loss + bl_loss optimizer.zero_grad() loss.backward() grad_norms = clip_grad_norms(optimizer.param_groups, args.max_grad_norm) optimizer.step() # Logging if step % int(args.log_step) == 0: log_values( cost=cost, grad_norms=grad_norms, bl_val=bl_val, epoch=epoch, batch_id=batch_id, step=step, log_likelihood=log_likelihood, reinforce_loss=rl_loss, bl_loss=bl_loss, log_p=log_p, logger=logger, args=args, )
def forward(self, x, edge_index): z = x for conv in self.convs[:-1]: z = self.relu(conv(z, edge_index)) # if not self.variational: z = self.convs[-1](z, edge_index) if self.use_mincut: z_p, mask = to_dense_batch(z, None) adj = to_dense_adj(edge_index, None) s = self.pool1(z) # print(s.shape) # print(np.bincount(s.detach().argmax(1).numpy().flatten())) _, adj, mc1, o1 = dense_mincut_pool(z_p, adj, s, mask) output = dict() if self.variational: output['mu'], output['logvar'] = self.conv_mu( z, edge_index), self.conv_logvar(z, edge_index) output['z'] = self.reparametrize(output['mu'], output['logvar']) # output=[self.conv_mu(z,edge_index), self.conv_logvar(z,edge_index)] else: output['z'] = z # output=[z] if self.prediction_task: output['y'] = self.classification_layer(z) if self.use_mincut: output['s'] = s output['mc1'] = mc1 output['o1'] = o1 # output.extend([s, mc1, o1]) elif self.activate_kmeans: s = self.kmeans(z) output['s'] = s # output.extend([s]) return output
def forward(self, Q, K, attention_mask=None, graph=None, return_attn=False): Q = self.fc_q(Q) # Adj: Exist (graph is not None), or Identity (else) if graph is not None: (x, edge_index, batch) = graph K, V = self.fc_k(x, edge_index), self.fc_v(x, edge_index) K, _ = to_dense_batch(K, batch) V, _ = to_dense_batch(V, batch) else: K, V = self.fc_k(K), self.fc_v(K) dim_split = self.dim_V // self.num_heads Q_ = torch.cat(Q.split(dim_split, 2), 0) K_ = torch.cat(K.split(dim_split, 2), 0) V_ = torch.cat(V.split(dim_split, 2), 0) if attention_mask is not None: attention_mask = torch.cat( [attention_mask for _ in range(self.num_heads)], 0) attention_score = Q_.bmm(K_.transpose(1, 2)) / math.sqrt( self.dim_V) A = torch.softmax(attention_mask + attention_score, self.softmax_dim) else: A = torch.softmax( Q_.bmm(K_.transpose(1, 2)) / math.sqrt(self.dim_V), self.softmax_dim) O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2) O = O if getattr(self, 'ln0', None) is None else self.ln0(O) O = O + F.relu(self.fc_o(O)) O = O if getattr(self, 'ln1', None) is None else self.ln1(O) if return_attn: return O, A else: return O
def test_to_dense_batch(): x = torch.Tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]]) batch = torch.tensor([0, 0, 1, 2, 2, 2]) out, mask = to_dense_batch(x, batch) expected = [ [[1, 2], [3, 4], [0, 0]], [[5, 6], [0, 0], [0, 0]], [[7, 8], [9, 10], [11, 12]], ] assert out.size() == (3, 3, 2) assert out.tolist() == expected assert mask.tolist() == [1, 1, 0, 1, 0, 0, 1, 1, 1] out = to_dense_batch(x)[0] assert out.size() == (1, 6, 2) assert out[0].tolist() == x.tolist()
def __call__(self, scores, labels, batch_vec): """ * the three input tensors have shape (N, ), N being the number of nodes in the batch * what makes possible to split values by query (i.e. graph) is the batch_vec vector, indicating which node belongs to which graph we want to compute all the pairwise contributions in the batch, dealing with: 1. not mixing between graphs 2. variable number of valid pairs between graphs (using masking) """ ids_pos = labels == 1 ids_neg = labels == 0 batch_vec_pos = batch_vec[ids_pos] batch_vec_neg = batch_vec[ids_neg] pos_scores = scores[ids_pos] neg_scores = scores[ids_neg] # densify the tensors (see: https://rusty1s.github.io/pytorch_geometric/build/html/modules/utils.html?highlight=to_dense#torch_geometric.utils.to_dense_batch) dense_pos_scores, pos_mask = to_dense_batch(pos_scores, batch_vec_pos, fill_value=0) # dense_pos_scores has shape (nb_graphs, padding => max number nodes for graphs in batch) pos_len = torch.sum( pos_mask, dim=-1) # shape (nb_graphs, ), actual number of nodes per graph dense_neg_scores, neg_mask = to_dense_batch(neg_scores, batch_vec_neg, fill_value=0) neg_len = torch.sum(neg_mask, dim=-1) max_pos_len = pos_len.max( ) # == the padding value for the positive scores max_neg_len = neg_len.max() pos_mask = masking(pos_len, max_pos_len.item()) neg_mask = masking(neg_len, max_neg_len.item()) diff_ = dense_pos_scores.view( -1, 1, dense_pos_scores.size(1)) - dense_neg_scores.view( -1, dense_neg_scores.size(1), 1) # now we use the mask and some reshaping to only extract the valid pair contributions: pos_mask_ = pos_mask.repeat(1, neg_mask.size(1)) neg_mask_ = neg_mask.view(-1, neg_mask.size(1), 1).repeat( 1, 1, pos_mask.size(1)).view(-1, neg_mask.size(1) * pos_mask.size(1)) flattened_mask = (pos_mask_ * neg_mask_).view(-1).long() valid_diff_ = diff_.view(-1)[flattened_mask > 0] loss = self.compute_loss(valid_diff_) return loss
def diffpool(self, abstract_features, edge_index, batch): """ Making differentiable pooling. :param abstract_features: Node feature matrix. :param batch: Batch vector, which assigns each node to a specific example :return pooled_features: Graph feature matrix. """ x, mask = to_dense_batch(abstract_features, batch) adj = to_dense_adj(edge_index, batch) return self.attention(x, adj, mask)
def forward( self, Q: Tensor, K: Tensor, graph: Optional[Tuple[Tensor, Tensor, Tensor]] = None, mask: Optional[Tensor] = None, ) -> Tensor: Q = self.fc_q(Q) if graph is not None: x, edge_index, batch = graph K, V = self.layer_k(x, edge_index), self.layer_v(x, edge_index) K, _ = to_dense_batch(K, batch) V, _ = to_dense_batch(V, batch) else: K, V = self.layer_k(K), self.layer_v(K) dim_split = self.dim_V // self.num_heads Q_ = torch.cat(Q.split(dim_split, 2), dim=0) K_ = torch.cat(K.split(dim_split, 2), dim=0) V_ = torch.cat(V.split(dim_split, 2), dim=0) if mask is not None: mask = torch.cat([mask for _ in range(self.num_heads)], 0) attention_score = Q_.bmm(K_.transpose(1, 2)) attention_score = attention_score / math.sqrt(self.dim_V) A = torch.softmax(mask + attention_score, 1) else: A = torch.softmax( Q_.bmm(K_.transpose(1, 2)) / math.sqrt(self.dim_V), 1) out = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2) if self.layer_norm: out = self.ln0(out) out = out + self.fc_o(out).relu() if self.layer_norm: out = self.ln1(out) return out
def forward(self, x, edge_index, batch, edge_attr, perturb=None): q0 = self._get_q0(batch, x, edge_index, edge_attr, perturb) q0, mask = to_dense_batch(q0, batch=batch) q0 = self.bn(q0.view(-1, q0.shape[-1])).view(*q0.size()) q, kl_total = q0, 0 for i, mem_layer in enumerate(self.mem_layers): q, kl = mem_layer(q, mask if i == 0 else None) kl_total += kl return self.mlp(q.mean(dim=-2)), kl_total / len(batch)
def test_to_dense_batch(): x = torch.Tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]]) batch = torch.tensor([0, 0, 1, 2, 2, 2]) x, num_nodes = to_dense_batch(x, batch) expected = [ [[1, 2], [3, 4], [0, 0]], [[5, 6], [0, 0], [0, 0]], [[7, 8], [9, 10], [11, 12]], ] assert x.tolist() == expected assert num_nodes.tolist() == [2, 1, 3]
def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x, mask = to_dense_batch(x, batch) adj = to_dense_adj(edge_index, batch) x = F.relu(self.gcn_1(x, adj, mask), True) x = F.relu(self.gcn_2(x, adj, mask), True) x, adj, l_lp, l_e = self.pooling(x, adj, mask) x = F.relu(self.gcn_3(x, adj)) x = F.relu(self.gcn_4(x, adj)) x = x.mean(1) logits = self.classifier(x) return logits, l_lp, l_e
def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch # Encoder for _ in range(self.args.num_convs): x = F.relu(self.convs[_](x, edge_index)) # Pooling for _index, _model_str in enumerate(self.model_sequence): if _index == 0: batch_x, mask = to_dense_batch(x, batch) extended_attention_mask = mask.unsqueeze(1) extended_attention_mask = extended_attention_mask.to( dtype=next(self.parameters()).dtype) extended_attention_mask = (1.0 - extended_attention_mask) * -1e9 if _model_str == 'GMPool_G': batch_x, attn = self.pools[_index]( batch_x, attention_mask=extended_attention_mask, graph=(x, edge_index, batch), return_attn=True) else: batch_x, attn = self.pools[_index]( batch_x, attention_mask=extended_attention_mask, return_attn=True) extended_attention_mask = None # Decoder x = torch.bmm(attn.transpose(1, 2), batch_x) x = x[mask] for _ in range(self.args.num_unconvs): x = self.unconvs[_](x, edge_index) if _ < (self.args.num_unconvs - 1): x = F.relu(x) return x
def aggregate(self, x_j, index): # `to_dense_batch` requires the `index` is sorted # TODO: is there any way to avoid `argsort`? ix = torch.argsort(index) index = index[ix] x_j = x_j[ix] dense_x, mask = to_dense_batch(x_j, index) out = x_j.new_zeros(dense_x.size(0), dense_x.size(-1)) deg = mask.sum(dim=1) for i in deg.unique(): deg_mask = deg == i out[deg_mask] = dense_x[deg_mask, :i].median(dim=1).values return out
def forward(self, data, negative_data): x, edge_index, batch = data.x, data.edge_index, data.batch x_n, edge_index_n, batch_n = negative_data.x, negative_data.edge_index, negative_data.batch pos_z = self.encoder(data) neg_z = self.encoder(negative_data) #graph summary = global_mean_pool(pos_z, batch) graph_emb = self.outgc(summary) pos_z, mask = to_dense_batch(pos_z, batch=batch) neg_z, mask_n = to_dense_batch(neg_z, batch=batch_n) mask = mask.contiguous().view(pos_z.size(0) * pos_z.size(1), -1) mask_n = mask_n.contiguous().view(neg_z.size(0) * neg_z.size(1), -1) loss_val = self.loss(pos_z, neg_z, mask, mask_n, self.sigm(summary)) return graph_emb, loss_val
def test_model(model, args, testset, pin_memory): model.eval() pred_ = [] truth_ = [] loss = 0.0 with torch.no_grad(): cn = 0 for data in testset: data = data.to(args.device, non_blocking=pin_memory) pred, _, _ = model(data, args.adj) loss += func.mse_loss(data.y, pred, reduction="mean") pred, _ = to_dense_batch(pred, batch=data.batch) data.y, _ = to_dense_batch(data.y, batch=data.batch) pred_.append(pred.cpu().data.numpy()) truth_.append(data.y.cpu().data.numpy()) cn += 1 loss = loss / cn args.logger.info("[*] loss:{:.4f}".format(loss)) pred_ = np.concatenate(pred_, 0) truth_ = np.concatenate(truth_, 0) mae = metric(truth_, pred_, args) return loss
def forward(self, x: Tensor, batch: Tensor, edge_index: Optional[Tensor] = None) -> Tensor: """""" x = self.lin1(x) batch_x, mask = to_dense_batch(x, batch) mask = (~mask).unsqueeze(1).to(dtype=x.dtype) * -1e9 for i, (name, pool) in enumerate(zip(self.pool_sequences, self.pools)): graph = (x, edge_index, batch) if name == 'GMPool_G' else None batch_x = pool(batch_x, graph, mask) mask = None return self.lin2(batch_x.squeeze(1))