class Trainer: ''' Helper class for model training ''' def __init__(self, config, model, dataset: VideoDataset, logger: Logger): self.config = config self.dataset = dataset self.logger = logger self.optimizer = torch.optim.Adam( model.parameters(), lr=config["training"]["learning_rate"], weight_decay=config["training"]["weight_decay"]) self.lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( self.optimizer, self.config["training"]["lr_schedule"], gamma=self.config["training"]["lr_gamma"]) self.dataloader = DataLoader( dataset, batch_size=self.config["training"]["batching"]["batch_size"], drop_last=True, shuffle=True, collate_fn=single_batch_elements_collate_fn, num_workers=self.config["training"]["batching"]["num_workers"], pin_memory=True) # Initializes losses self.weight_mask_calculator = MotionLossWeightMaskCalculator( self.config["training"]["motion_weights_bias"]) self.perceptual_loss = PerceptualLoss() self.observations_loss = ObservationsLoss() self.states_loss = StatesLoss() self.hidden_states_loss = HiddenStatesLoss() self.entropy_loss = EntropyLogitLoss() self.actions_divergence_loss = KLDivergence() self.samples_entropy_loss = EntropyProbabilityLoss() self.action_distribution_entropy = EntropyProbabilityLoss() self.perceptual_loss = ParallelPerceptualLoss() self.action_state_distribution_kl = KLGeneralGaussianDivergenceLoss() self.action_directions_kl_gaussian_divergence_loss = KLGaussianDivergenceLoss( ) self.mutual_information_loss = MutualInformationLoss() self.average_meter = AverageMeter() self.global_step = 0 self.action_mutual_infromation_entropy_lambda = config["training"][ "action_mutual_information_entropy_lambda"] # Observations count annealing parameters self.observations_count_start = self.config["training"]["batching"][ "observations_count_start"] self.observations_count_end = self.config["training"]["batching"][ "observations_count"] self.observations_count_steps = self.config["training"]["batching"][ "observations_count_steps"] # Real observations annealing parameters self.real_observations_start = self.config["training"][ "ground_truth_observations_start"] self.real_observations_end = self.config["training"][ "ground_truth_observations_end"] self.real_observations_steps = self.config["training"][ "ground_truth_observations_steps"] # Gumbel temperature annealing parameters self.gumbel_temperature_start = self.config["training"][ "gumbel_temperature_start"] self.gumbel_temperature_end = self.config["training"][ "gumbel_temperature_end"] self.gumbel_temperature_steps = self.config["training"][ "gumbel_temperature_steps"] def _get_current_lr(self): for param_group in self.optimizer.param_groups: return (param_group['lr']) def save_checkpoint(self, model, name=None): ''' Saves the current training state :param model: the model to save :param name: the name to give to the checkopoint. If None the default name is used :return: ''' if name is None: filename = os.path.join( self.config["logging"]["save_root_directory"], "latest.pth.tar") else: filename = os.path.join( self.config["logging"]["save_root_directory"], f"{name}_.pth.tar") # If the model is wrapped, save the internal state is_data_parallel = isinstance(model, nn.DataParallel) if is_data_parallel: model_state_dict = model.module.state_dict() else: model_state_dict = model.state_dict() torch.save( { "model": model_state_dict, "optimizer": self.optimizer.state_dict(), "lr_scheduler": self.lr_scheduler.state_dict(), "step": self.global_step }, filename) def load_checkpoint(self, model, name=None): """ Loads the model from a saved state :param model: The model to load :param name: Name of the checkpoint to use. If None the default name is used :return: """ if name is None: filename = os.path.join( self.config["logging"]["save_root_directory"], "latest.pth.tar") else: filename = os.path.join( self.config["logging"]["save_root_directory"], f"{name}.pth.tar") if not os.path.isfile(filename): raise Exception( f"Cannot load model: no checkpoint found at '{filename}'") loaded_state = torch.load(filename) model.load_state_dict(loaded_state["model"]) self.optimizer.load_state_dict(loaded_state["optimizer"]) self.lr_scheduler.load_state_dict(loaded_state["lr_scheduler"]) self.global_step = loaded_state["step"] def get_ground_truth_observations_count(self) -> int: ''' Computes the number of ground truth observations to use for the current training step according to the annealing parameters :return: number of ground truth observations to use in the training sequence at the current step ''' ground_truth_observations_count = self.real_observations_start - \ (self.real_observations_start - self.real_observations_end) * \ self.global_step / self.real_observations_steps ground_truth_observations_count = math.ceil( ground_truth_observations_count) ground_truth_observations_count = max(self.real_observations_end, ground_truth_observations_count) return ground_truth_observations_count def get_gumbel_temperature(self) -> float: ''' Computes the gumbel temperature to use at the current step :return: Gumbel temperature to use at the current step ''' gumbel_temperature = self.gumbel_temperature_start - \ (self.gumbel_temperature_start - self.gumbel_temperature_end) * \ self.global_step / self.gumbel_temperature_steps gumbel_temperature = max(self.gumbel_temperature_end, gumbel_temperature) return gumbel_temperature def get_observations_count(self): ''' Computes the number of observations to use for the sequence at the current training step according to the annealing parameters :return: Number of observations to use in each training sequence at the current step ''' observations_count = self.observations_count_start + \ (self.observations_count_end - self.observations_count_start) * \ self.global_step / self.observations_count_steps observations_count = math.floor(observations_count) observations_count = min(self.observations_count_end, observations_count) return observations_count def sum_loss_components( self, components: List[torch.Tensor], weights: Union[List[float], float]) -> List[torch.Tensor]: ''' Produces the weighted sum of the loss components :param components: List of scalar tensors :param weights: List of weights of the same length of components, or single weight to apply to each component :return: Weighted sum of the components ''' components_count = len(components) # If the weight is a scalar, broadcast it if not isinstance(weights, collections.Sequence): weights = [weights] * components_count total_sum = components[0] * 0.0 for current_component, current_weight in zip(components, weights): total_sum += current_component * current_weight return total_sum def compute_average_centroid_distance(self, centroids: torch.Tensor): ''' Computes the average distance between centroids :param centroids: (centroids_count, space_dimensions) tensor with centroids :return: Average L2 distance between centroids ''' centroids_count = centroids.size(0) centroids_1 = centroids.unsqueeze( 0) # (1, centroids_count, space_dimensions) centroids_2 = centroids.unsqueeze( 1) # (centroids_count, 1, space_dimensions) centroids_sum = (centroids_1 - centroids_2).pow(2).sum(2).sqrt().sum() average_centroid_distance = centroids_sum / (centroids_count * (centroids_count - 1)) return average_centroid_distance def plot_action_direction_space( self, estimated_action_centroids: torch.Tensor, action_directions_distribution: torch.Tensor, action_logits: torch.Tensor) -> Image: ''' Saves and returns a plot of the action direction space :param estimated_action_centroids: estimated action centroids in the format required by TensorDisplayer :param action_directions_distribution: distribution of action directions in the format required by TensorDisplayer :param action_logits: action logits with the space required by TensorDisplayer. Automatically converted to probabilities before being passed to TensorDisplayer :return: Image with the plot ''' with torch.no_grad(): action_probabilities = torch.softmax(action_logits, dim=-1) plot_filename = os.path.join( self.config["logging"]["output_images_directory"], f"action_direction_space_{self.global_step}.png") TensorDisplayer.show_action_directions(estimated_action_centroids, action_directions_distribution, action_probabilities, plot_filename) return Image.open(plot_filename) def plot_action_states(self, action_states: torch.Tensor, action_logits: torch.Tensor) -> Image: ''' Saves and returns a plot of the action state trajectories :param action_states: (bs, observations_count, action_space_dimension) action state trajectories or (bs, observations_count, 2, action_space_dimension) action state distribution trajectories :param action_logits: action logits with the space required by TensorDisplayer. Automatically converted to probabilities before being passed to TensorDisplayer :return: Image with the plot ''' with torch.no_grad(): action_probabilities = torch.softmax(action_logits, dim=-1) plot_filename = os.path.join( self.config["logging"]["output_images_directory"], f"action_state_trajectories_{self.global_step}.png") TensorDisplayer.show_action_states(action_states, action_probabilities, plot_filename) return Image.open(plot_filename) def compute_losses_pretraining(self, model, batch: Batch, observations_count: int) -> Tuple: ''' Computes losses for the pretraining phase :param model: The network model :param batch: Batch of data :param observations_count: The number of observations in each sequence :return: (total_loss, loss_info) total_loss: torch.Tensor with the total loss loss_info: Dict with an entry for every additional information about the loss additional_info: Dict with additional loggable information ''' # Ground truth observations to use at the current step ground_truth_observations_count = self.get_ground_truth_observations_count( ) # Since the annealing of the ground truth observations to use may produce a number greater than the number of # observations in the sequence, we cap it to the maximum value for the current sequence length if ground_truth_observations_count >= observations_count: ground_truth_observations_count = observations_count - 1 # Gumbel temperature to use at the current step gumbel_temperture = self.get_gumbel_temperature() # Computes forward and losses for the plain batch batch_tuple = batch.to_tuple() results = model(batch_tuple, pretraining=True, gumbel_temperature=gumbel_temperture) reconstructed_observations, multiresolution_reconstructed_observations, reconstructed_states, states, reconstructed_hidden_states, hidden_states, selected_actions, action_logits,\ action_samples, attention, action_directions_distribution, sampled_action_directions, \ action_states_distribution, sampled_action_states, action_variations, \ reconstructed_action_logits, \ reconstructed_action_directions_distribution, reconstructed_sampled_action_directions, \ reconstructed_action_states_distribution, reconstructed_sampled_action_states, \ *other_results = results estimated_action_centroids = model.module.centroid_estimator.get_estimated_centroids( ) # Computes the weights mask weights_mask = None if self.config["training"]["use_motion_weights"]: ground_truth_observations = batch_tuple[0] weights_mask = self.weight_mask_calculator.compute_weight_mask( ground_truth_observations, reconstructed_observations) perceptual_loss_lambda = self.config["training"]["loss_weights"][ "perceptual_loss_lambda_pretraining"] loss_info_reconstruction = {} # Computes perceptual and observation reconstruction losses averaged over all resolutions resolutions_count = len(multiresolution_reconstructed_observations) perceptual_loss = torch.zeros((1, ), dtype=float).cuda() perceptual_loss_term = torch.zeros((1, ), dtype=float).cuda() observations_rec_loss = torch.zeros((1, ), dtype=float).cuda() for resolution_idx, current_reconstructed_observations in enumerate( multiresolution_reconstructed_observations): current_perceptual_loss, current_perceptual_loss_components = self.perceptual_loss( batch.observations, current_reconstructed_observations, weights_mask) current_perceptual_loss_term = self.sum_loss_components( current_perceptual_loss_components, perceptual_loss_lambda) current_observations_rec_loss = self.observations_loss( batch.observations, current_reconstructed_observations, weights_mask) perceptual_loss += current_perceptual_loss perceptual_loss_term += current_perceptual_loss_term observations_rec_loss += current_observations_rec_loss loss_info_reconstruction[ f'perceptual_loss_r{resolution_idx}'] = current_perceptual_loss.item( ) loss_info_reconstruction[ f'observations_rec_loss_r{resolution_idx}'] = current_observations_rec_loss.item( ) for layer_idx, component in enumerate( current_perceptual_loss_components): loss_info_reconstruction[ f'perceptual_loss_r{resolution_idx}_l{layer_idx}'] = current_perceptual_loss_components[ layer_idx].item() perceptual_loss /= resolutions_count perceptual_loss_term /= resolutions_count observations_rec_loss /= resolutions_count states_rec_loss = self.states_loss(states.detach(), reconstructed_states) hidden_states_rec_loss = self.hidden_states_loss( hidden_states, reconstructed_hidden_states.detach() ) # Avoids gradient backpropagation from dynamics to representation network entropy_loss = self.entropy_loss(action_logits) action_directions_kl_divergence_loss = self.action_directions_kl_gaussian_divergence_loss( action_directions_distribution) action_mutual_information_loss = self.mutual_information_loss( torch.softmax(action_logits, dim=-1), torch.softmax(reconstructed_action_logits, dim=-1), lamb=self.action_mutual_infromation_entropy_lambda) action_state_distribution_kl_loss = self.action_state_distribution_kl( reconstructed_action_states_distribution, action_states_distribution.detach() ) # The reconstructed must get closer to the true ones, not the contrary # Additional debug information not used for backpropagation with torch.no_grad(): samples_entropy = self.samples_entropy_loss(action_samples) action_ditribution_entropy = self.action_distribution_entropy( action_samples.mean(dim=(0, 1)).unsqueeze(dim=0)) states_magnitude = torch.mean(torch.abs(states)).item() hidden_states_magnitude = torch.mean( torch.abs(hidden_states)).item() action_directions_mean_magnitude = torch.mean( torch.abs(action_directions_distribution[:, :, 0])).item( ) # Compute magnitude of the mean action_directions_variance_magnitude = torch.mean( torch.abs(action_directions_distribution[:, :, 1])).item( ) # Compute magnitude of the variance reconstructed_action_directions_mean_magnitude = torch.mean( torch.abs(reconstructed_action_directions_distribution[:, :, 0] )).item() # Compute magnitude of the mean reconstructed_action_directions_variance_magnitude = torch.mean( torch.abs(reconstructed_action_directions_distribution[:, :, 1] )).item() # Compute magnitude of the variance action_directions_reconstruction_error = torch.mean( (reconstructed_action_directions_distribution[:, :, 0] - action_directions_distribution[:, :, 0] ).pow(2)).item() # Compute differences of the mean reconstructed_action_directions_kl_divergence_loss = self.action_directions_kl_gaussian_divergence_loss( reconstructed_action_directions_distribution) centroids_mean_magnitude = torch.mean( torch.abs(estimated_action_centroids)).item() average_centroids_distance = self.compute_average_centroid_distance( estimated_action_centroids).item() average_action_variations_norm_l2 = action_variations.pow(2).sum( -1).sqrt().mean().item() action_variations_mean = action_variations.mean().item() # Computes the total loss total_loss = self.config["training"]["loss_weights"]["reconstruction_loss_lambda_pretraining"] * observations_rec_loss + \ perceptual_loss_term + \ self.config["training"]["loss_weights"]["hidden_states_rec_lambda_pretraining"] * hidden_states_rec_loss + \ self.config["training"]["loss_weights"]["states_rec_lambda_pretraining"] * states_rec_loss + \ self.config["training"]["loss_weights"]["entropy_lambda_pretraining"] * entropy_loss + \ self.config["training"]["loss_weights"]["action_directions_kl_lambda_pretraining"] * action_directions_kl_divergence_loss + \ self.config["training"]["loss_weights"]["action_mutual_information_lambda_pretraining"] * action_mutual_information_loss + \ self.config["training"]["loss_weights"]["action_state_distribution_kl_lambda_pretraining"] * action_state_distribution_kl_loss # Computes loss information loss_info = { "loss_component_observations_rec": self.config["training"]["loss_weights"] ["reconstruction_loss_lambda_pretraining"] * observations_rec_loss.item(), "loss_component_perceptual_loss": perceptual_loss_term.item(), "loss_component_hidden_states_rec": self.config["training"]["loss_weights"] ["hidden_states_rec_lambda_pretraining"] * hidden_states_rec_loss.item(), "loss_component_states_rec": self.config["training"]["loss_weights"] ["states_rec_lambda_pretraining"] * states_rec_loss.item(), "loss_component_entropy": self.config["training"]["loss_weights"] ["entropy_lambda_pretraining"] * entropy_loss.item(), "loss_component_action_directions_kl_divergence": self.config["training"]["loss_weights"] ["action_directions_kl_lambda_pretraining"] * action_directions_kl_divergence_loss.item(), "loss_component_action_mutual_information": self.config["training"]["loss_weights"] ["action_mutual_information_lambda_pretraining"] * action_mutual_information_loss.item(), "loss_component_action_state_distribution_kl": self.config["training"]["loss_weights"] ["action_state_distribution_kl_lambda_pretraining"] * action_state_distribution_kl_loss.item(), "avg_observations_rec_loss": observations_rec_loss.item(), "avg_perceptual_loss": perceptual_loss.item(), "states_rec_loss": states_rec_loss.item(), "hidden_states_rec_loss": hidden_states_rec_loss.item(), "entropy_loss": entropy_loss.item(), "samples_entropy": samples_entropy.item(), "action_distribution_entropy": action_ditribution_entropy.item(), "states_magnitude": states_magnitude, "hidden_states_magnitude": hidden_states_magnitude, "action_directions_mean_magnitude": action_directions_mean_magnitude, "action_directions_variance_magnitude": action_directions_variance_magnitude, "reconstructed_action_directions_mean_magnitude": reconstructed_action_directions_mean_magnitude, "reconstructed_action_directions_variance_magnitude": reconstructed_action_directions_variance_magnitude, "action_directions_reconstruction_error": action_directions_reconstruction_error, "action_directions_kl_loss": action_directions_kl_divergence_loss.item(), "centroids_mean_magnitude": centroids_mean_magnitude, "average_centroids_distance": average_centroids_distance, "average_action_variations_norm_l2": average_action_variations_norm_l2, "action_variations_mean": action_variations_mean, "reconstructed_action_directions_kl_loss": reconstructed_action_directions_kl_divergence_loss.item(), "action_mutual_information_loss": action_mutual_information_loss.item(), "action_state_distribution_kl_loss": action_state_distribution_kl_loss.item(), "ground_truth_observations": ground_truth_observations_count, "gumbel_temperature": gumbel_temperture, "observations_count": observations_count, } loss_info = dict(loss_info, **loss_info_reconstruction) additional_info = {} # Plots the action direction space at regular intervals if self.global_step % self.config["training"][ "action_direction_plotting_freq"] == 0: image = self.plot_action_direction_space( estimated_action_centroids, action_directions_distribution, action_logits) additional_info["action_direction_space"] = wandb.Image(image) image = self.plot_action_states(sampled_action_states, action_logits) additional_info["action_state_trajectories"] = wandb.Image(image) return total_loss, loss_info, additional_info def compute_losses(self, model, batch: Batch, observations_count: int) -> Tuple: ''' Computes losses using the full model :param model: The network model :param batch: Batch of data :param observations_count: The number of observations in each sequence :return: (total_loss, loss_info) total_loss: torch.Tensor with the total loss loss_info: Dict with an entry for every additional information about the loss additional_info: Dict with additional loggable information ''' # Ground truth observations to use at the current step ground_truth_observations_count = self.get_ground_truth_observations_count( ) # Since the annealing of the ground truth observations to use may produce a number greater than the number of # observations in the sequence, we cap it to the maximum value for the current sequence length if ground_truth_observations_count >= observations_count: ground_truth_observations_count = observations_count - 1 # Gumbel temperature to use at the current step gumbel_temperature = self.get_gumbel_temperature() # Computes forward and losses for the plain batch batch_tuple = batch.to_tuple() results = model(batch_tuple, ground_truth_observations_count, gumbel_temperature=gumbel_temperature) reconstructed_observations, multiresolution_reconstructed_observations, reconstructed_states, states, hidden_states, selected_actions, action_logits, action_samples, \ attention, reconstructed_attention, action_directions_distribution, sampled_action_directions, \ action_states_distribution, sampled_action_states, action_variations,\ reconstructed_action_logits, \ reconstructed_action_directions_distribution, reconstructed_sampled_action_directions, \ reconstructed_action_states_distribution, reconstructed_sampled_action_states, *other_results = results estimated_action_centroids = model.module.centroid_estimator.get_estimated_centroids( ) # Computes the weights mask weights_mask = None if self.config["training"]["use_motion_weights"]: ground_truth_observations = batch_tuple[0] weights_mask = self.weight_mask_calculator.compute_weight_mask( ground_truth_observations, reconstructed_observations) perceptual_loss_lambda = self.config["training"]["loss_weights"][ "perceptual_loss_lambda"] loss_info_reconstruction = {} # Computes perceptual and observation reconstruction losses averaged over all resolutions resolutions_count = len(multiresolution_reconstructed_observations) perceptual_loss = torch.zeros((1, ), dtype=float).cuda() perceptual_loss_term = torch.zeros((1, ), dtype=float).cuda() observations_rec_loss = torch.zeros((1, ), dtype=float).cuda() for resolution_idx, current_reconstructed_observations in enumerate( multiresolution_reconstructed_observations): current_perceptual_loss, current_perceptual_loss_components = self.perceptual_loss( batch.observations, current_reconstructed_observations, weights_mask) current_perceptual_loss_term = self.sum_loss_components( current_perceptual_loss_components, perceptual_loss_lambda) current_observations_rec_loss = self.observations_loss( batch.observations, current_reconstructed_observations, weights_mask) perceptual_loss += current_perceptual_loss perceptual_loss_term += current_perceptual_loss_term observations_rec_loss += current_observations_rec_loss loss_info_reconstruction[ f'perceptual_loss_r{resolution_idx}'] = current_perceptual_loss.item( ) loss_info_reconstruction[ f'observations_rec_loss_r{resolution_idx}'] = current_observations_rec_loss.item( ) for layer_idx, component in enumerate( current_perceptual_loss_components): loss_info_reconstruction[ f'perceptual_loss_r{resolution_idx}_l{layer_idx}'] = current_perceptual_loss_components[ layer_idx].item() perceptual_loss /= resolutions_count perceptual_loss_term /= resolutions_count observations_rec_loss /= resolutions_count states_rec_loss = self.states_loss(states.detach(), reconstructed_states) entropy_loss = self.entropy_loss(action_logits) action_directions_kl_divergence_loss = self.action_directions_kl_gaussian_divergence_loss( action_directions_distribution) action_mutual_information_loss = self.mutual_information_loss( torch.softmax(action_logits, dim=-1), torch.softmax(reconstructed_action_logits, dim=-1), lamb=self.action_mutual_infromation_entropy_lambda) action_state_distribution_kl_loss = self.action_state_distribution_kl( reconstructed_action_states_distribution, action_states_distribution.detach() ) # The reconstructed must get closer to the true ones, not the contrary # Additional debug information not used for backpropagation with torch.no_grad(): samples_entropy = self.samples_entropy_loss(action_samples) action_distribution_entropy = self.action_distribution_entropy( action_samples.mean(dim=(0, 1)).unsqueeze(dim=0)) states_magnitude = torch.mean(torch.abs(states)).item() hidden_states_magnitude = torch.mean( torch.abs(hidden_states)).item() action_directions_mean_magnitude = torch.mean( torch.abs(action_directions_distribution[:, :, 0])).item( ) # Compute magnitude of the mean action_directions_variance_magnitude = torch.mean( torch.abs(action_directions_distribution[:, :, 1])).item( ) # Compute magnitude of the variance reconstructed_action_directions_mean_magnitude = torch.mean( torch.abs(reconstructed_action_directions_distribution[:, :, 0] )).item() # Compute magnitude of the mean reconstructed_action_directions_variance_magnitude = torch.mean( torch.abs(reconstructed_action_directions_distribution[:, :, 1] )).item() # Compute magnitude of the variance action_directions_reconstruction_error = torch.mean( (reconstructed_action_directions_distribution[:, :, 0] - action_directions_distribution[:, :, 0] ).pow(2)).item() # Compute differences of the mean reconstructed_action_directions_kl_divergence_loss = self.action_directions_kl_gaussian_divergence_loss( reconstructed_action_directions_distribution) centroids_mean_magnitude = torch.mean( torch.abs(estimated_action_centroids)).item() average_centroids_distance = self.compute_average_centroid_distance( estimated_action_centroids).item() average_action_variations_norm_l2 = action_variations.pow(2).sum( -1).sqrt().mean().item() action_variations_mean = action_variations.mean().item() # Computes the total loss total_loss = self.config["training"]["loss_weights"]["reconstruction_loss_lambda"] * observations_rec_loss + \ perceptual_loss_term + \ self.config["training"]["loss_weights"]["states_rec_lambda"] * states_rec_loss + \ self.config["training"]["loss_weights"]["entropy_lambda"] * entropy_loss + \ self.config["training"]["loss_weights"]["action_directions_kl_lambda"] * action_directions_kl_divergence_loss + \ self.config["training"]["loss_weights"]["action_mutual_information_lambda"] * action_mutual_information_loss + \ self.config["training"]["loss_weights"]["action_state_distribution_kl_lambda"] * action_state_distribution_kl_loss # Computes loss information loss_info = { "loss_component_observations_rec": self.config["training"]["loss_weights"] ["reconstruction_loss_lambda"] * observations_rec_loss.item(), "loss_component_perceptual_loss": perceptual_loss_term.item(), "loss_component_states_rec": self.config["training"]["loss_weights"]["states_rec_lambda"] * states_rec_loss.item(), "loss_component_entropy": self.config["training"]["loss_weights"]["entropy_lambda"] * entropy_loss.item(), "loss_component_action_directions_kl_divergence": self.config["training"]["loss_weights"] ["action_directions_kl_lambda"] * action_directions_kl_divergence_loss.item(), "loss_component_action_mutual_information": self.config["training"]["loss_weights"] ["action_mutual_information_lambda"] * action_mutual_information_loss.item(), "loss_component_action_state_distribution_kl": self.config["training"]["loss_weights"] ["action_state_distribution_kl_lambda"] * action_state_distribution_kl_loss.item(), "avg_observations_rec_loss": observations_rec_loss.item(), "avg_perceptual_loss": perceptual_loss.item(), "states_rec_loss": states_rec_loss.item(), "entropy_loss": entropy_loss.item(), "samples_entropy": samples_entropy.item(), "action_distribution_entropy": action_distribution_entropy.item(), "states_magnitude": states_magnitude, "hidden_states_magnitude": hidden_states_magnitude, "action_directions_mean_magnitude": action_directions_mean_magnitude, "action_directions_variance_magnitude": action_directions_variance_magnitude, "reconstructed_action_directions_mean_magnitude": reconstructed_action_directions_mean_magnitude, "reconstructed_action_directions_variance_magnitude": reconstructed_action_directions_variance_magnitude, "action_directions_reconstruction_error": action_directions_reconstruction_error, "action_directions_kl_loss": action_directions_kl_divergence_loss.item(), "centroids_mean_magnitude": centroids_mean_magnitude, "average_centroids_distance": average_centroids_distance, "average_action_variations_norm_l2": average_action_variations_norm_l2, "action_variations_mean": action_variations_mean, "reconstructed_action_directions_kl_loss": reconstructed_action_directions_kl_divergence_loss.item(), "action_mutual_information_loss": action_mutual_information_loss.item(), "action_state_distribution_kl_loss": action_state_distribution_kl_loss.item(), "ground_truth_observations": ground_truth_observations_count, "gumbel_temperature": gumbel_temperature, "observations_count": observations_count, } # Concatenates the info dictionaries loss_info = dict(loss_info, **loss_info_reconstruction) additional_info = {} # Plots the action direction space at regular intervals if self.global_step % self.config["training"][ "action_direction_plotting_freq"] == 0: image = self.plot_action_direction_space( estimated_action_centroids, action_directions_distribution, action_logits) additional_info["action_direction_space"] = wandb.Image(image) image = self.plot_action_states(sampled_action_states, action_logits) additional_info["action_state_trajectories"] = wandb.Image(image) return total_loss, loss_info, additional_info def train_epoch(self, model): self.logger.print(f'== Train [{self.global_step}] ==') # Computes the number of observations to use in the current epoch observations_count = self.get_observations_count() # Modifies the number of observations to return before instantiating the dataloader self.dataset.set_observations_count(observations_count) # Number of training steps performed in this epoch performed_steps = 0 for step, batch_group in enumerate(self.dataloader): # If the maximum number of training steps per epoch is exceeded, we interrupt the epoch if performed_steps > self.config["training"]["max_steps_per_epoch"]: break self.global_step += 1 performed_steps += 1 # If there is a change in the number of observations to use, we interrupt the epoch current_observations_count = self.get_observations_count() if current_observations_count != observations_count: break if self.global_step <= self.config["training"]["pretraining_steps"]: loss, loss_info, additional_info = self.compute_losses_pretraining( model, batch_group, observations_count) else: loss, loss_info, additional_info = self.compute_losses( model, batch_group, observations_count) # Logs the loss loss_info["loss"] = loss.item() self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.lr_scheduler.step() self.average_meter.add(loss_info) if (self.global_step - 1) % 1 == 0: self.logger.print( f'step: {self.global_step}/{self.config["training"]["max_steps"]}', end=" ") average_values = { description: self.average_meter.pop(description) for description in loss_info } for description, value in average_values.items(): self.logger.print("{}:{:.3f}".format(description, value), end=" ") current_lr = self._get_current_lr() self.logger.print('lr: %.4f' % (current_lr)) if (self.global_step - 1) % 10 == 0: wandb = self.logger.get_wandb() logged_map = { "train/" + description: item for description, item in average_values.items() } logged_map["step"] = self.global_step logged_map["train/lr"] = current_lr wandb.log(logged_map, step=self.global_step) additional_info["step"] = self.global_step wandb.log(additional_info, step=self.global_step)
def evaluate(self, model, step: int): ''' Evaluates the performances of the given model :param model: The model to evaluate :param step: The current step :return: ''' loss_averager = AverageMeter() # All the selected actions and ground truth ones all_gt_actions = [] all_pred_actions = [] # Number of video sequence samples analyzed total_sequences = 0 self.logger.print(f"== Evaluation [{step}][{self.logger_prefix}] ==") self.logger.print(f"- Saving sample images") # Saves sample images with torch.no_grad(): for idx, batch in enumerate(self.imaging_dataloader): # Performs inference batch_tuple = batch.to_tuple() observations, actions, rewards, dones = batch_tuple ground_truth_observations = batch_tuple[0] results = model(batch_tuple, ground_truth_observations_init=1) reconstructed_observations, multiresolution_reconstructed_observations, reconstructed_states, states, hidden_states, selected_actions, action_logits, action_samples_distribution, *others = results weights_mask = self.weight_mask_calculator.compute_weight_mask( ground_truth_observations, reconstructed_observations) # Saves reconstructed observations at all resolutions for resolution_idx, current_reconstructed_observations in enumerate( multiresolution_reconstructed_observations): self.save_examples( observations, current_reconstructed_observations, step, max_batches=30, log_key=f"observations_r{resolution_idx}") # Plots weight masks assert (weights_mask.min().item() >= 0.0) weights_mask = weights_mask / torch.max( torch.abs( weights_mask)) # Normalizes the weights for plotting self.save_examples_with_weights( observations, weights_mask, reconstructed_observations, weights_mask, step, max_batches=30, log_key="motion_weighted_observations_") # If attention is used, plot also attention if len(others) > 0: attention = others[0] reconstructed_attention = others[1] self.save_examples_with_weights( observations, attention, reconstructed_observations, reconstructed_attention, step, max_batches=30, log_key="attentive_observations_") break self.logger.print(f"- Computing evaluation losses") current_evaluation_batches = 0 all_action_direction_distributions = [] all_action_logits = [] all_action_states = [] estimated_action_centroids = None with torch.no_grad(): for idx, batch in enumerate(self.dataloader): if self.max_evaluation_batches is not None and self.max_evaluation_batches <= current_evaluation_batches: self.logger.print( f"- Aborting evaluation, maximum number of evaluation batches reached" ) break current_evaluation_batches += 1 # Performs evaluation only on the plain batch total_sequences += batch.size # Performs inference batch_tuple = batch.to_tuple() results = model(batch_tuple, ground_truth_observations_init=1, action_sampler=self.action_sampler) # Extracts the results reconstructed_observations, multiresolution_reconstructed_observations, reconstructed_states, states, hidden_states, selected_actions, action_logits, action_samples_distribution, \ attention, reconstructed_attention, action_directions_distribution, sampled_action_directions, \ action_states_distribution, sampled_action_states, action_variations,\ reconstructed_action_logits, \ reconstructed_action_directions_distribution, reconstructed_sampled_action_directions, \ reconstructed_action_states_distribution, reconstructed_sampled_action_states, \ *other_results = results all_action_states.append(action_states_distribution[:, :, :, 0]) all_action_direction_distributions.append( action_directions_distribution.cpu()) all_action_logits.append(action_logits.cpu()) if estimated_action_centroids is None: estimated_action_centroids = model.module.centroid_estimator.get_estimated_centroids( ) # Computes losses entropy_loss = self.entropy_loss(action_logits) samples_entropy = self.samples_entropy_loss( action_samples_distribution) action_ditribution_entropy = self.action_distribution_entropy( action_samples_distribution.mean(dim=(0, 1)).unsqueeze(dim=0)) action_directions_kl_divergence_loss = self.action_directions_kl_gaussian_divergence_loss( action_directions_distribution) action_mutual_information_loss = self.mutual_information_loss( torch.softmax(action_logits, dim=-1), torch.softmax(reconstructed_action_logits, dim=-1)) # Evaluates the sequence losses sequence_observation_loss = self.evaluate_loss_on_sequence( batch.observations, reconstructed_observations, self.sequence_observation_loss, "observations_loss") sequence_perceptual_loss = self.evaluate_loss_on_sequence( batch.observations, reconstructed_observations, self.sequence_perceptual_loss, "perceptual_loss") sequence_states_loss = self.evaluate_loss_on_sequence( states, reconstructed_states, self.sequence_states_loss, "states_loss") loss_averager.add(sequence_observation_loss) loss_averager.add(sequence_perceptual_loss) loss_averager.add(sequence_states_loss) loss_averager.add({"entropy": entropy_loss.item()}) loss_averager.add({"samples_entropy": samples_entropy.item()}) loss_averager.add({ "action_distribution_entropy": action_ditribution_entropy.item() }) loss_averager.add({ "action_directions_kl_loss": action_directions_kl_divergence_loss.item() }) loss_averager.add({ "action_mutual_information_loss": action_mutual_information_loss.item() }) # Saves the flattened actions all_pred_actions.append(selected_actions.reshape((-1))) all_gt_actions.append(batch.actions[:, :-1].reshape( (-1) )) # The last action of each sequence cannot be predicted all_action_states = torch.cat(all_action_states) all_predecessor_action_states = TensorFolder.flatten( all_action_states[:, :-1]) all_successor_action_states = TensorFolder.flatten( all_action_states[:, 1:]) samples = torch.cat( [all_predecessor_action_states, all_successor_action_states], dim=-1).cpu().numpy() covariance_matrix = np.cov(samples, rowvar=False) all_pred_actions = torch.cat( all_pred_actions) # Concatenate on the batch size dimension all_gt_actions = torch.cat(all_gt_actions) actions_accuracy, actions_match = self.compute_actions_accuracy( all_pred_actions, all_gt_actions) # Plots the distribution of action directions all_action_direction_distributions = torch.cat( all_action_direction_distributions, dim=0) all_action_logits = torch.cat(all_action_logits, dim=0) all_action_probabilities = torch.softmax(all_action_logits, dim=-1) action_directions_plot_filename = os.path.join( self.config["logging"]["output_images_directory"], f"action_direction_space_eval_{step}.pdf") TensorDisplayer.show_action_directions( estimated_action_centroids.detach().cpu(), all_action_direction_distributions, all_action_probabilities, action_directions_plot_filename) # Registers the best match found self.best_action_mappings = actions_match # Populates data to log at the current step log_data = { f"{self.logger_prefix}/actions_accuracy": actions_accuracy, "step": step } for key in loss_averager.data: log_data[f'{self.logger_prefix}/{key}'] = loss_averager.pop(key) # Logs results wandb = self.logger.get_wandb() wandb.log(log_data, step=step) self.logger.print("- observations_loss: {:.3f}".format( log_data[f'{self.logger_prefix}/observations_loss/avg'])) self.logger.print("- perceptual_loss: {:.3f}".format( log_data[f'{self.logger_prefix}/perceptual_loss/avg'])) self.logger.print("- states_loss: {:.3f}".format( log_data[f'{self.logger_prefix}/states_loss/avg'])) self.logger.print( "- actions_accuracy: {:.3f}".format(actions_accuracy)) self.logger.print("- entropy: {:.3f}".format( log_data[f'{self.logger_prefix}/entropy'])) self.logger.print("- samples entropy: {:.3f}".format( log_data[f'{self.logger_prefix}/samples_entropy'])) self.logger.print("- action distribution entropy: {:.3f}".format( log_data[f'{self.logger_prefix}/action_distribution_entropy'])) return