def compute_loss(
    typ: str,
    trajectories: Trajectories,
    observations: torch.Tensor,
    predictions: torch.Tensor,
    starts: torch.Tensor,
    targets: torch.Tensor,
    rw_weights: torch.Tensor,
    trajectory_idx: int,
):
    """Compute the †raining loss

    Args:
        typ (str): loss flag from configuration, can be RMSE, dot_loss, log_dot_loss, target_only or nll_loss
        trajectories (Trajectories): full trajectories dataset evaluated
        observations (torch.Tensor): current trajectory observation [traj_length, n_node]
        predictions (torch.Tensor): output prediction of the model [n_pred, n_node]
        starts (torch.Tensor): indexes of starts extrapolation in observations [n_pred,]
        targets (torch.Tensor): indexes of targets extrapolation in observations [n_pred,]
        rw_weights (torch.Tensor): random walk weights output of model [n_pred, n_edge]
        trajectory_idx (int): index of evaluated trajectory

    Returns:
        torch.Tensor(): loss for this prediction
    """

    if typ == "RMSE":
        return ((predictions - observations[targets])**2).sum()
    elif typ == "dot_loss":
        return -1.0 * (predictions * observations[targets]).sum()
    elif typ == "log_dot_loss":
        return -1.0 * ((predictions * observations[targets]).sum(dim=1) +
                       1e-30).log().sum()
    elif typ == "target_only":
        return -predictions[observations[targets] > 0].sum()
    elif typ == "nll_loss":
        loss = torch.tensor(0.0, device=trajectories.device)
        log_rw_weights = -(rw_weights + 1e-20).log()
        for pred_id in range(len(starts)):
            for jump_id in range(starts[pred_id], targets[pred_id]):
                traversed_edges = trajectories.traversed_edges(
                    trajectory_idx, jump_id)
                loss += log_rw_weights[pred_id, traversed_edges].sum()
        return loss
    else:
        raise Exception(f'Unknown loss "{typ}"')
    def update_metrics(
        self,
        trajectories: Trajectories,
        graph: Graph,
        observations,
        observed,
        starts,
        targets,
        predictions,
        rw_weights,
        trajectory_idx,
        rw_non_backtracking,
    ):
        n_pred = len(starts)
        # remove added self loops
        rw_weights = rw_weights[:, :graph.n_edge]

        target_distributions = observations[targets]

        target_probabilities = compute_target_probability(
            target_distributions, predictions)
        self.metrics["target_probability"].add_all(target_probabilities)

        top1_contains_target = compute_topk_contains_target(
            target_distributions, predictions, k=1)
        self.metrics["precision_top1"].add_all(top1_contains_target)
        top5_contains_target = compute_topk_contains_target(
            target_distributions, predictions, k=5)
        self.metrics["precision_top5"].add_all(top5_contains_target)

        assert trajectories.has_traversed_edges
        noise_level = 1e-6  # very small noise is added to break the uniform cases

        # [n_pred, n_node]
        _, chosen_edge_at_each_node = scatter_max(
            rw_weights + torch.rand_like(rw_weights) * noise_level,
            graph.senders,
            fill_value=-1)
        if rw_non_backtracking:
            nb_rw_graph = graph.update(edges=rw_weights.transpose(
                0, 1)).non_backtracking_random_walk_graph
            # [n_edge, n_pred]
            _, chosen_hyperedge_at_each_edge = scatter_max(
                nb_rw_graph.edges +
                torch.rand_like(nb_rw_graph.edges) * noise_level,
                nb_rw_graph.senders,
                dim=0,
                fill_value=-1000,
            )
            chosen_edge_at_each_edge = nb_rw_graph.receivers[
                chosen_hyperedge_at_each_edge]
            # [n_pred, n_edge]
            chosen_edge_at_each_edge = chosen_edge_at_each_edge.transpose(0, 1)

        for pred_id in range(n_pred):
            # concat all edges traversed between start and target
            traversed_edges = torch.cat([
                trajectories.traversed_edges(trajectory_idx, i)
                for i in range(starts[pred_id], targets[pred_id])
            ])
            # remove consecutive duplicate
            duplicate_mask = torch.zeros_like(traversed_edges,
                                              dtype=torch.uint8)
            duplicate_mask[1:] = traversed_edges[:-1] == traversed_edges[1:]
            traversed_edges = traversed_edges[~duplicate_mask]

            nodes_where_decide = graph.senders[traversed_edges]
            """ choice accuracy """

            if rw_non_backtracking:
                chosen_edges = torch.zeros_like(traversed_edges,
                                                dtype=torch.long)
                first_node = nodes_where_decide[0]
                chosen_edges[0] = chosen_edge_at_each_node[pred_id, first_node]
                chosen_edges[1:] = chosen_edge_at_each_edge[
                    pred_id, traversed_edges[:-1]]
            else:
                chosen_edges = chosen_edge_at_each_node[pred_id,
                                                        nodes_where_decide]

            correct_choices = (traversed_edges == chosen_edges).float()
            self.metrics["choice_accuracy"].add_all(correct_choices)

            deg3_mask = graph.out_degree_counts[nodes_where_decide] > 2
            deg3_mask[0] = 1
            self.metrics["choice_accuracy_deg3"].add_all(
                correct_choices[deg3_mask])
            """NLL computation"""

            if not rw_non_backtracking:
                traversed_edges_weights = rw_weights[pred_id, traversed_edges]
            else:
                rw_graph = graph.update(edges=rw_weights[pred_id])
                nb_rw_graph = rw_graph.non_backtracking_random_walk_graph

                traversed_edges_weights = torch.zeros(len(traversed_edges))
                traversed_edges_weights[0] = rw_weights[pred_id,
                                                        traversed_edges[0]]
                for i, (s, r) in enumerate(
                        zip(traversed_edges[:-1], traversed_edges[1:])):
                    traversed_edges_weights[i + 1] = nb_rw_graph.edge(s, r)

            neg_log_weights = -(traversed_edges_weights + 1e-20).log()
            self.metrics["path_nll"].add(neg_log_weights.sum().item())
            deg3_mask = graph.out_degree_counts[
                graph.senders[traversed_edges]] > 2
            deg3_mask[0] = 1
            self.metrics["path_nll_deg3"].add(
                neg_log_weights[deg3_mask].sum().item())

        if self.config.dataset == "wikispeedia":
            jump_lengths = targets - starts
            """top k by jump"""
            self.update_metrics_by_keys("precision_top1", jump_lengths,
                                        top1_contains_target)
            self.update_metrics_by_keys("precision_top5", jump_lengths,
                                        top5_contains_target)
            """cumulative reciprocal rank"""
            # assumes only one target per observations
            target_nodes = observations[targets].nonzero()[:, 1]
            target_ranks = compute_rank(predictions, target_nodes)
            self.update_metrics_by_keys(
                "target_rank", jump_lengths,
                self.harmonic_numbers[target_ranks - 1])
            """West target accuracy"""
            start_nodes = observations[starts].nonzero()[:, 1]
            target2_acc = target2_accuracy(
                start_nodes,
                target_nodes,
                predictions,
                self.given_as_target,
                trajectories.pairwise_node_distances,
            )

            self.update_metrics_by_keys("target2_acc", jump_lengths,
                                        target2_acc)