def pretraining(args, loss_fn): """Pretraining the model.""" text_encoder_type = _text_encoder_type(args.text_encoder) train_dl = dataloader.UnalignedDataloader( file_name=args.src_train, vocab_size=args.vocab_size, text_encoder_type=text_encoder_type, max_seq_length=args.max_seq_length, cache_dir=_cache_dir(args), ) valid_dl = dataloader.UnalignedDataloader( file_name=args.src_valid, vocab_size=args.vocab_size, text_encoder_type=text_encoder_type, encoder=train_dl.encoder, max_seq_length=args.max_seq_length, cache_dir=_cache_dir(args), ) model = models.find(args, train_dl.encoder.vocab_size, train_dl.encoder.vocab_size) optim = _create_optimizer(model.embedding_size, args) pretraining = Pretraining(model, train_dl, valid_dl) pretraining.run( loss_fn, optim, batch_size=args.batch_size, num_epoch=args.epochs, checkpoint=args.checkpoint, )
def _load_encoders(settings): if settings.pretrained is not None: pretrained_dl = dataloader.UnalignedDataloader( file_name=settings.pretrained, vocab_size=settings.vocab_size, text_encoder_type=settings.text_encoder, max_seq_length=settings.max_seq_length, cache_dir=None, ) train_dl = dataloader.AlignedDataloader( file_name_input=settings.src_train, file_name_target=settings.target_train, text_encoder_type=settings.text_encoder, vocab_size=settings.vocab_size, encoder_input=pretrained_dl.encoder, max_seq_length=settings.max_seq_length, cache_dir=None, ) else: train_dl = dataloader.AlignedDataloader( file_name_input=settings.src_train, file_name_target=settings.target_train, text_encoder_type=settings.text_encoder, vocab_size=settings.vocab_size, max_seq_length=settings.max_seq_length, cache_dir=None, ) return train_dl.encoder_input, train_dl.encoder_target
def default_training(args, loss_fn): """Train the model.""" text_encoder_type = _text_encoder_type(args.text_encoder) if args.pretrained is not None: pretrained_dl = dataloader.UnalignedDataloader( file_name=args.pretrained, vocab_size=args.vocab_size, text_encoder_type=text_encoder_type, max_seq_length=args.max_seq_length, cache_dir=_cache_dir(args), ) train_dl = dataloader.AlignedDataloader( file_name_input=args.src_train, file_name_target=args.target_train, text_encoder_type=text_encoder_type, vocab_size=args.vocab_size, encoder_input=pretrained_dl.encoder, max_seq_length=args.max_seq_length, cache_dir=_cache_dir(args), ) else: train_dl = dataloader.AlignedDataloader( file_name_input=args.src_train, file_name_target=args.target_train, vocab_size=args.vocab_size, text_encoder_type=text_encoder_type, max_seq_length=args.max_seq_length, cache_dir=_cache_dir(args), ) valid_dl = dataloader.AlignedDataloader( file_name_input=args.src_valid, file_name_target=args.target_valid, vocab_size=args.vocab_size, text_encoder_type=text_encoder_type, encoder_input=train_dl.encoder_input, encoder_target=train_dl.encoder_target, max_seq_length=args.max_seq_length, cache_dir=_cache_dir(args), ) logger.debug(valid_dl.encoder_target.vocab_size) logger.debug(valid_dl.encoder_input.vocab_size) logger.debug(train_dl.encoder_target.vocab_size) logger.debug(train_dl.encoder_input.vocab_size) model = models.find(args, train_dl.encoder_input.vocab_size, train_dl.encoder_target.vocab_size) optim = _create_optimizer(model.embedding_size, args) training = Training(model, train_dl, valid_dl, [base.Metrics.BLEU]) training.run( loss_fn, optim, batch_size=args.batch_size, num_epoch=args.epochs, checkpoint=args.checkpoint, )
def test(args, loss_fn): """Test the model.""" text_encoder_type = _text_encoder_type(args.text_encoder) # Used to load the train text encoders. if args.pretrained is not None: pretrained_dl = dataloader.UnalignedDataloader( file_name=args.pretrained, vocab_size=args.vocab_size, text_encoder_type=text_encoder_type, max_seq_length=args.max_seq_length, cache_dir=_cache_dir(args), ) train_dl = dataloader.AlignedDataloader( file_name_input=args.src_train, file_name_target=args.target_train, text_encoder_type=text_encoder_type, vocab_size=args.vocab_size, encoder_input=pretrained_dl.encoder, max_seq_length=args.max_seq_length, cache_dir=_cache_dir(args), ) else: train_dl = dataloader.AlignedDataloader( file_name_input=args.src_train, file_name_target=args.target_train, vocab_size=args.vocab_size, text_encoder_type=text_encoder_type, max_seq_length=args.max_seq_length, cache_dir=_cache_dir(args), ) test_dl = dataloader.AlignedDataloader( file_name_input="data/splitted_data/test/test_token10000.en", file_name_target="data/splitted_data/test/test_token10000.fr", vocab_size=args.vocab_size, encoder_input=train_dl.encoder_input, encoder_target=train_dl.encoder_target, text_encoder_type=text_encoder_type, max_seq_length=args.max_seq_length, cache_dir=_cache_dir(args), ) model = models.find(args, train_dl.encoder_input.vocab_size, train_dl.encoder_target.vocab_size) base.test(model, loss_fn, test_dl, args.batch_size, args.checkpoint)
def generate_predictions(input_file_path: str, pred_file_path: str): """Generates predictions for the machine translation task (EN->FR). You are allowed to modify this function as needed, but one again, you cannot modify any other part of this file. We will be importing only this function in our final evaluation script. Since you will most definitely need to import modules for your code, you must import these inside the function itself. Args: input_file_path: the file path that contains the input data. pred_file_path: the file path where to store the predictions. Returns: None """ logger.info( f"Generate predictions with input {input_file_path} {pred_file_path}") settings = BackTranslationPretrainedDemiBertTransformer() encoder_input, encoder_target = _load_encoders(settings) # Load the model. model = models.find(settings, encoder_input.vocab_size, encoder_target.vocab_size) model.load(str(settings.checkpoint)) dl = dataloader.UnalignedDataloader( file_name=input_file_path, vocab_size=settings.vocab_size, text_encoder_type=settings.text_encoder, max_seq_length=settings.max_seq_length, cache_dir=None, encoder=encoder_input, ) predictions = _generate_predictions(model, dl, encoder_input, encoder_target, settings.batch_size) base.write_text(predictions, pred_file_path)
def back_translation_training(args, loss_fn): """Train the model with back translation.""" text_encoder_type = _text_encoder_type(args.text_encoder) logger.info("Creating training unaligned dataloader ...") train_dl = dataloader.UnalignedDataloader( "data/unaligned.en", args.vocab_size, text_encoder_type=text_encoder_type, max_seq_length=args.max_seq_length, ) logger.info(f"English vocab size: {train_dl.encoder.vocab_size}") logger.info("Creating reversed training unaligned dataloader ...") train_dl_reverse = dataloader.UnalignedDataloader( "data/unaligned.fr", args.vocab_size, text_encoder_type=text_encoder_type, max_seq_length=args.max_seq_length, ) logger.info(f"French vocab size: {train_dl_reverse.encoder.vocab_size}") logger.info("Creating training aligned dataloader ...") aligned_train_dl = dataloader.AlignedDataloader( file_name_input="data/splitted_data/sorted_train_token.en", file_name_target= "data/splitted_data/sorted_nopunctuation_lowercase_val_token.fr", vocab_size=args.vocab_size, encoder_input=train_dl.encoder, encoder_target=train_dl_reverse.encoder, text_encoder_type=text_encoder_type, max_seq_length=args.max_seq_length, cache_dir=_cache_dir(args), ) logger.info("Creating reversed training aligned dataloader ...") aligned_train_dl_reverse = dataloader.AlignedDataloader( file_name_input= "data/splitted_data/sorted_nopunctuation_lowercase_val_token.fr", file_name_target="data/splitted_data/sorted_train_token.en", vocab_size=args.vocab_size, encoder_input=aligned_train_dl.encoder_target, encoder_target=aligned_train_dl.encoder_input, text_encoder_type=text_encoder_type, max_seq_length=args.max_seq_length, cache_dir=_cache_dir(args), ) logger.info("Creating valid aligned dataloader ...") aligned_valid_dl = dataloader.AlignedDataloader( file_name_input="data/splitted_data/sorted_val_token.en", file_name_target= "data/splitted_data/sorted_nopunctuation_lowercase_val_token.fr", vocab_size=args.vocab_size, encoder_input=aligned_train_dl.encoder_input, encoder_target=aligned_train_dl.encoder_target, text_encoder_type=text_encoder_type, max_seq_length=args.max_seq_length, cache_dir=_cache_dir(args), ) logger.info("Creating reversed valid aligned dataloader ...") aligned_valid_dl_reverse = dataloader.AlignedDataloader( file_name_input= "data/splitted_data/sorted_nopunctuation_lowercase_val_token.frs", file_name_target="data/splitted_data/sorted_val_token.en", vocab_size=args.vocab_size, encoder_input=aligned_train_dl_reverse.encoder_input, encoder_target=aligned_train_dl_reverse.encoder_target, text_encoder_type=text_encoder_type, max_seq_length=args.max_seq_length, cache_dir=_cache_dir(args), ) model = models.find( args, aligned_train_dl.encoder_input.vocab_size, aligned_train_dl.encoder_target.vocab_size, ) optim = _create_optimizer(model.embedding_size, args) model_reverse = models.find( args, aligned_train_dl_reverse.encoder_input.vocab_size, aligned_train_dl_reverse.encoder_target.vocab_size, ) training = BackTranslationTraining( model, model_reverse, train_dl, train_dl_reverse, aligned_train_dl, aligned_train_dl_reverse, aligned_valid_dl, aligned_valid_dl_reverse, ) training.run( loss_fn, optim, batch_size=args.batch_size, num_epoch=args.epochs, checkpoint=args.checkpoint, )