def compute_loss( typ: str, trajectories: Trajectories, observations: torch.Tensor, predictions: torch.Tensor, starts: torch.Tensor, targets: torch.Tensor, rw_weights: torch.Tensor, trajectory_idx: int, ): """Compute the †raining loss Args: typ (str): loss flag from configuration, can be RMSE, dot_loss, log_dot_loss, target_only or nll_loss trajectories (Trajectories): full trajectories dataset evaluated observations (torch.Tensor): current trajectory observation [traj_length, n_node] predictions (torch.Tensor): output prediction of the model [n_pred, n_node] starts (torch.Tensor): indexes of starts extrapolation in observations [n_pred,] targets (torch.Tensor): indexes of targets extrapolation in observations [n_pred,] rw_weights (torch.Tensor): random walk weights output of model [n_pred, n_edge] trajectory_idx (int): index of evaluated trajectory Returns: torch.Tensor(): loss for this prediction """ if typ == "RMSE": return ((predictions - observations[targets])**2).sum() elif typ == "dot_loss": return -1.0 * (predictions * observations[targets]).sum() elif typ == "log_dot_loss": return -1.0 * ((predictions * observations[targets]).sum(dim=1) + 1e-30).log().sum() elif typ == "target_only": return -predictions[observations[targets] > 0].sum() elif typ == "nll_loss": loss = torch.tensor(0.0, device=trajectories.device) log_rw_weights = -(rw_weights + 1e-20).log() for pred_id in range(len(starts)): for jump_id in range(starts[pred_id], targets[pred_id]): traversed_edges = trajectories.traversed_edges( trajectory_idx, jump_id) loss += log_rw_weights[pred_id, traversed_edges].sum() return loss else: raise Exception(f'Unknown loss "{typ}"')
def update_metrics( self, trajectories: Trajectories, graph: Graph, observations, observed, starts, targets, predictions, rw_weights, trajectory_idx, rw_non_backtracking, ): n_pred = len(starts) # remove added self loops rw_weights = rw_weights[:, :graph.n_edge] target_distributions = observations[targets] target_probabilities = compute_target_probability( target_distributions, predictions) self.metrics["target_probability"].add_all(target_probabilities) top1_contains_target = compute_topk_contains_target( target_distributions, predictions, k=1) self.metrics["precision_top1"].add_all(top1_contains_target) top5_contains_target = compute_topk_contains_target( target_distributions, predictions, k=5) self.metrics["precision_top5"].add_all(top5_contains_target) assert trajectories.has_traversed_edges noise_level = 1e-6 # very small noise is added to break the uniform cases # [n_pred, n_node] _, chosen_edge_at_each_node = scatter_max( rw_weights + torch.rand_like(rw_weights) * noise_level, graph.senders, fill_value=-1) if rw_non_backtracking: nb_rw_graph = graph.update(edges=rw_weights.transpose( 0, 1)).non_backtracking_random_walk_graph # [n_edge, n_pred] _, chosen_hyperedge_at_each_edge = scatter_max( nb_rw_graph.edges + torch.rand_like(nb_rw_graph.edges) * noise_level, nb_rw_graph.senders, dim=0, fill_value=-1000, ) chosen_edge_at_each_edge = nb_rw_graph.receivers[ chosen_hyperedge_at_each_edge] # [n_pred, n_edge] chosen_edge_at_each_edge = chosen_edge_at_each_edge.transpose(0, 1) for pred_id in range(n_pred): # concat all edges traversed between start and target traversed_edges = torch.cat([ trajectories.traversed_edges(trajectory_idx, i) for i in range(starts[pred_id], targets[pred_id]) ]) # remove consecutive duplicate duplicate_mask = torch.zeros_like(traversed_edges, dtype=torch.uint8) duplicate_mask[1:] = traversed_edges[:-1] == traversed_edges[1:] traversed_edges = traversed_edges[~duplicate_mask] nodes_where_decide = graph.senders[traversed_edges] """ choice accuracy """ if rw_non_backtracking: chosen_edges = torch.zeros_like(traversed_edges, dtype=torch.long) first_node = nodes_where_decide[0] chosen_edges[0] = chosen_edge_at_each_node[pred_id, first_node] chosen_edges[1:] = chosen_edge_at_each_edge[ pred_id, traversed_edges[:-1]] else: chosen_edges = chosen_edge_at_each_node[pred_id, nodes_where_decide] correct_choices = (traversed_edges == chosen_edges).float() self.metrics["choice_accuracy"].add_all(correct_choices) deg3_mask = graph.out_degree_counts[nodes_where_decide] > 2 deg3_mask[0] = 1 self.metrics["choice_accuracy_deg3"].add_all( correct_choices[deg3_mask]) """NLL computation""" if not rw_non_backtracking: traversed_edges_weights = rw_weights[pred_id, traversed_edges] else: rw_graph = graph.update(edges=rw_weights[pred_id]) nb_rw_graph = rw_graph.non_backtracking_random_walk_graph traversed_edges_weights = torch.zeros(len(traversed_edges)) traversed_edges_weights[0] = rw_weights[pred_id, traversed_edges[0]] for i, (s, r) in enumerate( zip(traversed_edges[:-1], traversed_edges[1:])): traversed_edges_weights[i + 1] = nb_rw_graph.edge(s, r) neg_log_weights = -(traversed_edges_weights + 1e-20).log() self.metrics["path_nll"].add(neg_log_weights.sum().item()) deg3_mask = graph.out_degree_counts[ graph.senders[traversed_edges]] > 2 deg3_mask[0] = 1 self.metrics["path_nll_deg3"].add( neg_log_weights[deg3_mask].sum().item()) if self.config.dataset == "wikispeedia": jump_lengths = targets - starts """top k by jump""" self.update_metrics_by_keys("precision_top1", jump_lengths, top1_contains_target) self.update_metrics_by_keys("precision_top5", jump_lengths, top5_contains_target) """cumulative reciprocal rank""" # assumes only one target per observations target_nodes = observations[targets].nonzero()[:, 1] target_ranks = compute_rank(predictions, target_nodes) self.update_metrics_by_keys( "target_rank", jump_lengths, self.harmonic_numbers[target_ranks - 1]) """West target accuracy""" start_nodes = observations[starts].nonzero()[:, 1] target2_acc = target2_accuracy( start_nodes, target_nodes, predictions, self.given_as_target, trajectories.pairwise_node_distances, ) self.update_metrics_by_keys("target2_acc", jump_lengths, target2_acc)