def punctuation_training(args, loss_fn): """Train the model for the punctuation task.""" text_encoder_type = _text_encoder_type(args.text_encoder) train_dl = dataloader.AlignedDataloader( file_name_input=args.src_train, file_name_target=args.target_train, vocab_size=args.vocab_size, text_encoder_type=text_encoder_type, max_seq_length=args.max_seq_length, cache_dir=_cache_dir(args), ) valid_dl = dataloader.AlignedDataloader( file_name_input=args.src_valid, file_name_target=args.target_valid, vocab_size=args.vocab_size, text_encoder_type=text_encoder_type, encoder_input=train_dl.encoder_input, encoder_target=train_dl.encoder_target, max_seq_length=args.max_seq_length, cache_dir=_cache_dir(args), ) model = models.find(args, train_dl.encoder_input.vocab_size, train_dl.encoder_target.vocab_size) optim = _create_optimizer(model.embedding_size, args) training = Training(model, train_dl, valid_dl, [base.Metrics.BLEU]) training.run( loss_fn, optim, batch_size=args.batch_size, num_epoch=args.epochs, checkpoint=args.checkpoint, )
def _load_encoders(settings): if settings.pretrained is not None: pretrained_dl = dataloader.UnalignedDataloader( file_name=settings.pretrained, vocab_size=settings.vocab_size, text_encoder_type=settings.text_encoder, max_seq_length=settings.max_seq_length, cache_dir=None, ) train_dl = dataloader.AlignedDataloader( file_name_input=settings.src_train, file_name_target=settings.target_train, text_encoder_type=settings.text_encoder, vocab_size=settings.vocab_size, encoder_input=pretrained_dl.encoder, max_seq_length=settings.max_seq_length, cache_dir=None, ) else: train_dl = dataloader.AlignedDataloader( file_name_input=settings.src_train, file_name_target=settings.target_train, text_encoder_type=settings.text_encoder, vocab_size=settings.vocab_size, max_seq_length=settings.max_seq_length, cache_dir=None, ) return train_dl.encoder_input, train_dl.encoder_target
def default_training(args, loss_fn): """Train the model.""" text_encoder_type = _text_encoder_type(args.text_encoder) if args.pretrained is not None: pretrained_dl = dataloader.UnalignedDataloader( file_name=args.pretrained, vocab_size=args.vocab_size, text_encoder_type=text_encoder_type, max_seq_length=args.max_seq_length, cache_dir=_cache_dir(args), ) train_dl = dataloader.AlignedDataloader( file_name_input=args.src_train, file_name_target=args.target_train, text_encoder_type=text_encoder_type, vocab_size=args.vocab_size, encoder_input=pretrained_dl.encoder, max_seq_length=args.max_seq_length, cache_dir=_cache_dir(args), ) else: train_dl = dataloader.AlignedDataloader( file_name_input=args.src_train, file_name_target=args.target_train, vocab_size=args.vocab_size, text_encoder_type=text_encoder_type, max_seq_length=args.max_seq_length, cache_dir=_cache_dir(args), ) valid_dl = dataloader.AlignedDataloader( file_name_input=args.src_valid, file_name_target=args.target_valid, vocab_size=args.vocab_size, text_encoder_type=text_encoder_type, encoder_input=train_dl.encoder_input, encoder_target=train_dl.encoder_target, max_seq_length=args.max_seq_length, cache_dir=_cache_dir(args), ) logger.debug(valid_dl.encoder_target.vocab_size) logger.debug(valid_dl.encoder_input.vocab_size) logger.debug(train_dl.encoder_target.vocab_size) logger.debug(train_dl.encoder_input.vocab_size) model = models.find(args, train_dl.encoder_input.vocab_size, train_dl.encoder_target.vocab_size) optim = _create_optimizer(model.embedding_size, args) training = Training(model, train_dl, valid_dl, [base.Metrics.BLEU]) training.run( loss_fn, optim, batch_size=args.batch_size, num_epoch=args.epochs, checkpoint=args.checkpoint, )
def translate(args): """Translate user's input.""" # Used to load the train text encoders. text_encoder_type = TextEncoderType(args.text_encoder) train_dl = dataloader.AlignedDataloader( file_name_input="data/splitted_data/sorted_train_token.en", file_name_target= "data/splitted_data/sorted_nopunctuation_lowercase_train_token.fr", vocab_size=args.vocab_size, text_encoder_type=text_encoder_type, ) encoder_input = train_dl.encoder_input encoder_target = train_dl.encoder_target # Load the model. model = models.find(args, encoder_input.vocab_size, encoder_target.vocab_size) model.load(str(args.checkpoint)) # Create the message to translate. message = preprocessing.add_start_end_token([args.message])[0] x = tf.convert_to_tensor([train_dl.encoder_input.encode(message)]) # Translate the message. translated = model.translate(x, encoder_target, args.max_seq_length) translated_message = model.predictions(translated, encoder_target, logit=False) logger.info(f"Translation is {translated_message}")
def test(args, loss_fn): """Test the model.""" text_encoder_type = _text_encoder_type(args.text_encoder) # Used to load the train text encoders. if args.pretrained is not None: pretrained_dl = dataloader.UnalignedDataloader( file_name=args.pretrained, vocab_size=args.vocab_size, text_encoder_type=text_encoder_type, max_seq_length=args.max_seq_length, cache_dir=_cache_dir(args), ) train_dl = dataloader.AlignedDataloader( file_name_input=args.src_train, file_name_target=args.target_train, text_encoder_type=text_encoder_type, vocab_size=args.vocab_size, encoder_input=pretrained_dl.encoder, max_seq_length=args.max_seq_length, cache_dir=_cache_dir(args), ) else: train_dl = dataloader.AlignedDataloader( file_name_input=args.src_train, file_name_target=args.target_train, vocab_size=args.vocab_size, text_encoder_type=text_encoder_type, max_seq_length=args.max_seq_length, cache_dir=_cache_dir(args), ) test_dl = dataloader.AlignedDataloader( file_name_input="data/splitted_data/test/test_token10000.en", file_name_target="data/splitted_data/test/test_token10000.fr", vocab_size=args.vocab_size, encoder_input=train_dl.encoder_input, encoder_target=train_dl.encoder_target, text_encoder_type=text_encoder_type, max_seq_length=args.max_seq_length, cache_dir=_cache_dir(args), ) model = models.find(args, train_dl.encoder_input.vocab_size, train_dl.encoder_target.vocab_size) base.test(model, loss_fn, test_dl, args.batch_size, args.checkpoint)
def back_translation_training(args, loss_fn): """Train the model with back translation.""" text_encoder_type = _text_encoder_type(args.text_encoder) logger.info("Creating training unaligned dataloader ...") train_dl = dataloader.UnalignedDataloader( "data/unaligned.en", args.vocab_size, text_encoder_type=text_encoder_type, max_seq_length=args.max_seq_length, ) logger.info(f"English vocab size: {train_dl.encoder.vocab_size}") logger.info("Creating reversed training unaligned dataloader ...") train_dl_reverse = dataloader.UnalignedDataloader( "data/unaligned.fr", args.vocab_size, text_encoder_type=text_encoder_type, max_seq_length=args.max_seq_length, ) logger.info(f"French vocab size: {train_dl_reverse.encoder.vocab_size}") logger.info("Creating training aligned dataloader ...") aligned_train_dl = dataloader.AlignedDataloader( file_name_input="data/splitted_data/sorted_train_token.en", file_name_target= "data/splitted_data/sorted_nopunctuation_lowercase_val_token.fr", vocab_size=args.vocab_size, encoder_input=train_dl.encoder, encoder_target=train_dl_reverse.encoder, text_encoder_type=text_encoder_type, max_seq_length=args.max_seq_length, cache_dir=_cache_dir(args), ) logger.info("Creating reversed training aligned dataloader ...") aligned_train_dl_reverse = dataloader.AlignedDataloader( file_name_input= "data/splitted_data/sorted_nopunctuation_lowercase_val_token.fr", file_name_target="data/splitted_data/sorted_train_token.en", vocab_size=args.vocab_size, encoder_input=aligned_train_dl.encoder_target, encoder_target=aligned_train_dl.encoder_input, text_encoder_type=text_encoder_type, max_seq_length=args.max_seq_length, cache_dir=_cache_dir(args), ) logger.info("Creating valid aligned dataloader ...") aligned_valid_dl = dataloader.AlignedDataloader( file_name_input="data/splitted_data/sorted_val_token.en", file_name_target= "data/splitted_data/sorted_nopunctuation_lowercase_val_token.fr", vocab_size=args.vocab_size, encoder_input=aligned_train_dl.encoder_input, encoder_target=aligned_train_dl.encoder_target, text_encoder_type=text_encoder_type, max_seq_length=args.max_seq_length, cache_dir=_cache_dir(args), ) logger.info("Creating reversed valid aligned dataloader ...") aligned_valid_dl_reverse = dataloader.AlignedDataloader( file_name_input= "data/splitted_data/sorted_nopunctuation_lowercase_val_token.frs", file_name_target="data/splitted_data/sorted_val_token.en", vocab_size=args.vocab_size, encoder_input=aligned_train_dl_reverse.encoder_input, encoder_target=aligned_train_dl_reverse.encoder_target, text_encoder_type=text_encoder_type, max_seq_length=args.max_seq_length, cache_dir=_cache_dir(args), ) model = models.find( args, aligned_train_dl.encoder_input.vocab_size, aligned_train_dl.encoder_target.vocab_size, ) optim = _create_optimizer(model.embedding_size, args) model_reverse = models.find( args, aligned_train_dl_reverse.encoder_input.vocab_size, aligned_train_dl_reverse.encoder_target.vocab_size, ) training = BackTranslationTraining( model, model_reverse, train_dl, train_dl_reverse, aligned_train_dl, aligned_train_dl_reverse, aligned_valid_dl, aligned_valid_dl_reverse, ) training.run( loss_fn, optim, batch_size=args.batch_size, num_epoch=args.epochs, checkpoint=args.checkpoint, )