示例#1
0
def execute(gpu,
            exp_batch,
            exp_alias,
            suppress_output=True,
            number_of_workers=12):
    """
        The main training function. This functions loads the latest checkpoint
        for a given, exp_batch (folder) and exp_alias (experiment configuration).
        With this checkpoint it starts from the beginning or continue some training.
    Args:
        gpu: The GPU number
        exp_batch: the folder with the experiments
        exp_alias: the alias, experiment name
        suppress_output: if the output are going to be saved on a file
        number_of_workers: the number of threads used for data loading

    Returns:
        None

    """
    try:
        # We set the visible cuda devices to select the GPU
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu
        g_conf.VARIABLE_WEIGHT = {}
        # At this point the log file with the correct naming is created.
        # You merge the yaml file with the global configuration structure.
        merge_with_yaml(os.path.join('configs', exp_batch,
                                     exp_alias + '.yaml'))
        set_type_of_process('train')
        # Set the process into loading status.
        coil_logger.add_message('Loading', {'GPU': gpu})

        # Seed RNGs
        torch.manual_seed(g_conf.MAGICAL_SEED)
        random.seed(g_conf.MAGICAL_SEED)

        # Put the output to a separate file if it is the case

        if suppress_output:
            if not os.path.exists('_output_logs'):
                os.mkdir('_output_logs')
            sys.stdout = open(os.path.join(
                '_output_logs', exp_alias + '_' + g_conf.PROCESS_NAME + '_' +
                str(os.getpid()) + ".out"),
                              "a",
                              buffering=1)
            sys.stderr = open(os.path.join(
                '_output_logs', exp_alias + '_err_' + g_conf.PROCESS_NAME +
                '_' + str(os.getpid()) + ".out"),
                              "a",
                              buffering=1)

        if coil_logger.check_finish('train'):
            coil_logger.add_message('Finished', {})
            return

        # Preload option
        if g_conf.PRELOAD_MODEL_ALIAS is not None:
            checkpoint = torch.load(
                os.path.join('_logs', g_conf.PRELOAD_MODEL_BATCH,
                             g_conf.PRELOAD_MODEL_ALIAS, 'checkpoints',
                             str(g_conf.PRELOAD_MODEL_CHECKPOINT) + '.pth'))

        # Get the latest checkpoint to be loaded
        # returns none if there are no checkpoints saved for this model
        checkpoint_file = get_latest_saved_checkpoint()
        if checkpoint_file is not None:
            checkpoint = torch.load(
                os.path.join('_logs', exp_batch, exp_alias, 'checkpoints',
                             str(get_latest_saved_checkpoint())))
            iteration = checkpoint['iteration']
            best_loss = checkpoint['best_loss']
            best_loss_iter = checkpoint['best_loss_iter']
        else:
            iteration = 0
            best_loss = 10000.0
            best_loss_iter = 0

        # Define the dataset.
        # Can specify a list of training datasets or just a single training dataset
        if len(g_conf.TRAIN_DATASET_NAMES) == 0:
            train_dataset_list = [g_conf.TRAIN_DATASET_NAME]
        else:
            train_dataset_list = g_conf.TRAIN_DATASET_NAMES
        full_dataset = [
            os.path.join(os.environ["COIL_DATASET_PATH"], dataset_name)
            for dataset_name in train_dataset_list
        ]

        # By instantiating the augmenter we get a callable that augment images and transform them
        # into tensors.
        augmenter = Augmenter(g_conf.AUGMENTATION)

        # Instantiate the class used to read a dataset. The coil dataset generator
        # can be found
        dataset = CoILDataset(full_dataset,
                              transform=augmenter,
                              preload_names=[
                                  str(g_conf.NUMBER_OF_HOURS) + 'hours_' +
                                  dataset_name
                                  for dataset_name in train_dataset_list
                              ],
                              train_dataset=True)
        print("Loaded dataset")

        # Create dataloader, model, and optimizer
        data_loader = select_balancing_strategy(dataset, iteration,
                                                number_of_workers)
        model = CoILModel(g_conf.MODEL_TYPE, g_conf.MODEL_CONFIGURATION)
        model.cuda()
        optimizer = optim.Adam(model.parameters(), lr=g_conf.LEARNING_RATE)

        # If we have a previous checkpoint, load model, optimizer, and record of previous
        # train loss values (used for the learning rate schedule)
        if checkpoint_file is not None or g_conf.PRELOAD_MODEL_ALIAS is not None:
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            accumulated_time = checkpoint['total_time']
            loss_window = coil_logger.recover_loss_window('train', iteration)
        else:  # We accumulate iteration time and keep the average speed
            accumulated_time = 0
            loss_window = []

        print("Before the loss")

        # Define control loss function
        criterion = Loss(g_conf.LOSS_FUNCTION)

        if iteration == 0 and is_ready_to_save(iteration):

            state = {
                'iteration': iteration,
                'state_dict': model.state_dict(),
                'best_loss': best_loss,
                'total_time': accumulated_time,
                'optimizer': optimizer.state_dict(),
                'best_loss_iter': best_loss_iter
            }
            torch.save(
                state,
                os.path.join('_logs', exp_batch, exp_alias, 'checkpoints',
                             str(iteration) + '.pth'))
        # Training loop
        for data in data_loader:

            # Basically in this mode of execution, we validate every X Steps, if it goes up 3 times,
            # add a stop on the _logs folder that is going to be read by this process
            if g_conf.FINISH_ON_VALIDATION_STALE is not None and \
                    check_loss_validation_stopped(iteration, g_conf.FINISH_ON_VALIDATION_STALE):
                break
            """
                ####################################
                    Main optimization loop
                ####################################
            """

            iteration += 1

            # Adjust learning rate based on training loss
            if iteration % 1000 == 0:
                adjust_learning_rate_auto(optimizer, loss_window)

            capture_time = time.time()
            model.zero_grad()

            controls = data['directions']

            # Run model forward and get outputs
            # First case corresponds to training squeeze network, second case corresponds to training driving model without
            # mimicking losses, last case corresponds to training mimic network
            if "seg" in g_conf.SENSORS.keys():
                branches = model(data,
                                 dataset.extract_inputs(data).cuda(),
                                 dataset.extract_intentions(data).cuda())
            elif not g_conf.USE_REPRESENTATION_LOSS:
                branches = model(data, dataset.extract_inputs(data).cuda())
            else:
                branches, intermediate_reps = model(
                    data,
                    dataset.extract_inputs(data).cuda())

            # Compute control loss
            targets_to_use = dataset.extract_targets(data)
            loss_function_params = {
                'branches': branches,
                'targets': targets_to_use.cuda(),
                'controls': controls.cuda(),
                'inputs': dataset.extract_inputs(data).cuda(),
                'branch_weights': g_conf.BRANCH_LOSS_WEIGHT,
                'variable_weights': g_conf.VARIABLE_WEIGHT
            }
            loss, _ = criterion(loss_function_params)

            # Compute mimicking loss
            if g_conf.USE_REPRESENTATION_LOSS:
                expert_reps = dataset.extract_representations(data)
                # Seg mask mimicking loss
                if g_conf.USE_PERCEPTION_REP_LOSS:
                    perception_rep_loss_elementwise = (
                        intermediate_reps[0] - expert_reps[0].cuda())**2
                    perception_rep_loss = g_conf.PERCEPTION_REP_WEIGHT * torch.sum(
                        perception_rep_loss_elementwise) / branches[0].shape[0]
                else:
                    perception_rep_loss = torch.tensor(0.).cuda()
                # Speed mimicking loss
                if g_conf.USE_SPEED_REP_LOSS:
                    speed_rep_loss_elementwise = (intermediate_reps[1] -
                                                  expert_reps[1].cuda())**2
                    speed_rep_loss = g_conf.SPEED_REP_WEIGHT * torch.sum(
                        speed_rep_loss_elementwise) / branches[0].shape[0]
                else:
                    speed_rep_loss = torch.tensor(0.).cuda()
                # Stop intentions mimicking loss
                if g_conf.USE_INTENTION_REP_LOSS:
                    intentions_rep_loss_elementwise = (
                        intermediate_reps[2] - expert_reps[2].cuda())**2
                    intentions_rep_loss = g_conf.INTENTIONS_REP_WEIGHT * torch.sum(
                        intentions_rep_loss_elementwise) / branches[0].shape[0]
                else:
                    intentions_rep_loss = torch.tensor(0.).cuda()
                rep_loss = g_conf.REP_LOSS_WEIGHT * (
                    perception_rep_loss + speed_rep_loss + intentions_rep_loss)
                overall_loss = loss + rep_loss
            else:
                overall_loss = loss
            overall_loss.backward()
            optimizer.step()
            """
                ####################################
                    Saving the model if necessary
                ####################################
            """

            if is_ready_to_save(iteration):

                state = {
                    'iteration': iteration,
                    'state_dict': model.state_dict(),
                    'best_loss': best_loss,
                    'total_time': accumulated_time,
                    'optimizer': optimizer.state_dict(),
                    'best_loss_iter': best_loss_iter
                }
                torch.save(
                    state,
                    os.path.join('_logs', exp_batch, exp_alias, 'checkpoints',
                                 str(iteration) + '.pth'))
            """
                ################################################
                    Adding tensorboard logs.
                    Making calculations for logging purposes.
                    These logs are monitored by the printer module.
                #################################################
            """
            coil_logger.add_scalar('Loss', loss.data, iteration)
            if g_conf.USE_REPRESENTATION_LOSS:
                coil_logger.add_scalar('Perception Rep Loss',
                                       perception_rep_loss.data, iteration)
                coil_logger.add_scalar('Speed Rep Loss', speed_rep_loss.data,
                                       iteration)
                coil_logger.add_scalar('Intentions Rep Loss',
                                       intentions_rep_loss.data, iteration)
                coil_logger.add_scalar('Overall Rep Loss', rep_loss.data,
                                       iteration)
                coil_logger.add_scalar('Total Loss', overall_loss.data,
                                       iteration)
            if 'rgb' in data:
                coil_logger.add_image('Image', torch.squeeze(data['rgb']),
                                      iteration)
            if overall_loss.data < best_loss:
                best_loss = overall_loss.data.tolist()
                best_loss_iter = iteration

            # Log a random position
            position = random.randint(0, len(data) - 1)

            output = model.extract_branch(torch.stack(branches[0:4]), controls)
            error = torch.abs(output - targets_to_use.cuda())

            accumulated_time += time.time() - capture_time

            # Log to terminal and log file
            if g_conf.USE_REPRESENTATION_LOSS:
                coil_logger.add_message(
                    'Iterating', {
                        'Iteration':
                        iteration,
                        'Loss':
                        overall_loss.data.tolist(),
                        'Control Loss':
                        loss.data.tolist(),
                        'Rep Loss':
                        rep_loss.data.tolist(),
                        'Images/s':
                        (iteration * g_conf.BATCH_SIZE) / accumulated_time,
                        'BestLoss':
                        best_loss,
                        'BestLossIteration':
                        best_loss_iter,
                        'Output':
                        output[position].data.tolist(),
                        'GroundTruth':
                        targets_to_use[position].data.tolist(),
                        'Error':
                        error[position].data.tolist(),
                        'Inputs':
                        dataset.extract_inputs(data)[position].data.tolist()
                    }, iteration)
            else:
                coil_logger.add_message(
                    'Iterating', {
                        'Iteration':
                        iteration,
                        'Loss':
                        loss.data.tolist(),
                        'Images/s':
                        (iteration * g_conf.BATCH_SIZE) / accumulated_time,
                        'BestLoss':
                        best_loss,
                        'BestLossIteration':
                        best_loss_iter,
                        'Output':
                        output[position].data.tolist(),
                        'GroundTruth':
                        targets_to_use[position].data.tolist(),
                        'Error':
                        error[position].data.tolist(),
                        'Inputs':
                        dataset.extract_inputs(data)[position].data.tolist()
                    }, iteration)
            # Save training loss history (useful for restoring training runs since learning rate is adjusted
            # based on training loss)
            loss_window.append(overall_loss.data.tolist())
            coil_logger.write_on_error_csv('train', overall_loss.data)
            print("Iteration: %d  Loss: %f" % (iteration, overall_loss.data))

        coil_logger.add_message('Finished', {})

    except KeyboardInterrupt:
        coil_logger.add_message('Error', {'Message': 'Killed By User'})

    except RuntimeError as e:

        coil_logger.add_message('Error', {'Message': str(e)})

    except:
        traceback.print_exc()
        coil_logger.add_message('Error', {'Message': 'Something Happened'})
示例#2
0
def execute(gpu, exp_batch, exp_alias, dataset_name, suppress_output):
    latest = None
    try:
        # We set the visible cuda devices
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu

        # At this point the log file with the correct naming is created.
        merge_with_yaml(os.path.join('configs', exp_batch,
                                     exp_alias + '.yaml'))
        # The validation dataset is always fully loaded, so we fix a very high number of hours
        g_conf.NUMBER_OF_HOURS = 10000
        set_type_of_process('validation', dataset_name)

        if not os.path.exists('_output_logs'):
            os.mkdir('_output_logs')

        if suppress_output:
            sys.stdout = open(os.path.join(
                '_output_logs', exp_alias + '_' + g_conf.PROCESS_NAME + '_' +
                str(os.getpid()) + ".out"),
                              "a",
                              buffering=1)
            sys.stderr = open(os.path.join(
                '_output_logs', exp_alias + '_err_' + g_conf.PROCESS_NAME +
                '_' + str(os.getpid()) + ".out"),
                              "a",
                              buffering=1)

        # Define the dataset.
        full_dataset = [
            os.path.join(os.environ["COIL_DATASET_PATH"], dataset_name)
        ]
        augmenter = Augmenter(None)
        # Definition of the dataset to be used. Preload name is just the validation data name
        dataset = CoILDataset(full_dataset,
                              transform=augmenter,
                              preload_names=[dataset_name])

        # The data loader is the multi threaded module from pytorch that release a number of
        # workers to get all the data.
        data_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=g_conf.BATCH_SIZE,
            shuffle=False,
            num_workers=g_conf.NUMBER_OF_LOADING_WORKERS,
            pin_memory=True)

        # Create model.
        model = CoILModel(g_conf.MODEL_TYPE, g_conf.MODEL_CONFIGURATION)
        # The window used to keep track of the validation loss
        l1_window = []
        # If we have evaluated a checkpoint, get the validation losses of all the previously
        # evaluated checkpoints (validation loss is used for early stopping)
        latest = get_latest_evaluated_checkpoint()
        if latest is not None:  # When latest is noe
            l1_window = coil_logger.recover_loss_window(dataset_name, None)

        model.cuda()

        best_mse = 1000
        best_error = 1000
        best_mse_iter = 0
        best_error_iter = 0

        # Loop to validate all checkpoints as they are saved during training
        while not maximun_checkpoint_reach(latest, g_conf.TEST_SCHEDULE):
            if is_next_checkpoint_ready(g_conf.TEST_SCHEDULE):
                with torch.no_grad():
                    # Get and load latest checkpoint
                    latest = get_next_checkpoint(g_conf.TEST_SCHEDULE)

                    checkpoint = torch.load(
                        os.path.join('_logs', exp_batch, exp_alias,
                                     'checkpoints',
                                     str(latest) + '.pth'))
                    checkpoint_iteration = checkpoint['iteration']
                    print("Validation loaded ", checkpoint_iteration)

                    model.load_state_dict(checkpoint['state_dict'])
                    model.eval()

                    accumulated_mse = 0
                    accumulated_error = 0
                    iteration_on_checkpoint = 0
                    if g_conf.USE_REPRESENTATION_LOSS:
                        accumulated_perception_rep_mse = 0
                        accumulated_speed_rep_mse = 0
                        accumulated_intentions_rep_mse = 0
                        accumulated_rep_mse = 0
                        accumulated_perception_rep_error = 0
                        accumulated_speed_rep_error = 0
                        accumulated_intentions_rep_error = 0
                        accumulated_rep_error = 0

                    # Validation loop
                    for data in data_loader:

                        # Compute the forward pass on a batch from  the validation dataset
                        controls = data['directions']

                        # Run model forward and get outputs
                        # First case corresponds to squeeze network, second case corresponds to driving model without
                        # mimicking losses, last case corresponds to mimic network
                        if "seg" in g_conf.SENSORS.keys():
                            output = model.forward_branch(
                                data,
                                dataset.extract_inputs(data).cuda(), controls,
                                dataset.extract_intentions(data).cuda())
                        elif not g_conf.USE_REPRESENTATION_LOSS:
                            output = model.forward_branch(
                                data,
                                dataset.extract_inputs(data).cuda(), controls)
                        else:
                            output, intermediate_reps = model.forward_branch(
                                data,
                                dataset.extract_inputs(data).cuda(), controls)

                        write_regular_output(checkpoint_iteration, output)

                        # Compute control loss on current validation batch and accumulate it
                        targets_to_use = dataset.extract_targets(data)

                        mse = torch.mean(
                            (output - targets_to_use.cuda())**2).data.tolist()
                        mean_error = torch.mean(
                            torch.abs(output -
                                      targets_to_use.cuda())).data.tolist()

                        accumulated_error += mean_error
                        accumulated_mse += mse

                        error = torch.abs(output - targets_to_use.cuda())

                        # Compute mimicking losses on current validation batch and accumulate it
                        if g_conf.USE_REPRESENTATION_LOSS:
                            expert_reps = dataset.extract_representations(data)
                            # First L1 losses (seg mask, speed, intention mimicking losses)
                            if g_conf.USE_PERCEPTION_REP_LOSS:
                                perception_rep_loss = torch.sum(
                                    torch.abs(intermediate_reps[0] -
                                              expert_reps[0].cuda())
                                ).data.tolist() / (3 * output.shape[0])
                            else:
                                perception_rep_loss = 0
                            if g_conf.USE_SPEED_REP_LOSS:
                                speed_rep_loss = torch.sum(
                                    torch.abs(intermediate_reps[1] -
                                              expert_reps[1].cuda())
                                ).data.tolist() / (3 * output.shape[0])
                            else:
                                speed_rep_loss = 0
                            if g_conf.USE_INTENTION_REP_LOSS:
                                intentions_rep_loss = torch.sum(
                                    torch.abs(intermediate_reps[2] -
                                              expert_reps[2].cuda())
                                ).data.tolist() / (3 * output.shape[0])
                            else:
                                intentions_rep_loss = 0
                            rep_error = g_conf.REP_LOSS_WEIGHT * (
                                perception_rep_loss + speed_rep_loss +
                                intentions_rep_loss)
                            accumulated_perception_rep_error += perception_rep_loss
                            accumulated_speed_rep_error += speed_rep_loss
                            accumulated_intentions_rep_error += intentions_rep_loss
                            accumulated_rep_error += rep_error

                            # L2 losses now
                            if g_conf.USE_PERCEPTION_REP_LOSS:
                                perception_rep_loss = torch.sum(
                                    (intermediate_reps[0] -
                                     expert_reps[0].cuda())**
                                    2).data.tolist() / (3 * output.shape[0])
                            else:
                                perception_rep_loss = 0
                            if g_conf.USE_SPEED_REP_LOSS:
                                speed_rep_loss = torch.sum(
                                    (intermediate_reps[1] -
                                     expert_reps[1].cuda())**
                                    2).data.tolist() / (3 * output.shape[0])
                            else:
                                speed_rep_loss = 0
                            if g_conf.USE_INTENTION_REP_LOSS:
                                intentions_rep_loss = torch.sum(
                                    (intermediate_reps[2] -
                                     expert_reps[2].cuda())**
                                    2).data.tolist() / (3 * output.shape[0])
                            else:
                                intentions_rep_loss = 0
                            rep_mse = g_conf.REP_LOSS_WEIGHT * (
                                perception_rep_loss + speed_rep_loss +
                                intentions_rep_loss)
                            accumulated_perception_rep_mse += perception_rep_loss
                            accumulated_speed_rep_mse += speed_rep_loss
                            accumulated_intentions_rep_mse += intentions_rep_loss
                            accumulated_rep_mse += rep_mse

                        # Log a random position
                        position = random.randint(
                            0,
                            len(output.data.tolist()) - 1)

                        # Logging
                        if g_conf.USE_REPRESENTATION_LOSS:
                            total_mse = mse + rep_mse
                            total_error = mean_error + rep_error
                            coil_logger.add_message(
                                'Iterating', {
                                    'Checkpoint':
                                    latest,
                                    'Iteration':
                                    (str(iteration_on_checkpoint * 120) + '/' +
                                     str(len(dataset))),
                                    'MeanError':
                                    mean_error,
                                    'MSE':
                                    mse,
                                    'RepMeanError':
                                    rep_error,
                                    'RepMSE':
                                    rep_mse,
                                    'MeanTotalError':
                                    total_error,
                                    'TotalMSE':
                                    total_mse,
                                    'Output':
                                    output[position].data.tolist(),
                                    'GroundTruth':
                                    targets_to_use[position].data.tolist(),
                                    'Error':
                                    error[position].data.tolist(),
                                    'Inputs':
                                    dataset.extract_inputs(
                                        data)[position].data.tolist()
                                }, latest)
                        else:
                            coil_logger.add_message(
                                'Iterating', {
                                    'Checkpoint':
                                    latest,
                                    'Iteration':
                                    (str(iteration_on_checkpoint * 120) + '/' +
                                     str(len(dataset))),
                                    'MeanError':
                                    mean_error,
                                    'MSE':
                                    mse,
                                    'Output':
                                    output[position].data.tolist(),
                                    'GroundTruth':
                                    targets_to_use[position].data.tolist(),
                                    'Error':
                                    error[position].data.tolist(),
                                    'Inputs':
                                    dataset.extract_inputs(
                                        data)[position].data.tolist()
                                }, latest)
                        iteration_on_checkpoint += 1

                        if g_conf.USE_REPRESENTATION_LOSS:
                            print("Iteration %d  on Checkpoint %d : Error %f" %
                                  (iteration_on_checkpoint,
                                   checkpoint_iteration, total_error))
                        else:
                            print("Iteration %d  on Checkpoint %d : Error %f" %
                                  (iteration_on_checkpoint,
                                   checkpoint_iteration, mean_error))
                    """
                        ########
                        Finish a round of validation, write results, wait for the next
                        ########
                    """
                    # Compute average L1 and L2 losses over whole round of validation and log them
                    checkpoint_average_mse = accumulated_mse / (
                        len(data_loader))
                    checkpoint_average_error = accumulated_error / (
                        len(data_loader))
                    coil_logger.add_scalar('L2 Loss', checkpoint_average_mse,
                                           latest, True)
                    coil_logger.add_scalar('Loss', checkpoint_average_error,
                                           latest, True)

                    if g_conf.USE_REPRESENTATION_LOSS:
                        checkpoint_average_perception_rep_mse = accumulated_perception_rep_mse / (
                            len(data_loader))
                        checkpoint_average_speed_rep_mse = accumulated_speed_rep_mse / (
                            len(data_loader))
                        checkpoint_average_intentions_rep_mse = accumulated_intentions_rep_mse / (
                            len(data_loader))
                        checkpoint_average_rep_mse = accumulated_rep_mse / (
                            len(data_loader))
                        checkpoint_average_total_mse = checkpoint_average_mse + checkpoint_average_rep_mse

                        checkpoint_average_perception_rep_error = accumulated_perception_rep_error / (
                            len(data_loader))
                        checkpoint_average_speed_rep_error = accumulated_speed_rep_error / (
                            len(data_loader))
                        checkpoint_average_intentions_rep_error = accumulated_intentions_rep_error / (
                            len(data_loader))
                        checkpoint_average_rep_error = accumulated_rep_error / (
                            len(data_loader))
                        checkpoint_average_total_error = checkpoint_average_error + checkpoint_average_rep_mse

                        # Log L1/L2 loss terms
                        coil_logger.add_scalar(
                            'Perception Rep Loss',
                            checkpoint_average_perception_rep_mse, latest,
                            True)
                        coil_logger.add_scalar(
                            'Speed Rep Loss', checkpoint_average_speed_rep_mse,
                            latest, True)
                        coil_logger.add_scalar(
                            'Intentions Rep Loss',
                            checkpoint_average_intentions_rep_mse, latest,
                            True)
                        coil_logger.add_scalar('Overall Rep Loss',
                                               checkpoint_average_rep_mse,
                                               latest, True)
                        coil_logger.add_scalar('Total L2 Loss',
                                               checkpoint_average_total_mse,
                                               latest, True)

                        coil_logger.add_scalar(
                            'Perception Rep Error',
                            checkpoint_average_perception_rep_error, latest,
                            True)
                        coil_logger.add_scalar(
                            'Speed Rep Error',
                            checkpoint_average_speed_rep_error, latest, True)
                        coil_logger.add_scalar(
                            'Intentions Rep Error',
                            checkpoint_average_intentions_rep_error, latest,
                            True)
                        coil_logger.add_scalar('Total Rep Error',
                                               checkpoint_average_rep_error,
                                               latest, True)
                        coil_logger.add_scalar('Total Loss',
                                               checkpoint_average_total_error,
                                               latest, True)
                    else:
                        checkpoint_average_total_mse = checkpoint_average_mse
                        checkpoint_average_total_error = checkpoint_average_error

                    if checkpoint_average_total_mse < best_mse:
                        best_mse = checkpoint_average_total_mse
                        best_mse_iter = latest

                    if checkpoint_average_total_error < best_error:
                        best_error = checkpoint_average_total_error
                        best_error_iter = latest

                    # Print for logging / to terminal validation results
                    if g_conf.USE_REPRESENTATION_LOSS:
                        coil_logger.add_message(
                            'Iterating', {
                                'Summary': {
                                    'Control Error': checkpoint_average_error,
                                    'Control Loss': checkpoint_average_mse,
                                    'Rep Error': checkpoint_average_rep_error,
                                    'Rep Loss': checkpoint_average_rep_mse,
                                    'Error': checkpoint_average_total_error,
                                    'Loss': checkpoint_average_total_mse,
                                    'BestError': best_error,
                                    'BestMSE': best_mse,
                                    'BestMSECheckpoint': best_mse_iter,
                                    'BestErrorCheckpoint': best_error_iter
                                },
                                'Checkpoint': latest
                            }, latest)
                    else:
                        coil_logger.add_message(
                            'Iterating', {
                                'Summary': {
                                    'Error': checkpoint_average_error,
                                    'Loss': checkpoint_average_mse,
                                    'BestError': best_error,
                                    'BestMSE': best_mse,
                                    'BestMSECheckpoint': best_mse_iter,
                                    'BestErrorCheckpoint': best_error_iter
                                },
                                'Checkpoint': latest
                            }, latest)

                    # Save validation loss history (validation loss is used for early stopping)
                    l1_window.append(checkpoint_average_total_error)
                    coil_logger.write_on_error_csv(
                        dataset_name, checkpoint_average_total_error)

                    # Early stopping
                    if g_conf.FINISH_ON_VALIDATION_STALE is not None:
                        if dlib.count_steps_without_decrease(l1_window) > 3 and \
                                dlib.count_steps_without_decrease_robust(l1_window) > 3:
                            coil_logger.write_stop(dataset_name, latest)
                            break

            else:

                latest = get_latest_evaluated_checkpoint()
                time.sleep(1)

                coil_logger.add_message('Loading',
                                        {'Message': 'Waiting Checkpoint'})
                print("Waiting for the next Validation")

        coil_logger.add_message('Finished', {})

    except KeyboardInterrupt:
        coil_logger.add_message('Error', {'Message': 'Killed By User'})
        # We erase the output that was unfinished due to some process stop.
        if latest is not None:
            coil_logger.erase_csv(latest)

    except RuntimeError as e:
        if latest is not None:
            coil_logger.erase_csv(latest)
        coil_logger.add_message('Error', {'Message': str(e)})

    except:
        traceback.print_exc()
        coil_logger.add_message('Error', {'Message': 'Something Happened'})
        # We erase the output that was unfinished due to some process stop.
        if latest is not None:
            coil_logger.erase_csv(latest)
示例#3
0
def execute(gpu, exp_batch, exp_alias, dataset_name, validation_set=False):
    latest = None
    # We set the visible cuda devices
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    g_conf.immutable(False)
    # At this point the log file with the correct naming is created.
    merge_with_yaml(os.path.join('configs', exp_batch, exp_alias + '.yaml'))
    # If using validation dataset, fix a very high number of hours
    if validation_set:
        g_conf.NUMBER_OF_HOURS = 10000
    g_conf.immutable(True)

    # Define the dataset.
    full_dataset = [
        os.path.join(os.environ["COIL_DATASET_PATH"], dataset_name)
    ]
    augmenter = Augmenter(None)
    if validation_set:
        # Definition of the dataset to be used. Preload name is just the validation data name
        dataset = CoILDataset(full_dataset,
                              transform=augmenter,
                              preload_names=[dataset_name])
    else:
        dataset = CoILDataset(full_dataset,
                              transform=augmenter,
                              preload_names=[
                                  str(g_conf.NUMBER_OF_HOURS) + 'hours_' +
                                  dataset_name
                              ],
                              train_dataset=True)

    # The data loader is the multi threaded module from pytorch that release a number of
    # workers to get all the data.
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=g_conf.BATCH_SIZE,
        shuffle=False,
        num_workers=g_conf.NUMBER_OF_LOADING_WORKERS,
        pin_memory=True)

    # Define model
    model = CoILModel(g_conf.MODEL_TYPE, g_conf.MODEL_CONFIGURATION)
    """ 
        ######
        Run a single driving benchmark specified by the checkpoint were validation is stale
        ######
    """

    if g_conf.FINISH_ON_VALIDATION_STALE is not None:

        while validation_stale_point(
                g_conf.FINISH_ON_VALIDATION_STALE) is None:
            time.sleep(0.1)

        validation_state_iteration = validation_stale_point(
            g_conf.FINISH_ON_VALIDATION_STALE)

        checkpoint = torch.load(
            os.path.join('_logs', exp_batch, exp_alias, 'checkpoints',
                         str(validation_state_iteration) + '.pth'))
        print("Validation loaded ", validation_state_iteration)
    else:
        """
        #####
        Main Loop , Run a benchmark for each specified checkpoint on the "Test Configuration"
        #####
        """
        while not maximun_checkpoint_reach(latest, g_conf.TEST_SCHEDULE):
            # Get the correct checkpoint
            # We check it for some task name, all of then are ready at the same time
            if is_next_checkpoint_ready(g_conf.TEST_SCHEDULE,
                                        control_filename + '_' + task_list[0]):

                latest = get_next_checkpoint(
                    g_conf.TEST_SCHEDULE,
                    control_filename + '_' + task_list[0])

                checkpoint = torch.load(
                    os.path.join('_logs', exp_batch, exp_alias, 'checkpoints',
                                 str(latest) + '.pth'))
                print("Validation loaded ", latest)
            else:
                time.sleep(0.1)

    # Load the model and prepare set it for evaluation
    model.load_state_dict(checkpoint['state_dict'])
    model.cuda()
    model.eval()

    first_iter = True
    for data in data_loader:

        # Compute the forward pass on a batch from the dataset and get the intermediate
        # representations of the squeeze network
        if "seg" in g_conf.SENSORS.keys():
            perception_rep, speed_rep, intentions_rep = \
                model.get_intermediate_representations(data,
                                                       dataset.extract_inputs(data).cuda(),
                                                       dataset.extract_intentions(data).cuda())
            perception_rep = perception_rep.data.cpu()
            speed_rep = speed_rep.data.cpu()
            intentions_rep = intentions_rep.data.cpu()
        if first_iter:
            perception_rep_all = perception_rep
            speed_rep_all = speed_rep
            intentions_rep_all = intentions_rep
        else:
            perception_rep_all = torch.cat(
                [perception_rep_all, perception_rep], 0)
            speed_rep_all = torch.cat([speed_rep_all, speed_rep], 0)
            intentions_rep_all = torch.cat(
                [intentions_rep_all, intentions_rep], 0)
        first_iter = False

    # Save intermediate representations
    perception_rep_all = perception_rep_all.tolist()
    speed_rep_all = speed_rep_all.tolist()
    intentions_rep_all = intentions_rep_all.tolist()
    np.save(
        os.path.join(
            '_preloads', exp_batch + '_' + exp_alias + '_' + dataset_name +
            '_representations'),
        [perception_rep_all, speed_rep_all, intentions_rep_all])