def execute(gpu, exp_batch, exp_alias, suppress_output=True, number_of_workers=12): """ The main training function. This functions loads the latest checkpoint for a given, exp_batch (folder) and exp_alias (experiment configuration). With this checkpoint it starts from the beginning or continue some training. Args: gpu: The GPU number exp_batch: the folder with the experiments exp_alias: the alias, experiment name suppress_output: if the output are going to be saved on a file number_of_workers: the number of threads used for data loading Returns: None """ try: # We set the visible cuda devices to select the GPU os.environ["CUDA_VISIBLE_DEVICES"] = gpu g_conf.VARIABLE_WEIGHT = {} # At this point the log file with the correct naming is created. # You merge the yaml file with the global configuration structure. merge_with_yaml(os.path.join('configs', exp_batch, exp_alias + '.yaml')) set_type_of_process('train') # Set the process into loading status. coil_logger.add_message('Loading', {'GPU': gpu}) # Seed RNGs torch.manual_seed(g_conf.MAGICAL_SEED) random.seed(g_conf.MAGICAL_SEED) # Put the output to a separate file if it is the case if suppress_output: if not os.path.exists('_output_logs'): os.mkdir('_output_logs') sys.stdout = open(os.path.join( '_output_logs', exp_alias + '_' + g_conf.PROCESS_NAME + '_' + str(os.getpid()) + ".out"), "a", buffering=1) sys.stderr = open(os.path.join( '_output_logs', exp_alias + '_err_' + g_conf.PROCESS_NAME + '_' + str(os.getpid()) + ".out"), "a", buffering=1) if coil_logger.check_finish('train'): coil_logger.add_message('Finished', {}) return # Preload option if g_conf.PRELOAD_MODEL_ALIAS is not None: checkpoint = torch.load( os.path.join('_logs', g_conf.PRELOAD_MODEL_BATCH, g_conf.PRELOAD_MODEL_ALIAS, 'checkpoints', str(g_conf.PRELOAD_MODEL_CHECKPOINT) + '.pth')) # Get the latest checkpoint to be loaded # returns none if there are no checkpoints saved for this model checkpoint_file = get_latest_saved_checkpoint() if checkpoint_file is not None: checkpoint = torch.load( os.path.join('_logs', exp_batch, exp_alias, 'checkpoints', str(get_latest_saved_checkpoint()))) iteration = checkpoint['iteration'] best_loss = checkpoint['best_loss'] best_loss_iter = checkpoint['best_loss_iter'] else: iteration = 0 best_loss = 10000.0 best_loss_iter = 0 # Define the dataset. # Can specify a list of training datasets or just a single training dataset if len(g_conf.TRAIN_DATASET_NAMES) == 0: train_dataset_list = [g_conf.TRAIN_DATASET_NAME] else: train_dataset_list = g_conf.TRAIN_DATASET_NAMES full_dataset = [ os.path.join(os.environ["COIL_DATASET_PATH"], dataset_name) for dataset_name in train_dataset_list ] # By instantiating the augmenter we get a callable that augment images and transform them # into tensors. augmenter = Augmenter(g_conf.AUGMENTATION) # Instantiate the class used to read a dataset. The coil dataset generator # can be found dataset = CoILDataset(full_dataset, transform=augmenter, preload_names=[ str(g_conf.NUMBER_OF_HOURS) + 'hours_' + dataset_name for dataset_name in train_dataset_list ], train_dataset=True) print("Loaded dataset") # Create dataloader, model, and optimizer data_loader = select_balancing_strategy(dataset, iteration, number_of_workers) model = CoILModel(g_conf.MODEL_TYPE, g_conf.MODEL_CONFIGURATION) model.cuda() optimizer = optim.Adam(model.parameters(), lr=g_conf.LEARNING_RATE) # If we have a previous checkpoint, load model, optimizer, and record of previous # train loss values (used for the learning rate schedule) if checkpoint_file is not None or g_conf.PRELOAD_MODEL_ALIAS is not None: model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) accumulated_time = checkpoint['total_time'] loss_window = coil_logger.recover_loss_window('train', iteration) else: # We accumulate iteration time and keep the average speed accumulated_time = 0 loss_window = [] print("Before the loss") # Define control loss function criterion = Loss(g_conf.LOSS_FUNCTION) if iteration == 0 and is_ready_to_save(iteration): state = { 'iteration': iteration, 'state_dict': model.state_dict(), 'best_loss': best_loss, 'total_time': accumulated_time, 'optimizer': optimizer.state_dict(), 'best_loss_iter': best_loss_iter } torch.save( state, os.path.join('_logs', exp_batch, exp_alias, 'checkpoints', str(iteration) + '.pth')) # Training loop for data in data_loader: # Basically in this mode of execution, we validate every X Steps, if it goes up 3 times, # add a stop on the _logs folder that is going to be read by this process if g_conf.FINISH_ON_VALIDATION_STALE is not None and \ check_loss_validation_stopped(iteration, g_conf.FINISH_ON_VALIDATION_STALE): break """ #################################### Main optimization loop #################################### """ iteration += 1 # Adjust learning rate based on training loss if iteration % 1000 == 0: adjust_learning_rate_auto(optimizer, loss_window) capture_time = time.time() model.zero_grad() controls = data['directions'] # Run model forward and get outputs # First case corresponds to training squeeze network, second case corresponds to training driving model without # mimicking losses, last case corresponds to training mimic network if "seg" in g_conf.SENSORS.keys(): branches = model(data, dataset.extract_inputs(data).cuda(), dataset.extract_intentions(data).cuda()) elif not g_conf.USE_REPRESENTATION_LOSS: branches = model(data, dataset.extract_inputs(data).cuda()) else: branches, intermediate_reps = model( data, dataset.extract_inputs(data).cuda()) # Compute control loss targets_to_use = dataset.extract_targets(data) loss_function_params = { 'branches': branches, 'targets': targets_to_use.cuda(), 'controls': controls.cuda(), 'inputs': dataset.extract_inputs(data).cuda(), 'branch_weights': g_conf.BRANCH_LOSS_WEIGHT, 'variable_weights': g_conf.VARIABLE_WEIGHT } loss, _ = criterion(loss_function_params) # Compute mimicking loss if g_conf.USE_REPRESENTATION_LOSS: expert_reps = dataset.extract_representations(data) # Seg mask mimicking loss if g_conf.USE_PERCEPTION_REP_LOSS: perception_rep_loss_elementwise = ( intermediate_reps[0] - expert_reps[0].cuda())**2 perception_rep_loss = g_conf.PERCEPTION_REP_WEIGHT * torch.sum( perception_rep_loss_elementwise) / branches[0].shape[0] else: perception_rep_loss = torch.tensor(0.).cuda() # Speed mimicking loss if g_conf.USE_SPEED_REP_LOSS: speed_rep_loss_elementwise = (intermediate_reps[1] - expert_reps[1].cuda())**2 speed_rep_loss = g_conf.SPEED_REP_WEIGHT * torch.sum( speed_rep_loss_elementwise) / branches[0].shape[0] else: speed_rep_loss = torch.tensor(0.).cuda() # Stop intentions mimicking loss if g_conf.USE_INTENTION_REP_LOSS: intentions_rep_loss_elementwise = ( intermediate_reps[2] - expert_reps[2].cuda())**2 intentions_rep_loss = g_conf.INTENTIONS_REP_WEIGHT * torch.sum( intentions_rep_loss_elementwise) / branches[0].shape[0] else: intentions_rep_loss = torch.tensor(0.).cuda() rep_loss = g_conf.REP_LOSS_WEIGHT * ( perception_rep_loss + speed_rep_loss + intentions_rep_loss) overall_loss = loss + rep_loss else: overall_loss = loss overall_loss.backward() optimizer.step() """ #################################### Saving the model if necessary #################################### """ if is_ready_to_save(iteration): state = { 'iteration': iteration, 'state_dict': model.state_dict(), 'best_loss': best_loss, 'total_time': accumulated_time, 'optimizer': optimizer.state_dict(), 'best_loss_iter': best_loss_iter } torch.save( state, os.path.join('_logs', exp_batch, exp_alias, 'checkpoints', str(iteration) + '.pth')) """ ################################################ Adding tensorboard logs. Making calculations for logging purposes. These logs are monitored by the printer module. ################################################# """ coil_logger.add_scalar('Loss', loss.data, iteration) if g_conf.USE_REPRESENTATION_LOSS: coil_logger.add_scalar('Perception Rep Loss', perception_rep_loss.data, iteration) coil_logger.add_scalar('Speed Rep Loss', speed_rep_loss.data, iteration) coil_logger.add_scalar('Intentions Rep Loss', intentions_rep_loss.data, iteration) coil_logger.add_scalar('Overall Rep Loss', rep_loss.data, iteration) coil_logger.add_scalar('Total Loss', overall_loss.data, iteration) if 'rgb' in data: coil_logger.add_image('Image', torch.squeeze(data['rgb']), iteration) if overall_loss.data < best_loss: best_loss = overall_loss.data.tolist() best_loss_iter = iteration # Log a random position position = random.randint(0, len(data) - 1) output = model.extract_branch(torch.stack(branches[0:4]), controls) error = torch.abs(output - targets_to_use.cuda()) accumulated_time += time.time() - capture_time # Log to terminal and log file if g_conf.USE_REPRESENTATION_LOSS: coil_logger.add_message( 'Iterating', { 'Iteration': iteration, 'Loss': overall_loss.data.tolist(), 'Control Loss': loss.data.tolist(), 'Rep Loss': rep_loss.data.tolist(), 'Images/s': (iteration * g_conf.BATCH_SIZE) / accumulated_time, 'BestLoss': best_loss, 'BestLossIteration': best_loss_iter, 'Output': output[position].data.tolist(), 'GroundTruth': targets_to_use[position].data.tolist(), 'Error': error[position].data.tolist(), 'Inputs': dataset.extract_inputs(data)[position].data.tolist() }, iteration) else: coil_logger.add_message( 'Iterating', { 'Iteration': iteration, 'Loss': loss.data.tolist(), 'Images/s': (iteration * g_conf.BATCH_SIZE) / accumulated_time, 'BestLoss': best_loss, 'BestLossIteration': best_loss_iter, 'Output': output[position].data.tolist(), 'GroundTruth': targets_to_use[position].data.tolist(), 'Error': error[position].data.tolist(), 'Inputs': dataset.extract_inputs(data)[position].data.tolist() }, iteration) # Save training loss history (useful for restoring training runs since learning rate is adjusted # based on training loss) loss_window.append(overall_loss.data.tolist()) coil_logger.write_on_error_csv('train', overall_loss.data) print("Iteration: %d Loss: %f" % (iteration, overall_loss.data)) coil_logger.add_message('Finished', {}) except KeyboardInterrupt: coil_logger.add_message('Error', {'Message': 'Killed By User'}) except RuntimeError as e: coil_logger.add_message('Error', {'Message': str(e)}) except: traceback.print_exc() coil_logger.add_message('Error', {'Message': 'Something Happened'})
def execute(gpu, exp_batch, exp_alias, dataset_name, suppress_output): latest = None try: # We set the visible cuda devices os.environ["CUDA_VISIBLE_DEVICES"] = gpu # At this point the log file with the correct naming is created. merge_with_yaml(os.path.join('configs', exp_batch, exp_alias + '.yaml')) # The validation dataset is always fully loaded, so we fix a very high number of hours g_conf.NUMBER_OF_HOURS = 10000 set_type_of_process('validation', dataset_name) if not os.path.exists('_output_logs'): os.mkdir('_output_logs') if suppress_output: sys.stdout = open(os.path.join( '_output_logs', exp_alias + '_' + g_conf.PROCESS_NAME + '_' + str(os.getpid()) + ".out"), "a", buffering=1) sys.stderr = open(os.path.join( '_output_logs', exp_alias + '_err_' + g_conf.PROCESS_NAME + '_' + str(os.getpid()) + ".out"), "a", buffering=1) # Define the dataset. full_dataset = [ os.path.join(os.environ["COIL_DATASET_PATH"], dataset_name) ] augmenter = Augmenter(None) # Definition of the dataset to be used. Preload name is just the validation data name dataset = CoILDataset(full_dataset, transform=augmenter, preload_names=[dataset_name]) # The data loader is the multi threaded module from pytorch that release a number of # workers to get all the data. data_loader = torch.utils.data.DataLoader( dataset, batch_size=g_conf.BATCH_SIZE, shuffle=False, num_workers=g_conf.NUMBER_OF_LOADING_WORKERS, pin_memory=True) # Create model. model = CoILModel(g_conf.MODEL_TYPE, g_conf.MODEL_CONFIGURATION) # The window used to keep track of the validation loss l1_window = [] # If we have evaluated a checkpoint, get the validation losses of all the previously # evaluated checkpoints (validation loss is used for early stopping) latest = get_latest_evaluated_checkpoint() if latest is not None: # When latest is noe l1_window = coil_logger.recover_loss_window(dataset_name, None) model.cuda() best_mse = 1000 best_error = 1000 best_mse_iter = 0 best_error_iter = 0 # Loop to validate all checkpoints as they are saved during training while not maximun_checkpoint_reach(latest, g_conf.TEST_SCHEDULE): if is_next_checkpoint_ready(g_conf.TEST_SCHEDULE): with torch.no_grad(): # Get and load latest checkpoint latest = get_next_checkpoint(g_conf.TEST_SCHEDULE) checkpoint = torch.load( os.path.join('_logs', exp_batch, exp_alias, 'checkpoints', str(latest) + '.pth')) checkpoint_iteration = checkpoint['iteration'] print("Validation loaded ", checkpoint_iteration) model.load_state_dict(checkpoint['state_dict']) model.eval() accumulated_mse = 0 accumulated_error = 0 iteration_on_checkpoint = 0 if g_conf.USE_REPRESENTATION_LOSS: accumulated_perception_rep_mse = 0 accumulated_speed_rep_mse = 0 accumulated_intentions_rep_mse = 0 accumulated_rep_mse = 0 accumulated_perception_rep_error = 0 accumulated_speed_rep_error = 0 accumulated_intentions_rep_error = 0 accumulated_rep_error = 0 # Validation loop for data in data_loader: # Compute the forward pass on a batch from the validation dataset controls = data['directions'] # Run model forward and get outputs # First case corresponds to squeeze network, second case corresponds to driving model without # mimicking losses, last case corresponds to mimic network if "seg" in g_conf.SENSORS.keys(): output = model.forward_branch( data, dataset.extract_inputs(data).cuda(), controls, dataset.extract_intentions(data).cuda()) elif not g_conf.USE_REPRESENTATION_LOSS: output = model.forward_branch( data, dataset.extract_inputs(data).cuda(), controls) else: output, intermediate_reps = model.forward_branch( data, dataset.extract_inputs(data).cuda(), controls) write_regular_output(checkpoint_iteration, output) # Compute control loss on current validation batch and accumulate it targets_to_use = dataset.extract_targets(data) mse = torch.mean( (output - targets_to_use.cuda())**2).data.tolist() mean_error = torch.mean( torch.abs(output - targets_to_use.cuda())).data.tolist() accumulated_error += mean_error accumulated_mse += mse error = torch.abs(output - targets_to_use.cuda()) # Compute mimicking losses on current validation batch and accumulate it if g_conf.USE_REPRESENTATION_LOSS: expert_reps = dataset.extract_representations(data) # First L1 losses (seg mask, speed, intention mimicking losses) if g_conf.USE_PERCEPTION_REP_LOSS: perception_rep_loss = torch.sum( torch.abs(intermediate_reps[0] - expert_reps[0].cuda()) ).data.tolist() / (3 * output.shape[0]) else: perception_rep_loss = 0 if g_conf.USE_SPEED_REP_LOSS: speed_rep_loss = torch.sum( torch.abs(intermediate_reps[1] - expert_reps[1].cuda()) ).data.tolist() / (3 * output.shape[0]) else: speed_rep_loss = 0 if g_conf.USE_INTENTION_REP_LOSS: intentions_rep_loss = torch.sum( torch.abs(intermediate_reps[2] - expert_reps[2].cuda()) ).data.tolist() / (3 * output.shape[0]) else: intentions_rep_loss = 0 rep_error = g_conf.REP_LOSS_WEIGHT * ( perception_rep_loss + speed_rep_loss + intentions_rep_loss) accumulated_perception_rep_error += perception_rep_loss accumulated_speed_rep_error += speed_rep_loss accumulated_intentions_rep_error += intentions_rep_loss accumulated_rep_error += rep_error # L2 losses now if g_conf.USE_PERCEPTION_REP_LOSS: perception_rep_loss = torch.sum( (intermediate_reps[0] - expert_reps[0].cuda())** 2).data.tolist() / (3 * output.shape[0]) else: perception_rep_loss = 0 if g_conf.USE_SPEED_REP_LOSS: speed_rep_loss = torch.sum( (intermediate_reps[1] - expert_reps[1].cuda())** 2).data.tolist() / (3 * output.shape[0]) else: speed_rep_loss = 0 if g_conf.USE_INTENTION_REP_LOSS: intentions_rep_loss = torch.sum( (intermediate_reps[2] - expert_reps[2].cuda())** 2).data.tolist() / (3 * output.shape[0]) else: intentions_rep_loss = 0 rep_mse = g_conf.REP_LOSS_WEIGHT * ( perception_rep_loss + speed_rep_loss + intentions_rep_loss) accumulated_perception_rep_mse += perception_rep_loss accumulated_speed_rep_mse += speed_rep_loss accumulated_intentions_rep_mse += intentions_rep_loss accumulated_rep_mse += rep_mse # Log a random position position = random.randint( 0, len(output.data.tolist()) - 1) # Logging if g_conf.USE_REPRESENTATION_LOSS: total_mse = mse + rep_mse total_error = mean_error + rep_error coil_logger.add_message( 'Iterating', { 'Checkpoint': latest, 'Iteration': (str(iteration_on_checkpoint * 120) + '/' + str(len(dataset))), 'MeanError': mean_error, 'MSE': mse, 'RepMeanError': rep_error, 'RepMSE': rep_mse, 'MeanTotalError': total_error, 'TotalMSE': total_mse, 'Output': output[position].data.tolist(), 'GroundTruth': targets_to_use[position].data.tolist(), 'Error': error[position].data.tolist(), 'Inputs': dataset.extract_inputs( data)[position].data.tolist() }, latest) else: coil_logger.add_message( 'Iterating', { 'Checkpoint': latest, 'Iteration': (str(iteration_on_checkpoint * 120) + '/' + str(len(dataset))), 'MeanError': mean_error, 'MSE': mse, 'Output': output[position].data.tolist(), 'GroundTruth': targets_to_use[position].data.tolist(), 'Error': error[position].data.tolist(), 'Inputs': dataset.extract_inputs( data)[position].data.tolist() }, latest) iteration_on_checkpoint += 1 if g_conf.USE_REPRESENTATION_LOSS: print("Iteration %d on Checkpoint %d : Error %f" % (iteration_on_checkpoint, checkpoint_iteration, total_error)) else: print("Iteration %d on Checkpoint %d : Error %f" % (iteration_on_checkpoint, checkpoint_iteration, mean_error)) """ ######## Finish a round of validation, write results, wait for the next ######## """ # Compute average L1 and L2 losses over whole round of validation and log them checkpoint_average_mse = accumulated_mse / ( len(data_loader)) checkpoint_average_error = accumulated_error / ( len(data_loader)) coil_logger.add_scalar('L2 Loss', checkpoint_average_mse, latest, True) coil_logger.add_scalar('Loss', checkpoint_average_error, latest, True) if g_conf.USE_REPRESENTATION_LOSS: checkpoint_average_perception_rep_mse = accumulated_perception_rep_mse / ( len(data_loader)) checkpoint_average_speed_rep_mse = accumulated_speed_rep_mse / ( len(data_loader)) checkpoint_average_intentions_rep_mse = accumulated_intentions_rep_mse / ( len(data_loader)) checkpoint_average_rep_mse = accumulated_rep_mse / ( len(data_loader)) checkpoint_average_total_mse = checkpoint_average_mse + checkpoint_average_rep_mse checkpoint_average_perception_rep_error = accumulated_perception_rep_error / ( len(data_loader)) checkpoint_average_speed_rep_error = accumulated_speed_rep_error / ( len(data_loader)) checkpoint_average_intentions_rep_error = accumulated_intentions_rep_error / ( len(data_loader)) checkpoint_average_rep_error = accumulated_rep_error / ( len(data_loader)) checkpoint_average_total_error = checkpoint_average_error + checkpoint_average_rep_mse # Log L1/L2 loss terms coil_logger.add_scalar( 'Perception Rep Loss', checkpoint_average_perception_rep_mse, latest, True) coil_logger.add_scalar( 'Speed Rep Loss', checkpoint_average_speed_rep_mse, latest, True) coil_logger.add_scalar( 'Intentions Rep Loss', checkpoint_average_intentions_rep_mse, latest, True) coil_logger.add_scalar('Overall Rep Loss', checkpoint_average_rep_mse, latest, True) coil_logger.add_scalar('Total L2 Loss', checkpoint_average_total_mse, latest, True) coil_logger.add_scalar( 'Perception Rep Error', checkpoint_average_perception_rep_error, latest, True) coil_logger.add_scalar( 'Speed Rep Error', checkpoint_average_speed_rep_error, latest, True) coil_logger.add_scalar( 'Intentions Rep Error', checkpoint_average_intentions_rep_error, latest, True) coil_logger.add_scalar('Total Rep Error', checkpoint_average_rep_error, latest, True) coil_logger.add_scalar('Total Loss', checkpoint_average_total_error, latest, True) else: checkpoint_average_total_mse = checkpoint_average_mse checkpoint_average_total_error = checkpoint_average_error if checkpoint_average_total_mse < best_mse: best_mse = checkpoint_average_total_mse best_mse_iter = latest if checkpoint_average_total_error < best_error: best_error = checkpoint_average_total_error best_error_iter = latest # Print for logging / to terminal validation results if g_conf.USE_REPRESENTATION_LOSS: coil_logger.add_message( 'Iterating', { 'Summary': { 'Control Error': checkpoint_average_error, 'Control Loss': checkpoint_average_mse, 'Rep Error': checkpoint_average_rep_error, 'Rep Loss': checkpoint_average_rep_mse, 'Error': checkpoint_average_total_error, 'Loss': checkpoint_average_total_mse, 'BestError': best_error, 'BestMSE': best_mse, 'BestMSECheckpoint': best_mse_iter, 'BestErrorCheckpoint': best_error_iter }, 'Checkpoint': latest }, latest) else: coil_logger.add_message( 'Iterating', { 'Summary': { 'Error': checkpoint_average_error, 'Loss': checkpoint_average_mse, 'BestError': best_error, 'BestMSE': best_mse, 'BestMSECheckpoint': best_mse_iter, 'BestErrorCheckpoint': best_error_iter }, 'Checkpoint': latest }, latest) # Save validation loss history (validation loss is used for early stopping) l1_window.append(checkpoint_average_total_error) coil_logger.write_on_error_csv( dataset_name, checkpoint_average_total_error) # Early stopping if g_conf.FINISH_ON_VALIDATION_STALE is not None: if dlib.count_steps_without_decrease(l1_window) > 3 and \ dlib.count_steps_without_decrease_robust(l1_window) > 3: coil_logger.write_stop(dataset_name, latest) break else: latest = get_latest_evaluated_checkpoint() time.sleep(1) coil_logger.add_message('Loading', {'Message': 'Waiting Checkpoint'}) print("Waiting for the next Validation") coil_logger.add_message('Finished', {}) except KeyboardInterrupt: coil_logger.add_message('Error', {'Message': 'Killed By User'}) # We erase the output that was unfinished due to some process stop. if latest is not None: coil_logger.erase_csv(latest) except RuntimeError as e: if latest is not None: coil_logger.erase_csv(latest) coil_logger.add_message('Error', {'Message': str(e)}) except: traceback.print_exc() coil_logger.add_message('Error', {'Message': 'Something Happened'}) # We erase the output that was unfinished due to some process stop. if latest is not None: coil_logger.erase_csv(latest)
def execute(gpu, exp_batch, exp_alias, dataset_name, validation_set=False): latest = None # We set the visible cuda devices os.environ["CUDA_VISIBLE_DEVICES"] = gpu g_conf.immutable(False) # At this point the log file with the correct naming is created. merge_with_yaml(os.path.join('configs', exp_batch, exp_alias + '.yaml')) # If using validation dataset, fix a very high number of hours if validation_set: g_conf.NUMBER_OF_HOURS = 10000 g_conf.immutable(True) # Define the dataset. full_dataset = [ os.path.join(os.environ["COIL_DATASET_PATH"], dataset_name) ] augmenter = Augmenter(None) if validation_set: # Definition of the dataset to be used. Preload name is just the validation data name dataset = CoILDataset(full_dataset, transform=augmenter, preload_names=[dataset_name]) else: dataset = CoILDataset(full_dataset, transform=augmenter, preload_names=[ str(g_conf.NUMBER_OF_HOURS) + 'hours_' + dataset_name ], train_dataset=True) # The data loader is the multi threaded module from pytorch that release a number of # workers to get all the data. data_loader = torch.utils.data.DataLoader( dataset, batch_size=g_conf.BATCH_SIZE, shuffle=False, num_workers=g_conf.NUMBER_OF_LOADING_WORKERS, pin_memory=True) # Define model model = CoILModel(g_conf.MODEL_TYPE, g_conf.MODEL_CONFIGURATION) """ ###### Run a single driving benchmark specified by the checkpoint were validation is stale ###### """ if g_conf.FINISH_ON_VALIDATION_STALE is not None: while validation_stale_point( g_conf.FINISH_ON_VALIDATION_STALE) is None: time.sleep(0.1) validation_state_iteration = validation_stale_point( g_conf.FINISH_ON_VALIDATION_STALE) checkpoint = torch.load( os.path.join('_logs', exp_batch, exp_alias, 'checkpoints', str(validation_state_iteration) + '.pth')) print("Validation loaded ", validation_state_iteration) else: """ ##### Main Loop , Run a benchmark for each specified checkpoint on the "Test Configuration" ##### """ while not maximun_checkpoint_reach(latest, g_conf.TEST_SCHEDULE): # Get the correct checkpoint # We check it for some task name, all of then are ready at the same time if is_next_checkpoint_ready(g_conf.TEST_SCHEDULE, control_filename + '_' + task_list[0]): latest = get_next_checkpoint( g_conf.TEST_SCHEDULE, control_filename + '_' + task_list[0]) checkpoint = torch.load( os.path.join('_logs', exp_batch, exp_alias, 'checkpoints', str(latest) + '.pth')) print("Validation loaded ", latest) else: time.sleep(0.1) # Load the model and prepare set it for evaluation model.load_state_dict(checkpoint['state_dict']) model.cuda() model.eval() first_iter = True for data in data_loader: # Compute the forward pass on a batch from the dataset and get the intermediate # representations of the squeeze network if "seg" in g_conf.SENSORS.keys(): perception_rep, speed_rep, intentions_rep = \ model.get_intermediate_representations(data, dataset.extract_inputs(data).cuda(), dataset.extract_intentions(data).cuda()) perception_rep = perception_rep.data.cpu() speed_rep = speed_rep.data.cpu() intentions_rep = intentions_rep.data.cpu() if first_iter: perception_rep_all = perception_rep speed_rep_all = speed_rep intentions_rep_all = intentions_rep else: perception_rep_all = torch.cat( [perception_rep_all, perception_rep], 0) speed_rep_all = torch.cat([speed_rep_all, speed_rep], 0) intentions_rep_all = torch.cat( [intentions_rep_all, intentions_rep], 0) first_iter = False # Save intermediate representations perception_rep_all = perception_rep_all.tolist() speed_rep_all = speed_rep_all.tolist() intentions_rep_all = intentions_rep_all.tolist() np.save( os.path.join( '_preloads', exp_batch + '_' + exp_alias + '_' + dataset_name + '_representations'), [perception_rep_all, speed_rep_all, intentions_rep_all])