Exemplo n.º 1
0
    def compute(
        self,
        model: Model,
        graph: Graph,
        trajectories: Trajectories,
        pairwise_features: torch.Tensor,
    ):
        """Update the metrics for all trajectories in `trajectories`"""
        self.init_metrics()
        config = self.config

        with torch.no_grad():
            for trajectory_idx in tqdm(range(len(trajectories))):
                observations = trajectories[trajectory_idx]

                number_steps = None
                if config.rw_edge_weight_see_number_step or config.rw_expected_steps:
                    if config.use_shortest_path_distance:
                        number_steps = (trajectories.leg_shortest_lengths(
                            trajectory_idx).float() * 1.1).long()
                    else:
                        number_steps = trajectories.leg_lengths(trajectory_idx)

                observed, starts, targets = generate_masks(
                    trajectory_length=observations.shape[0],
                    number_observations=config.number_observations,
                    predict=config.target_prediction,
                    with_interpolation=config.with_interpolation,
                    device=config.device,
                )

                diffusion_graph = (graph if not config.diffusion_self_loops
                                   else graph.add_self_loops())

                predictions, _, rw_weights = model(
                    observations,
                    graph,
                    diffusion_graph,
                    observed=observed,
                    starts=starts,
                    targets=targets,
                    pairwise_node_features=pairwise_features,
                    number_steps=number_steps,
                )

                self.update_metrics(
                    trajectories,
                    graph,
                    observations,
                    observed,
                    starts,
                    targets,
                    predictions,
                    rw_weights,
                    trajectory_idx,
                    model.rw_non_backtracking,
                )
Exemplo n.º 2
0
def train_epoch(
    model: Model,
    graph: Graph,
    optimizer: torch.optim.Optimizer,
    config: Config,
    train_trajectories: Trajectories,
    pairwise_node_features: torch.Tensor,
):
    """One epoch of training"""
    model.train()

    print_cum_loss = 0.0
    print_num_preds = 0
    print_time = time.time()
    print_every = len(
        train_trajectories) // config.batch_size // config.print_per_epoch

    trajectories_shuffle_indices = np.arange(len(train_trajectories))
    if config.shuffle_samples:
        np.random.shuffle(trajectories_shuffle_indices)

    for iteration, batch_start in enumerate(
            range(0,
                  len(trajectories_shuffle_indices) - config.batch_size + 1,
                  config.batch_size)):
        optimizer.zero_grad()
        loss = torch.tensor(0.0, device=config.device)

        for i in range(batch_start, batch_start + config.batch_size):
            trajectory_idx = trajectories_shuffle_indices[i]
            observations = train_trajectories[trajectory_idx]
            length = train_trajectories.lengths[trajectory_idx]

            number_steps = None
            if config.rw_edge_weight_see_number_step or config.rw_expected_steps:
                if config.use_shortest_path_distance:
                    number_steps = (train_trajectories.leg_shortest_lengths(
                        trajectory_idx).float() * 1.1).long()
                else:
                    number_steps = train_trajectories.leg_lengths(
                        trajectory_idx)

            observed, starts, targets = generate_masks(
                trajectory_length=observations.shape[0],
                number_observations=config.number_observations,
                predict=config.target_prediction,
                with_interpolation=config.with_interpolation,
                device=config.device,
            )

            diffusion_graph = graph if not config.diffusion_self_loops else graph.add_self_loops(
            )

            predictions, potentials, rw_weights = model(
                observations,
                graph,
                diffusion_graph,
                observed=observed,
                starts=starts,
                targets=targets,
                pairwise_node_features=pairwise_node_features,
                number_steps=number_steps,
            )

            print_num_preds += starts.shape[0]

            l = (compute_loss(
                config.loss,
                train_trajectories,
                observations,
                predictions,
                starts,
                targets,
                rw_weights,
                trajectory_idx,
            ) / starts.shape[0])
            loss += l

        loss /= config.batch_size
        print_cum_loss += loss.item()
        loss.backward()
        optimizer.step()

        if (iteration + 1) % print_every == 0:
            print_loss = print_cum_loss / print_every
            print_loss /= print_num_preds
            pred_per_second = 1.0 * print_num_preds / \
                (time.time() - print_time)

            print_cum_loss = 0.0
            print_num_preds = 0
            print_time = time.time()

            progress_percent = int(100.0 * ((iteration + 1) // print_every) /
                                   config.print_per_epoch)

            print(
                f"Progress {progress_percent}% | iter {iteration} | {pred_per_second:.1f} pred/s | loss {config.loss} {print_loss}"
            )