Пример #1
0
    def batch_step(self, batch_loss):

        # store current batch loss
        self.batch_losses_.append(batch_loss)

        # compute statistics on batch loss trend
        count = count_steps_without_decrease(self.batch_losses_)
        count_robust = count_steps_without_decrease_robust(self.batch_losses_)

        # if batch loss hasn't been decreasing for a while
        patience = self.patience * self.batches_per_epoch
        if count > patience and count_robust > patience:

            # decrease optimizer learning rate
            for param_group in self.optimizer.param_groups:
                lr = param_group['lr'] * self.factor
                param_group['lr'] = lr

            # reset batch loss trend
            self.batch_losses_.clear()

        return {
            'lr': lr,
            'epochs_without_decrease': count / self.batches_per_epoch,
            'epochs_without_decrease_robust': \
                count_robust / self.batches_per_epoch}
Пример #2
0
def adjust_learning_rate_auto(optimizer, loss_window):
    """
    Adjusts the learning rate every epoch based on the selected schedule
    """
    minlr = 0.0000001
    learning_rate = g_conf.LEARNING_RATE
    thresh = g_conf.LEARNING_RATE_THRESHOLD
    decaylevel = g_conf.LEARNING_RATE_DECAY_LEVEL
    n = 1000
    start_point = 0
    print("Loss window ", loss_window)
    while n < len(loss_window):
        print("Startpoint ", start_point, " N  ", n)
        steps_no_decrease = dlib.count_steps_without_decrease(
            loss_window[start_point:n])
        steps_no_decrease_robust = dlib.count_steps_without_decrease_robust(
            loss_window[start_point:n])
        print("no decrease, ", steps_no_decrease, " robust",
              steps_no_decrease_robust)
        if steps_no_decrease > thresh and steps_no_decrease_robust > thresh:
            start_point = n
            learning_rate = learning_rate * decaylevel

        n += 1000

    learning_rate = max(learning_rate, minlr)

    for param_group in optimizer.param_groups:
        print("New Learning rate is ", learning_rate)
        param_group['lr'] = learning_rate
Пример #3
0
    def batch_step(self, batch_loss):

        # store current batch loss
        self.batch_losses_.append(batch_loss)

        # compute statistics on batch loss trend
        count = count_steps_without_decrease(self.batch_losses_)
        count_robust = count_steps_without_decrease_robust(self.batch_losses_)

        # if batch loss hasn't been decreasing for a while
        patience = self.patience * self.batches_per_epoch
        if count > patience and count_robust > patience:

            # decrease optimizer learning rate
            for param_group in self.optimizer.param_groups:
                lr = param_group['lr'] * self.factor
                param_group['lr'] = lr

            # reset batch loss trend
            self.batch_losses_.clear()

        return {
            'lr': lr,
            'epochs_without_decrease': count / self.batches_per_epoch,
            'epochs_without_decrease_robust': \
                count_robust / self.batches_per_epoch}
Пример #4
0
def adjust_learning_rate_auto(optimizer, loss_window, coil_logger):
    """
    Adjusts the learning rate every epoch based on the selected schedule
    """
    minlr = 0.0000001
    learning_rate = g_conf.LEARNING_RATE
    n = 1000
    start_point = 0
    while n < len(loss_window):
        # use dlib.net to see for how many steps the noisy loss has gone without noticeably decreasing in value
        steps_no_decrease = dlib.count_steps_without_decrease(
            loss_window[start_point:n])
        # ibidem, just discarding the 10% largest values
        steps_no_decrease_robust = dlib.count_steps_without_decrease_robust(
            loss_window[start_point:n])
        coil_logger.add_message(
            '{(Start_point/n): (Steps not decreased/not decreased robust)}', {
                f'({start_point}/{n})':
                f'{steps_no_decrease}/{steps_no_decrease_robust})'
            })
        if steps_no_decrease > g_conf.LEARNING_RATE_THRESHOLD and \
                steps_no_decrease_robust > g_conf.LEARNING_RATE_THRESHOLD:
            start_point = n
            learning_rate = learning_rate * g_conf.LEARNING_RATE_DECAY_LEVEL

        n += 1000

    learning_rate = max(learning_rate, minlr)
    coil_logger.add_message('Learning rate', {'lr': learning_rate})

    for param_group in optimizer.param_groups:
        param_group['lr'] = learning_rate

    return learning_rate
Пример #5
0
 def ok(self):
     t = self._threshold
     for tag in self._tag_order:
         if len(self._histories[tag]) > self._max_history:
             self._histories[tag] = self._histories[tag][-self.
                                                         _max_history:]
             assert len(self._histories[tag]) == self._max_history
         steps1 = dlib.count_steps_without_decrease(self._histories[tag])
         steps2 = dlib.count_steps_without_decrease_robust(
             self._histories[tag])
         # print('LearningRateHook({}): tag({}), s1({}), s2({}), t({})'.format(self._name, tag, steps1, steps2, t))
         if steps1 > t and steps2 > t:
             print('Metric {} for {} has flatlined'.format(tag, self._name))
             if self.run and tag == self._total_loss:
                 print('REQUESTING STOP', self._name)
                 self.run = False
     return self.run
Пример #6
0
def execute(gpu, exp_batch, exp_alias, dataset_name, suppress_output):
    latest = None
    try:
        # We set the visible cuda devices
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu

        # At this point the log file with the correct naming is created.
        merge_with_yaml(os.path.join('configs', exp_batch,
                                     exp_alias + '.yaml'))
        # The validation dataset is always fully loaded, so we fix a very high number of hours
        g_conf.NUMBER_OF_HOURS = 10000
        set_type_of_process('validation', dataset_name)

        if not os.path.exists('_output_logs'):
            os.mkdir('_output_logs')

        if suppress_output:
            sys.stdout = open(os.path.join(
                '_output_logs', exp_alias + '_' + g_conf.PROCESS_NAME + '_' +
                str(os.getpid()) + ".out"),
                              "a",
                              buffering=1)
            sys.stderr = open(os.path.join(
                '_output_logs', exp_alias + '_err_' + g_conf.PROCESS_NAME +
                '_' + str(os.getpid()) + ".out"),
                              "a",
                              buffering=1)

        # Define the dataset. This structure is has the __get_item__ redefined in a way
        # that you can access the HDFILES positions from the root directory as a in a vector.
        full_dataset = os.path.join(os.environ["COIL_DATASET_PATH"],
                                    dataset_name)
        augmenter = Augmenter(None)
        # Definition of the dataset to be used. Preload name is just the validation data name
        dataset = CoILDataset(full_dataset,
                              transform=augmenter,
                              preload_name=dataset_name)

        # Creates the sampler, this part is responsible for managing the keys. It divides
        # all keys depending on the measurements and produces a set of keys for each bach.

        # The data loader is the multi threaded module from pytorch that release a number of
        # workers to get all the data.
        data_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=g_conf.BATCH_SIZE,
            shuffle=False,
            num_workers=g_conf.NUMBER_OF_LOADING_WORKERS,
            pin_memory=True)

        model = CoILModel(g_conf.MODEL_TYPE, g_conf.MODEL_CONFIGURATION)

        # Set ERFnet for segmentation
        model_erf = ERFNet(20)
        model_erf = torch.nn.DataParallel(model_erf)
        model_erf = model_erf.cuda()

        print("LOAD ERFNet - validate")

        def load_my_state_dict(
            model, state_dict
        ):  #custom function to load model when not all dict elements
            own_state = model.state_dict()
            for name, param in state_dict.items():
                if name not in own_state:
                    continue
                own_state[name].copy_(param)
            return model

        model_erf = load_my_state_dict(
            model_erf,
            torch.load(os.path.join('trained_models/erfnet_pretrained.pth')))
        model_erf.eval()
        print("ERFNet and weights LOADED successfully")

        # The window used to keep track of the trainings
        l1_window = []
        latest = get_latest_evaluated_checkpoint()
        if latest is not None:  # When latest is noe
            l1_window = coil_logger.recover_loss_window(dataset_name, None)

        model.cuda()

        best_mse = 1000
        best_error = 1000
        best_mse_iter = 0
        best_error_iter = 0

        while not maximun_checkpoint_reach(latest, g_conf.TEST_SCHEDULE):

            if is_next_checkpoint_ready(g_conf.TEST_SCHEDULE):

                latest = get_next_checkpoint(g_conf.TEST_SCHEDULE)

                checkpoint = torch.load(
                    os.path.join('_logs', exp_batch, exp_alias, 'checkpoints',
                                 str(latest) + '.pth'))
                checkpoint_iteration = checkpoint['iteration']
                print("Validation loaded ", checkpoint_iteration)

                model.load_state_dict(checkpoint['state_dict'])

                model.eval()
                accumulated_mse = 0
                accumulated_error = 0
                iteration_on_checkpoint = 0
                for data in data_loader:

                    # Compute the forward pass on a batch from  the validation dataset
                    controls = data['directions']

                    # Seg batch
                    rgbs = data['rgb']
                    with torch.no_grad():
                        outputs = model_erf(rgbs)
                    labels = outputs.max(1)[1].byte().cpu().data

                    seg_road = (labels == 0)
                    seg_not_road = (labels != 0)
                    seg = torch.stack((seg_road, seg_not_road), 1).float()

                    output = model.forward_branch(
                        torch.squeeze(seg).cuda(),
                        dataset.extract_inputs(data).cuda(), controls)

                    #                    output = model.foward_branch(torch.squeeze(rgbs).cuda(),
                    #                                                 dataset.extract_inputs(data).cuda(),controls)
                    # It could be either waypoints or direct control
                    if 'waypoint1_angle' in g_conf.TARGETS:
                        write_waypoints_output(checkpoint_iteration, output)
                    else:
                        write_regular_output(checkpoint_iteration, output)

                    mse = torch.mean(
                        (output - dataset.extract_targets(data).cuda()
                         )**2).data.tolist()
                    mean_error = torch.mean(
                        torch.abs(output -
                                  dataset.extract_targets(data).cuda())
                    ).data.tolist()

                    accumulated_error += mean_error
                    accumulated_mse += mse
                    error = torch.abs(output -
                                      dataset.extract_targets(data).cuda())

                    # Log a random position
                    position = random.randint(0, len(output.data.tolist()) - 1)

                    coil_logger.add_message(
                        'Iterating', {
                            'Checkpoint':
                            latest,
                            'Iteration': (str(iteration_on_checkpoint * 120) +
                                          '/' + str(len(dataset))),
                            'MeanError':
                            mean_error,
                            'MSE':
                            mse,
                            'Output':
                            output[position].data.tolist(),
                            'GroundTruth':
                            dataset.extract_targets(
                                data)[position].data.tolist(),
                            'Error':
                            error[position].data.tolist(),
                            'Inputs':
                            dataset.extract_inputs(data)
                            [position].data.tolist()
                        }, latest)
                    iteration_on_checkpoint += 1
                    print("Iteration %d  on Checkpoint %d : Error %f" %
                          (iteration_on_checkpoint, checkpoint_iteration,
                           mean_error))
                """
                    ########
                    Finish a round of validation, write results, wait for the next
                    ########
                """

                checkpoint_average_mse = accumulated_mse / (len(data_loader))
                checkpoint_average_error = accumulated_error / (
                    len(data_loader))
                coil_logger.add_scalar('Loss', checkpoint_average_mse, latest,
                                       True)
                coil_logger.add_scalar('Error', checkpoint_average_error,
                                       latest, True)

                if checkpoint_average_mse < best_mse:
                    best_mse = checkpoint_average_mse
                    best_mse_iter = latest

                if checkpoint_average_error < best_error:
                    best_error = checkpoint_average_error
                    best_error_iter = latest

                coil_logger.add_message(
                    'Iterating', {
                        'Summary': {
                            'Error': checkpoint_average_error,
                            'Loss': checkpoint_average_mse,
                            'BestError': best_error,
                            'BestMSE': best_mse,
                            'BestMSECheckpoint': best_mse_iter,
                            'BestErrorCheckpoint': best_error_iter
                        },
                        'Checkpoint': latest
                    }, latest)

                l1_window.append(checkpoint_average_error)
                coil_logger.write_on_error_csv(dataset_name,
                                               checkpoint_average_error)

                # If we are using the finish when validation stops, we check the current
                if g_conf.FINISH_ON_VALIDATION_STALE is not None:
                    if dlib.count_steps_without_decrease(l1_window) > 3 and \
                            dlib.count_steps_without_decrease_robust(l1_window) > 3:
                        coil_logger.write_stop(dataset_name, latest)
                        break

            else:

                latest = get_latest_evaluated_checkpoint()
                time.sleep(1)

                coil_logger.add_message('Loading',
                                        {'Message': 'Waiting Checkpoint'})
                print("Waiting for the next Validation")

        coil_logger.add_message('Finished', {})

    except KeyboardInterrupt:
        coil_logger.add_message('Error', {'Message': 'Killed By User'})
        # We erase the output that was unfinished due to some process stop.
        if latest is not None:
            coil_logger.erase_csv(latest)

    except RuntimeError as e:
        if latest is not None:
            coil_logger.erase_csv(latest)
        coil_logger.add_message('Error', {'Message': str(e)})

    except:
        traceback.print_exc()
        coil_logger.add_message('Error', {'Message': 'Something Happened'})
        # We erase the output that was unfinished due to some process stop.
        if latest is not None:
            coil_logger.erase_csv(latest)
Пример #7
0
def execute(gpu, exp_batch, exp_alias, validation_dataset, suppress_output):
    latest = None
    try:
        # We set the visible cuda devices
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu

        # At this point the log file with the correct naming is created.
        merge_with_yaml(os.path.join('configs', exp_batch,
                                     f'{exp_alias}.yaml'))
        # The validation dataset is always fully loaded, so we fix a very high number of hours
        g_conf.NUMBER_OF_HOURS = 10000
        set_type_of_process(process_type='validation',
                            param=validation_dataset)

        # Save the output to a file if so desired
        if suppress_output:
            save_output(exp_alias)

        # Define the dataset. This structure has the __get_item__ redefined in a way
        # that you can access the HDFILES positions from the root directory as a in a vector.
        full_dataset = os.path.join(os.environ["COIL_DATASET_PATH"],
                                    validation_dataset)
        augmenter = Augmenter(None)
        # Definition of the dataset to be used. Preload name is just the validation data name
        dataset = CoILDataset(full_dataset,
                              transform=augmenter,
                              preload_name=validation_dataset,
                              process_type='validation')

        # Creates the sampler, this part is responsible for managing the keys. It divides
        # all keys depending on the measurements and produces a set of keys for each bach.

        # The data loader is the multi threaded module from pytorch that release a number of
        # workers to get all the data.
        data_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=g_conf.BATCH_SIZE,
            shuffle=False,
            num_workers=g_conf.NUMBER_OF_LOADING_WORKERS,
            pin_memory=True)

        model = CoILModel(g_conf.MODEL_TYPE, g_conf.MODEL_CONFIGURATION,
                          g_conf.SENSORS).cuda()
        # The window used to keep track of the trainings
        l1_window = []
        latest = get_latest_evaluated_checkpoint()
        if latest is not None:  # When latest is noe
            l1_window = coil_logger.recover_loss_window(
                validation_dataset, None)

        # Keep track of the best loss and the iteration where it happens
        best_loss = 1000
        best_loss_iter = 0

        print(20 * '#')
        print('Starting validation!')
        print(20 * '#')

        # Check if the maximum checkpoint for validating has been reached
        while not maximum_checkpoint_reached(latest):
            # Wait until the next checkpoint is ready (assuming this is run whilst training the model)
            if is_next_checkpoint_ready(g_conf.TEST_SCHEDULE):
                # Get next checkpoint for validation according to the test schedule and load it
                latest = get_next_checkpoint(g_conf.TEST_SCHEDULE)
                checkpoint = torch.load(
                    os.path.join('_logs', exp_batch, exp_alias, 'checkpoints',
                                 f'{latest}.pth'))
                checkpoint_iteration = checkpoint['iteration']

                model.load_state_dict(checkpoint['state_dict'])
                model.eval()  # Turn off dropout and batchnorm (if any)
                print(f"Validation loaded, checkpoint {checkpoint_iteration}")

                # Main metric will be the used loss for training the network
                criterion = Loss(g_conf.LOSS_FUNCTION)
                checkpoint_average_loss = 0

                # Counter
                iteration_on_checkpoint = 0

                with torch.no_grad():  # save some computation/memory
                    for data in data_loader:
                        # Compute the forward pass on a batch from the validation dataset
                        controls = data['directions'].cuda()
                        img = torch.squeeze(data['rgb']).cuda()
                        speed = dataset.extract_inputs(
                            data).cuda()  # this might not always be speed

                        # For auxiliary metrics
                        output = model.forward_branch(img, speed, controls)

                        # For the loss function
                        branches = model(img, speed)
                        loss_function_params = {
                            'branches': branches,
                            'targets': dataset.extract_targets(data).cuda(),
                            'controls': controls,
                            'inputs': speed,
                            'branch_weights': g_conf.BRANCH_LOSS_WEIGHT,
                            'variable_weights': g_conf.VARIABLE_WEIGHT
                        }
                        # It could be either waypoints or direct control
                        if 'waypoint1_angle' in g_conf.TARGETS:
                            write_waypoints_output(checkpoint_iteration,
                                                   output)
                        else:
                            write_regular_output(checkpoint_iteration, output)

                        loss, _ = criterion(loss_function_params)
                        loss = loss.data.tolist()

                        # Log a random position
                        position = random.randint(
                            0,
                            len(output.data.tolist()) - 1)

                        coil_logger.add_message(
                            'Iterating', {
                                'Checkpoint':
                                latest,
                                'Iteration':
                                f'{iteration_on_checkpoint * g_conf.BATCH_SIZE}/{len(dataset)}',
                                f'Validation Loss ({g_conf.LOSS_FUNCTION})':
                                loss,
                                'Output':
                                output[position].data.tolist(),
                                'GroundTruth':
                                dataset.extract_targets(
                                    data)[position].data.tolist(),
                                'Inputs':
                                dataset.extract_inputs(data)
                                [position].data.tolist()
                            }, latest)

                        # We get the average with a growing list of values
                        # Thanks to John D. Cook: http://www.johndcook.com/blog/standard_deviation/
                        iteration_on_checkpoint += 1
                        checkpoint_average_loss += (
                            loss -
                            checkpoint_average_loss) / iteration_on_checkpoint

                        print(
                            f"\rProgress: {100 * iteration_on_checkpoint * g_conf.BATCH_SIZE / len(dataset):3.4f}% - "
                            f"Average Loss ({g_conf.LOSS_FUNCTION}): {checkpoint_average_loss:.16f}",
                            end='')
                """
                    ########
                    Finish a round of validation, write results, wait for the next
                    ########
                """

                coil_logger.add_scalar(
                    f'Validation Loss ({g_conf.LOSS_FUNCTION})',
                    checkpoint_average_loss, latest, True)

                # Let's visualize the distribution of the loss
                coil_logger.add_histogram(
                    f'Validation Checkpoint Loss ({g_conf.LOSS_FUNCTION})',
                    checkpoint_average_loss, latest)

                if checkpoint_average_loss < best_loss:
                    best_loss = checkpoint_average_loss
                    best_loss_iter = latest

                coil_logger.add_message(
                    'Iterating', {
                        'Summary': {
                            'Loss': checkpoint_average_loss,
                            'BestLoss': best_loss,
                            'BestLossCheckpoint': best_loss_iter
                        },
                        'Checkpoint': latest
                    }, latest)

                l1_window.append(checkpoint_average_loss)
                coil_logger.write_on_error_csv(validation_dataset,
                                               checkpoint_average_loss, latest)

                # If we are using the finish when validation stops, we check the current checkpoint
                if g_conf.FINISH_ON_VALIDATION_STALE is not None:
                    if dlib.count_steps_without_decrease(l1_window) > 3 and \
                            dlib.count_steps_without_decrease_robust(l1_window) > 3:
                        coil_logger.write_stop(validation_dataset, latest)
                        break

            else:
                latest = get_latest_evaluated_checkpoint()
                time.sleep(1)

                coil_logger.add_message('Loading',
                                        {'Message': 'Waiting Checkpoint'})
                print("Waiting for the next Validation")

        print('\n' + 20 * '#')
        print('Finished validation!')
        print(20 * '#')
        coil_logger.add_message('Finished', {})

    except KeyboardInterrupt:
        coil_logger.add_message('Error', {'Message': 'Killed By User'})
        # We erase the output that was unfinished due to some process stop.
        if latest is not None:
            coil_logger.erase_csv(latest)

    except RuntimeError as e:
        if latest is not None:
            coil_logger.erase_csv(latest)
        coil_logger.add_message('Error', {'Message': str(e)})

    except:
        traceback.print_exc()
        coil_logger.add_message('Error', {'Message': 'Something Happened'})
        # We erase the output that was unfinished due to some process stop.
        if latest is not None:
            coil_logger.erase_csv(latest)
Пример #8
0
def execute(gpu, exp_batch, exp_alias, dataset_name, suppress_output):
    latest = None
    try:
        # We set the visible cuda devices
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu

        # At this point the log file with the correct naming is created.
        merge_with_yaml(os.path.join('configs', exp_batch,
                                     exp_alias + '.yaml'))
        # The validation dataset is always fully loaded, so we fix a very high number of hours
        g_conf.NUMBER_OF_HOURS = 10000
        set_type_of_process('validation', dataset_name)

        if not os.path.exists('_output_logs'):
            os.mkdir('_output_logs')

        if suppress_output:
            sys.stdout = open(os.path.join(
                '_output_logs', exp_alias + '_' + g_conf.PROCESS_NAME + '_' +
                str(os.getpid()) + ".out"),
                              "a",
                              buffering=1)
            sys.stderr = open(os.path.join(
                '_output_logs', exp_alias + '_err_' + g_conf.PROCESS_NAME +
                '_' + str(os.getpid()) + ".out"),
                              "a",
                              buffering=1)

        # Define the dataset.
        full_dataset = [
            os.path.join(os.environ["COIL_DATASET_PATH"], dataset_name)
        ]
        augmenter = Augmenter(None)
        # Definition of the dataset to be used. Preload name is just the validation data name
        dataset = CoILDataset(full_dataset,
                              transform=augmenter,
                              preload_names=[dataset_name])

        # The data loader is the multi threaded module from pytorch that release a number of
        # workers to get all the data.
        data_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=g_conf.BATCH_SIZE,
            shuffle=False,
            num_workers=g_conf.NUMBER_OF_LOADING_WORKERS,
            pin_memory=True)

        # Create model.
        model = CoILModel(g_conf.MODEL_TYPE, g_conf.MODEL_CONFIGURATION)
        # The window used to keep track of the validation loss
        l1_window = []
        # If we have evaluated a checkpoint, get the validation losses of all the previously
        # evaluated checkpoints (validation loss is used for early stopping)
        latest = get_latest_evaluated_checkpoint()
        if latest is not None:  # When latest is noe
            l1_window = coil_logger.recover_loss_window(dataset_name, None)

        model.cuda()

        best_mse = 1000
        best_error = 1000
        best_mse_iter = 0
        best_error_iter = 0

        # Loop to validate all checkpoints as they are saved during training
        while not maximun_checkpoint_reach(latest, g_conf.TEST_SCHEDULE):
            if is_next_checkpoint_ready(g_conf.TEST_SCHEDULE):
                with torch.no_grad():
                    # Get and load latest checkpoint
                    latest = get_next_checkpoint(g_conf.TEST_SCHEDULE)

                    checkpoint = torch.load(
                        os.path.join('_logs', exp_batch, exp_alias,
                                     'checkpoints',
                                     str(latest) + '.pth'))
                    checkpoint_iteration = checkpoint['iteration']
                    print("Validation loaded ", checkpoint_iteration)

                    model.load_state_dict(checkpoint['state_dict'])
                    model.eval()

                    accumulated_mse = 0
                    accumulated_error = 0
                    iteration_on_checkpoint = 0
                    if g_conf.USE_REPRESENTATION_LOSS:
                        accumulated_perception_rep_mse = 0
                        accumulated_speed_rep_mse = 0
                        accumulated_intentions_rep_mse = 0
                        accumulated_rep_mse = 0
                        accumulated_perception_rep_error = 0
                        accumulated_speed_rep_error = 0
                        accumulated_intentions_rep_error = 0
                        accumulated_rep_error = 0

                    # Validation loop
                    for data in data_loader:

                        # Compute the forward pass on a batch from  the validation dataset
                        controls = data['directions']

                        # Run model forward and get outputs
                        # First case corresponds to squeeze network, second case corresponds to driving model without
                        # mimicking losses, last case corresponds to mimic network
                        if "seg" in g_conf.SENSORS.keys():
                            output = model.forward_branch(
                                data,
                                dataset.extract_inputs(data).cuda(), controls,
                                dataset.extract_intentions(data).cuda())
                        elif not g_conf.USE_REPRESENTATION_LOSS:
                            output = model.forward_branch(
                                data,
                                dataset.extract_inputs(data).cuda(), controls)
                        else:
                            output, intermediate_reps = model.forward_branch(
                                data,
                                dataset.extract_inputs(data).cuda(), controls)

                        write_regular_output(checkpoint_iteration, output)

                        # Compute control loss on current validation batch and accumulate it
                        targets_to_use = dataset.extract_targets(data)

                        mse = torch.mean(
                            (output - targets_to_use.cuda())**2).data.tolist()
                        mean_error = torch.mean(
                            torch.abs(output -
                                      targets_to_use.cuda())).data.tolist()

                        accumulated_error += mean_error
                        accumulated_mse += mse

                        error = torch.abs(output - targets_to_use.cuda())

                        # Compute mimicking losses on current validation batch and accumulate it
                        if g_conf.USE_REPRESENTATION_LOSS:
                            expert_reps = dataset.extract_representations(data)
                            # First L1 losses (seg mask, speed, intention mimicking losses)
                            if g_conf.USE_PERCEPTION_REP_LOSS:
                                perception_rep_loss = torch.sum(
                                    torch.abs(intermediate_reps[0] -
                                              expert_reps[0].cuda())
                                ).data.tolist() / (3 * output.shape[0])
                            else:
                                perception_rep_loss = 0
                            if g_conf.USE_SPEED_REP_LOSS:
                                speed_rep_loss = torch.sum(
                                    torch.abs(intermediate_reps[1] -
                                              expert_reps[1].cuda())
                                ).data.tolist() / (3 * output.shape[0])
                            else:
                                speed_rep_loss = 0
                            if g_conf.USE_INTENTION_REP_LOSS:
                                intentions_rep_loss = torch.sum(
                                    torch.abs(intermediate_reps[2] -
                                              expert_reps[2].cuda())
                                ).data.tolist() / (3 * output.shape[0])
                            else:
                                intentions_rep_loss = 0
                            rep_error = g_conf.REP_LOSS_WEIGHT * (
                                perception_rep_loss + speed_rep_loss +
                                intentions_rep_loss)
                            accumulated_perception_rep_error += perception_rep_loss
                            accumulated_speed_rep_error += speed_rep_loss
                            accumulated_intentions_rep_error += intentions_rep_loss
                            accumulated_rep_error += rep_error

                            # L2 losses now
                            if g_conf.USE_PERCEPTION_REP_LOSS:
                                perception_rep_loss = torch.sum(
                                    (intermediate_reps[0] -
                                     expert_reps[0].cuda())**
                                    2).data.tolist() / (3 * output.shape[0])
                            else:
                                perception_rep_loss = 0
                            if g_conf.USE_SPEED_REP_LOSS:
                                speed_rep_loss = torch.sum(
                                    (intermediate_reps[1] -
                                     expert_reps[1].cuda())**
                                    2).data.tolist() / (3 * output.shape[0])
                            else:
                                speed_rep_loss = 0
                            if g_conf.USE_INTENTION_REP_LOSS:
                                intentions_rep_loss = torch.sum(
                                    (intermediate_reps[2] -
                                     expert_reps[2].cuda())**
                                    2).data.tolist() / (3 * output.shape[0])
                            else:
                                intentions_rep_loss = 0
                            rep_mse = g_conf.REP_LOSS_WEIGHT * (
                                perception_rep_loss + speed_rep_loss +
                                intentions_rep_loss)
                            accumulated_perception_rep_mse += perception_rep_loss
                            accumulated_speed_rep_mse += speed_rep_loss
                            accumulated_intentions_rep_mse += intentions_rep_loss
                            accumulated_rep_mse += rep_mse

                        # Log a random position
                        position = random.randint(
                            0,
                            len(output.data.tolist()) - 1)

                        # Logging
                        if g_conf.USE_REPRESENTATION_LOSS:
                            total_mse = mse + rep_mse
                            total_error = mean_error + rep_error
                            coil_logger.add_message(
                                'Iterating', {
                                    'Checkpoint':
                                    latest,
                                    'Iteration':
                                    (str(iteration_on_checkpoint * 120) + '/' +
                                     str(len(dataset))),
                                    'MeanError':
                                    mean_error,
                                    'MSE':
                                    mse,
                                    'RepMeanError':
                                    rep_error,
                                    'RepMSE':
                                    rep_mse,
                                    'MeanTotalError':
                                    total_error,
                                    'TotalMSE':
                                    total_mse,
                                    'Output':
                                    output[position].data.tolist(),
                                    'GroundTruth':
                                    targets_to_use[position].data.tolist(),
                                    'Error':
                                    error[position].data.tolist(),
                                    'Inputs':
                                    dataset.extract_inputs(
                                        data)[position].data.tolist()
                                }, latest)
                        else:
                            coil_logger.add_message(
                                'Iterating', {
                                    'Checkpoint':
                                    latest,
                                    'Iteration':
                                    (str(iteration_on_checkpoint * 120) + '/' +
                                     str(len(dataset))),
                                    'MeanError':
                                    mean_error,
                                    'MSE':
                                    mse,
                                    'Output':
                                    output[position].data.tolist(),
                                    'GroundTruth':
                                    targets_to_use[position].data.tolist(),
                                    'Error':
                                    error[position].data.tolist(),
                                    'Inputs':
                                    dataset.extract_inputs(
                                        data)[position].data.tolist()
                                }, latest)
                        iteration_on_checkpoint += 1

                        if g_conf.USE_REPRESENTATION_LOSS:
                            print("Iteration %d  on Checkpoint %d : Error %f" %
                                  (iteration_on_checkpoint,
                                   checkpoint_iteration, total_error))
                        else:
                            print("Iteration %d  on Checkpoint %d : Error %f" %
                                  (iteration_on_checkpoint,
                                   checkpoint_iteration, mean_error))
                    """
                        ########
                        Finish a round of validation, write results, wait for the next
                        ########
                    """
                    # Compute average L1 and L2 losses over whole round of validation and log them
                    checkpoint_average_mse = accumulated_mse / (
                        len(data_loader))
                    checkpoint_average_error = accumulated_error / (
                        len(data_loader))
                    coil_logger.add_scalar('L2 Loss', checkpoint_average_mse,
                                           latest, True)
                    coil_logger.add_scalar('Loss', checkpoint_average_error,
                                           latest, True)

                    if g_conf.USE_REPRESENTATION_LOSS:
                        checkpoint_average_perception_rep_mse = accumulated_perception_rep_mse / (
                            len(data_loader))
                        checkpoint_average_speed_rep_mse = accumulated_speed_rep_mse / (
                            len(data_loader))
                        checkpoint_average_intentions_rep_mse = accumulated_intentions_rep_mse / (
                            len(data_loader))
                        checkpoint_average_rep_mse = accumulated_rep_mse / (
                            len(data_loader))
                        checkpoint_average_total_mse = checkpoint_average_mse + checkpoint_average_rep_mse

                        checkpoint_average_perception_rep_error = accumulated_perception_rep_error / (
                            len(data_loader))
                        checkpoint_average_speed_rep_error = accumulated_speed_rep_error / (
                            len(data_loader))
                        checkpoint_average_intentions_rep_error = accumulated_intentions_rep_error / (
                            len(data_loader))
                        checkpoint_average_rep_error = accumulated_rep_error / (
                            len(data_loader))
                        checkpoint_average_total_error = checkpoint_average_error + checkpoint_average_rep_mse

                        # Log L1/L2 loss terms
                        coil_logger.add_scalar(
                            'Perception Rep Loss',
                            checkpoint_average_perception_rep_mse, latest,
                            True)
                        coil_logger.add_scalar(
                            'Speed Rep Loss', checkpoint_average_speed_rep_mse,
                            latest, True)
                        coil_logger.add_scalar(
                            'Intentions Rep Loss',
                            checkpoint_average_intentions_rep_mse, latest,
                            True)
                        coil_logger.add_scalar('Overall Rep Loss',
                                               checkpoint_average_rep_mse,
                                               latest, True)
                        coil_logger.add_scalar('Total L2 Loss',
                                               checkpoint_average_total_mse,
                                               latest, True)

                        coil_logger.add_scalar(
                            'Perception Rep Error',
                            checkpoint_average_perception_rep_error, latest,
                            True)
                        coil_logger.add_scalar(
                            'Speed Rep Error',
                            checkpoint_average_speed_rep_error, latest, True)
                        coil_logger.add_scalar(
                            'Intentions Rep Error',
                            checkpoint_average_intentions_rep_error, latest,
                            True)
                        coil_logger.add_scalar('Total Rep Error',
                                               checkpoint_average_rep_error,
                                               latest, True)
                        coil_logger.add_scalar('Total Loss',
                                               checkpoint_average_total_error,
                                               latest, True)
                    else:
                        checkpoint_average_total_mse = checkpoint_average_mse
                        checkpoint_average_total_error = checkpoint_average_error

                    if checkpoint_average_total_mse < best_mse:
                        best_mse = checkpoint_average_total_mse
                        best_mse_iter = latest

                    if checkpoint_average_total_error < best_error:
                        best_error = checkpoint_average_total_error
                        best_error_iter = latest

                    # Print for logging / to terminal validation results
                    if g_conf.USE_REPRESENTATION_LOSS:
                        coil_logger.add_message(
                            'Iterating', {
                                'Summary': {
                                    'Control Error': checkpoint_average_error,
                                    'Control Loss': checkpoint_average_mse,
                                    'Rep Error': checkpoint_average_rep_error,
                                    'Rep Loss': checkpoint_average_rep_mse,
                                    'Error': checkpoint_average_total_error,
                                    'Loss': checkpoint_average_total_mse,
                                    'BestError': best_error,
                                    'BestMSE': best_mse,
                                    'BestMSECheckpoint': best_mse_iter,
                                    'BestErrorCheckpoint': best_error_iter
                                },
                                'Checkpoint': latest
                            }, latest)
                    else:
                        coil_logger.add_message(
                            'Iterating', {
                                'Summary': {
                                    'Error': checkpoint_average_error,
                                    'Loss': checkpoint_average_mse,
                                    'BestError': best_error,
                                    'BestMSE': best_mse,
                                    'BestMSECheckpoint': best_mse_iter,
                                    'BestErrorCheckpoint': best_error_iter
                                },
                                'Checkpoint': latest
                            }, latest)

                    # Save validation loss history (validation loss is used for early stopping)
                    l1_window.append(checkpoint_average_total_error)
                    coil_logger.write_on_error_csv(
                        dataset_name, checkpoint_average_total_error)

                    # Early stopping
                    if g_conf.FINISH_ON_VALIDATION_STALE is not None:
                        if dlib.count_steps_without_decrease(l1_window) > 3 and \
                                dlib.count_steps_without_decrease_robust(l1_window) > 3:
                            coil_logger.write_stop(dataset_name, latest)
                            break

            else:

                latest = get_latest_evaluated_checkpoint()
                time.sleep(1)

                coil_logger.add_message('Loading',
                                        {'Message': 'Waiting Checkpoint'})
                print("Waiting for the next Validation")

        coil_logger.add_message('Finished', {})

    except KeyboardInterrupt:
        coil_logger.add_message('Error', {'Message': 'Killed By User'})
        # We erase the output that was unfinished due to some process stop.
        if latest is not None:
            coil_logger.erase_csv(latest)

    except RuntimeError as e:
        if latest is not None:
            coil_logger.erase_csv(latest)
        coil_logger.add_message('Error', {'Message': str(e)})

    except:
        traceback.print_exc()
        coil_logger.add_message('Error', {'Message': 'Something Happened'})
        # We erase the output that was unfinished due to some process stop.
        if latest is not None:
            coil_logger.erase_csv(latest)