def forward(self, x, edge_index, edge_attr, u, batch): # x: [N, F_x], where N is the number of nodes. # edge_index: [2, E] with max entry N - 1. # edge_attr: [E, F_e] # u: [B, F_u] # batch: [N] with max entry B - 1. row, col = edge_index out = torch.cat([x[col], edge_attr], dim=1) out = self.node_mlp_1(out) if self.aggregation == "mean": out = scatter_mean(out, row, dim=0, dim_size=x.size(0)) elif self.aggregation == "min": out, _ = scatter_min(out, row, dim=0, dim_size=x.size(0)) elif self.aggregation == "max": out, _ = scatter_max(out, row, dim=0, dim_size=x.size(0)) elif self.aggregation == "minmax": out = torch.cat([ scatter_min(out, row, dim=0, dim_size=x.size(0))[0], scatter_max(out, row, dim=0, dim_size=x.size(0))[0] ], dim=1) else: raise ValueError("Unknown aggregation type: {}".format(self.aggregation)) out = torch.cat([x, out, u[batch]], dim=1) return self.node_mlp_2(out)
def assign_edge_labels(self): """ Assigns self.graph_obj edge labels (tensor with shape (num_edges,)), with labels defined according to the network flow MOT formulation """ ids = torch.as_tensor(self.graph_df.id.values, device=self.graph_obj.edge_index.device) per_edge_ids = torch.stack([ ids[self.graph_obj.edge_index[0]], ids[self.graph_obj.edge_index[1]] ]) same_id = (per_edge_ids[0] == per_edge_ids[1]) & (per_edge_ids[0] != -1) same_ids_ixs = torch.where(same_id) same_id_edges = self.graph_obj.edge_index.T[same_id].T time_dists = torch.abs(same_id_edges[0] - same_id_edges[1]) # For every node, we get the index of the node in the future (resp. past) with the same id that is closest in time future_mask = same_id_edges[0] < same_id_edges[1] active_fut_edges = scatter_min(time_dists[future_mask], same_id_edges[0][future_mask], dim=0, dim_size=self.graph_obj.num_nodes)[1] original_node_ixs = torch.cat( (same_id_edges[1][future_mask], torch.as_tensor([-1], device=same_id.device) )) # -1 at the end for nodes that were not present active_fut_edges = original_node_ixs[ active_fut_edges] # Recover the node id of the corresponding fut_edge_is_active = active_fut_edges[ same_id_edges[0]] == same_id_edges[1] # Analogous for past edges past_mask = same_id_edges[0] > same_id_edges[1] active_past_edges = scatter_min(time_dists[past_mask], same_id_edges[0][past_mask], dim=0, dim_size=self.graph_obj.num_nodes)[1] original_node_ixs = torch.cat( (same_id_edges[1][past_mask], torch.as_tensor([-1], device=same_id.device) )) # -1 at the end for nodes that were not present active_past_edges = original_node_ixs[active_past_edges] past_edge_is_active = active_past_edges[ same_id_edges[0]] == same_id_edges[1] # Recover the ixs of active edges in the original edge_index tensor o active_edge_ixs = same_ids_ixs[0][past_edge_is_active | fut_edge_is_active] self.graph_obj.edge_labels = torch.zeros_like(same_id, dtype=torch.float) self.graph_obj.edge_labels[active_edge_ixs] = 1 self.graph_obj.tracking_id = ids
def scatter_distribution(src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, unbiased: bool = True) -> torch.Tensor: if out is not None: dim_size = out.size(dim) if dim < 0: dim = src.dim() + dim count_dim = dim if index.dim() <= dim: count_dim = index.dim() - 1 ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) count = scatter_sum(ones, index, count_dim, dim_size=dim_size) index = broadcast(index, src, dim) tmp = scatter_sum(src, index, dim, dim_size=dim_size) summ = tmp.clone() count = broadcast(count, tmp, dim).clamp(1) mean = tmp.div(count) var = (src - mean.gather(dim, index)) var = var * var var = scatter_sum(var, index, dim, out, dim_size) if unbiased: count = count.sub(1).clamp_(1) var = var.div(count) maximum = scatter_max(src, index, dim, out, dim_size)[0] minimum = scatter_min(src, index, dim, out, dim_size)[0] return torch.cat([summ,mean,var,maximum,minimum],dim=1)
def forward(self, data): data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr)) weight = normalized_cut_2d(data.edge_index, data.pos) cluster = graclus(data.edge_index, weight, data.x.size(0)) data.edge_attr = None data = max_pool(cluster, data, transform=transform) data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr)) weight = normalized_cut_2d(data.edge_index, data.pos) cluster = graclus(data.edge_index, weight, data.x.size(0)) x, batch = max_pool_x(cluster, data.x, data.batch) #x = global_mean_pool(x, batch) x_min = torch_scatter.scatter_min(x, batch, dim=0)[0] gather_idxs = batch.expand(x.shape[1], -1).t() gather_mins = torch.gather(x_min, 0, gather_idxs) s = F.relu(-gather_mins) x = x + s x = self.aggregator(x, batch) s_out = self.aggregator(s, batch) x = x - s_out x = F.elu(self.fc1(x)) x = F.dropout(x, training=self.training) return F.log_softmax(self.fc2(x), dim=1)
def take_action_deterministic_batch_dqn(target_net, player, batch_instances): with torch.no_grad(): # We compute the target values batch = batch_instances.G_torch.batch mask_values = batch_instances.J.eq(0)[:, 0] action_values = target_net( batch_instances.G_torch, batch_instances.n_nodes, batch_instances.Omegas, batch_instances.Phis, batch_instances.Lambdas, batch_instances.Omegas_norm, batch_instances.Phis_norm, batch_instances.Lambdas_norm, batch_instances.J, ) action_values = action_values[mask_values] batch = batch[mask_values] # if it's the turn of the attacker if player == 1: # we take the argmin values, actions = scatter_min(action_values, batch, dim=0) else: # we take the argmax values, actions = scatter_max(action_values, batch, dim=0) return actions.view(-1).tolist()
def forward(self, x, edge_index, edge_attr, u, batch): # x: [N, F_x], where N is the number of nodes. # edge_index: [2, E] with max entry N - 1. # edge_attr: [E, F_e] # u: [B, F_u] # batch: [N] with max entry B - 1. row, col = edge_index # define interaction tensor; every pair contains features from input and # output node together with #out = torch.cat([x[row], x[col], edge_attr], dim=1) out = torch.cat([x[row], x[col]], dim=1) #print("node pre", x.shape, out.shape) # take interaction feature tensor and embedd it into another tensor #out = self.node_mlp_1(out) out = self.mlp(out) #print("node mlp", out.shape) # compute the mean,sum and max of each embed feature tensor for each node out1 = scatter_mean(out, col, dim=0, dim_size=x.size(0)) out3 = scatter_max(out, col, dim=0, dim_size=x.size(0))[0] out4 = scatter_min(out, col, dim=0, dim_size=x.size(0))[0] # every node contains a feature tensor with the pooling of the messages from # neighbors, its own state, and a global feature out = torch.cat([x, out1, out3, out4, u[batch]], dim=1) #print("node post", out.shape) #return self.node_mlp_2(out) return out
def __call__(self, data, norm=True): row, col = data.edge_index N = data.num_nodes deg = degree(row, N, dtype=torch.float) if norm: deg = deg / deg.max() deg_col = deg[col] min_deg, _ = scatter_min(deg_col, row, dim_size=N) min_deg[min_deg > 10000] = 0 max_deg, _ = scatter_max(deg_col, row, dim_size=N) max_deg[max_deg < -10000] = 0 mean_deg = scatter_mean(deg_col, row, dim_size=N) std_deg = scatter_std(deg_col, row, dim_size=N) x = torch.stack([deg, min_deg, max_deg, mean_deg, std_deg], dim=1) if data.x is not None: data.x = data.x.view(-1, 1) if data.x.dim() == 1 else data.x data.x = torch.cat([data.x, x], dim=-1) else: data.x = x return data
def aggregate(self, inputs: Tensor, index: Tensor, dim_size: Optional[int] = None) -> Tensor: if self.aggr == 'softmax': out = scatter_softmax(inputs * self.t, index, dim=self.node_dim) return scatter(inputs * out, index, dim=self.node_dim, dim_size=dim_size, reduce='sum') elif self.aggr == 'softmax_sg': out = scatter_softmax(inputs * self.t, index, dim=self.node_dim).detach() return scatter(inputs * out, index, dim=self.node_dim, dim_size=dim_size, reduce='sum') elif self.aggr == 'stat': _mean = scatter_mean(inputs, index, dim=self.node_dim, dim_size=dim_size) _std = scatter_std(inputs, index, dim=self.node_dim, dim_size=dim_size).detach() _min = scatter_min(inputs, index, dim=self.node_dim, dim_size=dim_size)[0] _max = scatter_max(inputs, index, dim=self.node_dim, dim_size=dim_size)[0] _mean = _mean.unsqueeze(dim=-1) _std = _std.unsqueeze(dim=-1) _min = _min.unsqueeze(dim=-1) _max = _max.unsqueeze(dim=-1) stat = torch.cat([_mean, _std, _min, _max], dim=-1) stat = self.lin_stat(stat) stat = stat.squeeze(dim=-1) return stat else: min_value, max_value = 1e-7, 1e1 torch.clamp_(inputs, min_value, max_value) out = scatter(torch.pow(inputs, self.p), index, dim=self.node_dim, dim_size=dim_size, reduce='mean') torch.clamp_(out, min_value, max_value) return torch.pow(out, 1 / self.p)
def test_min_fill_value(): src = torch.Tensor([[-2, 0, -1, -4, -3], [0, -2, -1, -3, -4]]) index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]) out, _ = scatter_min(src, index) v = torch.finfo(torch.float).max assert out.tolist() == [[v, v, -4, -3, -2, 0], [-2, -4, -3, v, v, v]]
def update_time_(node_time_dict, index, node_type, num_nodes): node_time_dict[node_type] = node_time_dict[node_type].clone() node_time, _ = scatter_min(edge_label_time, index, dim=0, dim_size=num_nodes) # NOTE We assume that node_time is always less than edge_time. index_unique = index.unique() node_time_dict[node_type][index_unique] = node_time[index_unique]
def forward(self, batch_size, encode_coordinates, agent_encodings): channel = agent_encodings.shape[-1] pool_vector = agent_encodings.transpose(1, 0) # [C X D] init_map_ts = torch.zeros((channel, batch_size*self.pooling_size*self.pooling_size), device=self.device) # [C X B*H*W] out, _ = ts.scatter_min(src=pool_vector, index=encode_coordinates, out=init_map_ts) # [C X B*H*W] out, _ = ts.scatter_max(src=pool_vector, index=encode_coordinates, out=out) # [C X B*H*W] out = out.reshape((channel, batch_size, self.pooling_size, self.pooling_size)) # [C X B X H X W] out = out.permute((1, 0, 2, 3)) # [B X C X H X W] return out
def forward(self, data): # device = self.device # mode = self.mode k = self.k device = self.device pos_idx = self.pos_idx x, edge_index, batch = data.x, data.edge_index, data.batch edge_index = knn_graph(x=x[:, pos_idx], k=k, batch=batch).to(device) x = self.GGconv1(x, edge_index) x = self.relu(x) x = self.nn1(x) x = self.relu(x) y = self.resblock1(x) x = x + y z = self.resblock2(x) x = x + z del y, z x = self.nn2(x) x = self.relu(x) x = self.GGconv2(x, edge_index) x = self.relu(x) p = self.resblock3(x) x = x + p o = self.resblock4(x) x = x + o del p, o x = self.nn3(x) x = self.relu(x) a, _ = scatter_max(x, batch, dim=0) b, _ = scatter_min(x, batch, dim=0) c = scatter_sum(x, batch, dim=0) d = scatter_mean(x, batch, dim=0) x = torch.cat((a, b, c, d), dim=1) # print ("cat size",x.size()) del a, b, c, d x = self.nncat(x) x = self.relu(x) # if(torch.sum(torch.isnan(x)) != 0): # print('NAN ENCOUNTERED AT NN2') # print ("xsize %s batchsize %s a size %s b size %s y size %s end forward" %(x.size(),batch.size(),a.size(),b.size(),data.y[:,0].size())) return x
def forward(self, x, edge_index, edge_attr, u, batch): # x: [N, F_x], where N is the number of nodes. # edge_index: [2, E] with max entry N - 1. # edge_attr: [E, F_e] # u: [B, F_u] # batch: [N] with max entry B - 1. out1 = scatter_mean(x, batch, dim=0) out3 = scatter_max(x, batch, dim=0)[0] out4 = scatter_min(x, batch, dim=0)[0] out = torch.cat([u, out1, out3, out4], dim=1) #print("global pre",out.shape, x.shape, u.shape) out = self.global_mlp(out) #print("global post",out.shape) return out
def correctness(dataset): group, name = dataset mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr() rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long) row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long) dim_size = rowptr.size(0) - 1 for size in sizes: try: x = torch.randn((row.size(0), size), device=args.device) x = x.squeeze(-1) if size == 1 else x out1 = scatter_add(x, row, dim=0, dim_size=dim_size) out2 = segment_coo(x, row, dim_size=dim_size, reduce='add') out3 = segment_csr(x, rowptr, reduce='add') assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4) out1 = scatter_mean(x, row, dim=0, dim_size=dim_size) out2 = segment_coo(x, row, dim_size=dim_size, reduce='mean') out3 = segment_csr(x, rowptr, reduce='mean') assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4) x = x.abs_().mul_(-1) out1, _ = scatter_min(x, row, 0, torch.zeros_like(out1)) out2, _ = segment_coo(x, row, reduce='min') out3, _ = segment_csr(x, rowptr, reduce='min') assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4) x = x.abs_() out1, _ = scatter_max(x, row, 0, torch.zeros_like(out1)) out2, _ = segment_coo(x, row, reduce='max') out3, _ = segment_csr(x, rowptr, reduce='max') assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4) except RuntimeError as e: if 'out of memory' not in str(e): raise RuntimeError(e) torch.cuda.empty_cache()
def forward(self, data): k = self.k device = self.device mode = self.mode pos_idx = self.pos_idx #changing xtype to float, change back after saving graphs properly x, edge_index, batch = data.x, data.edge_index, data.batch edge_index = knn_graph(x=x[:,pos_idx],k=k,batch=batch).to(device) a = self.conv_add(x,edge_index) edge_index = knn_graph(x=a[:,pos_idx],k=k,batch=batch).to(device) "check if this recalculation of edge indices is correct, maybe you can do it over all of x" b = self.conv_add2(a,edge_index) edge_index = knn_graph(x=b[:,pos_idx],k=k,batch=batch).to(device) c = self.conv_add3(b,edge_index) edge_index = knn_graph(x=c[:,pos_idx],k=k,batch=batch).to(device) d = self.conv_add4(c,edge_index) x = torch.cat((x,a,b,c,d),dim = 1) del a,b,c,d x = self.nn1(x) x = self.relu(x) x = self.nn2(x) a,_ = scatter_max(x, batch, dim = 0) b,_ = scatter_min(x, batch, dim = 0) c = scatter_sum(x,batch,dim = 0) d = scatter_mean(x,batch,dim= 0) x = torch.cat((a,b,c,d),dim = 1) x = self.relu(x) x = self.nn3(x) x = self.relu(x) x = self.nn4(x) if mode == 'angle': x[:,0] = self.tanh(x[:,0]) x[:,1] = self.tanh(x[:,1]) return x
def scatter_distribution(src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, unbiased: bool = True) -> torch.Tensor: if out is not None: dim_size = out.size(dim) if dim < 0: dim = src.dim() + dim count_dim = dim if index.dim() <= dim: count_dim = index.dim() - 1 ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) count = scatter_sum(ones, index, count_dim, dim_size=dim_size) index = broadcast(index, src, dim) tmp = scatter_sum(src, index, dim, dim_size=dim_size) count = broadcast(count, tmp, dim).clamp(1) mean = tmp.div(count) src_minus_mean = (src - mean.gather(dim, index)) var = src_minus_mean * src_minus_mean var = scatter_sum(var, index, dim, out, dim_size) if unbiased: count = count.sub(1).clamp_(1) var = var.div(count) skew = src_minus_mean * src_minus_mean * src_minus_mean / ( var.gather(dim, index) + 1e-7)**(1.5) kurtosis = (src_minus_mean * src_minus_mean * src_minus_mean * src_minus_mean) / (var * var + 1e-7).gather(dim, index) skew = scatter_sum(skew, index, dim, out, dim_size) kurtosis = scatter_sum(kurtosis, index, dim, out, dim_size) skew = skew.div(count) kurtosis = kurtosis.div(count) maximum = scatter_max(src, index, dim, out, dim_size)[0] minimum = scatter_min(src, index, dim, out, dim_size)[0] return torch.cat([mean, var, skew, kurtosis, maximum, minimum], dim=1)
def __call__(self, data: Data): pos = data.pos edges = data.edge_index if self.weighted_normals: normals = weighted_normals(data.face_normals, data.face_areas, data.faces, len(data.pos)) else: normals = data.vertex_normals diffs = pos[edges[0]] - pos[edges[1]] projectors = make_projector(normals) projected = torch.einsum('pij, pj -> pi', projectors[edges[0]], pos[edges[1]]) projected /= projected.norm(dim=1, keepdim=True) gauge, _ = scatter_min(edges[1], edges[0]) e1 = projected[gauge] e2 = torch.cross(e1, normals) log_map = projected * diffs.norm(dim=1, keepdim=True) theta_x = torch.einsum('pi, pi -> p', e1[edges[0]], log_map) theta_y = torch.einsum('pi, pi -> p', e2[edges[0]], log_map) theta = torch.atan2(theta_y, theta_x) axis = torch.cross(normals[edges[1]], normals[edges[0]]) alpha = torch.einsum('pi, pi -> p', normals[edges[0]], normals[edges[1]]).clamp(-1, 1) alpha = torch.acos(alpha) rotvec = (alpha.unsqueeze(dim=-1) * axis).numpy() rotation = R.from_rotvec(rotvec) g_x = rotation.apply(e1[edges[1]].numpy()) g_x = torch.einsum('pi, pi -> p', torch.FloatTensor(g_x), e1[edges[0]]) g_y = rotation.apply(e2[edges[1]].numpy()) g_y = torch.einsum('pi, pi -> p', torch.FloatTensor(g_y), e2[edges[0]]) g = torch.atan2(g_y, g_x) del data.vertex_normals del data.faces del data.face_normals del data.face_areas if self.distance: data.distance = diffs.norm(dim=1) data.g = g data.theta = theta return data
def take_action_deterministic_batch(target_net, player, next_player, rewards, next_afterstates, weights=None, id_graphs=None, **kwargs): """Take actions in batch""" if id_graphs is None: n_nodes = sum([len(afterstate) for afterstate in next_afterstates]) id_graphs = torch.zeros(size=(n_nodes, ), dtype=torch.int64).to(device) # if the game is finished in the next turn # we know what is the best action to take # because we have the true rewards available if next_player == 3: # the targets are the true values targets = rewards # if it's not the end state, # we sample from the values else: with torch.no_grad(): # Create a Batch of graphs G_torch = Batch.from_data_list(next_afterstates).to(device) # We compute the target values targets = target_net(G_torch, **kwargs) if weights is not None: weights_tensor = torch.tensor(weights, dtype=torch.float).view( targets.size()).to(device) target_decision = targets + weights_tensor else: target_decision = targets # if it's the turn of the attacker if player == 1: # we take the argmin _, actions = scatter_min(target_decision, id_graphs, dim=0) else: # we take the argmax _, actions = scatter_max(target_decision, id_graphs, dim=0) values = targets[actions[:, 0]] return actions.view(-1).tolist(), targets, values.view(-1).tolist()
def get_depot_info(beam, graph): """ Finds for each group (set of visited nodes) in the beam the lowest cost to return to the depot This is useful since any non-dominated (lowest cost) expansion via the depot must necessarily also arrive at the depot at lowest cost (since remaining demand is reset at depot, only look at cost) :param beam: :param graph: :return: """ # Get total distance to depot for each entry in group, for first action current is undefined, don't add beam_cost_at_depot = beam.cost if beam.current is None else beam.cost + graph.cost_to_depot[ beam.batch_ids, beam.current.long()] if beam.sort_by == 'group_idx': group_min_cost_at_depot, group_idx_min_cost_at_depot = segment_min_coo(beam_cost_at_depot, beam.group_idx) else: group_min_cost_at_depot, group_idx_min_cost_at_depot = scatter_min(beam_cost_at_depot, beam.group_idx) beam_min_cost_at_depot = group_min_cost_at_depot.gather(0, beam.group_idx) beam_idx_min_cost_at_depot = group_idx_min_cost_at_depot.gather(0, beam.group_idx) return group_min_cost_at_depot, group_idx_min_cost_at_depot, beam_min_cost_at_depot, beam_idx_min_cost_at_depot
def forward(self, batched_data): h_node = self.gnn_node(batched_data) if self.graph_pooling == 'laf' and isinstance(self.pool, ScatterAggregationLayer): x_min = torch_scatter.scatter_min(h_node, batched_data.batch, dim=0)[0] gather_idxs = batched_data.batch.expand(h_node.shape[1], -1).t() gather_mins = torch.gather(x_min, 0, gather_idxs) s = F.relu(-gather_mins) h_node = h_node + s out = self.pool(h_node, batched_data.batch) s_out = self.pool(s, batched_data.batch) h_graph = out - s_out elif self.graph_pooling == 'laf' and isinstance( self.pool, ScatterExponentialLAF): h_graph = self.pool(h_node, batched_data.batch) else: h_graph = self.pool(h_node, batched_data.batch) return self.graph_pred_linear(h_graph)
def __call__(self, data): row, col = data.edge_index N = data.num_nodes deg = degree(row, N, dtype=torch.float) deg_col = deg[col] value = 1e16 min_deg, _ = scatter_min(deg_col, row, dim_size=N, fill_value=value) min_deg[min_deg == value] = 0 max_deg, _ = scatter_max(deg_col, row, dim_size=N) mean_deg = scatter_mean(deg_col, row, dim_size=N) std_deg = scatter_std(deg_col, row, dim_size=N) x = torch.stack([deg, min_deg, max_deg, mean_deg, std_deg], dim=1) if data.x is not None: data.x = data.x.view(-1, 1) if data.x.dim() == 1 else data.x data.x = torch.cat([data.x, x], dim=-1) else: data.x = x return data
def predict(self, inputs, labels): K = inputs['intrinsics'] extrinsics = inputs['extrinsics'] depths = inputs['depth'] depth_mask = inputs['depth_mask'] target_T = inputs['target_T'] segs = inputs['seg'] if self.ind is not None: depths = depths[:, self.ind:self.ind + 1] depth_mask = depth_mask[:, self.ind:self.ind + 1] target_T = target_T[:, self.ind:self.ind + 1] segs = segs[:, self.ind:self.ind + 1] # Step 1: back project 2d points into 3D using formula: # pts3d = depths * K^-1 * pts2d b, inp_t, height, width = depths.shape vs, us = torch.meshgrid( torch.arange(height, dtype=torch.float, device=depths.device), torch.arange(width, dtype=torch.float, device=depths.device), ) pts2d = torch.cat([ us.reshape(-1, 1), vs.reshape(-1, 1), torch.ones(height * width, 1, dtype=torch.float, device=us.device) ], dim=-1).unsqueeze(0).expand(b, -1, -1) K_inv = torch.inverse(K).reshape(b, 1, 3, 3) # [b, 1, 3, 3] x [b, hw, 3, 1]. After squeeze result is [b, hw, 3] pts3d_c = (K_inv @ pts2d.unsqueeze(-1)).squeeze(-1) pts3d_c = pts3d_c.unsqueeze(1) * depths.reshape(b, inp_t, -1, 1) pts3d_c = torch.cat([ pts3d_c, torch.ones(b, inp_t, height * width, 1, device=K.device) ], dim=-1) # Step 2: convert camera points (in RDF) to vehicle points (in FLU) # Here, pts3d is [b, inp_t, h*w, 4] pts3d_v = extrinsics.view(b, 1, 1, 4, 4) @ pts3d_c.unsqueeze(-1) # Step 3: transform points such that they lie in the final frame's # vehicle coordinate system # target_T shape: [b, inp_t, 4, 4] result_pts3d_v = target_T.unsqueeze(2) @ pts3d_v # Step 4: Project points to 2d (by first transforming to camera coordinates) result_pts3d_c = torch.inverse(extrinsics).reshape(b, 1, 1, 4, 4) @ result_pts3d_v result_pts3d_c = result_pts3d_c[:, :, :, :3] / result_pts3d_c[:, :, :, 3:4] result_depths = result_pts3d_c[:, :, :, 2] result2d = K.view(b, 1, 1, 3, 3) @ result_pts3d_c result2d = result2d[:, :, :, :2] / result2d[:, :, :, 2:3] #result2d = result2d.squeeze(-1).round().long() result2d = result2d.squeeze(-1) # Valid points have the following properties: # - They correspond to valid input depth values # - the depth values are > 0 (i.e. they lie in front of the camera) # - The u/v coordinates lie within the image inbounds_mask = (result2d[:, :, :, 0] >= 0) & \ (result2d[:, :, :, 0] < width) & \ (result2d[:, :, :, 1] >= 0) & \ (result2d[:, :, :, 1] < height) result_mask = depth_mask.view(b, inp_t, height*width)* \ (result_depths.squeeze(-1) > 0) & \ inbounds_mask # We need to translate our 2d predictions (which currently take the form # [batch, num_predicted_points, 2] and represent the u/v coordinates for each # point in the final camera frame) to the actual image, only keeping points # with valid depths and moreover keeping the closest valid point. # We do this using scatter (a good overview of how this works can be seen at # https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html#torch_scatter.scatter) # First: find the points with the smallest depth at each result location result_mask = result_mask.reshape(b, inp_t * height * width) result_depths = result_depths.reshape(b, inp_t * height * width) # Make sure we never select an invalid point for a location when a # valid point exists result_depths[~result_mask] = result_depths.max() + 1 result2d = result2d.reshape(b, inp_t * height * width, 2) result2d_0 = torch.stack([ result2d[:, :, 0].floor().long(), result2d[:, :, 1].floor().long() ], dim=-1) result2d_1 = torch.stack([ result2d[:, :, 0].floor().long(), result2d[:, :, 1].ceil().long() ], dim=-1) result2d_2 = torch.stack([ result2d[:, :, 0].ceil().long(), result2d[:, :, 1].floor().long() ], dim=-1) result2d_3 = torch.stack( [result2d[:, :, 0].ceil().long(), result2d[:, :, 1].ceil().long()], dim=-1) result2d = torch.cat([result2d_0, result2d_1, result2d_2, result2d_3], dim=1) result2d[:, :, 0].clamp_(0, width - 1) result2d[:, :, 1].clamp_(0, height - 1) result_depths = result_depths.repeat(1, 4) scatter_inds = result2d[:, :, 1] * width + result2d[:, :, 0] _, argmin = torch_scatter.scatter_min(result_depths, scatter_inds, -1, dim_size=inp_t * height * width * 4) tmp_mask = (argmin < inp_t * height * width * 4) ind0 = tmp_mask.nonzero()[:, 0] ind1 = argmin[tmp_mask] tgt_ind1 = tmp_mask.nonzero()[:, 1] if self.is_img: final_seg = torch.zeros(b, height * width, 3, dtype=segs.dtype, device=K.device) segs = segs.reshape(b, inp_t * height * width, 3).repeat(1, 4, 1) else: final_seg = torch.zeros(b, height * width, dtype=segs.dtype, device=K.device) segs = segs.reshape(b, inp_t * height * width).repeat(1, 4) # The following ensures we don't copy a prediction from an "invalid" point segs[~result_mask.repeat(1, 4)] = 0 final_seg[ind0, tgt_ind1] = segs[ind0, ind1] final_depths = torch.zeros(b, height * width, dtype=result_depths.dtype, device=K.device).fill_(-1) final_depths[ind0, tgt_ind1] = result_depths[ind0, ind1] if self.is_img: final_seg = final_seg.view(b, height, width, 3) else: final_seg = final_seg.view(b, height, width) result_dict = { 'seg': final_seg, 'result2d': result2d[:, :inp_t * height * width].reshape( b, inp_t, height, width, 2), 'depth': final_depths.view(b, height, width), } return result_dict
def forward(self, data): """ Provides a fractional solution to the data association problem. First, node and edge features are independently encoded by the encoder network. Then, they are iteratively 'combined' for a fixed number of steps via the Message Passing Network (self.MPNet). Finally, they are classified independently by the classifiernetwork. Args: data: object containing attribues - x: node features matrix - edge_index: tensor with shape [2, M], with M being the number of edges, indicating nonzero entries in the graph adjacency (i.e. edges) (i.e. sparse adjacency) - edge_attr: edge features matrix (sorted by edge apperance in edge_index) Returns: classified_edges: list of unnormalized node probabilites after each MP step """ x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr x_is_img = len(x.shape) == 4 if self.node_cnn is not None and x_is_img: x = self.node_cnn(x) emb_dists = nn.functional.pairwise_distance(x[edge_index[0]], x[edge_index[1]]).view(-1, 1) edge_attr = torch.cat((edge_attr, emb_dists), dim = 1) # Encoding features step latent_edge_feats, latent_node_feats = self.encoder(edge_attr, x) initial_edge_feats = latent_edge_feats initial_node_feats = latent_node_feats # During training, the feature vectors that the MPNetwork outputs for the last self.num_class_steps message # passing steps are classified in order to compute the loss. first_class_step = self.num_enc_steps - self.num_class_steps + 1 first_attention_step = self.num_enc_steps - self.num_attention_steps + 1 if self.use_attention: if self.graph_pruning: outputs_dict = {'classified_edges': [],'att_coefficients':[],'mask':[]} else: outputs_dict = {'classified_edges': [],'att_coefficients':[]} else: if self.graph_pruning: outputs_dict = {'classified_edges': [],'mask':[]} else: outputs_dict = {'classified_edges': []} mask = torch.full((edge_index.shape[1],), True, dtype=torch.bool) for step in range(1, self.num_enc_steps + 1): # Reattach the initially encoded embeddings before the update if self.reattach_initial_edges: latent_edge_feats = torch.cat((initial_edge_feats, latent_edge_feats), dim=1) # [M,16]+[M,16] -> [M,32] if self.reattach_initial_nodes: latent_node_feats = torch.cat((initial_node_feats, latent_node_feats), dim=1) # Message Passing Step if self.use_attention: if self.graph_pruning: a = torch.zeros(self.attention_head_num,edge_index.shape[1]).cuda() edge_feats = torch.zeros(latent_edge_feats.shape[0],16).cuda() latent_node_feats, edge_feats[mask],a_masked = self.MPNet(latent_node_feats, edge_index.T[mask].T, latent_edge_feats[mask]) latent_edge_feats = edge_feats a.T[mask] = a_masked.T else: latent_node_feats, latent_edge_feats,a = self.MPNet(latent_node_feats, edge_index, latent_edge_feats) else: if self.graph_pruning: edge_feats = torch.zeros(latent_edge_feats.shape[0],16).cuda() latent_node_feats, edge_feats[mask] = self.MPNet(latent_node_feats, edge_index.T[mask].T, latent_edge_feats[mask]) latent_edge_feats = edge_feats else: latent_node_feats, latent_edge_feats = self.MPNet(latent_node_feats, edge_index, latent_edge_feats) if step >= first_class_step: # Classification Step logits, _ = self.classifier(latent_edge_feats) pruning_this_step = self.graph_pruning and step >= self.first_prune_step and step < self.num_enc_steps if self.use_attention and step >= first_attention_step: outputs_dict['att_coefficients'].append(a) if pruning_this_step: if self.prune_mode == "classifier naive": valid_pro = probabilities[mask] topk_mask = torch.full((valid_pro.shape[0],), True,dtype=torch.bool) _,indice = torch.topk(valid_pro,int(len(valid_pro)*self.prune_factor),largest=False) topk_mask[indice]= False elif self.prune_mode == "classifier node wise": valid_pro = probabilities[mask] valid_idx = edge_index[0][mask] topk_mask = torch.ones(len(valid_pro), dtype=torch.bool) valid_pro_copy = valid_pro.clone() k = torch.ones(len(valid_idx)).cuda() k = torch.max(scatter_add(k, valid_idx)) k = int(k * self.prune_factor) for i in range(k): _, argmin = torch_scatter.scatter_min(valid_pro_copy, valid_idx) neighbor = scatter_add(topk_mask.long().cuda(), valid_idx) argmin = argmin[neighbor > self.prune_min_edge] topk_mask[argmin] = False valid_pro_copy[argmin] = 2 mask[mask == True] = topk_mask outputs_dict['mask'].append(mask.clone()) probabilities = torch.zeros_like(logits.view(-1)) probabilities[mask] = torch.sigmoid(logits.view(-1)[mask]) outputs_dict['classified_edges'].append(probabilities) if self.num_enc_steps == 0: dec_edge_feats, _ = torch.sigmoid(self.classifier(latent_edge_feats)) outputs_dict['classified_edges'].append(dec_edge_feats) return outputs_dict
def train_step(model, optimizer, train_iterator, args, step, writer): optimizer.zero_grad() x_scores, x_relations, y_scores, y_relations, mask_relations, w_scores, w_relations, berts, edge_indices, softmax_edge_indices, n_program, max_y_score_len, mask_relations_class, question_indices, step_indices, noisy_mask_relations = train_iterator.next_supervised( ) if args.cuda: x_scores = x_scores.cuda() x_relations = x_relations.cuda() y_scores = y_scores.cuda() y_relations = y_relations.cuda() mask_relations = mask_relations.cuda() w_scores = w_scores.cuda() w_relations = w_relations.cuda() berts = berts.cuda() edge_indices = edge_indices.cuda() softmax_edge_indices = softmax_edge_indices.cuda() mask_relations_class = mask_relations_class.cuda() question_indices = question_indices.cuda() step_indices = step_indices.cuda() noisy_mask_relations = noisy_mask_relations.cuda() scores, relations = model(x_scores, x_relations, berts, edge_indices, softmax_edge_indices, n_program, max_y_score_len) score_loss = torch.nn.CrossEntropyLoss(reduction='none')(scores, y_scores) if args.train_with_masking: relations = torch.where( mask_relations_class, relations, torch.tensor(-float('inf')).to(relations.device)) relation_loss = torch.nn.CrossEntropyLoss(reduction='none')( relations, y_relations) * mask_relations else: relation_loss = torch.nn.CrossEntropyLoss(reduction='none')( relations, y_relations) * mask_relations relation_loss = relation_loss[noisy_mask_relations] all_loss = score_loss + args.relation_coeff * relation_loss all_loss = torch_scatter.scatter_add( all_loss, step_indices[1], dim=0, dim_size=torch.max(step_indices[1]) + 1) loss, _ = torch_scatter.scatter_min( all_loss, question_indices[1], dim=0, dim_size=torch.max(question_indices[1]) + 1) loss = torch.mean(loss) score_loss = torch.mean( torch_scatter.scatter_min(torch_scatter.scatter_add( score_loss, step_indices[1], dim=0, dim_size=torch.max(step_indices[1]) + 1), question_indices[1], dim=0, dim_size=torch.max(question_indices[1]) + 1)[0]) relation_loss = torch.mean( torch_scatter.scatter_min(torch_scatter.scatter_add( relation_loss, step_indices[1], dim=0, dim_size=torch.max(step_indices[1]) + 1), question_indices[1], dim=0, dim_size=torch.max(question_indices[1]) + 1)[0]) loss.backward() optimizer.step() log = { 'supervised_loss': loss.item(), 'supervised_score_loss': score_loss.item(), 'supervised_relation_loss': relation_loss.item(), } for metric in log: writer.add_scalar(metric, log[metric], step) return log
def train_dqn(batch_size, size_test_data, lr, betas, n_episode, update_target, n_time_instance_seen, eps_end, eps_decay, eps_start, dim_embedding, dim_values, dim_hidden, n_heads, n_att_layers, n_pool, alpha, p, n_free_min, n_free_max, d_edge_min, d_edge_max, Omega_max, Phi_max, Lambda_max, weighted, w_max=1, directed=False, num_workers=0, resume_training=False, path_train="", path_test_data=None, exact_protection=False, rate_display=200, batch_unroll=128): """Train a DQN to solve the MCN problem""" # Gather the hyperparameters dict_args = locals() # Gather the date as a string date_str = (datetime.now().strftime('%b') + str(datetime.now().day) + "_" + str(datetime.now().hour) + "-" + str(datetime.now().minute) + "-" + str(datetime.now().second)) # Tensorboard init writer = SummaryWriter() # Init the counts count_steps = 0 count_instances = 0 # Compute n_max n_max = n_free_max + Omega_max + Phi_max + Lambda_max max_budget = Omega_max + Phi_max + Lambda_max list_players = [2] * Lambda_max + [1] * Phi_max + [0] * Omega_max # Compute the size of the memory size_memory = batch_size * n_time_instance_seen # Init the value net value_net = DQN( dim_input=5, dim_embedding=dim_embedding, dim_values=dim_values, dim_hidden=dim_hidden, n_heads=n_heads, n_att_layers=n_att_layers, n_pool=n_pool, K=n_max, alpha=alpha, p=p, weighted=weighted, ).to(device) # Initialize the optimizer optimizer = optim.Adam(value_net.parameters(), lr=lr, betas=betas) # Initialize the memory replay_memory_states = [] replay_memory_actions = [] replay_memory_afterstates = [] replay_memory_rewards = [] count_memory = 0 # If resume training if resume_training: # load the state dicts of the optimizer and value_net value_net, optimizer = load_training_param(value_net, optimizer, path_train) # Init the target net target_net = DQN( dim_input=5, dim_embedding=dim_embedding, dim_values=dim_values, dim_hidden=dim_hidden, n_heads=n_heads, n_att_layers=n_att_layers, n_pool=n_pool, K=n_max, alpha=alpha, p=p, weighted=weighted, ).to(device) target_net.load_state_dict(value_net.state_dict()) target_net.eval() # in order to use the current value_net during training for an evaluation task, # we first create a second instance of ValueNet in which we will load the # state_dicts of the learning value_net before each use value_net_bis = DQN( dim_input=5, dim_embedding=dim_embedding, dim_values=dim_values, dim_hidden=dim_hidden, n_heads=n_heads, n_att_layers=n_att_layers, n_pool=n_pool, K=n_max, alpha=alpha, p=p, weighted=weighted, ).to(device) # generate the test set test_set_generators = load_create_test_set_dqn( n_free_min, n_free_max, d_edge_min, d_edge_max, Omega_max, Phi_max, Lambda_max, weighted, w_max, directed, size_test_data, path_test_data, batch_size, num_workers) losses_test = [0] * max_budget print("Number of parameters to train = %2d \n" % count_param_NN(value_net)) for episode in tqdm(range(n_episode)): # Sample a random batch of instances from where to begin list_instances = generate_random_batch_instance( batch_unroll, n_free_min, n_free_max, d_edge_min, d_edge_max, Omega_max, Phi_max, Lambda_max, Budget_target=max_budget, weighted=weighted, w_max=w_max, directed=directed, ) # Initialize the environment env = EnvironmentDQN(list_instances) # Init the list of instances for the episode current_states = None current_actions = None current_rewards = None cpt_budget = 0 # Unroll the episode while env.Budget >= 1: last_states = current_states current_states = env.list_instance_torch action = sample_action_batch_dqn(value_net, env.player, env.batch_instance_torch, eps_end, eps_decay, eps_start, count_steps) env.step(action) last_actions = current_actions current_actions = action last_rewards = current_rewards current_rewards = env.rewards cpt_budget += 1 # if we have the couples (state, afterstates) available if cpt_budget > 1: n_visited = 0 for k in range(batch_unroll): if len(replay_memory_states) < size_memory: replay_memory_states.append(None) replay_memory_afterstates.append(None) replay_memory_actions.append(None) replay_memory_rewards.append(None) replay_memory_states[count_memory % size_memory] = last_states[k] replay_memory_afterstates[count_memory % size_memory] = current_states[k] replay_memory_rewards[count_memory % size_memory] = last_rewards[k] n_free = int(torch.sum(last_states[k].J.eq(0)[:, 0])) replay_memory_actions[ count_memory % size_memory] = last_actions[k] - n_visited n_visited += n_free count_memory += 1 # If we are in the last step, we push to memory the end rewards if env.Budget == 0 and cpt_budget > 1: n_visited = 0 for k in range(batch_unroll): if len(replay_memory_states) < size_memory: replay_memory_states.append(None) replay_memory_afterstates.append(None) replay_memory_actions.append(None) replay_memory_rewards.append(None) replay_memory_states[count_memory % size_memory] = current_states[k] # doesn't matter what we put in the afterstates here replay_memory_afterstates[count_memory % size_memory] = current_states[k] replay_memory_rewards[count_memory % size_memory] = current_rewards[k] n_free = int(torch.sum(current_states[k].J.eq(0)[:, 0])) replay_memory_actions[ count_memory % size_memory] = current_actions[k] - n_visited n_visited += n_free count_memory += 1 # if there is enough new instances in memory if count_memory > size_memory: # create a list of randomly shuffled indices to sample a batch from memory_size = len(replay_memory_states) id_batch = random.sample(range(memory_size), batch_size) # gather the states, afterstates, actions and rewards list_states = [replay_memory_states[k] for k in id_batch] list_afterstates = [ replay_memory_afterstates[k] for k in id_batch ] list_actions = [replay_memory_actions[k] for k in id_batch] list_rewards = [replay_memory_rewards[k] for k in id_batch] # recover the actions id in the batch n_visited = 0 list_actions_new = [] for k in range(len(list_actions)): n_free = int(torch.sum(list_states[k].J.eq(0)[:, 0])) list_actions_new.append(list_actions[k] + n_visited) n_visited += n_free # create the tensors batch_states = collate_fn(list_states) batch_afterstates = collate_fn(list_afterstates) batch_actions = torch.tensor(list_actions_new, dtype=torch.long).view( [len(list_actions), 1]).to(device) batch_rewards = torch.tensor( list_rewards, dtype=torch.float).view([len(list_rewards), 1]).to(device) # Compute the approximate values action_values = value_net( batch_states.G_torch, batch_states.n_nodes, batch_states.Omegas, batch_states.Phis, batch_states.Lambdas, batch_states.Omegas_norm, batch_states.Phis_norm, batch_states.Lambdas_norm, batch_states.J, ) # mask the attacked nodes mask_values = batch_states.J.eq(0)[:, 0] action_values = action_values[mask_values] # Gather the approximate values approx_values = action_values.gather(0, batch_actions) # compute the masks to apply to the target mask_attack = batch_states.next_player.eq(1)[:, 0] mask_exact = batch_states.next_player.eq(3)[:, 0] # Compute the approximate targets with torch.no_grad(): target_values = target_net( batch_afterstates.G_torch, batch_afterstates.n_nodes, batch_afterstates.Omegas, batch_afterstates.Phis, batch_afterstates.Lambdas, batch_afterstates.Omegas_norm, batch_afterstates.Phis_norm, batch_afterstates.Lambdas_norm, batch_afterstates.J, ).detach() batch = batch_afterstates.G_torch.batch mask_J = batch_afterstates.J.eq(0)[:, 0] # mask the attacked nodes batch = batch[mask_J] target_values = target_values[mask_J] # Compute the min and max val_min, _ = scatter_min(target_values, batch, dim=0) val_max, _ = scatter_max(target_values, batch, dim=0) # create the target tensor target = val_max target[mask_attack] = val_min[mask_attack] target[mask_exact] = batch_rewards[mask_exact] # Init the optimizer optimizer.zero_grad() # Compute the loss of the batch loss = torch.sqrt(torch.mean((approx_values - target)**2)) # Update the parameters of the Value_net loss.backward() optimizer.step() # compute the loss on the test set using the value_net_bis value_net_bis.load_state_dict(value_net.state_dict()) value_net_bis.eval() # Check the test losses every 20 steps if count_steps % 20 == 0: losses_test = compute_loss_test_dqn( test_set_generators, list_players, value_net=value_net_bis) for k in range(len(losses_test)): name_loss = 'Loss test budget ' + str(k + 1) writer.add_scalar(name_loss, float(losses_test[k]), count_steps) # Update the tensorboard writer.add_scalar("Loss", float(loss), count_steps) count_steps += 1 # Update the target net if count_steps % update_target == 0: target_net.load_state_dict(value_net.state_dict()) target_net.eval() # Saves model every rate_display steps if count_steps % rate_display == 0: save_models(date_str, dict_args, value_net, optimizer, count_steps) print( " \n Episode: %2d/%2d" % (episode * batch_size, n_episode), " \n Loss of the current value net: %f" % float(loss), " \n Losses on test set : ", losses_test, )
def compute_loss_test_dqn(test_set_generators, list_players, value_net=None, list_experts=None, id_to_test=None): """Compute the list of losses of the value_net or the list_of_experts over the list of exactly solved datasets that constitutes the test set""" list_losses = [] with torch.no_grad(): if id_to_test is None: iterator = range(len(test_set_generators)) else: iterator = [id_to_test] for k in iterator: target = [] val_approx = [] player = list_players[k] if list_experts is not None: try: target_net = list_experts[k] except IndexError: target_net = None elif value_net is not None: target_net = value_net if target_net is None: list_losses.append(0) else: for i_batch, batch_instances in enumerate( test_set_generators[k]): batch = batch_instances.G_torch.batch mask_values = batch_instances.J.eq(0)[:, 0] action_values = target_net( batch_instances.G_torch, batch_instances.n_nodes, batch_instances.Omegas, batch_instances.Phis, batch_instances.Lambdas, batch_instances.Omegas_norm, batch_instances.Phis_norm, batch_instances.Lambdas_norm, batch_instances.J, ) action_values = action_values[mask_values] batch = batch[mask_values] # if it's the turn of the attacker if player == 1: # we take the argmin values, actions = scatter_min(action_values, batch, dim=0) else: # we take the argmax values, actions = scatter_max(action_values, batch, dim=0) val_approx.append(values) target.append(batch_instances.target) # Compute the loss target = torch.cat(target) val_approx = torch.cat(val_approx) loss_target_net = float( torch.sqrt(torch.mean( (val_approx[:, 0] - target[:, 0])**2))) list_losses.append(loss_target_net) return list_losses
src = torch.Tensor([[2, 1, 1, 4, 2], [1, 2, 1, 2, 4]]).float() index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]) out = src.new_ones((2, 6)) out = scatter_div(src, index, out=out) print(out) # tensor([[1.0000, 1.0000, 0.2500, 0.5000, 0.5000, 1.0000], # [0.5000, 0.2500, 0.5000, 1.0000, 1.0000, 1.0000]]) # 最大最小平均值 src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]) index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]) out, argmax = scatter_max(src, index) print(out, argmax) out, argmin = scatter_min(src, index) print(out, argmin) out = scatter_mean(src, index) print(out) out = scatter_mul(src, index) print(out) out = scatter_std(src, index) print(out) out = scatter_sub(src, index) print(out)
def forward(self, data, tvol = None): x = data.x edge_index = data.edge_index batch = data.batch xinit= x.clone() row, col = edge_index mask = get_mask(x,edge_index,1).to(x.dtype).unsqueeze(-1) x = self.conv1(x, edge_index) xpostconv1 = x.detach() x = x*mask for conv, bn in zip(self.convs, self.bns): if(x.dim()>1): x = x + conv(x, edge_index) mask = get_mask(mask,edge_index,1).to(x.dtype) x = x*mask x = bn(x) x = self.conv2(x, edge_index) mask = get_mask(mask,edge_index,1).to(x.dtype) x = x*mask xpostconvs = x.detach() # x = F.leaky_relu(self.lin1(x)) x = x*mask x = self.bn2(x) xpostlin1 = x.detach() x = F.dropout(x, p=0.5, training=self.training) x = F.leaky_relu(self.lin2(x)) x = x*mask xprethresh = x.detach() N_size = x.shape[0] batch_max = scatter_max(x, batch, 0, dim_size= N_size)[0] batch_max = torch.index_select(batch_max, 0, batch) batch_min = scatter_min(x, batch, 0, dim_size= N_size)[0] batch_min = torch.index_select(batch_min, 0, batch) #min-max normalize x = (x-batch_min)/(batch_max+1e-6-batch_min) x = x*mask + mask*1e-6 #add dirac in the set x = x + xinit.unsqueeze(-1) #calculate x2 = x.detach() r, c = edge_index tv = total_var(x, edge_index, batch) deg = degree(r).unsqueeze(-1) conduct_1 = (tv) totalvol = scatter_add(deg.detach()*torch.ones_like(x, device=device), batch, 0)+1e-6 totalcard = scatter_add(torch.ones_like(x, device=device), batch, 0)+1e-6 #receptive field recvol_hard = scatter_add(deg*mask.float(), batch, 0, dim_size = batch.max().item()+1)+1e-6 reccard_hard = scatter_add(mask.float(), batch, 0, dim_size = batch.max().item()+1)+1e-6 assert recvol_hard.mean()/totalvol.mean() <=1, "Something went wrong! Receptive field is larger than total volume." target = torch.zeros_like(totalvol) #generate target vol if tvol is None: feasible_vols = data.recfield_vol/data.total_vol-0.0 target = torch.rand_like(feasible_vols, device=device)*feasible_vols*0.85 + 0.1 target = target.squeeze(-1)*totalvol.squeeze(-1) else: target = tvol*totalvol.squeeze(-1) a = torch.ones((batch.max().item()+1,1), device = device) xfilt = x ############################################################################### #iterative rescaling counter_no2 = 0 for iteration in range(self.num_iterations): counter_no2 += 1 keep = (((a[batch]*xfilt)<1).to(x.dtype)) x_k, d_k, d_nk = xfilt*keep*mask, deg*keep*mask, deg*(1-keep)*mask diff = target.unsqueeze(-1) - scatter_add(d_nk, batch, 0) dot = scatter_add(x_k*d_k, batch, 0) a = diff/(dot+1e-5) volcur = (scatter_add(torch.clamp(a[batch]*xfilt,max = 1., min = 0.)*deg,batch,0)) volcheck = (torch.abs(target - volcur.squeeze(-1))>0.1) checki = torch.abs(target.squeeze(-1)-volcur.squeeze(-1))>0.01 targetcheck = torch.abs(volcur.squeeze(-1) - target) check = (targetcheck<= self.elasticity*target).to(x.dtype) if (tvol is not None): pass if(check.sum()>=batch.max().item()+1): break; probs = torch.clamp(a[batch]*x*mask, max = 1., min = 0.) ############################################################################### #collect useful numbers x2 = ((probs - torch.rand_like(x, device = device))>0).float() vol_1 = scatter_add(probs*deg, batch, 0)+1e-6 card_1 = scatter_add(probs, batch,0) rec_field = scatter_add(mask, batch, 0)+1e-6 cut_size = scatter_add(x2, batch, 0) tv_hard = total_var(x2, edge_index, batch) vol_hard = scatter_add(deg*x2, batch, 0, dim_size = batch.max().item()+1)+1e-6 conduct_hard = tv_hard/vol_hard rec_field_ratio = cut_size/rec_field rec_field_volratio = vol_hard/recvol_hard total_vol_ratio = vol_hard/totalvol #calculate loss expected_cut = scatter_add(probs*deg, batch, 0) - scatter_add((probs[row]*probs[col]), batch[row], 0) loss = expected_cut #return dict retdict = {} retdict["output"] = [probs.squeeze(-1),"hist"] #output #retdict["|Expected_vol - Target|"]= [targetcheck, "sequence"] #absolute distance from targetvol retdict["Expected_volume"] = [vol_1.mean(),"sequence"] #volume retdict["Expected_cardinality"] = [card_1.mean(),"sequence"] retdict["volume_hard"] = [vol_hard.mean(),"sequence"] #volume2 #retdict["cut1"] = [tv.mean(),"sequence"] #cut1 retdict["cut_hard"] = [tv_hard.mean(),"sequence"] #cut1 retdict["Average cardinality ratio of receptive field "] = [rec_field_ratio.mean(),"sequence"] retdict["Recfield volume/Total volume"] = [recvol_hard.mean()/totalvol.mean(), "sequence"] retdict["Average ratio of receptive field volume"]= [rec_field_volratio.mean(),'sequence'] retdict["Average ratio of total volume"]= [total_vol_ratio.mean(),'sequence'] retdict["mask"] = [mask, "aux"] #mask retdict["xinit"] = [xinit,"hist"] #layer input diracs retdict["xpostlin1"] = [xpostlin1.mean(1),"hist"] #after first linear layer retdict["xprethresh"] = [xprethresh.mean(1),"hist"] #pre thresholding activations 195 x 1 retdict["lossvol"] = [lossvol.mean(),"sequence"] #volume constraint retdict["losscard"] = [losscard.mean(),"sequence"] #cardinality constraint retdict["loss"] = [loss.mean().squeeze(),"sequence"] #final loss return retdict
def forward(self, data, edge_dropout = None, penalty_coefficient = 0.25): x = data.x edge_index = data.edge_index batch = data.batch num_graphs = batch.max().item() + 1 row, col = edge_index total_num_edges = edge_index.shape[1] N_size = x.shape[0] if edge_dropout is not None: edge_index = dropout_adj(edge_index, edge_attr = (torch.ones(edge_index.shape[1], device=device)).long(), p = edge_dropout, force_undirected=True)[0] edge_index = add_remaining_self_loops(edge_index, num_nodes = batch.shape[0])[0] reduced_num_edges = edge_index.shape[1] current_edge_percentage = (reduced_num_edges/total_num_edges) no_loop_index,_ = remove_self_loops(edge_index) no_loop_row, no_loop_col = no_loop_index xinit= x.clone() x = x.unsqueeze(-1) mask = get_mask(x,edge_index,1).to(x.dtype) x = F.leaky_relu(self.conv1(x, edge_index))# +x x = x*mask x = self.gnorm(x) x = self.bn1(x) for conv, bn in zip(self.convs, self.bns): if(x.dim()>1): x = x+F.leaky_relu(conv(x, edge_index)) mask = get_mask(mask,edge_index,1).to(x.dtype) x = x*mask x = self.gnorm(x) x = bn(x) xpostconvs = x.detach() # x = F.leaky_relu(self.lin1(x)) x = x*mask xpostlin1 = x.detach() x = F.leaky_relu(self.lin2(x)) x = x*mask #calculate min and max batch_max = scatter_max(x, batch, 0, dim_size= N_size)[0] batch_max = torch.index_select(batch_max, 0, batch) batch_min = scatter_min(x, batch, 0, dim_size= N_size)[0] batch_min = torch.index_select(batch_min, 0, batch) #min-max normalize x = (x-batch_min)/(batch_max+1e-6-batch_min) probs=x x2 = x.detach() deg = degree(row).unsqueeze(-1) totalvol = scatter_add(deg.detach()*torch.ones_like(x, device=device), batch, 0)+1e-6 totalcard = scatter_add(torch.ones_like(x, device=device), batch, 0)+1e-6 x2 = ((x2 - torch.rand_like(x, device = device))>0).float() vol_1 = scatter_add(probs*deg, batch, 0)+1e-6 card_1 = scatter_add(probs, batch,0) set_size = scatter_add(x2, batch, 0) vol_hard = scatter_add(deg*x2, batch, 0, dim_size = batch.max().item()+1)+1e-6 total_vol_ratio = vol_hard/totalvol #calculating the terms for the expected distance between clique and graph pairwise_prodsums = torch.zeros(num_graphs, device = device) for graph in range(num_graphs): batch_graph = (batch==graph) pairwise_prodsums[graph] = (torch.conv1d(probs[batch_graph].unsqueeze(-1), probs[batch_graph].unsqueeze(-1))).sum()/2 ###calculate loss terms self_sums = scatter_add((probs*probs), batch, 0, dim_size = num_graphs) expected_weight_G = scatter_add(probs[no_loop_row]*probs[no_loop_col], batch[no_loop_row], 0, dim_size = num_graphs)/2. expected_clique_weight = (pairwise_prodsums.unsqueeze(-1) - self_sums)/1. expected_distance = (expected_clique_weight - expected_weight_G) ###useful numbers max_set_weight = (scatter_add(torch.ones_like(x)[no_loop_row], batch[no_loop_row], 0, dim_size = num_graphs)/2).squeeze(-1) set_weight = (scatter_add(x2[no_loop_row]*x2[no_loop_col], batch[no_loop_row], 0, dim_size = num_graphs)/2)+1e-6 clique_edges_hard = (set_size*(set_size-1)/2) +1e-6 clique_dist_hard = set_weight/clique_edges_hard clique_check = ((clique_edges_hard != clique_edges_hard)) setedge_check = ((set_weight != set_weight)) assert ((clique_dist_hard>=1.1).sum())<=1e-6, "Invalid set vol/clique vol ratio." ###calculate loss expected_loss = (penalty_coefficient)*expected_distance*0.5 - 0.5*expected_weight_G loss = expected_loss retdict = {} retdict["output"] = [probs.squeeze(-1),"hist"] #output retdict["Expected_cardinality"] = [card_1.mean(),"sequence"] retdict["Expected_cardinality_hist"] = [card_1,"hist"] retdict["losses histogram"] = [loss.squeeze(-1),"hist"] retdict["Set sizes"] = [set_size.squeeze(-1),"hist"] retdict["volume_hard"] = [vol_hard.mean(),"aux"] #volume2 retdict["cardinality_hard"] = [set_size[0],"sequence"] #volumeq retdict["Expected weight(G)"]= [expected_weight_G.mean(), "sequence"] retdict["Expected maximum weight"] = [expected_clique_weight.mean(),"sequence"] retdict["Expected distance"]= [expected_distance.mean(), "sequence"] retdict["Currvol/Cliquevol"] = [clique_dist_hard.mean(),'sequence'] retdict["Currvol/Cliquevol all graphs in batch"] = [clique_dist_hard.squeeze(-1),'hist'] retdict["Average ratio of total volume"]= [total_vol_ratio.mean(),'sequence'] retdict["cardinalities"] = [cardinalities.squeeze(-1),"hist"] retdict["Current edge percentage"] = [torch.tensor(current_edge_percentage),'sequence'] retdict["loss"] = [loss.mean().squeeze(),"sequence"] #final loss return retdict
def forward(self, data, edge_dropout=None, penalty_coefficient=0.25): x = data.x edge_index = data.edge_index batch = data.batch num_graphs = batch.max().item() + 1 row, col = edge_index total_num_edges = edge_index.shape[1] N_size = x.shape[0] if edge_dropout is not None: edge_index = dropout_adj( edge_index, edge_attr=(torch.ones(edge_index.shape[1], device=device)).long(), p=edge_dropout, force_undirected=True)[0] edge_index = add_remaining_self_loops(edge_index, num_nodes=batch.shape[0])[0] reduced_num_edges = edge_index.shape[1] current_edge_percentage = (reduced_num_edges / total_num_edges) no_loop_index, _ = remove_self_loops(edge_index) no_loop_row, no_loop_col = no_loop_index xinit = x.clone() x = x.unsqueeze(-1) mask = get_mask(x, edge_index, 1).to(x.dtype) x = F.leaky_relu(self.conv1(x, edge_index)) # +x x = x * mask x = self.gnorm(x) x = self.bn1(x) for conv, bn in zip(self.convs, self.bns): if (x.dim() > 1): x = x + F.leaky_relu(conv(x, edge_index)) mask = get_mask(mask, edge_index, 1).to(x.dtype) x = x * mask x = self.gnorm(x) x = bn(x) xpostconvs = x.detach() # x = F.leaky_relu(self.lin1(x)) x = x * mask xpostlin1 = x.detach() x = F.leaky_relu(self.lin2(x)) x = x * mask #calculate min and max batch_max = scatter_max(x, batch, 0, dim_size=N_size)[0] batch_max = torch.index_select(batch_max, 0, batch) batch_min = scatter_min(x, batch, 0, dim_size=N_size)[0] batch_min = torch.index_select(batch_min, 0, batch) #min-max normalize x = (x - batch_min) / (batch_max + 1e-6 - batch_min) probs = x #calculating the terms for the expected distance between clique and graph pairwise_prodsums = torch.zeros(num_graphs, device=device) for graph in range(num_graphs): batch_graph = (batch == graph) pairwise_prodsums[graph] = (torch.conv1d( probs[batch_graph].unsqueeze(-1), probs[batch_graph].unsqueeze(-1))).sum() / 2 ###calculate loss terms self_sums = scatter_add((probs * probs), batch, 0, dim_size=num_graphs) expected_weight_G = scatter_add( probs[no_loop_row] * probs[no_loop_col], batch[no_loop_row], 0, dim_size=num_graphs) / 2. expected_clique_weight = (pairwise_prodsums.unsqueeze(-1) - self_sums) / 1. expected_distance = (expected_clique_weight - expected_weight_G) ###calculate loss expected_loss = (penalty_coefficient ) * expected_distance * 0.5 - 0.5 * expected_weight_G loss = expected_loss retdict = {} retdict["output"] = [probs.squeeze(-1), "hist"] #output retdict["losses histogram"] = [loss.squeeze(-1), "hist"] retdict["Expected weight(G)"] = [expected_weight_G.mean(), "sequence"] retdict["Expected maximum weight"] = [ expected_clique_weight.mean(), "sequence" ] retdict["Expected distance"] = [expected_distance.mean(), "sequence"] retdict["loss"] = [loss.mean().squeeze(), "sequence"] #final loss return retdict