Пример #1
0
def _do_evaluation(model: Module,
                   settings_data:  MutableMapping[str, Any],
                   settings_io:  MutableMapping[str, Any],
                   indices_list: MutableSequence[str]) \
        -> None:
    """Evaluation of an optimized model.
    :param model: Model to use.
    :type model: torch.nn.Module
    :param settings_data: Data settings to use.
    :type settings_data: dict
    :param indices_list: Sequence with the words of the captions.
    :type indices_list: list[str]
    """
    model.eval()
    logger_main = logger.bind(is_caption=False, indent=1)

    data_path_evaluation = Path(
        settings_io['root_dirs']['data'],
        settings_io['dataset']['features_dirs']['output'],
        settings_io['dataset']['features_dirs']['evaluation'])

    logger_main.info('Getting evaluation data')
    validation_data = get_clotho_loader(
        settings_io['dataset']['features_dirs']['evaluation'],
        is_training=False,
        settings_data=settings_data,
        settings_io=settings_io)
    logger_main.info('Done')

    text_sep = '-' * 100
    starting_text = 'Starting evaluation on evaluation data'

    logger_main.info(starting_text)
    logger.bind(is_caption=True,
                indent=0).info(f'{text_sep}\n{text_sep}\n{text_sep}\n\n')
    logger.bind(is_caption=True, indent=0).info(f'{starting_text}.\n\n')

    with no_grad():
        evaluation_outputs = module_epoch_passing(data=validation_data,
                                                  module=model,
                                                  objective=None,
                                                  optimizer=None,
                                                  epoch=-1,
                                                  max_epoch=-1)

    captions_pred, captions_gt = _decode_outputs(
        evaluation_outputs[1],
        evaluation_outputs[2],
        indices_object=indices_list,
        file_names=sorted(list(data_path_evaluation.iterdir())),
        eos_token='<eos>',
        print_to_console=False)

    logger_main.info('Evaluation done')

    metrics = evaluate_metrics(captions_pred, captions_gt)

    for metric, values in metrics.items():
        logger_main.info(f'{metric:<7s}: {values["score"]:7.4f}')
def _do_training(model: Module,
                 settings_training:  MutableMapping[
                     str, Union[Any, MutableMapping[str, Any]]],
                 settings_data:  MutableMapping[
                     str, Union[Any, MutableMapping[str, Any]]],
                 settings_io:  MutableMapping[
                     str, Union[Any, MutableMapping[str, Any]]],
                 model_file_name: str,
                 model_dir: Path,
                 indices_list: MutableSequence[str]) \
        -> None:
    """Optimization of the model.

    :param model: Model to optimize.
    :type model: torch.nn.Module
    :param settings_training: Training settings to use.
    :type settings_training: dict
    :param settings_data: Training data settings to use.
    :type settings_data: dict
    :param settings_io: Data I/O settings to use.
    :type settings_io: dict
    :param model_file_name: File name of the model.
    :type model_file_name: str
    :param model_dir: Directory to serialize the model to.
    :type model_dir: pathlib.Path
    :param indices_list: A sequence with the words.
    :type indices_list: list[str]
    """
    # Initialize variables for the training process
    prv_training_loss = 1e8
    patience: int = settings_training['patience']
    loss_thr: float = settings_training['loss_thr']
    patience_counter = 0
    best_epoch = 0

    # Initialize logger
    logger_main = logger.bind(is_caption=False, indent=1)

    # Inform that we start getting the data
    logger_main.info('Getting training data')

    # Get training data and count the amount of batches
    training_data = get_clotho_loader(
        settings_io['dataset']['features_dirs']['development'],
        is_training=True,
        settings_data=settings_data,
        settings_io=settings_io)

    logger_main.info('Done')

    # Initialize loss and optimizer objects
    objective = CrossEntropyLoss()
    optimizer = Adam(params=model.parameters(),
                     lr=settings_training['optimizer']['lr'])

    # Inform that we start training
    logger_main.info('Starting training')

    model.train()
    for epoch in range(settings_training['nb_epochs']):

        # Log starting time
        start_time = time()

        # Do a complete pass over our training data
        epoch_output = module_epoch_passing(
            data=training_data,
            module=model,
            objective=objective,
            optimizer=optimizer,
            grad_norm=settings_training['grad_norm']['norm'],
            grad_norm_val=settings_training['grad_norm']['value'])
        objective_output, output_y_hat, output_y, f_names = epoch_output

        # Get mean loss of training and print it with logger
        training_loss = objective_output.mean().item()

        logger_main.info(f'Epoch: {epoch:05d} -- '
                         f'Training loss: {training_loss:>7.4f} | '
                         f'Time: {time() - start_time:>5.3f}')

        # Check if we have to decode captions for the current epoch
        if divmod(epoch + 1,
                  settings_training['text_output_every_nb_epochs'])[-1] == 0:

            # Get the subset of files for decoding their captions
            sampling_indices = sorted(
                randperm(len(output_y_hat))
                [:settings_training['nb_examples_to_sample']].tolist())

            # Do the decoding
            _decode_outputs(*zip(*[[output_y_hat[i], output_y[i]]
                                   for i in sampling_indices]),
                            indices_object=indices_list,
                            file_names=[
                                Path(f_names[i_f_name])
                                for i_f_name in sampling_indices
                            ],
                            eos_token='<eos>',
                            print_to_console=False)

        # Check improvement of loss
        if prv_training_loss - training_loss > loss_thr:
            # Log the current loss
            prv_training_loss = training_loss

            # Log the current epoch
            best_epoch = epoch

            # Serialize the model keeping the epoch
            pt_save(
                model.state_dict(),
                str(
                    model_dir.joinpath(
                        f'epoch_{best_epoch:05d}_{model_file_name}')))

            # Zero out the patience
            patience_counter = 0

        else:

            # Increase patience counter
            patience_counter += 1

        # Serialize the model and optimizer.
        for pt_obj, save_str in zip([model, optimizer], ['', '_optimizer']):
            pt_save(
                pt_obj.state_dict(),
                str(model_dir.joinpath(f'latest{save_str}_{model_file_name}')))

        # Check for stopping criteria
        if patience_counter >= patience:
            logger_main.info('No lower training loss for '
                             f'{patience_counter} epochs. '
                             'Training stops.')
            break

    # Inform that we are done
    logger_main.info('Training done')

    # Load best model
    model.load_state_dict(
        pt_load(
            str(model_dir.joinpath(
                f'epoch_{best_epoch:05d}_{model_file_name}'))))
def _do_testing(model: Module,
                settings_data:  MutableMapping[str, Any],
                settings_io:  MutableMapping[str, Any],
                indices_list: MutableSequence[str]) \
        -> None:
    """Evaluation of an optimized model.

    :param model: Model to use.
    :type model: torch.nn.Module
    :param settings_data: Data settings to use.
    :type settings_data: dict
    :param indices_list: Sequence with the words of the captions.
    :type indices_list: list[str]
    """
    model.eval()
    logger_main = logger.bind(is_caption=False, indent=1)

    data_path_test = Path(settings_io['root_dirs']['data'],
                          settings_io['dataset']['features_dirs']['output'],
                          settings_io['dataset']['features_dirs']['test'])

    logger_main.info('Getting test data')
    test_data = get_clotho_loader(
        settings_io['dataset']['features_dirs']['test'],
        is_training=False,
        settings_data=settings_data,
        settings_io=settings_io)
    logger_main.info('Done')

    text_sep = '-' * 100
    starting_text = 'Starting testing on test data'

    logger_main.info(starting_text)
    logger.bind(is_caption=True,
                indent=0).info(f'{text_sep}\n{text_sep}\n{text_sep}\n\n')
    logger.bind(is_caption=True, indent=0).info(f'{starting_text}.\n\n')

    with no_grad():
        test_outputs = module_epoch_passing(data=test_data,
                                            module=model,
                                            objective=None,
                                            optimizer=None)

    captions_pred, _ = _decode_outputs(test_outputs[1],
                                       test_outputs[2],
                                       indices_object=indices_list,
                                       file_names=list(
                                           data_path_test.iterdir()),
                                       eos_token='<eos>',
                                       print_to_console=False)

    # clotho_file_{file_name} to {file_name}.wav
    for i, entry in enumerate(captions_pred):
        entry['file_name'] = entry['file_name']\
            .replace('clotho_file_', '') + '.wav'
        captions_pred[i] = entry

    submission_dir = Path().joinpath(
        settings_io['root_dirs']['outputs'],
        settings_io['submissions']['submissions_dir'])
    submission_dir.mkdir(parents=True, exist_ok=True)
    csv_functions.write_csv_file(captions_pred,
                                 settings_io['submissions']['submission_file'],
                                 submission_dir,
                                 add_timestamp=True)

    logger_main.info('Testing done')