def test_radius_graph(dtype, device): x = tensor([ [-1, -1], [-1, +1], [+1, +1], [+1, -1], ], dtype, device) edge_index = radius_graph(x, r=2, flow='target_to_source') assert to_set(edge_index) == set([(0, 1), (0, 3), (1, 0), (1, 2), (2, 1), (2, 3), (3, 0), (3, 2)]) edge_index = radius_graph(x, r=2, flow='source_to_target') assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2), (3, 2), (0, 3), (2, 3)])
def forward(self, data) -> torch.Tensor: num_neighbors = 3 # typical number of neighbors num_nodes = 4 # typical number of nodes num_z = self.num_z # number of atom types # graph edge_src, edge_dst = radius_graph(data.pos, 10.0, data.batch) # spherical harmonics edge_vec = data.pos[edge_src] - data.pos[edge_dst] edge_sh = o3.spherical_harmonics(self.irreps_sh, edge_vec, normalize=False, normalization='component') # edge types edge_zz = num_z * data.z[edge_src] + data.z[edge_dst] # from 0 to num_z^2 - 1 edge_zz = torch.nn.functional.one_hot(edge_zz, num_z**2).mul(num_z) edge_zz = edge_zz.to(edge_sh.dtype) # edge attributes edge_attr = self.mul(edge_zz, edge_sh) # For each node, the initial features are the sum of the spherical harmonics of the neighbors node_features = scatter(edge_sh, edge_dst, dim=0).div(num_neighbors**0.5) # For each edge, tensor product the features on the source node with the spherical harmonics edge_features = self.tp1(node_features[edge_src], edge_attr) node_features = scatter(edge_features, edge_dst, dim=0).div(num_neighbors**0.5) edge_features = self.tp2(node_features[edge_src], edge_attr) node_features = scatter(edge_features, edge_dst, dim=0).div(num_neighbors**0.5) # For each graph, all the node's features are summed return scatter(node_features, data.batch, dim=0).div(num_nodes**0.5)
def forward(self, x, batch=None): spatial = self.lin_s(x) to_propagate = self.lin_flr(x) if self.neighbor_algo == "knn": edge_index = knn_graph(spatial, self.k, batch, loop=False, flow=self.flow, cosine=False) elif self.neighbor_algo == "radius": edge_index = radius_graph(spatial, self.radius, batch, loop=False, flow=self.flow, max_num_neighbors=self.k) else: raise Exception("Unknown neighbor algo {}".format( self.neighbor_algo)) reference = spatial.index_select(0, edge_index[1]) neighbors = spatial.index_select(0, edge_index[0]) distancessq = torch.sum((reference - neighbors)**2, dim=-1) # Factor 10 gives a better initial spread distance_weight = torch.exp(-10. * distancessq) prop_feat = self.propagate(edge_index, x=to_propagate, edge_weight=distance_weight) return edge_index, self.lin_fout(torch.cat([prop_feat, x], dim=-1))
def forward(self, data) -> torch.Tensor: num_neighbors = 2 # typical number of neighbors num_nodes = 4 # typical number of nodes edge_src, edge_dst = radius_graph( x=data.pos, r=1.1, batch=data.batch) # tensors of indices representing the graph edge_vec = data.pos[edge_src] - data.pos[edge_dst] edge_sh = o3.spherical_harmonics( l=self.irreps_sh, x=edge_vec, normalize= False, # here we don't normalize otherwise it would not be a polynomial normalization='component') # For each node, the initial features are the sum of the spherical harmonics of the neighbors node_features = scatter(edge_sh, edge_dst, dim=0).div(num_neighbors**0.5) # For each edge, tensor product the features on the source node with the spherical harmonics edge_features = self.tp1(node_features[edge_src], edge_sh) node_features = scatter(edge_features, edge_dst, dim=0).div(num_neighbors**0.5) edge_features = self.tp2(node_features[edge_src], edge_sh) node_features = scatter(edge_features, edge_dst, dim=0).div(num_neighbors**0.5) # For each graph, all the node's features are summed return scatter(node_features, data.batch, dim=0).div(num_nodes**0.5)
def train(model, optimizer, scheduler, loss_fn, dataloader, epoch): model.train() loss_avg_arr = [] loss_avg = utils.RunningAverage() with tqdm(total=len(dataloader)) as t: for data in dataloader: optimizer.zero_grad() data = data.to('cuda') x_cont = data.x[:,:8] x_cat = data.x[:,8:].long() phi = torch.atan2(data.x[:,1], data.x[:,0]) etaphi = torch.cat([data.x[:,3][:,None], phi[:,None]], dim=1) # NB: there is a problem right now for comparing hits at the +/- pi boundary edge_index = radius_graph(etaphi, r=deltaR, batch=data.batch, loop=True, max_num_neighbors=255) result = model(x_cont, x_cat, edge_index, data.batch) loss = loss_fn(result, data.x, data.y, data.batch) loss.backward() optimizer.step() # update the average loss loss_avg_arr.append(loss.item()) loss_avg.update(loss.item()) t.set_postfix(loss='{:05.3f}'.format(loss_avg())) t.update() scheduler.step(np.mean(loss_avg_arr)) print('Training epoch: {:02d}, MSE: {:.4f}'.format(epoch, np.mean(loss_avg_arr)))
def forward(self, data): x, pos, batch, u = data.x, data.pos, data.batch, data.u # Get edges using positions by computing the kNNs or the neighbors within a radius #edge_index = knn_graph(pos, k=self.k_nn, batch=batch, loop=self.loop) edge_index = radius_graph(pos, r=self.k_nn, batch=batch, loop=self.loop) # Start message passing for layer in self.layers: if self.namemodel == "DeepSet": x = layer(x) elif self.namemodel == "PointNet": x = layer(x=x, pos=pos, edge_index=edge_index) elif self.namemodel == "MetaNet": x, dumb, u = layer(x, edge_index, None, u, batch) else: x = layer(x=x, edge_index=edge_index) self.h = x x = x.relu() # Mix different global pooling layers addpool = global_add_pool(x, batch) # [num_examples, hidden_channels] meanpool = global_mean_pool(x, batch) maxpool = global_max_pool(x, batch) #self.pooled = torch.cat([addpool, meanpool, maxpool], dim=1) self.pooled = torch.cat([addpool, meanpool, maxpool, u], dim=1) # Final linear layer return self.lin(self.pooled)
def radius(x, r=0.5, loop=False, dtype=None, device=None): N, D = x.shape batch = torch.zeros(N, dtype=torch.long) edge_index = radius_graph(x, r, batch=batch, loop=loop).to(device) edge_val = torch.ones(edge_index.shape[-1], dtype=dtype, device=device) return SparseTensor( row=edge_index[0], col=edge_index[1], value=edge_val, sparse_sizes=(N, N) )
def test_radius_graph(dtype, device): x = tensor([ [-1, -1], [-1, +1], [+1, +1], [+1, -1], ], dtype, device) row, col = radius_graph(x, r=2, flow='target_to_source') col = col.view(-1, 2).sort(dim=-1)[0].view(-1) assert row.tolist() == [0, 0, 1, 1, 2, 2, 3, 3] assert col.tolist() == [1, 3, 0, 2, 1, 3, 0, 2] row, col = radius_graph(x, r=2, flow='source_to_target') row = row.view(-1, 2).sort(dim=-1)[0].view(-1) assert row.tolist() == [1, 3, 0, 2, 1, 3, 0, 2] assert col.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
def forward(self, data: Union[Data, Dict[str, torch.Tensor]]) -> torch.Tensor: """evaluate the network Parameters ---------- data : `torch_geometric.data.Data` or dict data object containing - ``pos`` the position of the nodes (atoms) - ``x`` the input features of the nodes, optional - ``z`` the attributes of the nodes, for instance the atom type, optional - ``batch`` the graph to which the node belong, optional """ if 'batch' in data: batch = data['batch'] else: batch = data['pos'].new_zeros(data['pos'].shape[0], dtype=torch.long) edge_index = radius_graph(data['pos'], self.max_radius, batch) edge_src = edge_index[0] edge_dst = edge_index[1] edge_vec = data['pos'][edge_src] - data['pos'][edge_dst] edge_sh = o3.spherical_harmonics(self.irreps_edge_attr, edge_vec, True, normalization='component') edge_length = edge_vec.norm(dim=1) edge_length_embedded = soft_one_hot_linspace( x=edge_length, start=0.0, end=self.max_radius, number=self.number_of_basis, basis='gaussian', cutoff=False).mul(self.number_of_basis**0.5) edge_attr = smooth_cutoff( edge_length / self.max_radius)[:, None] * edge_sh if self.input_has_node_in and 'x' in data: assert self.irreps_in is not None x = data['x'] else: assert self.irreps_in is None x = data['pos'].new_ones((data['pos'].shape[0], 1)) if self.input_has_node_attr and 'z' in data: z = data['z'] else: assert self.irreps_node_attr == o3.Irreps("0e") z = data['pos'].new_ones((data['pos'].shape[0], 1)) for lay in self.layers: x = lay(x, z, edge_src, edge_dst, edge_attr, edge_length_embedded) if self.reduce_output: return scatter(x, batch, dim=0).div(self.num_nodes**0.5) else: return x
def test_radius_graph(dtype, device): x = tensor([ [-1, -1], [-1, +1], [+1, +1], [+1, -1], ], dtype, device) out = radius_graph(x, r=2) assert coalesce(out).tolist() == [[0, 0, 1, 1, 2, 2, 3, 3], [1, 3, 0, 2, 1, 3, 0, 2]]
def forward(self, data) -> torch.Tensor: node_atom = data['z'] node_pos = data['pos'] batch = data['batch'] # The graph edge_src, edge_dst = radius_graph(node_pos, r=self.max_radius, batch=batch, max_num_neighbors=1000) # Edge attributes edge_vec = node_pos[edge_src] - node_pos[edge_dst] edge_sh = o3.spherical_harmonics(l=range(self.sh_lmax + 1), x=edge_vec, normalize=True, normalization='component') # Edge length embedding edge_length = edge_vec.norm(dim=1) edge_length_embedding = soft_one_hot_linspace( edge_length, 0.0, self.max_radius, self.num_basis, basis='smooth_finite', cutoff=True, ).mul(self.num_basis**0.5) node_input = node_pos.new_ones(node_pos.shape[0], 1) node_attr = node_atom.new_tensor([-1, 0, -1, -1, -1, -1, 1, 2, 3, 4])[node_atom] node_attr = torch.nn.functional.one_hot(node_attr, 5).mul(5**0.5) node_outputs = self.mp(node_features=node_input, node_attr=node_attr, edge_src=edge_src, edge_dst=edge_dst, edge_attr=edge_sh, edge_scalars=edge_length_embedding) node_outputs = node_outputs[:, 0] + node_outputs[:, 1].pow(2).mul(0.5) node_outputs = node_outputs.view(-1, 1) node_outputs = node_outputs.div(self.num_nodes**0.5) if self.atomref is not None: node_outputs = node_outputs + self.atomref[node_atom] # for target=7, MAE of 75eV outputs = scatter(node_outputs, batch, dim=0) return outputs
def test_radius_graph(dtype, device): x = tensor([ [-1, -1], [-1, +1], [+1, +1], [+1, -1], ], dtype, device) row, col = radius_graph(x, r=2) assert row.tolist() == [0, 0, 1, 1, 2, 2, 3, 3] assert col.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
def test_radius_graph_large(dtype, device): x = torch.randn(1000, 3, dtype=dtype, device=device) edge_index = radius_graph(x, r=0.5, flow='target_to_source', loop=True, max_num_neighbors=2000) tree = scipy.spatial.cKDTree(x.cpu().numpy()) col = tree.query_ball_point(x.cpu(), r=0.5) truth = set([(i, j) for i, ns in enumerate(col) for j in ns]) assert to_set(edge_index.cpu()) == truth
def forward(self, inputs): """Apply forward pass of the model""" x = inputs.x # print(x.shape) spatial = self.input_spatial_network(x) features = self.input_feature_network(x) spatial = self.emb_network( torch.cat([inputs.x, features, spatial], axis=-1)) edge_index = radius_graph(spatial, r=self.r, batch=inputs.batch, loop=False, max_num_neighbors=30) # Loop over iterations of edge and node networks for i in range(self.n_graph_iters): features_inital = features # Apply edge network e = torch.sigmoid(self.edge_network(features, edge_index)) # Apply node network features = self.node_network(features, e, edge_index) spatial = self.emb_network( torch.cat([inputs.x, features, spatial], axis=-1)) edge_index = radius_graph(spatial, r=self.r, batch=inputs.batch, loop=False, max_num_neighbors=30) features = features_inital + features return self.edge_network(features, edge_index), spatial, edge_index
def radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32, flow='source_to_target', num_workers=1): r"""Computes graph edges to all points within a given distance. Args: x (Tensor): Node feature matrix :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`. r (float): The radius. batch (LongTensor, optional): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. (default: :obj:`None`) loop (bool, optional): If :obj:`True`, the graph will contain self-loops. (default: :obj:`False`) max_num_neighbors (int, optional): The maximum number of neighbors to return for each element in :obj:`y`. (default: :obj:`32`) flow (string, optional): The flow direction when using in combination with message passing (:obj:`"source_to_target"` or :obj:`"target_to_source"`). (default: :obj:`"source_to_target"`) num_workers (int): Number of workers to use for computation. Has no effect in case :obj:`batch` is not :obj:`None`, or the input lies on the GPU. (default: :obj:`1`) :rtype: :class:`LongTensor` .. code-block:: python import torch from torch_geometric.nn import radius_graph x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]]) batch = torch.tensor([0, 0, 0, 0]) edge_index = radius_graph(x, r=1.5, batch=batch, loop=False) """ if torch_cluster is None: raise ImportError('`radius_graph` requires `torch-cluster`.') return torch_cluster.radius_graph(x, r, batch, loop, max_num_neighbors, flow, num_workers)
def test(): from torch_cluster import radius_graph from e3nn.util.test import assert_equivariant, assert_auto_jitable mp = MessagePassing( irreps_node_input="0e", irreps_node_hidden="0e + 1e", irreps_node_output="1e", irreps_node_attr="0e + 1e", irreps_edge_attr="1e", layers=3, fc_neurons=[2, 100], num_neighbors=3.0, ) num_nodes = 4 node_pos = torch.randn(num_nodes, 3) edge_index = radius_graph(node_pos, 3.0) edge_src, edge_dst = edge_index num_edges = edge_index.shape[1] edge_attr = node_pos[edge_index[0]] - node_pos[edge_index[1]] node_features = torch.randn(num_nodes, 1) node_attr = torch.randn(num_nodes, 4) edge_scalars = torch.randn(num_edges, 2) assert mp(node_features, node_attr, edge_src, edge_dst, edge_attr, edge_scalars).shape == (num_nodes, 3) assert_equivariant( mp, irreps_in=[ mp.irreps_node_input, mp.irreps_node_attr, None, None, mp.irreps_edge_attr, None ], args_in=[ node_features, node_attr, edge_src, edge_dst, edge_attr, edge_scalars ], irreps_out=[mp.irreps_node_output], ) assert_auto_jitable(mp.layers[0].first)
def forward(self, pos, batch): edge_index = radius_graph( pos, r=self.cutoff_upper, batch=batch, loop=self.loop, max_num_neighbors=self.max_num_neighbors + 1, ) # make sure we didn't miss any neighbors due to max_num_neighbors assert not (torch.unique( edge_index[0], return_counts=True )[1] > self.max_num_neighbors).any(), ( "The neighbor search missed some atoms due to max_num_neighbors being too low. " "Please increase this parameter to include the maximum number of atoms within the cutoff." ) edge_vec = pos[edge_index[0]] - pos[edge_index[1]] mask: Optional[torch.Tensor] = None if self.loop: # mask out self loops when computing distances because # the norm of 0 produces NaN gradients # NOTE: might influence force predictions as self loop gradients are ignored mask = edge_index[0] != edge_index[1] edge_weight = torch.zeros(edge_vec.size(0), device=edge_vec.device) edge_weight[mask] = torch.norm(edge_vec[mask], dim=-1) else: edge_weight = torch.norm(edge_vec, dim=-1) lower_mask = edge_weight >= self.cutoff_lower if self.loop and mask is not None: # keep self loops even though they might be below the lower cutoff lower_mask = lower_mask | ~mask edge_index = edge_index[:, lower_mask] edge_weight = edge_weight[lower_mask] if self.return_vecs: edge_vec = edge_vec[lower_mask] return edge_index, edge_weight, edge_vec # TODO: return only `edge_index` and `edge_weight` once # Union typing works with TorchScript (https://github.com/pytorch/pytorch/pull/53180) return edge_index, edge_weight, None
def preprocess(self, data: Union[Data, Dict[str, torch.Tensor]]) -> torch.Tensor: if 'batch' in data: batch = data['batch'] else: batch = data['pos'].new_zeros(data['pos'].shape[0], dtype=torch.long) # Create graph edge_index = radius_graph(data['pos'], self.max_radius, batch, max_num_neighbors=len(data['pos']) - 1) edge_src = edge_index[0] edge_dst = edge_index[1] # Edge attributes edge_vec = data['pos'][edge_src] - data['pos'][edge_dst] return batch, data['x'], edge_src, edge_dst, edge_vec
def prepare_data(): ### Load dataset dataset = GeometricShapes(root='data/GeometricShapes') print(dataset) # # visualize shapes # data = dataset[2] # print(data) # visualize_mesh(data.pos, data.face) # data = dataset[4] # print(data) # visualize_mesh(data.pos, data.face) ### Generate point cloud torch.manual_seed(42) dataset.transform = SamplePoints(num=256) # data = dataset[0] # print(data) # visualize_points(data.pos, data.edge_index) # data = dataset[4] # print(data) # visualize_points(data.pos) ### Grouping data = dataset[0] # # data.edge_index = knn_graph(data.pos, k=6) data.edge_index = radius_graph(data.pos, r=0.2) print(data.edge_index.shape) visualize_points(data.pos, edge_index=data.edge_index) data = dataset[4] data.edge_index = knn_graph(data.pos, k=6) print(data.edge_index.shape) visualize_points(data.pos, edge_index=data.edge_index) return dataset
def forward(self, data) -> torch.Tensor: num_nodes = 4 # typical number of nodes edge_src, edge_dst = radius_graph(x=data.pos, r=2.5, batch=data.batch) edge_vec = data.pos[edge_src] - data.pos[edge_dst] edge_attr = o3.spherical_harmonics(l=self.irreps_sh, x=edge_vec, normalize=True, normalization='component') edge_length_embedded = soft_one_hot_linspace(x=edge_vec.norm(dim=1), start=0.5, end=2.5, number=3, basis='smooth_finite', cutoff=True) * 3**0.5 x = scatter(edge_attr, edge_dst, dim=0).div(self.num_neighbors**0.5) x = self.conv(x, edge_src, edge_dst, edge_attr, edge_length_embedded) x = self.gate(x) x = self.final(x, edge_src, edge_dst, edge_attr, edge_length_embedded) return scatter(x, data.batch, dim=0).div(num_nodes**0.5)
def forward(self, x, batch: OptTensor = None): if batch is None: batch = torch.zeros(x.size()[0], dtype=torch.int64, device=x.device) '''Embedding1: Intermediate Latent space features (hiddenDim)''' x_emb = self.inputnet(x) '''KNN(k neighbors) over intermediate Latent space features''' for ec in self.edgeconvs: edge_index = knn_graph(x_emb, self.k, batch, loop=False, flow=ec.flow) x_emb = x_emb + ec(x_emb, edge_index) ''' [1] Embedding2: Final Latent Space embedding coords from x,y,z to ncats_out ''' out = self.output(x_emb) #plot = self.plotlayer(out) '''KNN(k neighbors) over Embedding2 features''' edge_index = radius_graph(out, r=0.5, batch=batch, max_num_neighbors=self.k, loop=False) ''' use Embedding1 to build an edge classifier inputnet_cat is residual to inputnet ''' x_cat = self.inputnet_cat(x) #+ x_emb ''' [2] Compute Edge Categories Convolution over Embedding1 ''' for ec in self.edgecatconvs: x_cat = x_cat + ec(torch.cat([x_cat, x_emb.detach(), x], dim=1), edge_index) edge_scores = self.edge_classifier( torch.cat([x_cat[edge_index[0]], x_cat[edge_index[1]]], dim=1)).squeeze() ''' use the predicted graph to generate disjoint subgraphs these are our physics objects ''' objects = UnionFind(x.size()[0]) good_edges = edge_index[:, torch.argmax(edge_scores, dim=1) > 0] good_edges_cpu = good_edges.cpu().numpy() for edge in good_edges_cpu.T: objects.union(edge[0], edge[1]) cluster_map = torch.from_numpy( np.array([objects.find(i) for i in range(x.shape[0])], dtype=np.int64)).to(x.device) cluster_roots, inverse = torch.unique(cluster_map, return_inverse=True) # remap roots to [0, ..., nclusters-1] cluster_map = torch.arange(cluster_roots.size()[0], dtype=torch.int64, device=x.device)[inverse] ''' [3] use Embedding1 to learn segmented cluster properties inputnet_cat is residual to inputnet ''' x_prop = self.inputnet_prop(x) #+ x_emb # now we accumulate over all selected disjoint subgraphs # to define per-object properties for ec in self.propertyconvs: x_prop = x_prop + ec(torch.cat([x_prop, x_emb.detach(), x], dim=1), good_edges) props_pooled, cluster_batch = max_pool_x(cluster_map, x_prop, batch) cluster_props = self.property_predictor(props_pooled) return out, edge_scores, edge_index, cluster_map, cluster_props, cluster_batch
def balanced_train(model, train_loader, optimizer, loss_fn, m_configs): edge_correct, edge_total_positive, edge_total_true, edge_true_positive, total = ( 0, 1, 0, 0, 0, ) ( cluster_correct, cluster_total_positive, cluster_total_true, cluster_total_true_positive, cluster_total, ) = (0, 1, 0, 0, 0) correct = 0 total = 0 total_loss = 0 for i, batch in enumerate(train_loader): optimizer.zero_grad() data = batch.to(device) pred, spatial, e = model(data) # Get fake edge list candidates = radius_graph( spatial, r=m_configs["r_train"], batch=batch.batch, loop=False, max_num_neighbors=200, ) fake_list = candidates[:, batch.pid[candidates[0]] != batch. pid[candidates[1]]] # print(batch.pid[fake_list[0]] == batch.pid[fake_list[1]]) # Concatenate all candidates e_spatial = torch.cat( [fake_list, batch.true_edges.T.to(device)], axis=-1) reference = spatial.index_select(0, e_spatial[1]) neighbors = spatial.index_select(0, e_spatial[0]) d = torch.sum((reference - neighbors)**2, dim=-1) y_edge = batch.pid[e[0]] == batch.pid[e[1]] y_cluster = batch.pid[e_spatial[0]] == batch.pid[e_spatial[1]] hinge = y_cluster.float() hinge[batch.pid[e_spatial[0]] != batch.pid[e_spatial[1]]] = -1 loss_1 = F.binary_cross_entropy_with_logits(pred.float(), y_edge.float(), pos_weight=torch.tensor( m_configs["weight"])) loss_2 = torch.nn.functional.hinge_embedding_loss( d, hinge, margin=m_configs["margin"], reduction=m_configs["reduction"]) # print("Loss 1:", loss_1.item(), "Loss 2:", loss_2.item()) loss = loss_fn([loss_1.to(device), loss_2.to(device)]) # print("Loss:", loss, "Noise params:", loss_fn.noise_params) total_loss += loss.item() loss.backward() optimizer.step() # Cluster performance batch_cpu = batch.pid.cpu() pids, counts = np.unique(batch_cpu, return_counts=True) cluster_true_positive = (y_cluster.float()).sum().item() cluster_total_true_positive += cluster_true_positive cluster_positive = len(e_spatial[0]) cluster_total_positive += max(cluster_positive, 1) edge_correct += ((sig(pred) > 0.5) == (y_edge.float() > 0.5)).sum().item() total += len(pred) edge_acc = edge_correct / total cluster_pur = cluster_total_true_positive / cluster_total_positive return edge_acc, cluster_pur, total_loss
def evaluate(model, test_loader, loss_fn, m_configs): edge_correct, edge_total_positive, edge_total_true, edge_true_positive, total = ( 0, 1, 0, 0, 0, ) ( cluster_correct, cluster_total_positive, cluster_total_true, cluster_total_true_positive, cluster_total, ) = (0, 1, 0, 0, 0) total_loss = 0 for batch in test_loader: data = batch.to(device) pred, spatial, e = model(data) e_spatial = radius_graph( spatial, r=m_configs["r_val"], batch=batch.batch, loop=False, max_num_neighbors=200, ) # e_spatial = knn_graph(spatial, k=m_configs["k"], batch=batch.batch, loop=False) reference = spatial.index_select(0, e_spatial[1]) neighbors = spatial.index_select(0, e_spatial[0]) d = torch.sum((reference - neighbors)**2, dim=-1) y_edge = (batch.pid[e[0]] == batch.pid[e[1]]) & ( batch.layers[e[1]] - batch.layers[e[0]] == 1) y_cluster = (batch.pid[e_spatial[0]] == batch.pid[e_spatial[1]]) & ( batch.layers[e_spatial[1]] - batch.layers[e_spatial[0]] == 1) hinge = y_cluster.float() hinge[hinge == 0] = -1 loss_1 = F.binary_cross_entropy_with_logits(pred.float(), y_edge.float(), pos_weight=torch.tensor( m_configs["weight"])) loss_2 = torch.nn.functional.hinge_embedding_loss( d, hinge, margin=m_configs["margin"], reduction=m_configs["reduction"]) # print("Loss 1:", loss_1.item(), "Loss 2:", loss_2.item()) loss = loss_fn([loss_1, loss_2]).item() # print("Combined loss:", loss, "Noise params:", loss_fn.noise_params) total_loss += loss # Cluster performance cluster_true = len(batch.true_edges) cluster_true_positive = (y_cluster.float()).sum().item() cluster_total_true_positive += cluster_true_positive cluster_positive = len(e_spatial[0]) cluster_total_positive += max(cluster_positive, 1) cluster_total_true += cluster_true # Edge performance edge_true, edge_false = y_edge.float() > 0.5, y_edge.float() < 0.5 edge_positive, edge_negative = sig(pred) > 0.5, sig(pred) < 0.5 edge_correct += ((sig(pred) > 0.5) == (y_edge.float() > 0.5)).sum().item() edge_true_positive += (edge_true & edge_positive).sum().item() edge_total_true += edge_true.sum().item() edge_total_positive += edge_positive.sum().item() # print("EDGES:", "True positive:", (edge_true & edge_positive).sum().item(), "True:", edge_true.sum().item(), "Positive", edge_positive.sum().item()) # print("CLUSTER:", "True positive:", cluster_true_positive, "True:", cluster_true, "Positive:", cluster_positive) total += len(pred) edge_acc = edge_correct / total edge_eff = edge_true_positive / max(edge_total_true, 1) edge_pur = edge_true_positive / max(edge_total_positive, 1) cluster_eff = cluster_total_true_positive / max(cluster_total_true, 1) cluster_pur = cluster_total_true_positive / max(cluster_total_positive, 1) # print('EDGE Accuracy: {:.4f}, Purity: {:.4f}, Efficiency: {:.4f}'.format(edge_acc, edge_pur, edge_eff)) # print('CLUSTER Purity: {:.4f}, Efficiency: {:.4f}'.format(cluster_pur, cluster_eff)) return edge_acc, edge_pur, edge_eff, cluster_pur, cluster_eff, total_loss
def forward(self, node_atom, node_pos, batch) -> torch.Tensor: # The graph edge_src, edge_dst = radius_graph( node_pos, r=self.max_radius, batch=batch, max_num_neighbors=1000 ) # Edge attributes edge_vec = node_pos[edge_src] - node_pos[edge_dst] edge_sh = o3.spherical_harmonics( l=range(self.lmax + 1), x=edge_vec, normalize=True, normalization='component' ) # Edge length embedding edge_length = edge_vec.norm(dim=1) edge_length_embedding = soft_one_hot_linspace( edge_length, 0.0, self.max_radius, self.number_of_basis, basis='cosine', # the cosine basis with cutoff = True goes to zero at max_radius cutoff=True, # no need for an additional smooth cutoff ).mul(self.number_of_basis**0.5) node_input = node_pos.new_ones(node_pos.shape[0], 1) node_attr = node_atom.new_tensor([-1, 0, -1, -1, -1, -1, 1, 2, 3, 4])[node_atom] node_attr = torch.nn.functional.one_hot(node_attr, 5).mul(5**0.5) node_outputs = self.mp( node_features=node_input, node_attr=node_attr, edge_src=edge_src, edge_dst=edge_dst, edge_attr=edge_sh, edge_scalars=edge_length_embedding ) node_outputs = node_outputs[:, 0] + node_outputs[:, 1].pow(2).mul(0.5) node_outputs = node_outputs.view(-1, 1) node_outputs = node_outputs.div(self.num_nodes**0.5) if self.mean is not None and self.std is not None: node_outputs = node_outputs * self.std + self.mean if self.atomref is not None: node_outputs = node_outputs + self.atomref[node_atom] # for target=7, MAE of 75eV outputs = scatter(node_outputs, batch, dim=0) if self.scale is not None: outputs = self.scale * outputs return outputs
def plot_weight(model, loss_fn, dataloader, metrics, deltaR, model_dir, saveplot): """plot_weight Args: model: (torch.nn.Module) the neural network loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches data metrics: (dict) a dictionary of functions that compute a metric using the output and labels of each batch num_steps: (int) number of batches to train on, each of size params.batch_size """ #model.eval() # summary for current eval loop qT_arr = [] weight_arr = [1, 11, 13, 22, 130, 211] label = { 1: 'HF Candidate', 11: 'Electron', 13: 'Muon', 22: 'Gamma', 130: 'Neutral Hadron', 211: 'Charged Hadron', } binedges_list = { 'Pt': np.arange(-0.05, 25.05, 0.1), 'eta': np.arange(-0.1, 5.1, 0.2), 'Puppi': [ -0.05, 0.05, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 1.1 ], 'graph_weight': np.arange(-0.05, 1.15, 0.01), 'qT1D': np.arange(0, 420, 20), } #print("bin list:", binedges_list) weight_pt_hist = {} weight_eta_hist = {} weight_puppi_hist = {} weight_pt_histN = {} weight_eta_histN = {} weight_puppi_histN = {} weight_CH_hist = {} weight_qT_hist = {} result = torch.empty(0, 6).to('cuda') for key in weight_arr: weight_pt_hist[label[key]] = [] weight_pt_histN[label[key]] = [] for i in range(1, len(binedges_list['Pt'])): weight_pt_hist[label[key]].append(0) weight_pt_histN[label[key]].append(0) weight_eta_hist[label[key]] = [] weight_eta_histN[label[key]] = [] for i in range(1, len(binedges_list['eta'])): weight_eta_hist[label[key]].append(0) weight_eta_histN[label[key]].append(0) for key in (1, 22, 130): weight_puppi_hist[label[key]] = [] weight_puppi_histN[label[key]] = [] for i in range(1, len(binedges_list['Puppi'])): weight_puppi_hist[label[key]].append(0) weight_puppi_histN[label[key]].append(0) weight_CH_hist['puppi0'] = [] weight_CH_hist['puppi1'] = [] for i in range(1, len(binedges_list['graph_weight'])): weight_CH_hist['puppi0'].append(0) weight_CH_hist['puppi1'].append(0) weight_qT_hist['TrueMET'] = [] weight_qT_hist['GraphMET'] = [] weight_qT_hist['PFMET'] = [] weight_qT_hist['PUPPIMET'] = [] weight_qT_hist['DeepMETResponse'] = [] weight_qT_hist['DeepMETResolution'] = [] for i in range(1, len(binedges_list['qT1D'])): weight_qT_hist['TrueMET'].append(0) weight_qT_hist['GraphMET'].append(0) weight_qT_hist['PFMET'].append(0) weight_qT_hist['PUPPIMET'].append(0) weight_qT_hist['DeepMETResponse'].append(0) weight_qT_hist['DeepMETResolution'].append(0) # compute metrics over the dataset for data in dataloader: data = data.to('cuda') x_cont = data.x[:, :8] x_cat = data.x[:, 8:].long() print("x_cont shape:", (x_cont.shape)) print("x_cat shape:", (x_cat.shape)) phi = torch.atan2(data.x[:, 1], data.x[:, 0]) etaphi = torch.cat([data.x[:, 3][:, None], phi[:, None]], dim=1) # NB: there is a problem right now for comparing hits at the +/- pi boundary edge_index = radius_graph(etaphi, r=deltaR, batch=data.batch, loop=True, max_num_neighbors=255) # compute model output result = model(x_cont, x_cat, edge_index, data.batch) TrueqT = torch.sqrt(data.y[:, 0]**2 + data.y[:, 1]**2).cpu().detach().numpy() pfqT = torch.sqrt(data.y[:, 2]**2 + data.y[:, 3]**2).cpu().detach().numpy() puppiqT = torch.sqrt(data.y[:, 4]**2 + data.y[:, 5]**2).cpu().detach().numpy() deepMETResponseqT = torch.sqrt(data.y[:, 6]**2 + data.y[:, 7]**2).cpu().detach().numpy() deepMETResolutionqT = torch.sqrt(data.y[:, 8]**2 + data.y[:, 9]**2).cpu().detach().numpy() graphMETx = scatter_add(result * data.x[:, 0], data.batch) graphMETy = scatter_add(result * data.x[:, 1], data.batch) graphMETqT = torch.sqrt(graphMETx**2 + graphMETy**2).cpu().detach().numpy() # qT 1D distribution for i in range(1, len(binedges_list['qT1D'])): binnedqT = TrueqT[np.where((TrueqT >= binedges_list['qT1D'][i - 1]) & (TrueqT < binedges_list['qT1D'][i]))] weight_qT_hist['TrueMET'][i - 1] += len(binnedqT) binnedqT = graphMETqT[ np.where((graphMETqT >= binedges_list['qT1D'][i - 1]) & (graphMETqT < binedges_list['qT1D'][i]))] weight_qT_hist['GraphMET'][i - 1] += len(binnedqT) binnedqT = pfqT[np.where((pfqT >= binedges_list['qT1D'][i - 1]) & (pfqT < binedges_list['qT1D'][i]))] weight_qT_hist['PFMET'][i - 1] += len(binnedqT) binnedqT = puppiqT[ np.where((puppiqT >= binedges_list['qT1D'][i - 1]) & (puppiqT < binedges_list['qT1D'][i]))] weight_qT_hist['PUPPIMET'][i - 1] += len(binnedqT) binnedqT = deepMETResponseqT[ np.where((deepMETResponseqT >= binedges_list['qT1D'][i - 1]) & (deepMETResponseqT < binedges_list['qT1D'][i]))] weight_qT_hist['DeepMETResponse'][i - 1] += len(binnedqT) binnedqT = deepMETResolutionqT[ np.where((deepMETResolutionqT >= binedges_list['qT1D'][i - 1]) & (deepMETResolutionqT < binedges_list['qT1D'][i]))] weight_qT_hist['DeepMETResolution'][i - 1] += len(binnedqT) #pX,pY,pT,eta,d0,dz,mass,puppiWeight,pdgId,charge,fromPV #pdg, pt, eta, puppi, weight ZQt = torch.gather(torch.sqrt(data.y[:, 0]**2 + data.y[:, 1]**2), 0, data.batch) result = torch.stack( (torch.abs(x_cat[:, 0]), torch.abs(x_cont[:, 2]), torch.abs(x_cont[:, 3]), torch.abs(x_cont[:, 7]), result, ZQt), dim=1) #result = result[np.where(result[:,5].cpu()<30 )] # weight vs pt # weight vs eta for key in weight_arr: if (key == 1): W_arr = result[np.where((result[:, 0].cpu() == key) | ( result[:, 0].cpu() == 2))].cpu().detach().numpy() else: W_arr = result[np.where( result[:, 0].cpu() == key)].cpu().detach().numpy() for i in range(1, len(binedges_list['Pt'])): W_i = W_arr[ np.where((W_arr[:, 1] >= binedges_list['Pt'][i - 1]) & (W_arr[:, 1] < binedges_list['Pt'][i]))][:, 4] weight_pt_hist[label[key]][i - 1] += np.sum(W_i) weight_pt_histN[label[key]][i - 1] += len(W_i) for i in range(1, len(binedges_list['eta'])): W_i = W_arr[ np.where((W_arr[:, 2] >= binedges_list['eta'][i - 1]) & (W_arr[:, 2] < binedges_list['eta'][i]))][:, 4] weight_eta_hist[label[key]][i - 1] += np.sum(W_i) weight_eta_histN[label[key]][i - 1] += len(W_i) # weight vs puppi for key in (1, 22, 130): if (key == 1): W_arr = result[np.where((result[:, 0].cpu() == key) | ( result[:, 0].cpu() == 2))].cpu().detach().numpy() else: W_arr = result[np.where( result[:, 0].cpu() == key)].cpu().detach().numpy() for i in range(1, len(binedges_list['Puppi'])): W_i = W_arr[np.where( (W_arr[:, 3] >= binedges_list['Puppi'][i - 1]) & (W_arr[:, 3] < binedges_list['Puppi'][i]))][:, 4] weight_puppi_hist[label[key]][i - 1] += np.sum(W_i) weight_puppi_histN[label[key]][i - 1] += len(W_i) # weight distribution W_arr = result[np.where( result[:, 0].cpu() == 211)].cpu().detach().numpy() for i in range(1, len(binedges_list['graph_weight'])): W_i = W_arr[np.where( (W_arr[:, 3] == 0) & (W_arr[:, 4] >= binedges_list['graph_weight'][i - 1]) & (W_arr[:, 4] < binedges_list['graph_weight'][i]))][:, 4] weight_CH_hist['puppi0'][i - 1] += len(W_i) W_i = W_arr[np.where( (W_arr[:, 3] == 1) & (W_arr[:, 4] >= binedges_list['graph_weight'][i - 1]) & (W_arr[:, 4] < binedges_list['graph_weight'][i]))][:, 4] weight_CH_hist['puppi1'][i - 1] += len(W_i) for key in weight_arr: for i in range(1, len(binedges_list['Pt'])): weight_pt_hist[label[key]][ i - 1] /= 1.0 * weight_pt_histN[label[key]][i - 1] weight_pt_hist[label[key]] = np.nan_to_num( weight_pt_hist[label[key]]) for i in range(1, len(binedges_list['eta'])): weight_eta_hist[label[key]][ i - 1] /= 1.0 * weight_eta_histN[label[key]][i - 1] weight_eta_hist[label[key]] = np.nan_to_num( weight_eta_hist[label[key]]) for key in (1, 22, 130): for i in range(1, len(binedges_list['Puppi'])): weight_puppi_hist[label[key]][ i - 1] /= 1.0 * weight_puppi_histN[label[key]][i - 1] weight_puppi_hist[label[key]] = np.nan_to_num( weight_puppi_hist[label[key]]) weights = { 'bin_edges': binedges_list, 'weight_pt_hist': weight_pt_hist, 'weight_eta_hist': weight_eta_hist, 'weight_puppi_hist': weight_puppi_hist, 'weight_CH_hist': weight_CH_hist, 'weight_qT_hist': weight_qT_hist, } utils.save(weights, 'weight.plt') return result
def evaluate(model, device, loss_fn, dataloader, metrics, deltaR, deltaR_dz, model_dir): """Evaluate the model on `num_steps` batches. Args: model: (torch.nn.Module) the neural network loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches data metrics: (dict) a dictionary of functions that compute a metric using the output and labels of each batch num_steps: (int) number of batches to train on, each of size params.batch_size """ # set model to evaluation mode model.eval() # summary for current eval loop loss_avg_arr = [] qT_arr = [] has_deepmet = False resolutions_arr = { 'MET': [[], [], []], 'pfMET': [[], [], []], 'puppiMET': [[], [], []], } colors = { 'pfMET': 'black', 'puppiMET': 'red', 'deepMETResponse': 'blue', 'deepMETResolution': 'green', 'MET': 'magenta', } labels = { 'pfMET': 'PF MET', 'puppiMET': 'PUPPI MET', 'deepMETResponse': 'DeepMETResponse', 'deepMETResolution': 'DeepMETResolution', 'MET': 'DeepMETv2' } # compute metrics over the dataset for data in dataloader: has_deepmet = (data.y.size()[1] > 6) if has_deepmet == True and 'deepMETResponse' not in resolutions_arr.keys( ): resolutions_arr.update({ 'deepMETResponse': [[], [], []], 'deepMETResolution': [[], [], []] }) data = data.to(device) #x_cont = data.x[:,:7] #remove puppi x_cont = data.x[:, :8] #include puppi x_cat = data.x[:, 8:].long() phi = torch.atan2(data.x[:, 1], data.x[:, 0]) etaphi = torch.cat([data.x[:, 3][:, None], phi[:, None]], dim=1) # NB: there is a problem right now for comparing hits at the +/- pi boundary edge_index = radius_graph(etaphi, r=deltaR, batch=data.batch, loop=True, max_num_neighbors=255) result = model(x_cont, x_cat, edge_index, data.batch) #add dz connection #tic = time.time() #tinf = (torch.ones(len(data.x[:,5]))*float("Inf")).to('cuda') #edge_index_dz = radius_graph(torch.where(data.x[:,7]!=0, data.x[:,5], tinf), r=deltaR_dz, batch=data.batch, loop=True, max_num_neighbors=127) #cat_edges = torch.cat([edge_index,edge_index_dz],dim=1) #result = model(x_cont, x_cat, cat_edges, data.batch) #toc = time.time() #print('Event processing speed', toc - tic) loss = loss_fn(result, data.x, data.y, data.batch) # compute all metrics on this batch resolutions, qT = metrics['resolution'](result, data.x, data.y, data.batch) for key in resolutions_arr: for i in range(len(resolutions_arr[key])): resolutions_arr[key][i] = np.concatenate( (resolutions_arr[key][i], resolutions[key][i])) qT_arr = np.concatenate((qT_arr, qT)) loss_avg_arr.append(loss.item()) # compute mean of all metrics in summary max_x = 400 # max qT value x_n = 40 #number of bins bin_edges = np.arange(0, max_x, 10) inds = np.digitize(qT_arr, bin_edges) qT_hist = [] for i in range(1, len(bin_edges)): qT_hist.append((bin_edges[i] + bin_edges[i - 1]) / 2.) resolution_hists = {} for key in resolutions_arr: R_arr = resolutions_arr[key][2] u_perp_arr = resolutions_arr[key][0] u_par_arr = resolutions_arr[key][1] u_perp_hist = [] u_perp_scaled_hist = [] u_par_hist = [] u_par_scaled_hist = [] R_hist = [] for i in range(1, len(bin_edges)): R_i = R_arr[np.where(inds == i)[0]] R_hist.append(np.mean(R_i)) u_perp_i = u_perp_arr[np.where(inds == i)[0]] u_perp_scaled_i = u_perp_i / np.mean(R_i) u_perp_hist.append( (np.quantile(u_perp_i, 0.84) - np.quantile(u_perp_i, 0.16)) / 2.) u_perp_scaled_hist.append( (np.quantile(u_perp_scaled_i, 0.84) - np.quantile(u_perp_scaled_i, 0.16)) / 2.) u_par_i = u_par_arr[np.where(inds == i)[0]] u_par_scaled_i = u_par_i / np.mean(R_i) u_par_hist.append( (np.quantile(u_par_i, 0.84) - np.quantile(u_par_i, 0.16)) / 2.) u_par_scaled_hist.append((np.quantile(u_par_scaled_i, 0.84) - np.quantile(u_par_scaled_i, 0.16)) / 2.) u_perp_resolution = np.histogram(qT_hist, bins=x_n, range=(0, max_x), weights=u_perp_hist) u_perp_scaled_resolution = np.histogram(qT_hist, bins=x_n, range=(0, max_x), weights=u_perp_scaled_hist) u_par_resolution = np.histogram(qT_hist, bins=x_n, range=(0, max_x), weights=u_par_hist) u_par_scaled_resolution = np.histogram(qT_hist, bins=x_n, range=(0, max_x), weights=u_par_scaled_hist) R = np.histogram(qT_hist, bins=x_n, range=(0, max_x), weights=R_hist) resolution_hists[key] = { 'u_perp_resolution': u_perp_resolution, 'u_perp_scaled_resolution': u_perp_scaled_resolution, 'u_par_resolution': u_par_resolution, 'u_par_scaled_resolution': u_par_scaled_resolution, 'R': R } metrics_mean = { 'loss': np.mean(loss_avg_arr), #'resolution': (np.quantile(resolution_arr,0.84)-np.quantile(resolution_arr,0.16))/2. } metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in metrics_mean.items()) print("- Eval metrics : " + metrics_string) return metrics_mean, resolution_hists
def evaluate(model, loss_fn, dataloader, metrics, deltaR): """Evaluate the model on `num_steps` batches. Args: model: (torch.nn.Module) the neural network loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches data metrics: (dict) a dictionary of functions that compute a metric using the output and labels of each batch num_steps: (int) number of batches to train on, each of size params.batch_size """ # set model to evaluation mode model.eval() # summary for current eval loop loss_avg_arr = [] qT_arr = [] resolutions_arr = { 'MET': [[], [], []], 'pfMET': [[], [], []], 'puppiMET': [[], [], []] } # compute metrics over the dataset for data in dataloader: data = data.to('cuda') x_cont = data.x[:, :8] x_cat = data.x[:, 8:].long() phi = torch.atan2(data.x[:, 1], data.x[:, 0]) etaphi = torch.cat([data.x[:, 3][:, None], phi[:, None]], dim=1) # NB: there is a problem right now for comparing hits at the +/- pi boundary edge_index = radius_graph(etaphi, r=deltaR, batch=data.batch, loop=True, max_num_neighbors=255) # compute model output result = model(x_cont, x_cat, edge_index, data.batch) loss = loss_fn(result, data.x, data.y, data.batch) # compute all metrics on this batch resolutions, qT = metrics['resolution'](result, data.x, data.y, data.batch) for key in resolutions_arr: for i in range(len(resolutions_arr[key])): resolutions_arr[key][i] = np.concatenate( (resolutions_arr[key][i], resolutions[key][i])) qT_arr = np.concatenate((qT_arr, qT)) loss_avg_arr.append(loss.item()) # compute mean of all metrics in summary bin_edges = np.arange(0, 500, 25) inds = np.digitize(qT_arr, bin_edges) qT_hist = [] for i in range(1, len(bin_edges)): qT_hist.append((bin_edges[i] + bin_edges[i - 1]) / 2.) resolution_hists = {} for key in resolutions_arr: R_arr = resolutions_arr[key][2] u_perp_arr = resolutions_arr[key][0] u_par_arr = resolutions_arr[key][1] u_perp_hist = [] u_perp_scaled_hist = [] u_par_hist = [] u_par_scaled_hist = [] R_hist = [] for i in range(1, len(bin_edges)): R_i = R_arr[np.where(inds == i)[0]] R_hist.append(np.mean(R_i)) u_perp_i = u_perp_arr[np.where(inds == i)[0]] u_perp_scaled_i = u_perp_i / np.mean(R_i) u_perp_hist.append( (np.quantile(u_perp_i, 0.84) - np.quantile(u_perp_i, 0.16)) / 2.) u_perp_scaled_hist.append( (np.quantile(u_perp_scaled_i, 0.84) - np.quantile(u_perp_scaled_i, 0.16)) / 2.) u_par_i = u_par_arr[np.where(inds == i)[0]] u_par_scaled_i = u_par_i / np.mean(R_i) u_par_hist.append( (np.quantile(u_par_i, 0.84) - np.quantile(u_par_i, 0.16)) / 2.) u_par_scaled_hist.append((np.quantile(u_par_scaled_i, 0.84) - np.quantile(u_par_scaled_i, 0.16)) / 2.) u_perp_resolution = np.histogram(qT_hist, bins=20, range=(0, 500), weights=u_perp_hist) u_perp_scaled_resolution = np.histogram(qT_hist, bins=20, range=(0, 500), weights=u_perp_scaled_hist) u_par_resolution = np.histogram(qT_hist, bins=20, range=(0, 500), weights=u_par_hist) u_par_scaled_resolution = np.histogram(qT_hist, bins=20, range=(0, 500), weights=u_par_scaled_hist) R = np.histogram(qT_hist, bins=20, range=(0, 500), weights=R_hist) resolution_hists[key] = { 'u_perp_resolution': u_perp_resolution, 'u_perp_scaled_resolution': u_perp_scaled_resolution, 'u_par_resolution': u_par_resolution, 'u_par_scaled_resolution': u_par_scaled_resolution, 'R': R } metrics_mean = { 'loss': np.mean(loss_avg_arr), #'resolution': (np.quantile(resolution_arr,0.84)-np.quantile(resolution_arr,0.16))/2. } metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in metrics_mean.items()) print("- Eval metrics : " + metrics_string) return metrics_mean, resolution_hists