Ejemplo n.º 1
0
    def __init__(self, opt):
        self.opt = opt

        model = Transformer(opt)

        checkpoint = torch.load(opt['model_path'])
        model.load_state_dict(checkpoint)
        print('Loaded pre-trained model_state..')

        self.model = model
        self.model.eval()
Ejemplo n.º 2
0
    def test_forward(self):
        """
        Test the model forward pass (i.e. not training it) on a batch of randomly created samples.

        """
        params = {
            'd_model': 512,
            'src_vocab_size': 27000,
            'tgt_vocab_size': 27000,
            'N': 6,
            'dropout': 0.1,
            'attention': {
                'n_head': 8,
                'd_k': 64,
                'd_v': 64,
                'dropout': 0.1
            },
            'feed-forward': {
                'd_ff': 2048,
                'dropout': 0.1
            },
        }

        # 1. test constructor
        transformer = Transformer(params)

        # 2. test forward pass
        batch_size = 64
        input_sequence_length = 10
        output_sequence_length = 13

        # create a batch of random samples
        src = torch.randint(low=1,
                            high=params["src_vocab_size"],
                            size=(batch_size, input_sequence_length))
        trg = torch.randint(low=1,
                            high=params["tgt_vocab_size"],
                            size=(batch_size, output_sequence_length))

        # create masks for src & trg: assume pad_token=0. Since we draw randomly in [1, upper_bound), the mask is only 1s.
        src_mask = torch.ones_like(src).unsqueeze(-2)
        trg_mask = torch.ones_like(trg).unsqueeze(-2)

        logits = transformer(src_sequences=src,
                             src_mask=src_mask,
                             trg_sequences=trg,
                             trg_mask=trg_mask)

        self.assertIsInstance(logits, torch.Tensor)
        self.assertEqual(
            logits.shape,
            torch.Size(
                [batch_size, output_sequence_length,
                 params['tgt_vocab_size']]))
        # check no nan values
        self.assertEqual(torch.isnan(logits).sum(), 0)
Ejemplo n.º 3
0
def main(opt):

    model = Transformer(opt)
    checkpoint = torch.load(opt['model_path'])
    model.load_state_dict(checkpoint)

    translator = Translator(opt)

    for i, batch in enumerate(valid_iterator):
        print(i, end='\n')
        src = batch.src

        src_mask = (src != SRC.vocab.stoi["<pad>"]).unsqueeze(-2)
        translator.translate_batch(src, src_mask)

        print("Target:", end="\t")
        for k in range(src.size(1)):
            for j in range(1, batch.trg.size(1)):
                sym = TRG.vocab.itos[batch.trg[k, j]]
                if sym == KEY_EOS: break
                print(sym, end=" ")
            print('\n')
Ejemplo n.º 4
0
def main(opt):
    model = Transformer(opt)

    optimizer = NoamOpt(
        opt['d_model'], opt['optim_factor'], opt['n_warmup_steps'],
        torch.optim.Adam(model.parameters(),
                         lr=opt['lr'],
                         betas=(0.9, 0.98),
                         eps=1e-9))

    criterion = LabelSmoothing(vocab_size=len(TRG.vocab),
                               pad_idx=PAD_IDX,
                               smoothing=opt['smoothing'])

    for epoch in range(option["max_epochs"]):
        model.train()
        single_epoch(train_iterator, model,
                     LossComputation(criterion, optimizer=optimizer))
        model.eval()
        loss = single_epoch(valid_iterator, model,
                            LossComputation(criterion, optimizer=None))
        print(loss)

    torch.save(model.state_dict(), opt['model_path'])
Ejemplo n.º 5
0
import tensorflow as tf
import argparse
import os
import tqdm

from config.hyperparams import hparams as hp
from transformer.feeder import get_batch_data, load_src_vocab, load_tgt_vocab
from transformer.model import Transformer


if __name__ == '__main__':      
    parser = argparse.ArgumentParser(description="Train model.")
    parser.add_argument("--gpu_index", default=0, type=int,
                            help="GPU index.")
    args = parser.parse_args()

    g = Transformer()
    
    sv = tf.train.Supervisor(graph=g.graph, 
                             logdir=hp.logdir,
                             save_model_secs=0)
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_index)
    with sv.managed_session() as sess:
        for epoch in range(1, hp.num_epochs+1): 
            if sv.should_stop():
                break
            for step in tqdm.tqdm(range(g.num_batch), total=g.num_batch, ncols=70, leave=False, unit='b'):
                sess.run(g.train_op)
                
            gs = sess.run(g.global_step)   
            sv.saver.save(sess, hp.logdir + '/model_epoch_%02d_gs_%d' % (epoch, gs))
Ejemplo n.º 6
0
def train(config_path: str = 'default',
          tqdm: Optional[Iterable] = None,
          **kwargs):
    """Trains a Transformer model on the config-specified dataset.

    Optionally Tokenizes and restructures the input data to be used with a batch generator.
    Any additional preprocessing (removal of certain words, replacement of words, etc.)
    ought to be done beforehand if required using a custom preprocess.initial_cleanup for each
    respective dataset used.
    """
    config = get_config(config_path)
    tqdm = get_tqdm(tqdm or config.get('tqdm'))
    logger.setLevel(config.logging_level)

    logger('Start Training.')
    # ### Setup ### #
    if config.tokenize:
        logger('> creating tokenizer...')
        tokenizer = Tokenizer(
            input_paths=config.input_paths,
            tokenizer_output_path=config.tokenizer_output_path,
            vocab_size=config.vocab_size,
            lowercase=config.lowercase)
        logger('> creating tokens with tokenizer')
        tokenizer.encode_files(input_paths=config.input_paths,
                               tokens_output_dir=config.tokens_output_dir,
                               return_encodings=False,
                               tqdm=tqdm)
    else:
        tokenizer = Tokenizer.load(path=config.tokenizer_output_path)

    if config.create_dataset and config.load_tokens:
        logger('> loading tokens')
        english_tokens_path = os.path.join(config.tokens_output_dir,
                                           'train-en.pkl')
        german_tokens_path = os.path.join(config.tokens_output_dir,
                                          'train-de.pkl')

        logger('> loading tokens (1/2)')
        english_tokens = load_tokens(path=english_tokens_path)
        logger(f'>>> length of tokens: {len(english_tokens)}')

        logger('> loading tokens (2/2)')
        german_tokens = load_tokens(path=german_tokens_path)
        logger(f'>>> length of tokens: {len(german_tokens)}')

        logger.debug(f'GERMAN TOKENS:  {german_tokens[:3]}')
        logger.debug(f'ENGLISH TOKENS: {english_tokens[:3]}')

    if config.create_dataset:
        logger('> creating dataset for training')
        create_training_dataset(
            english_tokens=english_tokens,
            german_tokens=german_tokens,
            max_samples=config.max_samples,
            validation_split=config.validation_split,
            sample_length=config.sample_length,
            save_dataset=config.save_training_dataset,
            save_interval=config.save_interval,
            save_dir=config.processed_dir,
            save_dir_validation=config.processed_dir_validation,
            save_compression=config.compression,
            tqdm=tqdm)

    if config.retrain or not os.path.exists(config.model_output_path):
        logger('> creating Transformer model')
        model = Transformer(
            sequence_length=config.sequence_length,
            d_layers=config.d_layers,
            d_heads=config.d_heads,
            d_model=config.d_model,
            d_k=config.d_k,
            d_v=config.d_v,
            d_mlp_hidden=config.d_mlp_hidden,
            batch_size=config.batch_size,
            vocab_size=config.vocab_size,
            use_mask=config.use_mask,
            use_positional_encoding=config.use_positional_encoding)

        logger('>>> compiling Transformer model')
        model.compile(optimizer='Adam',
                      loss='categorical_crossentropy',
                      metrics=['accuracy'])
    else:
        logger('> loading Transformer model')
        model = Transformer.load(config.model_output_path, compile=True)

    if config.verbose:
        model.summary(print_fn=logger)

    logger('>>> creating batch generator')
    generator = NextTokenBatchGenerator(data_dir=config.processed_dir,
                                        epoch_steps=config.train_steps,
                                        batch_size=config.batch_size,
                                        vocab_size=config.vocab_size,
                                        sample_length=config.sample_length)
    validation_generator = NextTokenBatchGenerator(
        data_dir=config.processed_dir_validation,
        epoch_steps=config.validation_steps,
        batch_size=config.batch_size,
        vocab_size=config.vocab_size,
        sample_length=config.sample_length)

    logger('>>> creating callbacks')
    use_callbacks = [
        callbacks.VaswaniLearningRate(
            steps_per_epoch=generator.steps_per_epoch,
            warmup_steps=config.warmup_steps,
            print_fn=logger.debug),
        callbacks.WriteLogsToFile(filepath=config.train_logs_output_path,
                                  overwrite_old_file=False),
        callbacks.SaveModel(filepath=config.model_output_path,
                            on_epoch_end=True),
        callbacks.PrintExamples(tokenizer=tokenizer,
                                generator=generator,
                                print_fn=logger)
    ]

    logger('> starting training of model')
    model.fit_generator(generator=generator(),
                        validation_data=validation_generator(),
                        steps_per_epoch=generator.steps_per_epoch,
                        validation_steps=validation_generator.steps_per_epoch,
                        epochs=config.epochs,
                        callbacks=use_callbacks,
                        shuffle=False)
    logger('Completed Training.')
Ejemplo n.º 7
0
    _, val_iterator, _, src_vocab, trg_vocab = (IWSLTDatasetBuilder.build(
        language_pair=LanguagePair.fr_en,
        split=Split.Validation,
        max_length=40,
        batch_size_train=batch_size))
    print(f"Loading model from '{args.model_path}'...")
    model = Transformer.load_model_from_file(args.model_path,
                                             params={
                                                 'd_model': 512,
                                                 'N': 6,
                                                 'dropout': 0.1,
                                                 'src_vocab_size':
                                                 len(src_vocab),
                                                 'tgt_vocab_size':
                                                 len(trg_vocab),
                                                 'attention': {
                                                     'n_head': 8,
                                                     'd_k': 64,
                                                     'd_v': 64,
                                                     'dropout': 0.1
                                                 },
                                                 'feed-forward': {
                                                     'd_ff': 2048,
                                                     'dropout': 0.1
                                                 }
                                             })

    print("Computing loss on validation set...")
    loss_fn = LabelSmoothingLoss(size=len(trg_vocab),
                                 padding_token=src_vocab.stoi['<blank>'],
                                 smoothing=smoothing)
    val_loss = 0.
Ejemplo n.º 8
0
"""(object, scalar, scalar)
注意这个maxlen“100”,太大了的话图会变得非常大
"""
eval_batches, num_eval_batches, num_eval_samples = get_batch(hp.eval1, hp.eval2, hp.maxlen1, hp.maxlen2,
                                                             hp.vocab, hp.eval_batch_size, shuffle=False)
"""(x, x_seqlen, sent1), (decoder_input, y, y_seqlen, sent2)
eval_batches.output_types = ((tf.int32, tf.int32, tf.string), (tf.int32, tf.int32, tf.int32, tf.string))
eval_batches.output_shapes = (([None], (), ()), ([None], [None], (), ()))
"""
iterr = tf.data.Iterator.from_structure(eval_batches.output_types, eval_batches.output_shapes)
xs, ys = iterr.get_next()

eval_init_op = iterr.make_initializer(eval_batches)  # Create an op, but not run yet

context = Context(hp)
m = Transformer(context)
y_hat, eval_summaries = m.eval(xs, ys)  # y_hat: elements are indices, a target that can be run


logging.info("# Session")
saver = tf.train.Saver()
with tf.Session() as sess:
    ckpt = tf.train.latest_checkpoint(hp.logdir)
    if ckpt is None:
        logging.warning("No checkpoint is found")
        exit(1)
    else:
        saver.restore(sess, ckpt)

    logging.info("# test evaluation")
    sess.run(eval_init_op)
Ejemplo n.º 9
0
    hp.maxlen2,
    hp.vocab,
    hp.eval_batch_size,
    shuffle=False)
"""(x, x_seqlen, sent1), (decoder_input, y, y_seqlen, sent2)
eval_batches.output_types = ((tf.int32, tf.int32, tf.string), (tf.int32, tf.int32, tf.int32, tf.string))
eval_batches.output_shapes = (([None], (), ()), ([None], [None], (), ()))
"""
iterr = tf.data.Iterator.from_structure(eval_batches.output_types,
                                        eval_batches.output_shapes)
xs, ys = iterr.get_next()

eval_init_op = iterr.make_initializer(
    eval_batches)  # Create an op, but not run yet

m = Transformer(hp)
y_hat, _ = m.debug(xs,
                   ys)  # y_hat: elements are indices, a target that can be run

logging.info("# Session")
saver = tf.train.Saver()
with tf.Session() as sess:
    ckpt = tf.train.latest_checkpoint(hp.logdir)
    if ckpt is None:
        logging.warning("No checkpoint is found")
        exit(1)
    else:
        saver.restore(sess, ckpt)
    # logging.info("# init profile")
    # run_metadata = tf.RunMetadata()
    # run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
Ejemplo n.º 10
0
parser = argparse.ArgumentParser()
parser.add_argument("--output_dir", type=str)
parser.add_argument("--mode", type=str)
args = parser.parse_args()


if __name__ == "__main__":
    output_dir = os.path.expanduser(args.output_dir)
    os.makedirs(output_dir, exist_ok=True)

    base_dir = Path(__file__).resolve().parent
    hparams = Hparams(base_dir.joinpath("hparams.json"))

    model = Transformer(
        num_layers=hparams.num_layers,
        num_heads=hparams.num_heads,
        d_model=hparams.d_model,
        d_ff=hparams.d_ff,
        encoder_vocab_size=hparams.encoder_vocab_size,
        decoder_vocab_size=hparams.decoder_vocab_size,
        dropout_rate=hparams.dropout_rate,
    )

    trainer = Trainer(hparams, model)
    trainer.train_and_evaluate(
        train_steps=hparams.train_steps,
        eval_steps=hparams.eval_steps,
        eval_frequency=hparams.eval_frequency,
        checkpoint_dir=output_dir,
    )
    fields=(SRC, TGT),
    filter_pred=lambda x: len(vars(x)['src']) <= config['max_len'] and len(
        vars(x)['trg']) <= config['max_len'],
    root='./.data/')

# building shared vocabulary

test_iter = MyIterator(test,
                       batch_size=config['batch_size'],
                       device=torch.device(0),
                       repeat=False,
                       sort_key=lambda x: (len(x.src), len(x.trg)),
                       batch_size_fn=batch_size_fn,
                       train=False)

model = Transformer(len(SRC.vocab), len(TGT.vocab), N=config['num_layers'])
model.load_state_dict(state['model_state_dict'])
print('Model loaded.')

model.cuda()
model.eval()

test_bleu = validate(model,
                     test_iter,
                     SRC,
                     TGT,
                     BOS_WORD,
                     EOS_WORD,
                     BLANK_WORD,
                     config['max_len'],
                     logging=True)
Ejemplo n.º 12
0
    hp.vocab,
    hp.eval_batch_size,
    shuffle=False)

# create a iterator of the correct shape and type
iterr = tf.data.Iterator.from_structure(train_batches.output_types,
                                        train_batches.output_shapes)
xs, ys = iterr.get_next()

# 照抄即可,目前不是很熟悉这些接口
train_init_op = iterr.make_initializer(train_batches)
eval_init_op = iterr.make_initializer(eval_batches)

logging.info("# Load model")

m = Transformer(context)
loss, train_op, global_step, train_summaries = m.train(xs, ys)
y_hat, eval_summaries = m.eval(xs, ys)
# y_hat = m.infer(xs, ys)

logging.info("# Session")
saver = tf.train.Saver(max_to_keep=hp.num_epochs)
with tf.Session() as sess:
    ckpt = tf.train.latest_checkpoint(hp.logdir)
    if ckpt is None:
        logging.info("Initializing from scratch")
        sess.run(tf.global_variables_initializer())
        save_variable_specs(os.path.join(hp.logdir, "specs"))
    else:
        saver.restore(sess, ckpt)
Ejemplo n.º 13
0
        vars(x)['trg']) <= config['max_len'],
    root='./.data/')
print('Train set length: ', len(train))

print('Source vocab length: ', len(SRC.vocab.itos))
print('Target vocab length: ', len(TGT.vocab.itos))

test_iter = MyIterator(test,
                       batch_size=config['batch_size'],
                       device=torch.device(0),
                       repeat=False,
                       sort_key=lambda x: (len(x.src), len(x.trg)),
                       batch_size_fn=batch_size_fn,
                       train=False)

model = Transformer(len(SRC.vocab), len(TGT.vocab), N=config['num_layers'])

model.load_state_dict(torch.load('en-de__Sep-30-2019_09-19.pt'))
print('Last model loaded.')
save_model(model,
           None,
           loss=0,
           src_field=SRC,
           tgt_field=TGT,
           updates=0,
           epoch=0,
           prefix='last_model')

model.cuda()
model.eval()
Ejemplo n.º 14
0
    def __init__(self, params: dict):
        """
        Constructor of the Trainer.
        Sets up the following:
            - Device available (e.g. if CUDA is present)
            - Initialize the model, dataset, loss, optimizer
            - log statistics (epoch, elapsed time, BLEU score etc.)
        """
        self.is_hyperparameter_tuning = HYPERTUNER is not None
        self.gcs_job_dir = params["settings"].get("save_dir", None)
        self.model_name = params["settings"].get("model_name", None)

        # configure all logging
        self.configure_logging(training_problem_name="IWSLT")

        # set all seeds
        self.set_random_seeds(pytorch_seed=params["settings"]["pytorch_seed"],
                              numpy_seed=params["settings"]["numpy_seed"],
                              random_seed=params["settings"]["random_seed"])

        # Initialize TensorBoard and statistics collection.
        self.initialize_statistics_collection()

        if self.is_hyperparameter_tuning:
            assert "save_dir" in params["settings"] and "model_name" in params["settings"], \
                "Expected parameters 'save_dir' and 'model_name'."
        else:
            # save the configuration as a json file in the experiments dir
            with open(self.log_dir + 'params.json', 'w') as fp:
                json.dump(params, fp)
            self.logger.info(
                'Configuration saved to {}.'.format(self.log_dir +
                                                    'params.json'))

        self.initialize_tensorboard(log_dir=self.gcs_job_dir)

        # initialize training Dataset class
        self.logger.info(
            "Creating the training & validation dataset, may take some time..."
        )
        (self.training_dataset_iterator, self.validation_dataset_iterator,
         self.test_dataset_iterator, self.src_vocab,
         self.trg_vocab) = (IWSLTDatasetBuilder.build(
             language_pair=LanguagePair.fr_en,
             split=Split.Train | Split.Validation | Split.Test,
             max_length=params["dataset"]["max_seq_length"],
             min_freq=params["dataset"]["min_freq"],
             start_token=params["dataset"]["start_token"],
             eos_token=params["dataset"]["eos_token"],
             blank_token=params["dataset"]["pad_token"],
             batch_size_train=params["training"]["train_batch_size"],
             batch_size_validation=params["training"]["valid_batch_size"],
         ))

        # get the size of the vocab sets
        self.src_vocab_size, self.trg_vocab_size = len(self.src_vocab), len(
            self.trg_vocab)

        # Find integer value of "padding" in the respective vocabularies
        self.src_padding = self.src_vocab.stoi[params["dataset"]["pad_token"]]
        self.trg_padding = self.trg_vocab.stoi[params["dataset"]["pad_token"]]

        # just for safety, assume that the padding token of the source vocab is always equal to the target one (for now)
        assert self.src_padding == self.trg_padding, (
            "the padding token ({}) for the source vocab is not equal "
            "to the one from the target vocab ({}).".format(
                self.src_padding, self.trg_padding))

        self.logger.info(
            "Created a training & a validation dataset, with src_vocab_size={} and trg_vocab_size={}"
            .format(self.src_vocab_size, self.trg_vocab_size))

        # pass the size of input & output vocabs to model's params
        params["model"]["src_vocab_size"] = self.src_vocab_size
        params["model"]["tgt_vocab_size"] = self.trg_vocab_size

        # can now instantiate model
        self.model = Transformer(params["model"])  # type: Transformer

        if params["training"].get("multi_gpu", False):
            self.model = torch.nn.DataParallel(self.model)
            self.logger.info(
                "Multi-GPU training activated, on devices: {}".format(
                    self.model.device_ids))
            self.multi_gpu = True
        else:
            self.multi_gpu = False

        if params["training"].get("load_trained_model", False):
            if self.multi_gpu:
                self.model.module.load(
                    checkpoint=params["training"]["trained_model_checkpoint"],
                    logger=self.logger)
            else:
                self.model.load(
                    checkpoint=params["training"]["trained_model_checkpoint"],
                    logger=self.logger)

        if torch.cuda.is_available():
            self.model = self.model.cuda()  # type: Transformer

        # whether to save the model at every epoch or not
        self.save_intermediate = params["training"].get(
            "save_intermediate", False)

        # instantiate loss
        if "smoothing" in params["training"]:
            self.loss_fn = LabelSmoothingLoss(
                size=self.trg_vocab_size,
                padding_token=self.src_padding,
                smoothing=params["training"]["smoothing"])
            self.logger.info(f"Using LabelSmoothingLoss with "
                             f"smoothing={params['training']['smoothing']}.")
        else:
            self.loss_fn = CrossEntropyLoss(pad_token=self.src_padding)
            self.logger.info("Using CrossEntropyLoss.")

        # instantiate optimizer
        self.optimizer = NoamOpt(model=self.model,
                                 model_size=params["model"]["d_model"],
                                 lr=params["optim"]["lr"],
                                 betas=params["optim"]["betas"],
                                 eps=params["optim"]["eps"],
                                 factor=params["optim"]["factor"],
                                 warmup=params["optim"]["warmup"],
                                 step=params["optim"]["step"])

        # get number of epochs and related hyper parameters
        self.epochs = params["training"]["epochs"]

        self.logger.info('Experiment setup done.')
Ejemplo n.º 15
0
class Trainer(object):
    """
    Represents a worker taking care of the training of an instance of the ``Transformer`` model.

    """
    def __init__(self, params: dict):
        """
        Constructor of the Trainer.
        Sets up the following:
            - Device available (e.g. if CUDA is present)
            - Initialize the model, dataset, loss, optimizer
            - log statistics (epoch, elapsed time, BLEU score etc.)
        """
        self.is_hyperparameter_tuning = HYPERTUNER is not None
        self.gcs_job_dir = params["settings"].get("save_dir", None)
        self.model_name = params["settings"].get("model_name", None)

        # configure all logging
        self.configure_logging(training_problem_name="IWSLT")

        # set all seeds
        self.set_random_seeds(pytorch_seed=params["settings"]["pytorch_seed"],
                              numpy_seed=params["settings"]["numpy_seed"],
                              random_seed=params["settings"]["random_seed"])

        # Initialize TensorBoard and statistics collection.
        self.initialize_statistics_collection()

        if self.is_hyperparameter_tuning:
            assert "save_dir" in params["settings"] and "model_name" in params["settings"], \
                "Expected parameters 'save_dir' and 'model_name'."
        else:
            # save the configuration as a json file in the experiments dir
            with open(self.log_dir + 'params.json', 'w') as fp:
                json.dump(params, fp)
            self.logger.info(
                'Configuration saved to {}.'.format(self.log_dir +
                                                    'params.json'))

        self.initialize_tensorboard(log_dir=self.gcs_job_dir)

        # initialize training Dataset class
        self.logger.info(
            "Creating the training & validation dataset, may take some time..."
        )
        (self.training_dataset_iterator, self.validation_dataset_iterator,
         self.test_dataset_iterator, self.src_vocab,
         self.trg_vocab) = (IWSLTDatasetBuilder.build(
             language_pair=LanguagePair.fr_en,
             split=Split.Train | Split.Validation | Split.Test,
             max_length=params["dataset"]["max_seq_length"],
             min_freq=params["dataset"]["min_freq"],
             start_token=params["dataset"]["start_token"],
             eos_token=params["dataset"]["eos_token"],
             blank_token=params["dataset"]["pad_token"],
             batch_size_train=params["training"]["train_batch_size"],
             batch_size_validation=params["training"]["valid_batch_size"],
         ))

        # get the size of the vocab sets
        self.src_vocab_size, self.trg_vocab_size = len(self.src_vocab), len(
            self.trg_vocab)

        # Find integer value of "padding" in the respective vocabularies
        self.src_padding = self.src_vocab.stoi[params["dataset"]["pad_token"]]
        self.trg_padding = self.trg_vocab.stoi[params["dataset"]["pad_token"]]

        # just for safety, assume that the padding token of the source vocab is always equal to the target one (for now)
        assert self.src_padding == self.trg_padding, (
            "the padding token ({}) for the source vocab is not equal "
            "to the one from the target vocab ({}).".format(
                self.src_padding, self.trg_padding))

        self.logger.info(
            "Created a training & a validation dataset, with src_vocab_size={} and trg_vocab_size={}"
            .format(self.src_vocab_size, self.trg_vocab_size))

        # pass the size of input & output vocabs to model's params
        params["model"]["src_vocab_size"] = self.src_vocab_size
        params["model"]["tgt_vocab_size"] = self.trg_vocab_size

        # can now instantiate model
        self.model = Transformer(params["model"])  # type: Transformer

        if params["training"].get("multi_gpu", False):
            self.model = torch.nn.DataParallel(self.model)
            self.logger.info(
                "Multi-GPU training activated, on devices: {}".format(
                    self.model.device_ids))
            self.multi_gpu = True
        else:
            self.multi_gpu = False

        if params["training"].get("load_trained_model", False):
            if self.multi_gpu:
                self.model.module.load(
                    checkpoint=params["training"]["trained_model_checkpoint"],
                    logger=self.logger)
            else:
                self.model.load(
                    checkpoint=params["training"]["trained_model_checkpoint"],
                    logger=self.logger)

        if torch.cuda.is_available():
            self.model = self.model.cuda()  # type: Transformer

        # whether to save the model at every epoch or not
        self.save_intermediate = params["training"].get(
            "save_intermediate", False)

        # instantiate loss
        if "smoothing" in params["training"]:
            self.loss_fn = LabelSmoothingLoss(
                size=self.trg_vocab_size,
                padding_token=self.src_padding,
                smoothing=params["training"]["smoothing"])
            self.logger.info(f"Using LabelSmoothingLoss with "
                             f"smoothing={params['training']['smoothing']}.")
        else:
            self.loss_fn = CrossEntropyLoss(pad_token=self.src_padding)
            self.logger.info("Using CrossEntropyLoss.")

        # instantiate optimizer
        self.optimizer = NoamOpt(model=self.model,
                                 model_size=params["model"]["d_model"],
                                 lr=params["optim"]["lr"],
                                 betas=params["optim"]["betas"],
                                 eps=params["optim"]["eps"],
                                 factor=params["optim"]["factor"],
                                 warmup=params["optim"]["warmup"],
                                 step=params["optim"]["step"])

        # get number of epochs and related hyper parameters
        self.epochs = params["training"]["epochs"]

        self.logger.info('Experiment setup done.')

    def train(self):
        """
        Main training loop.

            - Trains the Transformer model on the specified dataset for a given number of epochs
            - Logs statistics to logger for every batch per epoch

        """
        # Reset the counter.
        episode = -1
        val_loss = 0.

        for epoch in range(self.epochs):

            # Empty the statistics collectors.
            self.training_stat_col.empty()
            self.validation_stat_col.empty()

            # collect epoch index
            self.training_stat_col['epoch'] = epoch + 1
            self.validation_stat_col['epoch'] = epoch + 1

            # ensure train mode for the model
            self.model.train()

            for i, batch in enumerate(
                    IWSLTDatasetBuilder.masked(
                        IWSLTDatasetBuilder.transposed(
                            self.training_dataset_iterator))):

                # "Move on" to the next episode.
                episode += 1

                # 1. reset all gradients
                self.optimizer.zero_grad()

                # Convert batch to CUDA.
                if torch.cuda.is_available():
                    batch.cuda()

                # 2. Perform forward pass.
                logits = self.model(batch.src, batch.src_mask, batch.trg,
                                    batch.trg_mask)

                # 3. Evaluate loss function.
                loss = self.loss_fn(logits, batch.trg_shifted)

                # 4. Backward gradient flow.
                loss.backward()

                # 4.1. Export to csv - at every step.
                # collect loss, episode
                self.training_stat_col['loss'] = loss.item()
                self.training_stat_col['episode'] = episode
                self.training_stat_col['src_seq_length'] = batch.src.shape[1]
                self.training_stat_col.export_to_csv()

                # 4.2. Exports statistics to the logger.
                self.logger.info(self.training_stat_col.export_to_string())

                # 4.3 Exports to tensorboard
                self.training_stat_col.export_to_tensorboard()

                # 5. Perform optimization step.
                self.optimizer.step()

            # save model at end of each epoch if indicated:
            if self.save_intermediate:
                if self.multi_gpu:
                    self.model.module.save(self.model_dir, epoch, loss.item())
                else:
                    self.model.save(self.model_dir, epoch, loss.item())
                self.logger.info("Model exported to checkpoint.")

            # validate the model on the validation set
            self.model.eval()
            val_loss = 0.

            with torch.no_grad():
                for i, batch in enumerate(
                        IWSLTDatasetBuilder.masked(
                            IWSLTDatasetBuilder.transposed(
                                self.validation_dataset_iterator))):

                    # Convert batch to CUDA.
                    if torch.cuda.is_available():
                        batch.cuda()

                    # 1. Perform forward pass.
                    logits = self.model(batch.src, batch.src_mask, batch.trg,
                                        batch.trg_mask)

                    # 2. Evaluate loss function.
                    loss = self.loss_fn(logits, batch.trg_shifted)

                    # Accumulate loss
                    val_loss += loss.item()

            # 3.1 Collect loss, episode: Log only one point per validation (for now)
            self.validation_stat_col['loss'] = val_loss / (i + 1)
            self.validation_stat_col['episode'] = episode

            # 3.1. Export to csv.
            self.validation_stat_col.export_to_csv()

            # 3.2 Exports statistics to the logger.
            self.logger.info(
                self.validation_stat_col.export_to_string('[Validation]'))

            # 3.3 Export to Tensorboard
            self.validation_stat_col.export_to_tensorboard()

            # 3.4 Save model on GCloud
            if self.gcs_job_dir is not None:
                if self.multi_gpu:
                    filename = self.model.module.save(
                        self.model_dir,
                        epoch,
                        loss.item(),
                        model_name=self.model_name)
                else:
                    filename = self.model.save(self.model_dir,
                                               epoch,
                                               loss.item(),
                                               model_name=self.model_name)
                save_model(self.gcs_job_dir, filename, self.model_name)

            # 3.4.b Export to Hypertune
            if self.is_hyperparameter_tuning:
                assert HYPERTUNER is not None
                HYPERTUNER.report_hyperparameter_tuning_metric(
                    hyperparameter_metric_tag='validation_loss',
                    metric_value=val_loss / (i + 1),
                    global_step=epoch)

        # always save the model at end of training
        if self.multi_gpu:
            self.model.module.save(self.model_dir, epoch, loss.item())
        else:
            self.model.save(self.model_dir, epoch, loss.item())

        self.logger.info("Final model exported to checkpoint.")

        # training done, end statistics collection
        self.finalize_statistics_collection()
        self.finalize_tensorboard()

        return val_loss

    def configure_logging(self,
                          training_problem_name: str,
                          logger_config=None) -> None:
        """
        Takes care of the initialization of logging-related objects:

            - Sets up a logger with a specific configuration,
            - Sets up a logging directory
            - sets up a logging file in the log directory
            - Sets up a folder to store trained models

        :param training_problem_name: Name of the dataset / training task (e.g. "copy task", "IWLST"). Used for the logging
        folder name.
        """
        # instantiate logger
        # Load the default logger configuration.
        if logger_config is None:
            logger_config = {
                'version': 1,
                'disable_existing_loggers': False,
                'formatters': {
                    'simple': {
                        'format':
                        '[%(asctime)s] - %(levelname)s - %(name)s >>> %(message)s',
                        'datefmt': '%Y-%m-%d %H:%M:%S'
                    }
                },
                'handlers': {
                    'console': {
                        'class': 'logging.StreamHandler',
                        'level': 'INFO',
                        'formatter': 'simple',
                        'stream': 'ext://sys.stdout'
                    }
                },
                'root': {
                    'level': 'DEBUG',
                    'handlers': ['console']
                }
            }
            if self.is_hyperparameter_tuning or self.gcs_job_dir is not None:
                # Running on GCloud, use shorter messages as time and debug level will be
                # saved elsewhere.
                logger_config['formatters']['simple'][
                    'format'] = "%(name)s >>> %(message)s"

        logging.config.dictConfig(logger_config)

        # Create the Logger, set its label and logging level.
        self.logger = logging.getLogger(name='Trainer')

        # Prepare the output path for logging
        time_str = '{0:%Y%m%d_%H%M%S}'.format(datetime.now())
        self.log_dir = 'experiments/' + training_problem_name + '/' + time_str + '/'

        os.makedirs(self.log_dir, exist_ok=False)
        self.logger.info('Folder {} created.'.format(self.log_dir))

        # Set log dir and add the handler for the logfile to the logger.
        self.log_file = self.log_dir + 'training.log'
        self.add_file_handler_to_logger(self.log_file)

        self.logger.info('Log File {} created.'.format(self.log_file))

        # Models dir: to store the trained models.
        self.model_dir = self.log_dir + 'models/'
        os.makedirs(self.model_dir, exist_ok=False)

        self.logger.info('Model folder {} created.'.format(self.model_dir))

    def add_file_handler_to_logger(self, logfile: str) -> None:
        """
        Add a ``logging.FileHandler`` to the logger.

        Specifies a ``logging.Formatter``:
            >>> logging.Formatter(fmt='[%(asctime)s] - %(levelname)s - %(name)s >>> %(message)s',
            ...                   datefmt='%Y-%m-%d %H:%M:%S')

        :param logfile: File used by the ``FileHandler``.

        """
        # create file handler which logs even DEBUG messages
        fh = logging.FileHandler(logfile)

        # set logging level for this file
        fh.setLevel(logging.DEBUG)

        # create formatter and add it to the handlers
        formatter = logging.Formatter(
            fmt='[%(asctime)s] - %(levelname)s - %(name)s >>> %(message)s',
            datefmt='%Y-%m-%d %H:%M:%S')
        fh.setFormatter(formatter)

        # add the handler to the logger
        self.logger.addHandler(fh)

    def initialize_statistics_collection(self) -> None:
        """
        Initializes 2 :py:class:`StatisticsCollector` to track statistics for training and validation.

        Adds some default statistics, such as the loss, episode idx and the epoch idx.

        Also creates the output files (csv).
        """
        # TRAINING.
        # Create statistics collector for training.
        self.training_stat_col = StatisticsCollector()

        # add default statistics
        self.training_stat_col.add_statistic('epoch', '{:02d}')
        self.training_stat_col.add_statistic('loss', '{:12.10f}')
        self.training_stat_col.add_statistic('episode', '{:06d}')
        self.training_stat_col.add_statistic('src_seq_length', '{:02d}')

        # Create the csv file to store the training statistics.
        self.training_batch_stats_file = self.training_stat_col.initialize_csv_file(
            self.log_dir, 'training_statistics.csv')

        # VALIDATION.
        # Create statistics collector for validation.
        self.validation_stat_col = StatisticsCollector()

        # add default statistics
        self.validation_stat_col.add_statistic('epoch', '{:02d}')
        self.validation_stat_col.add_statistic('loss', '{:12.10f}')
        self.validation_stat_col.add_statistic('episode', '{:06d}')

        # Create the csv file to store the validation statistics.
        self.validation_batch_stats_file = self.validation_stat_col.initialize_csv_file(
            self.log_dir, 'validation_statistics.csv')

    def finalize_statistics_collection(self) -> None:
        """
        Finalizes the statistics collection by closing the csv files.
        """
        # Close all files.
        self.training_batch_stats_file.close()
        self.validation_batch_stats_file.close()

    def initialize_tensorboard(self, log_dir=None) -> None:
        """
        Initializes the TensorBoard writers, and log directories.
        """
        from tensorboardX import SummaryWriter

        if log_dir is None:
            log_dir = self.log_dir

        self.training_writer = SummaryWriter(
            join(log_dir, "tensorboard", 'training'))
        self.training_stat_col.initialize_tensorboard(self.training_writer)

        self.validation_writer = SummaryWriter(
            join(log_dir, "tensorboard", 'validation'))
        self.validation_stat_col.initialize_tensorboard(self.validation_writer)

    def finalize_tensorboard(self) -> None:
        """
        Finalizes the operation of TensorBoard writers by closing them.
        """
        # Close the TensorBoard writers.
        self.training_writer.close()
        self.validation_writer.close()

    def set_random_seeds(self, pytorch_seed: int, numpy_seed: int,
                         random_seed: int) -> None:
        """
        Set all random seeds to ensure the reproducibility of the experiments.
        Notably:

        - Set the random seed of Pytorch (seed the RNG for all devices (both CPU and CUDA):

            >>> torch.manual_seed(pytorch_seed)

        - When running on the CuDNN backend, two further options must be set:

            >>> torch.backends.cudnn.deterministic = True
            >>> torch.backends.cudnn.benchmark = False

        - Set the random seed of numpy:

            >>> np.random.seed(numpy_seed)

        - Finally, we initialize the random number generator:

            >>> random.seed(random_seed)

        """

        # set pytorch seed
        torch.manual_seed(pytorch_seed)

        # set deterministic CuDNN
        if torch.cuda.is_available():
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False

        # set numpy seed
        import numpy
        numpy.random.seed(numpy_seed)

        # set the state of the random number generator:
        random.seed(random_seed)

        self.logger.info("torch seed was set to {}".format(pytorch_seed))
        self.logger.info("numpy seed was set to {}".format(numpy_seed))
        self.logger.info("random seed was set to {}".format(random_seed))
def main():
    devices = list(range(torch.cuda.device_count()))
    print('Selected devices: ', devices)

    def tokenize_bpe(text):
        return text.split()

    SRC = data.Field(tokenize=tokenize_bpe, pad_token=BLANK_WORD)
    TGT = data.Field(tokenize=tokenize_bpe,
                     init_token=BOS_WORD,
                     eos_token=EOS_WORD,
                     pad_token=BLANK_WORD)

    train, val, test = datasets.WMT14.splits(
        exts=('.en', '.de'),
        train='train.tok.clean.bpe.32000',
        # train='newstest2014.tok.bpe.32000',
        validation='newstest2013.tok.bpe.32000',
        test='newstest2014.tok.bpe.32000',
        fields=(SRC, TGT),
        filter_pred=lambda x: len(vars(x)['src']) <= config['max_len'] and len(
            vars(x)['trg']) <= config['max_len'],
        root='./.data/')
    print('Train set length: ', len(train))

    # building shared vocabulary
    TGT.build_vocab(train.src, train.trg, min_freq=config['min_freq'])
    SRC.vocab = TGT.vocab

    print('Source vocab length: ', len(SRC.vocab.itos))
    print('Target vocab length: ', len(TGT.vocab.itos))
    wandb.config.update({
        'src_vocab_length': len(SRC.vocab),
        'target_vocab_length': len(TGT.vocab)
    })

    pad_idx = TGT.vocab.stoi[BLANK_WORD]
    print('Pad index:', pad_idx)

    train_iter = MyIterator(train,
                            batch_size=config['batch_size'],
                            device=torch.device(0),
                            repeat=False,
                            sort_key=lambda x: (len(x.src), len(x.trg)),
                            batch_size_fn=batch_size_fn,
                            train=True)

    valid_iter = MyIterator(val,
                            batch_size=config['batch_size'],
                            device=torch.device(0),
                            repeat=False,
                            sort_key=lambda x: (len(x.src), len(x.trg)),
                            batch_size_fn=batch_size_fn,
                            train=False)

    test_iter = MyIterator(test,
                           batch_size=config['batch_size'],
                           device=torch.device(0),
                           repeat=False,
                           sort_key=lambda x: (len(x.src), len(x.trg)),
                           batch_size_fn=batch_size_fn,
                           train=False)

    model = Transformer(len(SRC.vocab), len(TGT.vocab), N=config['num_layers'])

    # weight tying
    model.src_embed[0].lookup_table.weight = model.tgt_embed[
        0].lookup_table.weight
    model.generator.lookup_table.weight = model.tgt_embed[
        0].lookup_table.weight

    model.cuda()

    model_size = model.src_embed[0].d_model
    print('Model created with size of', model_size)
    wandb.config.update({'model_size': model_size})

    criterion = LabelSmoothingKLLoss(
        size=len(TGT.vocab),
        padding_idx=pad_idx,
        smoothing=0.1,
        batch_multiplier=config['batch_multiplier'])
    criterion.cuda()

    eval_criterion = LabelSmoothingKLLoss(size=len(TGT.vocab),
                                          padding_idx=pad_idx,
                                          smoothing=0.1,
                                          batch_multiplier=1)
    eval_criterion.cuda()

    model_par = nn.DataParallel(model, device_ids=devices)

    model_opt = NoamOpt(warmup_init_lr=config['warmup_init_lr'],
                        warmup_end_lr=config['warmup_end_lr'],
                        warmup_updates=config['warmup'],
                        min_lr=config['warmup_init_lr'],
                        optimizer=torch.optim.Adam(model.parameters(),
                                                   lr=0,
                                                   betas=(config['beta_1'],
                                                          config['beta_2']),
                                                   eps=config['epsilon']))

    wandb.watch(model)

    current_steps = 0
    for epoch in range(1, config['max_epochs'] + 1):
        # training model
        model_par.train()
        loss_calculator = MultiGPULossCompute(model.generator,
                                              criterion,
                                              devices=devices,
                                              opt=model_opt)

        (_, steps) = run_epoch((rebatch(pad_idx, b) for b in train_iter),
                               model_par,
                               loss_calculator,
                               steps_so_far=current_steps,
                               batch_multiplier=config['batch_multiplier'],
                               logging=True)

        current_steps += steps

        # calculating validation loss and bleu score
        model_par.eval()
        loss_calculator_without_optimizer = MultiGPULossCompute(
            model.generator, eval_criterion, devices=devices, opt=None)

        (loss, _) = run_epoch((rebatch(pad_idx, b) for b in valid_iter),
                              model_par,
                              loss_calculator_without_optimizer,
                              steps_so_far=current_steps)

        if (epoch > 10) or current_steps > config['max_step']:
            # greedy decoding takes a while so Bleu won't be evaluated for every epoch
            print('Calculating BLEU score...')
            bleu = validate(model, valid_iter, SRC, TGT, BOS_WORD, EOS_WORD,
                            BLANK_WORD, config['max_len'])
            wandb.log({'Epoch bleu': bleu})
            print(f'Epoch {epoch} | Bleu score: {bleu} ')

        print(f"Epoch {epoch} | Loss: {loss}")
        wandb.log({'Epoch': epoch, 'Epoch loss': loss})
        if epoch > 10:
            save_model(model=model,
                       optimizer=model_opt,
                       loss=loss,
                       src_field=SRC,
                       tgt_field=TGT,
                       updates=current_steps,
                       epoch=epoch)
        if current_steps > config['max_step']:
            break

    save_model(model=model,
               optimizer=model_opt,
               loss=loss,
               src_field=SRC,
               tgt_field=TGT,
               updates=current_steps,
               epoch=epoch)

    test_bleu = validate(model,
                         test_iter,
                         SRC,
                         TGT,
                         BOS_WORD,
                         EOS_WORD,
                         BLANK_WORD,
                         config['max_len'],
                         logging=True)
    print(f"Test Bleu score: {test_bleu}")
    wandb.config.update({'Test bleu score': test_bleu})