def compute( self, model: Model, graph: Graph, trajectories: Trajectories, pairwise_features: torch.Tensor, ): """Update the metrics for all trajectories in `trajectories`""" self.init_metrics() config = self.config with torch.no_grad(): for trajectory_idx in tqdm(range(len(trajectories))): observations = trajectories[trajectory_idx] number_steps = None if config.rw_edge_weight_see_number_step or config.rw_expected_steps: if config.use_shortest_path_distance: number_steps = (trajectories.leg_shortest_lengths( trajectory_idx).float() * 1.1).long() else: number_steps = trajectories.leg_lengths(trajectory_idx) observed, starts, targets = generate_masks( trajectory_length=observations.shape[0], number_observations=config.number_observations, predict=config.target_prediction, with_interpolation=config.with_interpolation, device=config.device, ) diffusion_graph = (graph if not config.diffusion_self_loops else graph.add_self_loops()) predictions, _, rw_weights = model( observations, graph, diffusion_graph, observed=observed, starts=starts, targets=targets, pairwise_node_features=pairwise_features, number_steps=number_steps, ) self.update_metrics( trajectories, graph, observations, observed, starts, targets, predictions, rw_weights, trajectory_idx, model.rw_non_backtracking, )
def train_epoch( model: Model, graph: Graph, optimizer: torch.optim.Optimizer, config: Config, train_trajectories: Trajectories, pairwise_node_features: torch.Tensor, ): """One epoch of training""" model.train() print_cum_loss = 0.0 print_num_preds = 0 print_time = time.time() print_every = len( train_trajectories) // config.batch_size // config.print_per_epoch trajectories_shuffle_indices = np.arange(len(train_trajectories)) if config.shuffle_samples: np.random.shuffle(trajectories_shuffle_indices) for iteration, batch_start in enumerate( range(0, len(trajectories_shuffle_indices) - config.batch_size + 1, config.batch_size)): optimizer.zero_grad() loss = torch.tensor(0.0, device=config.device) for i in range(batch_start, batch_start + config.batch_size): trajectory_idx = trajectories_shuffle_indices[i] observations = train_trajectories[trajectory_idx] length = train_trajectories.lengths[trajectory_idx] number_steps = None if config.rw_edge_weight_see_number_step or config.rw_expected_steps: if config.use_shortest_path_distance: number_steps = (train_trajectories.leg_shortest_lengths( trajectory_idx).float() * 1.1).long() else: number_steps = train_trajectories.leg_lengths( trajectory_idx) observed, starts, targets = generate_masks( trajectory_length=observations.shape[0], number_observations=config.number_observations, predict=config.target_prediction, with_interpolation=config.with_interpolation, device=config.device, ) diffusion_graph = graph if not config.diffusion_self_loops else graph.add_self_loops( ) predictions, potentials, rw_weights = model( observations, graph, diffusion_graph, observed=observed, starts=starts, targets=targets, pairwise_node_features=pairwise_node_features, number_steps=number_steps, ) print_num_preds += starts.shape[0] l = (compute_loss( config.loss, train_trajectories, observations, predictions, starts, targets, rw_weights, trajectory_idx, ) / starts.shape[0]) loss += l loss /= config.batch_size print_cum_loss += loss.item() loss.backward() optimizer.step() if (iteration + 1) % print_every == 0: print_loss = print_cum_loss / print_every print_loss /= print_num_preds pred_per_second = 1.0 * print_num_preds / \ (time.time() - print_time) print_cum_loss = 0.0 print_num_preds = 0 print_time = time.time() progress_percent = int(100.0 * ((iteration + 1) // print_every) / config.print_per_epoch) print( f"Progress {progress_percent}% | iter {iteration} | {pred_per_second:.1f} pred/s | loss {config.loss} {print_loss}" )