def pretraining(args, loss_fn):
    """Pretraining the model."""
    text_encoder_type = _text_encoder_type(args.text_encoder)

    train_dl = dataloader.UnalignedDataloader(
        file_name=args.src_train,
        vocab_size=args.vocab_size,
        text_encoder_type=text_encoder_type,
        max_seq_length=args.max_seq_length,
        cache_dir=_cache_dir(args),
    )
    valid_dl = dataloader.UnalignedDataloader(
        file_name=args.src_valid,
        vocab_size=args.vocab_size,
        text_encoder_type=text_encoder_type,
        encoder=train_dl.encoder,
        max_seq_length=args.max_seq_length,
        cache_dir=_cache_dir(args),
    )
    model = models.find(args, train_dl.encoder.vocab_size,
                        train_dl.encoder.vocab_size)
    optim = _create_optimizer(model.embedding_size, args)
    pretraining = Pretraining(model, train_dl, valid_dl)
    pretraining.run(
        loss_fn,
        optim,
        batch_size=args.batch_size,
        num_epoch=args.epochs,
        checkpoint=args.checkpoint,
    )
Example #2
0
def _load_encoders(settings):
    if settings.pretrained is not None:
        pretrained_dl = dataloader.UnalignedDataloader(
            file_name=settings.pretrained,
            vocab_size=settings.vocab_size,
            text_encoder_type=settings.text_encoder,
            max_seq_length=settings.max_seq_length,
            cache_dir=None,
        )
        train_dl = dataloader.AlignedDataloader(
            file_name_input=settings.src_train,
            file_name_target=settings.target_train,
            text_encoder_type=settings.text_encoder,
            vocab_size=settings.vocab_size,
            encoder_input=pretrained_dl.encoder,
            max_seq_length=settings.max_seq_length,
            cache_dir=None,
        )
    else:
        train_dl = dataloader.AlignedDataloader(
            file_name_input=settings.src_train,
            file_name_target=settings.target_train,
            text_encoder_type=settings.text_encoder,
            vocab_size=settings.vocab_size,
            max_seq_length=settings.max_seq_length,
            cache_dir=None,
        )
    return train_dl.encoder_input, train_dl.encoder_target
def default_training(args, loss_fn):
    """Train the model."""
    text_encoder_type = _text_encoder_type(args.text_encoder)

    if args.pretrained is not None:
        pretrained_dl = dataloader.UnalignedDataloader(
            file_name=args.pretrained,
            vocab_size=args.vocab_size,
            text_encoder_type=text_encoder_type,
            max_seq_length=args.max_seq_length,
            cache_dir=_cache_dir(args),
        )
        train_dl = dataloader.AlignedDataloader(
            file_name_input=args.src_train,
            file_name_target=args.target_train,
            text_encoder_type=text_encoder_type,
            vocab_size=args.vocab_size,
            encoder_input=pretrained_dl.encoder,
            max_seq_length=args.max_seq_length,
            cache_dir=_cache_dir(args),
        )
    else:
        train_dl = dataloader.AlignedDataloader(
            file_name_input=args.src_train,
            file_name_target=args.target_train,
            vocab_size=args.vocab_size,
            text_encoder_type=text_encoder_type,
            max_seq_length=args.max_seq_length,
            cache_dir=_cache_dir(args),
        )
    valid_dl = dataloader.AlignedDataloader(
        file_name_input=args.src_valid,
        file_name_target=args.target_valid,
        vocab_size=args.vocab_size,
        text_encoder_type=text_encoder_type,
        encoder_input=train_dl.encoder_input,
        encoder_target=train_dl.encoder_target,
        max_seq_length=args.max_seq_length,
        cache_dir=_cache_dir(args),
    )
    logger.debug(valid_dl.encoder_target.vocab_size)
    logger.debug(valid_dl.encoder_input.vocab_size)
    logger.debug(train_dl.encoder_target.vocab_size)
    logger.debug(train_dl.encoder_input.vocab_size)
    model = models.find(args, train_dl.encoder_input.vocab_size,
                        train_dl.encoder_target.vocab_size)
    optim = _create_optimizer(model.embedding_size, args)
    training = Training(model, train_dl, valid_dl, [base.Metrics.BLEU])
    training.run(
        loss_fn,
        optim,
        batch_size=args.batch_size,
        num_epoch=args.epochs,
        checkpoint=args.checkpoint,
    )
def test(args, loss_fn):
    """Test the model."""
    text_encoder_type = _text_encoder_type(args.text_encoder)
    # Used to load the train text encoders.
    if args.pretrained is not None:
        pretrained_dl = dataloader.UnalignedDataloader(
            file_name=args.pretrained,
            vocab_size=args.vocab_size,
            text_encoder_type=text_encoder_type,
            max_seq_length=args.max_seq_length,
            cache_dir=_cache_dir(args),
        )
        train_dl = dataloader.AlignedDataloader(
            file_name_input=args.src_train,
            file_name_target=args.target_train,
            text_encoder_type=text_encoder_type,
            vocab_size=args.vocab_size,
            encoder_input=pretrained_dl.encoder,
            max_seq_length=args.max_seq_length,
            cache_dir=_cache_dir(args),
        )
    else:
        train_dl = dataloader.AlignedDataloader(
            file_name_input=args.src_train,
            file_name_target=args.target_train,
            vocab_size=args.vocab_size,
            text_encoder_type=text_encoder_type,
            max_seq_length=args.max_seq_length,
            cache_dir=_cache_dir(args),
        )
    test_dl = dataloader.AlignedDataloader(
        file_name_input="data/splitted_data/test/test_token10000.en",
        file_name_target="data/splitted_data/test/test_token10000.fr",
        vocab_size=args.vocab_size,
        encoder_input=train_dl.encoder_input,
        encoder_target=train_dl.encoder_target,
        text_encoder_type=text_encoder_type,
        max_seq_length=args.max_seq_length,
        cache_dir=_cache_dir(args),
    )
    model = models.find(args, train_dl.encoder_input.vocab_size,
                        train_dl.encoder_target.vocab_size)
    base.test(model, loss_fn, test_dl, args.batch_size, args.checkpoint)
Example #5
0
def generate_predictions(input_file_path: str, pred_file_path: str):
    """Generates predictions for the machine translation task (EN->FR).

    You are allowed to modify this function as needed, but one again, you cannot
    modify any other part of this file. We will be importing only this function
    in our final evaluation script. Since you will most definitely need to import
    modules for your code, you must import these inside the function itself.

    Args:
        input_file_path: the file path that contains the input data.
        pred_file_path: the file path where to store the predictions.

    Returns: None

    """
    logger.info(
        f"Generate predictions with input {input_file_path} {pred_file_path}")
    settings = BackTranslationPretrainedDemiBertTransformer()
    encoder_input, encoder_target = _load_encoders(settings)

    # Load the model.
    model = models.find(settings, encoder_input.vocab_size,
                        encoder_target.vocab_size)
    model.load(str(settings.checkpoint))

    dl = dataloader.UnalignedDataloader(
        file_name=input_file_path,
        vocab_size=settings.vocab_size,
        text_encoder_type=settings.text_encoder,
        max_seq_length=settings.max_seq_length,
        cache_dir=None,
        encoder=encoder_input,
    )

    predictions = _generate_predictions(model, dl, encoder_input,
                                        encoder_target, settings.batch_size)
    base.write_text(predictions, pred_file_path)
def back_translation_training(args, loss_fn):
    """Train the model with back translation."""
    text_encoder_type = _text_encoder_type(args.text_encoder)

    logger.info("Creating training unaligned dataloader ...")
    train_dl = dataloader.UnalignedDataloader(
        "data/unaligned.en",
        args.vocab_size,
        text_encoder_type=text_encoder_type,
        max_seq_length=args.max_seq_length,
    )
    logger.info(f"English vocab size: {train_dl.encoder.vocab_size}")

    logger.info("Creating reversed training unaligned dataloader ...")
    train_dl_reverse = dataloader.UnalignedDataloader(
        "data/unaligned.fr",
        args.vocab_size,
        text_encoder_type=text_encoder_type,
        max_seq_length=args.max_seq_length,
    )
    logger.info(f"French vocab size: {train_dl_reverse.encoder.vocab_size}")

    logger.info("Creating training aligned dataloader ...")
    aligned_train_dl = dataloader.AlignedDataloader(
        file_name_input="data/splitted_data/sorted_train_token.en",
        file_name_target=
        "data/splitted_data/sorted_nopunctuation_lowercase_val_token.fr",
        vocab_size=args.vocab_size,
        encoder_input=train_dl.encoder,
        encoder_target=train_dl_reverse.encoder,
        text_encoder_type=text_encoder_type,
        max_seq_length=args.max_seq_length,
        cache_dir=_cache_dir(args),
    )

    logger.info("Creating reversed training aligned dataloader ...")
    aligned_train_dl_reverse = dataloader.AlignedDataloader(
        file_name_input=
        "data/splitted_data/sorted_nopunctuation_lowercase_val_token.fr",
        file_name_target="data/splitted_data/sorted_train_token.en",
        vocab_size=args.vocab_size,
        encoder_input=aligned_train_dl.encoder_target,
        encoder_target=aligned_train_dl.encoder_input,
        text_encoder_type=text_encoder_type,
        max_seq_length=args.max_seq_length,
        cache_dir=_cache_dir(args),
    )

    logger.info("Creating valid aligned dataloader ...")
    aligned_valid_dl = dataloader.AlignedDataloader(
        file_name_input="data/splitted_data/sorted_val_token.en",
        file_name_target=
        "data/splitted_data/sorted_nopunctuation_lowercase_val_token.fr",
        vocab_size=args.vocab_size,
        encoder_input=aligned_train_dl.encoder_input,
        encoder_target=aligned_train_dl.encoder_target,
        text_encoder_type=text_encoder_type,
        max_seq_length=args.max_seq_length,
        cache_dir=_cache_dir(args),
    )

    logger.info("Creating reversed valid aligned dataloader ...")
    aligned_valid_dl_reverse = dataloader.AlignedDataloader(
        file_name_input=
        "data/splitted_data/sorted_nopunctuation_lowercase_val_token.frs",
        file_name_target="data/splitted_data/sorted_val_token.en",
        vocab_size=args.vocab_size,
        encoder_input=aligned_train_dl_reverse.encoder_input,
        encoder_target=aligned_train_dl_reverse.encoder_target,
        text_encoder_type=text_encoder_type,
        max_seq_length=args.max_seq_length,
        cache_dir=_cache_dir(args),
    )

    model = models.find(
        args,
        aligned_train_dl.encoder_input.vocab_size,
        aligned_train_dl.encoder_target.vocab_size,
    )

    optim = _create_optimizer(model.embedding_size, args)
    model_reverse = models.find(
        args,
        aligned_train_dl_reverse.encoder_input.vocab_size,
        aligned_train_dl_reverse.encoder_target.vocab_size,
    )

    training = BackTranslationTraining(
        model,
        model_reverse,
        train_dl,
        train_dl_reverse,
        aligned_train_dl,
        aligned_train_dl_reverse,
        aligned_valid_dl,
        aligned_valid_dl_reverse,
    )

    training.run(
        loss_fn,
        optim,
        batch_size=args.batch_size,
        num_epoch=args.epochs,
        checkpoint=args.checkpoint,
    )