def __call__(self, data): if len(data.x) < self.min_nodes: return data connectivity = self.to_dense(data).adj clustering = AgglomerativeClustering( n_clusters=None, distance_threshold=self.distance_threshold, connectivity=connectivity) labels = clustering.fit_predict(data.x) labels = torch.from_numpy(labels) data.x = scatter_mean(data.x, labels, dim=0) data.pos = scatter_mean(data.pos, labels, dim=0) if data.pos is not None else None if data.edge_index is not None: new_edges = [] edges = data.edge_index.T for edge in edges: if labels[edge[0]] != labels[edge[1]]: new_edges.append((labels[edge[0]], labels[edge[1]])) data.edge_index = torch.from_numpy(np.unique(edges, axis=0)).T return data
def __call__(self, img): from skimage.segmentation import slic img = img.permute(1, 2, 0) h, w, c = img.size() seg = slic(img.to(torch.double).numpy(), start_label=0, **self.kwargs) seg = torch.from_numpy(seg) x = scatter_mean(img.view(h * w, c), seg.view(h * w), dim=0) pos_y = torch.arange(h, dtype=torch.float) pos_y = pos_y.view(-1, 1).repeat(1, w).view(h * w) pos_x = torch.arange(w, dtype=torch.float) pos_x = pos_x.view(1, -1).repeat(h, 1).view(h * w) pos = torch.stack([pos_x, pos_y], dim=-1) pos = scatter_mean(pos, seg.view(h * w), dim=0) data = Data(x=x, pos=pos) if self.add_seg: data.seg = seg.view(1, h, w) if self.add_img: data.img = img.permute(2, 0, 1).view(1, c, h, w) return data
def forward(self, x_1, x_2, edge_index_pos, edge_index_neg): """ Forward propagation pass with features an indices. :param x_1: Features for left hand side vertices. :param x_2: Features for right hand side vertices. :param edge_index_pos: Positive indices. :param edge_index_neg: Negative indices. :return out: Abstract convolved features. """ edge_index_pos, _ = remove_self_loops(edge_index_pos, None) edge_index_pos, _ = add_self_loops(edge_index_pos, num_nodes=x_1.size(0)) edge_index_neg, _ = remove_self_loops(edge_index_neg, None) edge_index_neg, _ = add_self_loops(edge_index_neg, num_nodes=x_2.size(0)) row_pos, col_pos = edge_index_pos row_neg, col_neg = edge_index_neg if self.norm: # pos: [x_1:balance features; x_2:unbalance features] neg :[x_1:unbalance features; x_2:balance features] out_1 = scatter_mean(x_1[col_pos], row_pos, dim=0, dim_size=x_1.size(0)) out_2 = scatter_mean(x_2[col_neg], row_neg, dim=0, dim_size=x_2.size(0)) else: out_1 = scatter_add(x_1[col_pos], row_pos, dim=0, dim_size=x_1.size(0)) out_2 = scatter_add(x_2[col_neg], row_neg, dim=0, dim_size=x_2.size(0)) out = torch.cat((out_1, out_2, x_1), 1) out = torch.matmul(out, self.weight) if self.bias is not None: out = out + self.bias if self.norm_embed: out = F.normalize(out, p=2, dim=-1) return out
def forward(self, x, edge_index, edge_attr, h, u, batch): """ Global Update of Graph Net Layer @param x: [N x n_outc], where N is the number of nodes. @param edge_index: [2 x E] with max entry N - 1. @param edge_attr: [E x e_outc] @param h: [B x hidden] @param u: [B x u_inc] @param batch: [N] with max entry B - 1. @return: a [B x u_outc] torch tensor """ row, col = edge_index edge_batch = batch[ row] # edge_batch is same as batch in EdgeModel.forward(). Shape: [E] per_batch_edge_aggregations = scatter_mean( edge_attr, edge_batch, dim=0) # Shape: [B x e_outc] per_batch_node_aggregations = scatter_mean( x, batch, dim=0) # Shape: [B x n_outc] x = torch.cat( [u, per_batch_node_aggregations, per_batch_edge_aggregations], dim=1) # Shape: [B x (u_inc + n_outc + e_outc)] x = F.relu(self.in1(self.fc1(x))) x = x.unsqueeze(1) h = h.unsqueeze(0) x, h = self.gru(x, h) x = x.squeeze() h = x.squeeze() x = self.fc2(x) return x, h
def forward(self, x, edge_index, edge_attr): #print('weight : ', torch.sum(self.weight)) row, col = edge_index num_node = len(x) edge_attr = edge_attr.unsqueeze( -1) if edge_attr.dim() == 1 else edge_attr # create edge feature by concatenating node feature alpha = torch.cat([x[row], x[col]], dim=-1) # multiply the edge features with the fliter alpha = torch.mm(alpha, self.weight) # multiply each edge features with the corresponding dist alpha = edge_attr * alpha # scatter the resulting edge feature to get node features out = torch.zeros(num_node, self.out_channels).to(alpha.device) out = scatter_mean(alpha, row, dim=0, out=out) out = scatter_mean(alpha, col, dim=0, out=out) # add the bias if self.bias is not None: out = out + self.bias return out
def forward(self, data): data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr)) data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr)) data.x = F.elu(self.conv3(data.x, data.edge_index, data.edge_attr)) x = data.x x_1 = scatter_mean(data.x, data.batch, dim=0) data.x = avg_pool(x, data.assignment_2) data.x = torch.cat([data.x, data.iso_type_2], dim=1) data.x = F.elu(self.conv4(data.x, data.edge_index_2)) data.x = F.elu(self.conv5(data.x, data.edge_index_2)) x_2 = scatter_mean(data.x, data.batch_2, dim=0) data.x = avg_pool(x, data.assignment_3) data.x = torch.cat([data.x, data.iso_type_3], dim=1) data.x = F.elu(self.conv6(data.x, data.edge_index_3)) data.x = F.elu(self.conv7(data.x, data.edge_index_3)) x_3 = scatter_mean(data.x, data.batch_3, dim=0) x = torch.cat([x_1, x_2, x_3], dim=1) x = F.elu(self.fc1(x)) x = F.elu(self.fc2(x)) x = self.fc3(x) return x.view(-1)
def forward(self, data): data.x = F.elu(self.conv1(data.x, data.edge_index)) data.x = F.elu(self.conv2(data.x, data.edge_index)) data.x = F.elu(self.conv3(data.x, data.edge_index)) x = data.x x_1 = scatter_add(data.x, data.batch, dim=0) data.x = avg_pool(x, data.assignment_index_2) data.x = torch.cat([data.x, data.iso_type_2], dim=1) data.x = F.elu(self.conv4(data.x, data.edge_index_2)) data.x = F.elu(self.conv5(data.x, data.edge_index_2)) x_2 = scatter_mean(data.x, data.batch_2, dim=0) data.x = avg_pool(x, data.assignment_index_3) data.x = torch.cat([data.x, data.iso_type_3], dim=1) data.x = F.elu(self.conv6(data.x, data.edge_index_3)) data.x = F.elu(self.conv7(data.x, data.edge_index_3)) x_3 = scatter_mean(data.x, data.batch_3, dim=0) x = torch.cat([x_1, x_2, x_3], dim=1) x = F.elu(self.fc1(x)) x = F.dropout(x, p=0.5, training=self.training) x = F.elu(self.fc2(x)) x = self.fc3(x) if self.args.type == "regression": return x.view(-1) elif self.args.type == "TU": return F.log_softmax(x, dim=1) else: return self.sigmoid(x).view(-1)
def forward(self, vertices): # This function computes the node values # vertices has shape [n_data, in_channels, n_nodes] # edge_index has shape [2, E], top indicating the source and bottom indicating the dest vertices = vertices.reshape((vertices.shape[0], 1, self.n_nodes, 11)) v_features = self.x_lin(vertices).squeeze() msgs = self.compute_msgs(v_features, len(vertices)) msgs = msgs.repeat((1, 1, 2)) new_vertex = torch_scatter.scatter_mean(msgs, self.dest_edges, dim=-1) new_vertex = new_vertex[:, None, :, :] ##### msg passing n_msg_passing = 5 for i in range(n_msg_passing): vertices_after_first_round = self.x_lin_after_first_round(new_vertex).squeeze() msgs = self.compute_msgs(vertices_after_first_round, len(vertices)) msgs = msgs.repeat((1, 1, 2)) residual = torch_scatter.scatter_mean(msgs, self.dest_edges, dim=-1) residual = residual[:, None, :, :] new_vertex = new_vertex + residual ##### end of msg passing # Final round of output # new_vertex = new_vertex, 1, new_vertex.shape[1], new_vertex.shape[2])) final_vertex_output = self.vertex_output_lin(new_vertex).squeeze() graph_output = self.graph_output_lin(final_vertex_output) return graph_output
def get_aggregation(aggregation): """ Factory dictionary for aggregation depending on the hparams["aggregation"] """ aggregation_dict = { "sum": lambda e, end, x: scatter_add(e, end, dim=0, dim_size=x.shape[0]), "mean": lambda e, end, x: scatter_mean(e, end, dim=0, dim_size=x.shape[0]), "max": lambda e, end, x: scatter_max(e, end, dim=0, dim_size=x.shape[0])[0], "sum_max": lambda e, end, x: torch.cat( [ scatter_max(e, end, dim=0, dim_size=x.shape[0])[0], scatter_add(e, end, dim=0, dim_size=x.shape[0]), ], dim=-1, ), "mean_sum": lambda e, end, x: torch.cat( [ scatter_mean(e, end, dim=0, dim_size=x.shape[0]), scatter_add(e, end, dim=0, dim_size=x.shape[0]), ], dim=-1, ), "mean_max": lambda e, end, x: torch.cat( [ scatter_max(e, end, dim=0, dim_size=x.shape[0])[0], scatter_mean(e, end, dim=0, dim_size=x.shape[0]), ], dim=-1, ), } return aggregation_dict[aggregation]
def forward(self, data): edge_index, x_ids = data.edge_index, data.node_ids x = self.embed(x_ids) # print(f"X[0]: {x.shape}") hidden_reps = [x] for i in range(self.num_layers - 1): x = self.gins[i](x, edge_index) x = self.batch_norms[i](x) x = self.relu(x) # print(f"X[{i + 1}]: {x.shape}") hidden_reps.append(x) score_over_layer = 0 for layer, h in enumerate(hidden_reps): score_over_layer += self.dp(self.linears_prediction[layer](h)) if self.aggr == 'mean': score_over_layer = scatter_mean(score_over_layer, data.batch, dim=0) if self.aggr == 'max': score_over_layer, _ = scatter_mean(score_over_layer, data.batch, dim=0) if self.aggr == 'sum': score_over_layer = scatter_mean(score_over_layer, data.batch, dim=0) return score_over_layer
def forward(self, x, edge_index, edge_attr): row, col = edge_index num_node = len(x) edge_attr = edge_attr.unsqueeze( -1) if edge_attr.dim() == 1 else edge_attr # create edge feature by concatenating node feature alpha = torch.cat([x[row], x[col]], dim=-1) # multiply the edge features with the fliter alpha = torch.mm(alpha, self.weight) # multiply each edge features with the corresponding dist alpha = edge_attr * alpha # scatter the resulting edge feature to get node features out = torch.zeros(num_node, self.out_channels).to(alpha.device) out = scatter_mean(alpha, row, dim=0, out=out) # if the graph is undirected and (i,j) and (j,i) are both in # the edge_index then we do not need to have that second line # or we count everything twice if not self.undirected: out = scatter_mean(alpha, col, dim=0, out=out) # add the bias if self.bias is not None: out = out + self.bias return out
def forward(self, konf, pose, noise): robot_curr_pose = pose[:, -4:] robot_curr_pose_expanded = robot_curr_pose.unsqueeze(1).repeat( (1, 618, 1)).unsqueeze(-1) concat = torch.cat([robot_curr_pose_expanded, konf], dim=2) concat = concat.reshape((concat.shape[0], concat.shape[-1], concat.shape[1], concat.shape[2])) features = self.first_features(concat) features = features.squeeze() paired_feature_values = features[:, :, self.edges[1]] indices_of_neighboring_nodes = self.edges[0] vertex_values = torch_scatter.scatter_mean( paired_feature_values, indices_of_neighboring_nodes, dim=-1) n_passes = 2 for _ in range(n_passes): vertex_values = vertex_values.view( (vertex_values.shape[0], vertex_values.shape[1], vertex_values.shape[2], 1)) vertex_values = self.features(vertex_values).squeeze() paired_feature_values = vertex_values[:, :, self.edges[1]] vertex_values = torch_scatter.scatter_mean( paired_feature_values, indices_of_neighboring_nodes, dim=-1) features = features.view( (features.shape[0], features.shape[1], features.shape[2], 1)) features = torch.nn.MaxPool2d(kernel_size=(2, 1))(features) features = torch.nn.MaxPool2d(kernel_size=(2, 1))(features) features = features.view( (features.shape[0], features.shape[1] * features.shape[2])) value = self.value(features) return value
def get_vertex_activations(self, vertices): ### First round of vertex feature computation # v = np.concatenate([prm_vertices, q0s, qgs, self.collisions[idx]], axis=-1) v_features = self.get_vertex_features(vertices) collisions = vertices[:, :, 6:] collisions = collisions.permute((0, 2, 1)) v_features = torch.cat((v_features, collisions), 1) ############## msgs = self.compute_msgs(v_features, len(vertices)) msgs = msgs.repeat((1, 1, 2)) new_vertex = torch_scatter.scatter_mean(msgs, self.dest_edges, dim=-1) new_vertex = new_vertex[:, None, :, :] ##### msg passing for i in range(self.n_msg_passing): vertices_after_first_round = self.x_lin_after_first_round( new_vertex).squeeze(dim=2) vertices_after_first_round = torch.cat( (vertices_after_first_round, collisions), 1) msgs = self.compute_msgs(vertices_after_first_round, len(vertices)) msgs = msgs.repeat((1, 1, 2)) residual = torch_scatter.scatter_mean(msgs, self.dest_edges, dim=-1) residual = residual[:, None, :, :] new_vertex = new_vertex + residual ##### end of msg passing final_vertex_output = self.vertex_output_lin(new_vertex).squeeze() n_data = len(vertices) if n_data == 1: final_vertex_output = final_vertex_output[None, :] return final_vertex_output
def forward(self, data): data.x = F.elu(self.conv1(data.x, data.edge_index)) data.x = F.elu(self.conv2(data.x, data.edge_index)) data.x = F.elu(self.conv3(data.x, data.edge_index)) x = data.x x_1 = scatter_mean(data.x, data.batch, dim=0) data.x = avg_pool(x, data.assignment_index_2) data.x = torch.cat([data.x, data.iso_type_2], dim=1) data.x = F.elu(self.conv4(data.x, data.edge_index_2)) data.x = F.elu(self.conv5(data.x, data.edge_index_2)) x_2 = scatter_mean(data.x, data.batch_2, dim=0) data.x = avg_pool(x, data.assignment_index_3) data.x = torch.cat([data.x, data.iso_type_3], dim=1) data.x = F.elu(self.conv6(data.x, data.edge_index_3)) data.x = F.elu(self.conv7(data.x, data.edge_index_3)) x_3 = scatter_mean(data.x, data.batch_3, dim=0) x = torch.cat([x_1, x_2, x_3], dim=1) if args.no_train: x = x.detach() x = F.elu(self.fc1(x)) x = F.dropout(x, p=0.5, training=self.training) x = F.elu(self.fc2(x)) x = self.fc3(x) return F.log_softmax(x, dim=1)
def forward(self, x, edge_index, edge_attr, u, batch): """ Global Update of Graph Net Layer @param x: [N x n_outc], where N is the number of nodes. @param edge_index: [2 x E] with max entry N - 1. @param edge_attr: [E x e_outc] @param u: [B x u_inc] @param batch: [N] with max entry B - 1. @return: a [B x u_outc] torch tensor """ row, col = edge_index edge_batch = batch[ row] # edge_batch is same as batch in EdgeModel.forward(). Shape: [E] per_batch_edge_aggregations = scatter_mean( edge_attr, edge_batch, dim=0) # Shape: [B x e_outc] per_batch_node_aggregations = scatter_mean( x, batch, dim=0) # Shape: [B x n_outc] out = torch.cat( [u, per_batch_node_aggregations, per_batch_edge_aggregations], dim=1) # Shape: [B x (u_inc + n_outc + e_outc)] return self.global_mlp(out)
def explain_single_layer(self, to_explain, index=None, name=None): # todo: deal with special case when previous layer has not been explained # preparing variables required for computing LRP layer = self.model.get_layer(index=index, name=name) rule = self.model.get_rule(index=index, layer_name=name) if rule == "z": rule = LRP.z_rule elif rule == "eps": rule = LRP.eps_rule else: # default to use epsilon rule if provided rule name not supported rule = LRP.eps_rule if name is None: name = self.model.index2name(index) if index is None: index = self.model.name2index(name) input = to_explain['A'][name] R = to_explain["R"][index + 1] if name in self.model.special_layers: n_tracks = to_explain["inputs"]["x"].shape[0] row, col = to_explain["inputs"]["edge_index"] if "node_mlp_2.3" in name: R = R.repeat(n_tracks, 1) / n_tracks elif "node_mlp_1.3" in name: r_x, r_ = R[:, :48], R[:, 48:] R = r_[col] / (n_tracks - 1) to_explain["R"]["r_x"] = r_x elif "edge_mlp.3" in name: r_x_row, r_ = R[:, :48], R[:, 48:] R = r_ to_explain["R"]["r_x_row"] = r_x_row elif "bn" in name: r_src, r_dest = R[:, :48], R[:, 48:] to_explain["R"]['r_src'] = r_src to_explain["R"]['r_dest'] = r_dest # aggregate r_x_src = scatter_mean(r_src, row, dim=0, dim_size=n_tracks) r_x_dest = scatter_mean(r_dest, col, dim=0, dim_size=n_tracks) r_x = to_explain['R']['r_x'] r_x_row = to_explain['R']['r_x_row'] R = (r_x_src + r_x_dest + r_x + scatter_mean(r_x_row, row, dim=0, dim_size=n_tracks) + 1e-10) else: pass # backward pass with specified LRP rule # print(name) R = rule(layer, input, R) # store result to_explain["R"][index] = R
def forward(self, data): # print("---------------------------------------") # print(data.x.shape, data.BU_edge_index.shape) # print("---------------------------------------") """ x, edge_index = data.x, data.BU_edge_index x1 = copy.copy(x.float()) x = self.conv1(x, edge_index) x2 = copy.copy(x) rootindex = data.rootindex root_extend = th.zeros(len(data.batch), x1.size(1)).to(device) batch_size = max(data.batch) + 1 for num_batch in range(batch_size): index = (th.eq(data.batch, num_batch)) root_extend[index] = x1[rootindex[num_batch]] x = th.cat((x, root_extend), 1) x = F.relu(x) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) x = F.relu(x) root_extend = th.zeros(len(data.batch), x2.size(1)).to(device) for num_batch in range(batch_size): index = (th.eq(data.batch, num_batch)) root_extend[index] = x2[rootindex[num_batch]] x = th.cat((x, root_extend), 1) x = scatter_mean(x, data.batch, dim=0) """ x, edge_index = data.x, data.BU_edge_index x1 = copy.copy(x.float()) rootindex = data.rootindex root_extend = th.zeros(len(data.batch), x1.size(1)).to(device) batch_size = max(data.batch) + 1 for num_batch in range(batch_size): index = (th.eq(data.batch, num_batch)) root_extend[index] = x1[rootindex[num_batch]] # x = th.cat((x, root_extend), 1) x = self.conv1(x, edge_index) x = F.relu(x) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) x = F.relu(x) # x = scatter_mean(x, data.batch, dim=0) # print(x.shape, root_extend.shape) # HERE # x = th.cat((x, root_extend), 1) x = scatter_mean(x, data.batch, dim=0) # GCN - mean root_extend = scatter_mean(root_extend, data.batch, dim=0) # print(x.shape, root_extend.shape) return x, root_extend
def forward(self, x_h, x_g, edge_index, edge_attr, u, batch_h, batch_g): out = torch.cat([ u, scatter_mean(x_h, batch_h, dim=0), scatter_mean(x_g, batch_g, dim=0) ], dim=1) return self.global_mlp(out)
def multi_hop(self, triple_prob, distance, head, tail, concept_label, triple_label, gamma=0.8, iteration=3, method="avg"): ''' triple_prob: bsz x L x mem_t distance: bsz x mem head, tail: bsz x mem_t concept_label: bsz x mem triple_label: bsz x mem_t Init binary vector with source concept == 1 and others 0 expand to size: bsz x L x mem ''' concept_probs = [] cpt_size = (triple_prob.size(0), triple_prob.size(1), distance.size(1)) init_mask = torch.zeros_like(distance).unsqueeze(1).expand(*cpt_size).to(distance.device).float() init_mask.masked_fill_((distance == 0).unsqueeze(1), 1) final_mask = init_mask.clone() init_mask.masked_fill_((concept_label == -1).unsqueeze(1), 0) concept_probs.append(init_mask) head = head.unsqueeze(1).expand(triple_prob.size(0), triple_prob.size(1), -1) tail = tail.unsqueeze(1).expand(triple_prob.size(0), triple_prob.size(1), -1) for step in range(iteration): ''' Calculate triple head score ''' node_score = concept_probs[-1] triple_head_score = node_score.gather(2, head) triple_head_score.masked_fill_((triple_label == -1).unsqueeze(1), 0) ''' Method: - avg: s(v) = Avg_{u \in N(v)} gamma * s(u) + R(u->v) - max: s(v) = max_{u \in N(v)} gamma * s(u) + R(u->v) ''' update_value = triple_head_score * gamma + triple_prob out = torch.zeros_like(node_score).to(node_score.device).float() if method == "max": scatter_max(update_value, tail, dim=-1, out=out) elif method == "avg": scatter_mean(update_value, tail, dim=-1, out=out) out.masked_fill_((concept_label == -1).unsqueeze(1), 0) concept_probs.append(out) ''' Natural decay of concept that is multi-hop away from source ''' total_concept_prob = final_mask * -1e5 for prob in concept_probs[1:]: total_concept_prob += prob # bsz x L x mem return total_concept_prob
def forward(self, x, edge_index, edge_attr, u, batch): """ """ row, col = edge_index e_batch = batch[row] e_agg = scatter_mean(edge_attr, e_batch, dim=0) x_agg = scatter_mean(x, batch, dim=0) out = torch.cat([x_agg, e_agg, u], 1) return self.phi_u(out)
def forward(self, x, edge_index, edge_attr, u, batch): row, col = edge_index edge_info = scatter_mean(edge_attr, batch[row], dim=0, dim_size=u.size(0)) node_info = scatter_mean(x, batch, dim=0, dim_size=u.size(0)) out = torch.cat([u, node_info, edge_info], dim=1) return self.global_mlp(out)
def forward(self, data): f1 = F.relu(self.conv1(data.adj, data.input)) f1 = F.relu(self.conv12(data.adj, f1)) batch = Variable(data.batch.view(-1, 1).expand(data.batch.size(0), 64)) f1_res = scatter_mean(batch, f1) size, start = self.pool_args(32, 8) data2, _ = sparse_voxel_max_pool( data, size, start, transform, weight=f1[:, 0]) f2 = F.relu(self.conv2(data2.adj, data2.input)) f2 = F.relu(self.conv22(data2.adj, f2)) batch = Variable( data2.batch.view(-1, 1).expand(data2.batch.size(0), 64)) f2_res = scatter_mean(batch, f2) size, start = self.pool_args(16, 4) data3, _ = sparse_voxel_max_pool( data2, size, start, transform, weight=f2[:, 0]) f3 = F.relu(self.conv3(data3.adj, data3.input)) f3 = F.relu(self.conv32(data3.adj, f3)) batch = Variable( data3.batch.view(-1, 1).expand(data3.batch.size(0), 64)) f3_res = scatter_mean(batch, f3) size, start = self.pool_args(8, 2) data4, _ = sparse_voxel_max_pool( data3, size, start, transform, weight=f3[:, 0]) f4 = F.relu(self.conv4(data4.adj, data4.input)) f4 = F.relu(self.conv42(data4.adj, f4)) batch = Variable( data4.batch.view(-1, 1).expand(data4.batch.size(0), 64)) f4_res = scatter_mean(batch, f4) size, start = self.pool_args(4, 1) data5, _ = sparse_voxel_max_pool( data4, size, start, transform, weight=f4[:, 0]) f5 = F.relu(self.conv5(data5.adj, data5.input)) f5 = F.relu(self.conv52(data5.adj, f5)) batch = Variable( data5.batch.view(-1, 1).expand(data5.batch.size(0), 64)) f5_res = scatter_mean(batch, f5) x = torch.cat([f1_res, f2_res, f3_res, f4_res, f5_res], dim=1) # data.input = F.relu(self.conv5(data.adj, data.input)) # data, _ = dense_voxel_max_pool(data, 1, -0.5, 1.5) x = x.view(-1, self.fc1.weight.size(1)) x = F.dropout(x, training=self.training) x = self.fc1(x) return F.log_softmax(x, dim=1)
def global_model(self, x, edge_index, edge_attr, u, batch): # x: [N, F_x], where N is the number of nodes. # edge_index: [2, E] with max entry N - 1. # edge_attr: [E, F_e] # u: [B, F_u] # batch: [N] with max entry B - 1. row, _ = edge_index edge_mean = scatter_mean(edge_attr, batch[row], dim=0) out = torch.cat([u, scatter_mean(x, batch, dim=0), edge_mean], dim=1) out = self.global_msg(out) return out
def forward(self, x, edge_index, edge_attr, u, batch): """ """ row, col = edge_index # compute the batch index for all edges e_batch = batch[row] # aggregate all edges in the graph e_agg = scatter_mean(edge_attr, e_batch, dim=0) # aggregate all nodes in the graph x_agg = scatter_mean(x, batch, dim=0) out = torch.cat([x_agg, e_agg, u], 1) return self.phi_u(out)
def global_model(x, edge_attr, u, v_indices, e_indices): if self.independent: return self.global_mlp(u) out = torch.cat( [ u, scatter_mean(x, v_indices, dim=0), scatter_mean(edge_attr, e_indices, dim=0), ], dim=1, ) return self.global_mlp(out)
def forward(self, x: Tensor, batch: Optional[Tensor] = None) -> Tensor: """""" if batch is None: batch = x.new_zeros(x.size(0), dtype=torch.long) batch_size = int(batch.max()) + 1 mean = scatter_mean(x, batch, dim=0, dim_size=batch_size) out = x - mean.index_select(0, batch) * self.mean_scale var = scatter_mean(out.pow(2), batch, dim=0, dim_size=batch_size) std = (var + self.eps).sqrt().index_select(0, batch) return self.weight * out / std + self.bias
def forward(self, data): """Given a :obj:`data` batch, computes the forward pass. Args: data (torch_geometric.data.Data): The input data, holding subject :obj:`sub`, relation :obj:`rel` and object :obj:`obj` information with shape :obj:`[batch_size]`. In addition, :obj:`data` needs to hold history information for subjects, given by a vector of node indices :obj:`h_sub` and their relative timestamps :obj:`h_sub_t` and batch assignments :obj:`h_sub_batch`. The same information must be given for objects (:obj:`h_obj`, :obj:`h_obj_t`, :obj:`h_obj_batch`). """ assert 'h_sub_batch' in data and 'h_obj_batch' in data batch_size, seq_len = data.sub.size(0), self.seq_len h_sub_t = data.h_sub_t + data.h_sub_batch * seq_len h_obj_t = data.h_obj_t + data.h_obj_batch * seq_len h_sub = scatter_mean(self.ent[data.h_sub], h_sub_t, dim=0, dim_size=batch_size * seq_len).view( batch_size, seq_len, -1) h_obj = scatter_mean(self.ent[data.h_obj], h_obj_t, dim=0, dim_size=batch_size * seq_len).view( batch_size, seq_len, -1) sub = self.ent[data.sub].unsqueeze(1).repeat(1, seq_len, 1) rel = self.rel[data.rel].unsqueeze(1).repeat(1, seq_len, 1) obj = self.ent[data.obj].unsqueeze(1).repeat(1, seq_len, 1) _, h_sub = self.sub_gru(torch.cat([sub, h_sub, rel], dim=-1)) _, h_obj = self.obj_gru(torch.cat([obj, h_obj, rel], dim=-1)) h_sub, h_obj = h_sub.squeeze(0), h_obj.squeeze(0) h_sub = torch.cat([self.ent[data.sub], h_sub, self.rel[data.rel]], dim=-1) h_obj = torch.cat([self.ent[data.obj], h_obj, self.rel[data.rel]], dim=-1) h_sub = F.dropout(h_sub, p=self.dropout, training=self.training) h_obj = F.dropout(h_obj, p=self.dropout, training=self.training) log_prob_obj = F.log_softmax(self.sub_lin(h_sub), dim=1) log_prob_sub = F.log_softmax(self.obj_lin(h_obj), dim=1) return log_prob_obj, log_prob_sub
def forward(self, batch_data): if self.graph_level_feature: ### Use rdkit_2d_normalized_features as input graph-level feature x, edge_index, edge_attr = batch_data.x, batch_data.edge_index, batch_data.edge_attr row, col = edge_index u = scatter_mean(edge_attr, batch_data.batch[row], dim=0, dim_size=max(batch_data.batch) + 1) aug_feat = batch_data.graph_attr if len(aug_feat.shape) != 2: aug_feat = torch.reshape(aug_feat, (-1, self.num_global_features)) else: x, edge_index, edge_attr = batch_data.x, batch_data.edge_index, batch_data.edge_attr row, col = edge_index u = scatter_mean(edge_attr, batch_data.batch[row], dim=0, dim_size=max(batch_data.batch) + 1) x = self.mlp_node(x) edge_attr = self.mlp_edge(edge_attr) u = self.mlp_global(u) row, col = edge_index ori_batch = batch_data.batch x = self.norm_node[-1](x, ori_batch) edge_attr = self.norm_edge[-1](edge_attr, ori_batch[row]) x = self.bn_node[-1](x) edge_attr = self.bn_edge[-1](edge_attr) u = self.bn_global[-1](u) for i in range(self.depth): x, edge_attr, u = self.gn[i](x, edge_index, edge_attr, u, ori_batch) x = self.norm_node[i](x, batch_data.batch) edge_attr = self.norm_edge[i](edge_attr, batch_data.batch[row]) x = self.bn_node[i](x) edge_attr = self.bn_edge[i](edge_attr) u = self.bn_global[i](u) if self.graph_level_feature: u = torch.cat([u, aug_feat], dim=1) out = self.mlp1(u) return out
def message_step(self, x, start, end, e): # Compute new node features if self.hparams["aggregation"] == "sum": edge_messages = scatter_add(e, end, dim=0, dim_size=x.shape[0]) elif self.hparams["aggregation"] == "mean": edge_messages = scatter_mean(e, end, dim=0, dim_size=x.shape[0]) elif self.hparams["aggregation"] == "max": edge_messages = scatter_max(e, end, dim=0, dim_size=x.shape[0])[0] elif self.hparams["aggregation"] == "sum_max": edge_messages = torch.cat( [ scatter_max(e, end, dim=0, dim_size=x.shape[0])[0], scatter_add(e, end, dim=0, dim_size=x.shape[0]), ], dim=-1, ) elif self.hparams["aggregation"] == "mean_sum": edge_messages = torch.cat( [ scatter_mean(e, end, dim=0, dim_size=x.shape[0]), scatter_add(e, end, dim=0, dim_size=x.shape[0]), ], dim=-1, ) elif self.hparams["aggregation"] == "mean_max": edge_messages = torch.cat( [ scatter_max(e, end, dim=0, dim_size=x.shape[0])[0], scatter_mean(e, end, dim=0, dim_size=x.shape[0]), ], dim=-1, ) node_inputs = torch.cat([x, edge_messages], dim=-1) x_out = self.node_network(node_inputs) x_out += x # Compute new edge features edge_inputs = torch.cat([x_out[start], x_out[end], e], dim=-1) e_out = self.edge_network(edge_inputs) e_out += e return x_out, e_out
def forward(self, x, edge_index): row, col = edge_index mean = scatter_mean(x[col], row, dim=0, dim_size=x.size(0)) mean = torch.mean(mean, dim=-1, keepdim=True) var = scatter_mean((x[col] - mean[row])**2, row, dim=0, dim_size=x.size(0)) var = torch.mean(var, dim=-1, keepdim=True) # std = scatter_std(x[col], row, dim=0, dim_size=x.size(0)) out = (x[col] - mean[row]) / (var[row] + self.eps).sqrt() # out = (x[col] - mean[row]) / (std[row]**2 + self.eps).sqrt() out = self.gamma * out + self.beta return out