def run(args):
    writer = SummaryWriter()
    src, tgt, train_iterator, val_iterator = build_dataset(args)

    src_vocab_size = len(src.vocab.itos)
    tgt_vocab_size = len(tgt.vocab.itos)

    print('Instantiating model...')
    device = args.device
    model = Transformer(src_vocab_size,
                        tgt_vocab_size,
                        device,
                        p_dropout=args.dropout)
    model = model.to(device)

    if args.checkpoint is not None:
        model.load_state_dict(torch.load(args.checkpoint))
    else:
        for p in model.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    print('Model instantiated!')

    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           betas=(0.9, 0.98),
                           eps=1e-9)

    print('Starting training...')
    for epoch in range(args.epochs):
        acc = train(model, epoch + 1, train_iterator, optimizer, src.vocab,
                    tgt.vocab, args, writer)
        model_file = 'models/model_' + str(epoch) + '_' + str(acc) + '.pth'
        torch.save(model.state_dict(), model_file)
        print('Saved model to ' + model_file)
        validate(model, epoch + 1, val_iterator, src.vocab, tgt.vocab, args,
                 writer)
    print('Finished training.')
Ejemplo n.º 2
0
def main():
    pprint(arg)

    # load dataset
    train_loader, valid_loader, test_loader = prepare_dataloaders(arg)
    print("Data loaded. Instances: {} train / {} dev / {} test".format(
        len(train_loader), len(valid_loader), len(test_loader)))

    # prepare model
    device = torch.device('cuda' if arg["cuda"] == True else 'cpu')

    #print(len(train_loader.dataset.w2i)) # nice, we can index internal propertied of CNNDMDataset from the loader!
    print()
    transformer_network = Transformer(
        len(train_loader.dataset.w2i),  # src_vocab_size,
        len(train_loader.dataset.w2i),  # tgt_vocab_size, is equal to src size
        train_loader.dataset.conf[
            "max_sequence_len"],  # max_token_seq_len, from the preprocess config
        tgt_emb_prj_weight_sharing=True,  # opt.proj_share_weight,
        emb_src_tgt_weight_sharing=True,  #opt.embs_share_weight,
        d_k=arg["d_k"],
        d_v=arg["d_v"],
        d_model=arg["d_model"],
        d_word_vec=arg["d_model"],  # d_word_vec,
        d_inner=arg["d_inner_hid"],
        n_layers=arg["n_layers"],
        n_head=arg["n_head"],
        dropout=arg["dropout"]).to(device)

    print("Transformer model initialized.")
    print()

    # train model
    optimizer = transformer.optimizers.ScheduledOptim(
        optim.Adam(
            filter(lambda x: x.requires_grad, transformer_network.parameters()
                   ),  # apply only on parameters that require_grad
            betas=(0.9, 0.98),
            eps=1e-09),
        arg["d_model"],
        arg["n_warmup_steps"])

    train(transformer_network, train_loader, valid_loader, test_loader,
          optimizer, device, arg)
Ejemplo n.º 3
0
def do_train():
    train_iterator, valid_iterator, test_iterator, SRC, TGT = prepare_data_multi30k()

    src_pad_idx = SRC.vocab.stoi[SRC.pad_token]
    tgt_pad_idx = TGT.vocab.stoi[TGT.pad_token]
    src_vocab_size = len(SRC.vocab)
    tgt_vocab_size = len(TGT.vocab)
    model = Transformer(n_src_vocab=src_vocab_size,
                        n_trg_vocab=tgt_vocab_size,
                        src_pad_idx=src_pad_idx,
                        trg_pad_idx=tgt_pad_idx,
                        d_word_vec=256,
                        d_model=256,
                        d_inner=512,
                        n_layer=3,
                        n_head=8,
                        dropout=0.1,
                        n_position=200)

    model.cuda()
    optimizer = Adam(model.parameters(), lr=5e-4)

    num_epoch = 10
    results = []
    model_dir  = os.path.join("./checkpoint/transformer")
    for epoch in range(num_epoch):
        train_loss, train_accuracy = train_epoch(model, optimizer, train_iterator, tgt_pad_idx, smoothing=False)
        eval_loss,  eval_accuracy  = eval_epoch(model, valid_iterator, tgt_pad_idx, smoothing=False)

        os.makedirs(model_dir, exist_ok=True)
        model_path = os.path.join(model_dir, f"model_{epoch}.pt")
        torch.save(model.state_dict(), model_path)

        results.append({"epoch": epoch, "train_loss": train_loss, "eval_loss": eval_loss})
        print("[TIME] --- {} --- [TIME]".format(time.ctime(time.time())))
        print("epoch: {}, train_loss: {}, eval_loss: {}".format(epoch, train_loss, eval_loss))
        print("epoch: {}, train_accuracy: {}, eval_accuracy: {}".format(epoch, train_accuracy, eval_accuracy))

    result_path = os.path.join(model_dir, "result.json")
    with open(result_path, "w", encoding="utf-8") as writer:
        json.dump(results, writer, ensure_ascii=False, indent=4)
Ejemplo n.º 4
0
def train_net(args):
    torch.manual_seed(7)
    np.random.seed(7)
    checkpoint = args.checkpoint
    start_epoch = 0
    best_loss = float('inf')
    writer = SummaryWriter()
    epochs_since_improvement = 0

    # Initialize / load checkpoint
    if checkpoint is None:
        # model
        encoder = Encoder(n_src_vocab,
                          args.n_layers_enc,
                          args.n_head,
                          args.d_k,
                          args.d_v,
                          args.d_model,
                          args.d_inner,
                          dropout=args.dropout,
                          pe_maxlen=args.pe_maxlen)
        decoder = Decoder(
            sos_id,
            eos_id,
            n_tgt_vocab,
            args.d_word_vec,
            args.n_layers_dec,
            args.n_head,
            args.d_k,
            args.d_v,
            args.d_model,
            args.d_inner,
            dropout=args.dropout,
            tgt_emb_prj_weight_sharing=args.tgt_emb_prj_weight_sharing,
            pe_maxlen=args.pe_maxlen)
        model = Transformer(encoder, decoder)
        # print(model)
        # model = nn.DataParallel(model)

        # optimizer
        optimizer = TransformerOptimizer(
            torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-09))

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        model = checkpoint['model']
        optimizer = checkpoint['optimizer']

    # Move to GPU, if available
    model = model.to(device)

    # Custom dataloaders
    train_dataset = AiChallenger2017Dataset('train')
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               collate_fn=pad_collate,
                                               shuffle=True,
                                               num_workers=args.num_workers)
    valid_dataset = AiChallenger2017Dataset('valid')
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=args.batch_size,
                                               collate_fn=pad_collate,
                                               shuffle=False,
                                               num_workers=args.num_workers)

    # Epochs
    for epoch in range(start_epoch, args.epochs):
        # One epoch's training
        train_loss = train(train_loader=train_loader,
                           model=model,
                           optimizer=optimizer,
                           epoch=epoch,
                           logger=logger,
                           writer=writer)

        writer.add_scalar('epoch/train_loss', train_loss, epoch)
        writer.add_scalar('epoch/learning_rate', optimizer.lr, epoch)

        print('\nLearning rate: {}'.format(optimizer.lr))
        print('Step num: {}\n'.format(optimizer.step_num))

        # One epoch's validation
        valid_loss = valid(valid_loader=valid_loader,
                           model=model,
                           logger=logger)
        writer.add_scalar('epoch/valid_loss', valid_loss, epoch)

        # Check if there was an improvement
        is_best = valid_loss < best_loss
        best_loss = min(valid_loss, best_loss)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(epoch, epochs_since_improvement, model, optimizer,
                        best_loss, is_best)
Ejemplo n.º 5
0
def main():
    global char2index
    global index2char
    global SOS_token
    global EOS_token
    global PAD_token

    parser = argparse.ArgumentParser(description='Speech hackathon Baseline')

    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='batch size in training (default: 32)')
    parser.add_argument(
        '--workers',
        type=int,
        default=4,
        help='number of workers in dataset loader (default: 4)')
    parser.add_argument('--max_epochs',
                        type=int,
                        default=10,
                        help='number of max epochs in training (default: 10)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0001,
                        help='learning rate (default: 0.0001)')
    parser.add_argument('--teacher_forcing',
                        type=float,
                        default=0.5,
                        help='teacher forcing ratio in decoder (default: 0.5)')
    parser.add_argument('--max_len',
                        type=int,
                        default=WORD_MAXLEN,
                        help='maximum characters of sentence (default: 80)')
    parser.add_argument('--no_cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        help='random seed (default: 1)')
    parser.add_argument('--save_name',
                        type=str,
                        default='model',
                        help='the name of model in nsml or local')
    parser.add_argument('--mode', type=str, default='train')
    parser.add_argument("--pause", type=int, default=0)
    parser.add_argument(
        '--word',
        action='store_true',
        help='Train/Predict model using word based label (default: False)')
    parser.add_argument('--gen_label_index',
                        action='store_true',
                        help='Generate word label index map(default: False)')
    parser.add_argument('--iteration', type=str, help='Iteratiom')
    parser.add_argument('--premodel_session',
                        type=str,
                        help='Session name of premodel')

    # transformer model parameter
    parser.add_argument('--d_model',
                        type=int,
                        default=128,
                        help='transformer_d_model')
    parser.add_argument('--n_head',
                        type=int,
                        default=8,
                        help='transformer_n_head')
    parser.add_argument('--num_encoder_layers',
                        type=int,
                        default=4,
                        help='num_encoder_layers')
    parser.add_argument('--num_decoder_layers',
                        type=int,
                        default=4,
                        help='transformer_num_decoder_layers')
    parser.add_argument('--dim_feedforward',
                        type=int,
                        default=2048,
                        help='transformer_d_model')
    parser.add_argument('--dropout',
                        type=float,
                        default=0.1,
                        help='transformer_dropout')

    # transformer warmup parameter
    parser.add_argument('--warmup_multiplier',
                        type=int,
                        default=3,
                        help='transformer_warmup_multiplier')
    parser.add_argument('--warmup_epoch',
                        type=int,
                        default=10,
                        help='transformer_warmup_epoch')

    args = parser.parse_args()
    char_loader = CharLabelLoader()
    char_loader.load_char2index('./hackathon.labels')
    label_loader = char_loader
    if args.word:
        if args.gen_label_index:
            generate_word_label_index_file(char_loader, TRAIN_LABEL_CHAR_PATH)
            from subprocess import call
            call(f'cat {TRAIN_LABEL_CHAR_PATH}', shell=True)
        # ??? ??? ??? ??
        word_loader = CharLabelLoader()
        word_loader.load_char2index('./hackathon.pos.labels')
        label_loader = word_loader
        if os.path.exists(TRAIN_LABEL_CHAR_PATH):
            generate_word_label_file(char_loader, word_loader,
                                     TRAIN_LABEL_POS_PATH,
                                     TRAIN_LABEL_CHAR_PATH)
    char2index = label_loader.char2index
    index2char = label_loader.index2char
    SOS_token = char2index['<s>']
    EOS_token = char2index['</s>']
    PAD_token = char2index['_']
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    args.cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device('cuda' if args.cuda else 'cpu')

    ############ model
    print("model: transformer")
    # model = Transformer(d_model= args.d_model, n_head= args.n_head, num_encoder_layers= args.num_encoder_layers, num_decoder_layers= args.num_decoder_layers,
    #                     dim_feedforward= args.dim_feedforward, dropout= args.dropout, vocab_size= len(char2index), sound_maxlen= SOUND_MAXLEN, word_maxlen= WORD_MAXLEN)

    encoder = Encoder(d_input=128,
                      n_layers=6,
                      n_head=4,
                      d_k=128,
                      d_v=128,
                      d_model=128,
                      d_inner=2048,
                      dropout=0.1,
                      pe_maxlen=SOUND_MAXLEN)
    decoder = Decoder(sos_id=SOS_token,
                      eos_id=EOS_token,
                      n_tgt_vocab=len(char2index),
                      d_word_vec=128,
                      n_layers=6,
                      n_head=4,
                      d_k=128,
                      d_v=128,
                      d_model=128,
                      d_inner=2048,
                      dropout=0.1,
                      tgt_emb_prj_weight_sharing=True,
                      pe_maxlen=SOUND_MAXLEN)
    model = Transformer(encoder, decoder)

    optimizer = TransformerOptimizer(
        torch.optim.Adam(model.parameters(),
                         lr=0.0004,
                         betas=(0.9, 0.98),
                         eps=1e-09))

    ############/

    for param in model.parameters():
        param.data.uniform_(-0.08, 0.08)

    model = nn.DataParallel(model).to(device)
    """
    optimizer = optim.Adam(model.module.parameters(), lr=args.lr)

    scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.max_epochs)
    scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=args.warmup_multiplier, total_epoch=args.warmup_epoch, after_scheduler=scheduler_cosine)
    
    
    criterion = nn.CrossEntropyLoss(reduction='sum', ignore_index=PAD_token).to(device)
    """

    bind_model(model, optimizer)

    if args.pause == 1:
        nsml.paused(scope=locals())

    if args.mode != "train":
        return

    data_list = os.path.join(DATASET_PATH, 'train_data', 'data_list.csv')
    wav_paths = list()
    script_paths = list()

    with open(data_list, 'r') as f:
        for line in f:
            # line: "aaa.wav,aaa.label"

            wav_path, script_path = line.strip().split(',')
            wav_paths.append(os.path.join(DATASET_PATH, 'train_data',
                                          wav_path))
            script_paths.append(
                os.path.join(DATASET_PATH, 'train_data', script_path))

    best_loss = 1e10
    begin_epoch = 0

    # load all target scripts for reducing disk i/o
    # target_path = os.path.join(DATASET_PATH, 'train_label')
    target_path = TRAIN_LABEL_CHAR_PATH
    if args.word:
        target_path = TRAIN_LABEL_POS_PATH
    load_targets(target_path)

    train_batch_num, train_dataset_list, valid_dataset = split_dataset(
        args, wav_paths, script_paths, valid_ratio=0.05)

    if args.iteration:
        if args.premodel_session:
            nsml.load(args.iteration, session=args.premodel_session)
            logger.info(f'Load {args.premodel_session} {args.iteration}')
        else:
            nsml.load(args.iteration)
            logger.info(f'Load {args.iteration}')
    logger.info('start')

    train_begin = time.time()

    for epoch in range(begin_epoch, args.max_epochs):
        # learning rate scheduler

        train_queue = queue.Queue(args.workers * 2)

        train_loader = MultiLoader(train_dataset_list, train_queue,
                                   args.batch_size, args.workers)
        train_loader.start()

        train_loss, train_cer = train(model, train_batch_num, train_queue,
                                      optimizer, device, train_begin,
                                      args.workers, 10, args.teacher_forcing)
        logger.info('Epoch %d (Training) Loss %0.4f CER %0.4f' %
                    (epoch, train_loss, train_cer))

        train_loader.join()

        print("~~~~~~~~~~~~")

        if epoch == 10 or (epoch > 48 and epoch % 10 == 9):
            valid_queue = queue.Queue(args.workers * 2)
            valid_loader = BaseDataLoader(valid_dataset, valid_queue,
                                          args.batch_size, 0)
            valid_loader.start()

            eval_loss, eval_cer = evaluate(model, valid_loader, valid_queue,
                                           device, args.max_len,
                                           args.batch_size)
            logger.info('Epoch %d (Evaluate) Loss %0.4f CER %0.4f' %
                        (epoch, eval_loss, eval_cer))

            valid_loader.join()

            nsml.report(False,
                        step=epoch,
                        train_epoch__loss=train_loss,
                        train_epoch__cer=train_cer,
                        eval__loss=eval_loss,
                        eval__cer=eval_cer)

            best_model = (eval_loss < best_loss)
            nsml.save(args.save_name)

            if best_model:
                nsml.save('best')
                best_loss = eval_loss
Ejemplo n.º 6
0
def train_net(args):
    # 为了保证程序执行结果一致, 给随机化设定种子
    torch.manual_seed(7)
    np.random.seed(7)
    checkpoint = args.checkpoint

    start_epoch = 0
    writer = SummaryWriter()

    if checkpoint is None:
        # model
        encoder = Encoder(Config.vocab_size, args.n_layers_enc, args.n_head,
                          args.d_k, args.d_v, args.d_model, args.d_inner,
                          dropout=args.dropout, pe_maxlen=args.pe_maxlen)

        decoder = Decoder(Config.sos_id, Config.eos_id, Config.vocab_size,
                          args.d_word_vec, args.n_layers_dec, args.n_head,
                          args.d_k, args.d_v, args.d_model, args.d_inner,
                          dropout=args.dropout,
                          tgt_emb_prj_weight_sharing=args.tgt_emb_prj_weight_sharing,
                          pe_maxlen=args.pe_maxlen)

        model = Transformer(encoder, decoder)

        # optimizer
        optimizer = TransformerOptimizer(
            torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-09))

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        model = checkpoint['model']
        optimizer = checkpoint['optimizer']

    # Move to GPU, if available
    model = model.to(Config.device)

    # Custom dataloaders  数据的加载 注意这里指定了一个参数collate_fn代表的数据需要padding
    train_dataset = TranslateDataset()

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=pad_collate,
                                               shuffle=True, num_workers=args.num_workers)

    # Epochs
    Loss_list = []
    for epoch in range(start_epoch, args.epochs):
        # One epoch's training
        train_loss = train(train_loader=train_loader,
                           model=model,
                           optimizer=optimizer,
                           epoch=epoch,
                           logger=logger,
                           writer=writer)

        l = str(train_loss)
        Loss_list.append(l)

        l_temp = l + '\n'
        with open('loss_epoch.txt', 'a+') as f:
            f.write(l_temp)

        writer.add_scalar('epoch/train_loss', train_loss, epoch)
        writer.add_scalar('epoch/learning_rate', optimizer.lr, epoch)

        print('\nLearning rate: {}'.format(optimizer.lr))
        print('Step num: {}\n'.format(optimizer.step_num))

        # Save checkpoint
        save_checkpoint(epoch, model, optimizer, train_loss)
    with open('loss.txt', 'w') as f:
        f.write('\n'.join(Loss_list))
Ejemplo n.º 7
0
d_input = hp["numcep"]
label_shape = len(train_speaker_list)

# model
d_m = hp["d_m"]
encoder = Encoder(d_input=d_input,
                  n_layers=2,
                  d_k=d_m,
                  d_v=d_m,
                  d_m=d_m,
                  d_ff=hp["d_ff"],
                  dropout=0.1).to(device)
pooling = SelfAttentionPooling(d_m, dropout=0.1).to(device)
model = Transformer(encoder, pooling, d_m, label_shape, dropout=0.2).to(device)

opt = torch.optim.Adam(model.parameters(),
                       lr=hp["lr"],
                       weight_decay=hp["weight_decay"])

loss_func = torch.nn.CrossEntropyLoss()

best_eer = 99.
if hp["comet"]:
    with experiment.train():
        for epoch in tqdm(range(epochs)):
            cce_loss = fit(model, loss_func, opt, train_ds_gen, device)
            experiment.log_metric("cce", cce_loss, epoch=epoch)

            val_eer = test(model,
                           val_ds_gen,
                           val_utt,
Ejemplo n.º 8
0
        eos_id,
        vocab_size,
        args.d_word_vec,
        args.n_layers_dec,
        args.n_head,
        args.d_k,
        args.d_v,
        args.d_model,
        args.d_inner,
        dropout=args.dropout,
        tgt_emb_prj_weight_sharing=args.tgt_emb_prj_weight_sharing,
        pe_maxlen=args.pe_maxlen)
    model = Transformer(encoder, decoder)

    optimizer = TransformerOptimizer(
        torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-09),
        args.k, args.d_model, args.warmup_steps)

    print(args.k)
    print(args.d_model)
    print(args.warmup_steps)

    lr_list = []
    for step_num in range(1, 50000):
        # print(step_num)
        lr_1 = k * init_lr * min(step_num**(-0.5),
                                 step_num * (warmup_steps**(-1.5)))
        optimizer.step()
        lr_2 = optimizer.lr
        # print(lr_1)
        # print(lr_2)
Ejemplo n.º 9
0
        vocab_size,
        args.d_word_vec,
        args.n_layers_dec,
        args.n_head,
        args.d_k,
        args.d_v,
        args.d_model,
        args.d_inner,
        dropout=args.dropout,
        tgt_emb_prj_weight_sharing=args.tgt_emb_prj_weight_sharing,
        pe_maxlen=args.pe_maxlen)
    model = Transformer(encoder, decoder)

    for i in range(3):
        print("\n***** Utt", i + 1)
        Ti = i + 20
        input = torch.randn(Ti, D)
        length = torch.tensor([Ti], dtype=torch.int)
        nbest_hyps = model.recognize(input, length, char_list, args)

    file_path = "./temp.pth"
    optimizer = torch.optim.Adam(model.parameters())
    torch.save(model.serialize(model, optimizer, 1, LFR_m=1, LFR_n=1),
               file_path)
    model, LFR_m, LFR_n = Transformer.load_model(file_path)
    print(model)

    import os

    os.remove(file_path)
Ejemplo n.º 10
0
class Translator(nn.Module):
    def __init__(self, vocabulary_size_in, vocabulary_size_out, constants,
                 hyperparams):
        super(Translator, self).__init__()
        self.Transformer = Transformer(vocabulary_size_in, vocabulary_size_out,
                                       constants, hyperparams)
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.Transformer.parameters(),
                                    betas=(0.9, 0.98),
                                    eps=1e-9)
        self.scheduler = Scheduler(d_model=hyperparams.D_MODEL,
                                   warmup_steps=hyperparams.WARMUP_STEPS)
        self.constants = constants
        self.hyperparams = hyperparams

    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def fit(self, training_steps, data_training, data_eval=None):
        '''
        Arg:
            data_training: iterator which gives two batches: one of source language and one for target language
        '''
        writer = SummaryWriter()
        training_loss, gradient_norm = [], []

        for i in tqdm(range(training_steps)):
            X, Y = next(data_training)
            batch_size = X.shape[0]
            bos = torch.zeros(batch_size, 1).fill_(self.constants.BOS_IDX).to(
                self.constants.DEVICE, torch.LongTensor)
            translation = torch.cat((bos, Y[:, :-1]), dim=1)
            output = self.Transformer(X, translation)
            output = output.contiguous().view(-1, output.size(-1))
            target = Y.contiguous().view(-1)
            lr = self.scheduler.step()
            for p in self.optimizer.param_groups:
                p['lr'] = lr
            self.optimizer.zero_grad()
            loss = self.criterion(output, target)
            training_loss.append(loss.item())
            loss.backward()
            self.optimizer.step()
            temp = 0
            for p in self.Transformer.parameters():
                temp += torch.sum(p.grad.data**2)
            temp = np.sqrt(temp.cpu())
            gradient_norm.append(temp)

            if ((i + 1) % self.hyperparams.EVAL_EVERY_TIMESTEPS) == 0:
                torch.save(self.state_dict(), self.constants.WEIGHTS_FILE)
                writer.add_scalar('0_training_set/loss',
                                  np.mean(training_loss), i)
                writer.add_scalar('0_training_set/gradient_norm',
                                  np.mean(gradient_norm), i)
                writer.add_scalar('2_other/lr', lr, i)
                training_loss, gradient_norm = [], []

                if data_eval:
                    eval_references = []
                    eval_hypotheses = []
                    for l, (X_batch, Y_batch) in enumerate(data_eval):
                        for i in range(Y_batch.shape[0]):
                            eval_references.append(data_eval.itotok(
                                Y_batch[i]))
                        hypotheses = self.translate(X_batch)
                        for i in range(len(hypotheses)):
                            eval_hypotheses.append(
                                data_eval.itotok(hypotheses[i]))

                    def subwords_to_string(subwords):
                        string = ""
                        for subword in subwords:
                            if subword[-2:] == "@@":
                                string += subword[:-2]
                            elif subword != self.constants.PADDING_WORD:
                                string += subword + " "
                        return string

                    for i, (ref, hyp) in enumerate(
                            zip(eval_references, eval_hypotheses)):
                        eval_references[i] = subwords_to_string(ref)
                        eval_hypotheses[i] = subwords_to_string(hyp)

                    ex_phrases = ''
                    for i, (ref, hyp) in enumerate(
                            zip(eval_references, eval_hypotheses)):
                        ex_phrases = ex_phrases + "\n truth: " + ref + "\n prediction: " + hyp + "\n"
                        if i == 4:
                            break

                    BLEU = nltk.translate.bleu_score.corpus_bleu(
                        eval_references, eval_hypotheses)
                    writer.add_scalar('1_eval_set/BLEU', BLEU, i)
                    writer.add_text('examples', ex_phrases, i)

    def translate(self, X):
        '''
        Arg:
            X: batch of phrases to translate: tensor(nb_texts, nb_tokens)
        '''
        self.train(False)
        batch_size, max_seq = X.shape
        max_seq += 10  #TODO: remove hard code
        temp = torch.zeros(batch_size, max_seq).type(torch.LongTensor).to(
            self.constants.DEVICE)
        temp[:, 0] = self.constants.BOS_IDX
        enc = self.Transformer.forward_encoder(X)
        for j in range(1, max_seq):
            output = self.Transformer.forward_decoder(X, enc, temp)
            output = torch.argmax(output, dim=-1)
            temp[:, j] = output[:, j - 1]
        #remove padding
        translations = []
        for translation in temp:
            temp2 = []
            for i in range(max_seq):
                if translation[i] == self.constants.PADDING_IDX:
                    break
                if i != 0:
                    temp2.append(translation[i])
            translations.append(temp2)
        return translations