)

if __name__ == '__main__':
    V = 11
    criterion = LabelSmoothing(
        size        = V,
        padding_idx = 0,
        smoothing   = 0.0
    )
    model = make_model(V, V, N = 2)
    model_opt = NoamOpt(
        model.src_embed[0].d_model,
        1,
        400,
        torch.optim.Adam(
            model.parameters(),
            lr = 0,
            betas = (0.9, 0.98),
            eps = 1e-9
        )
    )

    for epoch in range(10):
        model.train()
        run_epoch(
            data_gen(
                V, 30, 20
            ),
            model,
            SimpleLossCompute(
                model.generator,
示例#2
0
def train(config):
    # os.environ['CUDA_VISIBLE_DEVICES'] = config.gpu_id
    if config.gpu_id != '-1':
        device = 'cuda'
    else:
        device = 'cpu'
    if config.load_vocab is not None:
        train_data, val_data, en, vi = get_dataloader(
            config.data_dir,
            batch_size=config.batch_size,
            device=device,
            reload=config.load_vocab)
    else:
        train_data, val_data, en, vi = get_dataloader(
            config.data_dir,
            batch_size=config.batch_size,
            device=device,
            save_path=config.snapshots_folder)
    # train_data, val_data, en, vi = get_dataloader(config.data_dir, batch_size=config.batch_size, device=device)
    src_pad_idx = en.vocab.stoi[en.pad_token]
    trg_pad_idx = vi.vocab.stoi[vi.pad_token]
    print('vocab size: en:', len(en.vocab.stoi), 'vi:', len(vi.vocab.stoi))
    model = Transformer(len(en.vocab.stoi),
                        len(vi.vocab.stoi),
                        src_pad_idx,
                        trg_pad_idx,
                        device,
                        d_model=256,
                        n_layers=5)
    model = model.to(device)
    model.apply(initialize_weights)
    print('Model parameter: ', count_parameters(model))
    if config.pretrain_model != "":
        model.load_state_dict(torch.load(config.pretrain_model))
    criterion = torch.nn.CrossEntropyLoss(ignore_index=trg_pad_idx)
    # todo warm up cool down lr
    optimizer = NoamOpt(
        512, 1, 2000,
        torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98),
                         eps=1e-9))
    # optimizer = torch.optim.Adam(model.parameters(), lr=config.lr, betas=(0.9, 0.98), eps=1e-9)
    best_loss = 100
    # count = 0
    # epoch_loss = 0
    for i in range(config.num_epochs):
        # model.train()
        # for j, batch in tqdm(enumerate(train_data)):
        train_loss = train_one_iter(model, train_data, optimizer, criterion,
                                    config.grad_clip_norm, device, en, vi)
        writer.add_scalar('train', train_loss, i)
        print('train_avg:', train_loss)

        # epoch_loss += train_loss
        # count += 1
        # if count % config.snapshot_iter == 0:
        torch.save(
            model.state_dict(),
            os.path.join(config.snapshots_folder, "Epoch_" + str(i) + '.pth'))
        val_loss = evaluate(model, val_data, criterion, device)
        writer.add_scalar('val loss', val_loss, i)
        print('val_loss:', val_loss)
        if val_loss < best_loss:
            torch.save(model.state_dict(),
                       os.path.join(config.snapshots_folder, "best.pth"))
示例#3
0
文件: show.py 项目: royyoung388/srl
import matplotlib.pyplot as plt
import numpy as np

from optimizer import NoamOpt

if __name__ == '__main__':
    opts = [
        NoamOpt(200, 1, 400, None),
        NoamOpt(200, 1, 800, None),
        NoamOpt(100, 1, 400, None)
    ]
    plt.subplot(221)
    plt.plot(np.arange(1, 2000),
             [[opt.rate(i) for opt in opts] for i in range(1, 2000)])
    plt.legend(["200:1:400", "200:1:800", "100:1:400"])

    opts = [
        NoamOpt(200, 4, 400, None),
        NoamOpt(200, 4, 800, None),
        NoamOpt(100, 4, 400, None)
    ]
    plt.subplot(222)
    plt.plot(np.arange(1, 2000),
             [[opt.rate(i) for opt in opts] for i in range(1, 2000)])
    plt.legend(["200:4:400", "200:4:800", "100:4:400"])

    opts = [
        NoamOpt(300, 1, 400, None),
        NoamOpt(300, 1, 800, None),
        NoamOpt(200, 1, 400, None)
    ]
示例#4
0
文件: translate.py 项目: jinkilee/LaH
def main():
    args = parser.parse_args()
    torch.cuda.set_device(args.gpu)

    # load dataset
    sent_pairs = load_dataset(
        path='/heavy_data/jkfirst/workspace/git/LaH/dataset/')
    #sent_pairs = list(map(lambda x: remove_bos_eos(x), sent_pairs))

    # split train/valid sentence pairs
    n_train = int(len(sent_pairs) * 0.8)
    valid_sent_pairs = sent_pairs[n_train:]

    # make dataloader with dataset
    # FIXME: RuntimeError: Internal: unk is not defined.
    src_spm, trg_spm = get_sentencepiece(src_prefix,
                                         trg_prefix,
                                         src_cmd=src_cmd,
                                         trg_cmd=trg_cmd)
    valid_dataset = TranslationDataset(valid_sent_pairs, src_spm, trg_spm)
    valid_dataloader = DataLoader(valid_dataset,
                                  batch_size=100,
                                  collate_fn=set_padding)

    # fix torch randomness
    fix_torch_randomness()

    # Train the simple copy task.
    args.inp_n_words = src_vocab_size
    args.out_n_words = trg_vocab_size
    model = make_model(args.inp_n_words, args.out_n_words)

    criterion = LabelSmoothing(size=args.out_n_words,
                               padding_idx=0,
                               smoothing=0.0)
    optimizer = NoamOpt(
        model.src_embed[0].d_model, 1, 400,
        torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98),
                         eps=1e-9))

    # load model
    model_name_full_path = './models/model-tmp.bin'
    device_pairs = zip([0], [args.gpu])
    map_location = {
        'cuda:{}'.format(x): '{}'.format('cuda:{}'.format(y))
        for x, y in device_pairs
    }
    checkpoint = torch.load(model_name_full_path, map_location=map_location)
    state_dict = checkpoint['state_dict']
    model.load_state_dict(state_dict)

    sum_of_weight = sum([p[1].data.sum() for p in model.named_parameters()])
    log.info('model was successfully loaded: {:.4f}'.format(sum_of_weight))

    # make gpu-distributed model
    device = torch.device(
        'cuda:{}'.format(args.gpu) if torch.cuda.is_available() else 'cpu')
    model.to(device)

    original_input, translated_pred, translated_lbl = do_translate(
        valid_dataloader, model, SimpleTranslation(model.generator), device, 0)

    original_input = list(map(src_spm.DecodeIds, original_input))
    translated_pred = list(map(trg_spm.DecodeIds, translated_pred))
    translated_label = list(map(trg_spm.DecodeIds, translated_lbl))

    with open('output/model-tmp.out', 'w', encoding='utf-8') as out_f:
        for src, pred, trg in zip(original_input, translated_pred,
                                  translated_label):
            src = ''.join(src)
            pred = ''.join(pred)
            trg = ''.join(trg)
            out_f.write('input: {}\n'.format(src))
            out_f.write('pred : {}\n'.format(pred))
            out_f.write('label: {}\n'.format(trg))
            out_f.write('----------\n')

    #translated_pred_text = list(map(lambda x: ''.join(x), translated_pred))
    #translated_label_text = list(map(lambda x: ''.join(x), translated_label))
    translated_pred = list(
        map(lambda x: trg_spm.EncodeAsPieces(x), translated_pred))
    translated_label = list(
        map(lambda x: trg_spm.EncodeAsPieces(x), translated_label))

    min_pred_len = min([len(p) for p in translated_pred])
    min_label_len = min([len(p) for p in translated_label])
    log.debug('{} {}'.format(min_pred_len, min_label_len))
    log.debug('{}'.format(translated_pred[:4]))
    log.debug('{}'.format(translated_label[:4]))

    bleu = bleu_score(translated_pred, translated_label)
    log.info('bleu score: {:.4f}'.format(bleu))

    bleu = bleu_score(translated_pred, translated_pred)
    log.info('bleu score: {:.4f}'.format(bleu))
示例#5
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu
    torch.cuda.set_device(args.gpu)

    dist.init_process_group(backend='nccl',
                            init_method='env://',
                            world_size=args.world_size,
                            rank=args.gpu)

    # load dataset
    #sent_pairs = load_dataset_aihub()
    sent_pairs = load_dataset_aihub(path='data/')
    random.seed(100)
    random.shuffle(sent_pairs)
    log.debug('GPU#{} seeding with {}'.format(args.gpu, args.gpu))

    # make dataloader with dataset
    # FIXME: RuntimeError: Internal: unk is not defined.
    inp_lang, out_lang = get_sentencepiece(src_prefix, trg_prefix)
    log.info('loaded input sentencepiece model: {}'.format(src_prefix))
    log.info('loaded output sentencepiece model: {}'.format(trg_prefix))

    # split train/valid sentence pairs
    n_train = int(len(sent_pairs) * 0.8)
    n_split = int(n_train * 0.25)
    train_sent_pairs = sent_pairs[:n_train]
    log.info('train_sent_pairs: {}'.format(len(train_sent_pairs)))
    train_sent_pairs = train_sent_pairs[:args.gpu *
                                        n_split] + train_sent_pairs[
                                            (args.gpu + 1) * n_split:]
    valid_sent_pairs = sent_pairs[n_train:]
    train_sent_pairs = sorted(train_sent_pairs,
                              key=lambda x: (len(x[0]), len(x[1])))
    #log.info('train_sent_pairs: {}'.format(len(train_sent_pairs)))
    log.info('valid_sent_pairs: {}'.format(len(valid_sent_pairs)))

    # these are used for defining tokenize method and some reserved words
    SRC = KRENField(
        #tokenize=inp_lang.EncodeAsPieces,
        pad_token='<pad>')
    TRG = KRENField(
        #tokenize=out_lang.EncodeAsPieces,
        #init_token='<s>',
        #eos_token='</s>',
        pad_token='<pad>')

    # load SRC/TRG
    if not os.path.exists('spm/{}.model'.format(src_prefix)) or \
     not os.path.exists('spm/{}.model'.format(trg_prefix)):
        # build vocabulary
        SRC.build_vocab(train.src)
        TRG.build_vocab(train.trg)
        torch.save(SRC.vocab,
                   'spm/{}.spm'.format(src_prefix),
                   pickle_module=dill)
        torch.save(TRG.vocab,
                   'spm/{}.spm'.format(trg_prefix),
                   pickle_module=dill)
        log.info(
            'input vocab was created and saved: spm/{}.spm'.format(src_prefix))
        log.info('output vocab was created and saved: spm/{}.spm'.format(
            trg_prefix))
    else:
        src_vocab = torch.load('spm/{}.spm'.format(src_prefix),
                               pickle_module=dill)
        trg_vocab = torch.load('spm/{}.spm'.format(trg_prefix),
                               pickle_module=dill)
        SRC.vocab = src_vocab
        TRG.vocab = trg_vocab
        log.info('input vocab was loaded: spm/{}.spm'.format(src_prefix))
        log.info('output vocab was loaded: spm/{}.spm'.format(trg_prefix))
    # define tokenizer
    #SRC.tokenize = src_tokenize
    #TRG.tokenize = trg_tokenize

    # make dataloader from KRENDataset
    train, valid, test = KRENDataset.splits(sent_pairs, (SRC, TRG),
                                            inp_lang,
                                            out_lang,
                                            encoding_type='pieces')
    # output -> ['<s>', '▁', 'Central', '▁Asian', '▁c', 'u', 'is', ... '▁yesterday', '.', '</s>']
    train_iter = MyIterator(train,
                            batch_size=1024,
                            device=0,
                            repeat=False,
                            sort_key=lambda x: (len(x.src), len(x.trg)),
                            batch_size_fn=batch_size_fn,
                            train=True)
    valid_iter = MyIterator(valid,
                            batch_size=100,
                            device=0,
                            repeat=False,
                            sort_key=lambda x: (len(x.src), len(x.trg)),
                            batch_size_fn=batch_size_fn,
                            train=False)
    # fix torch randomness
    fix_torch_randomness()

    # define input/output size
    args.inp_n_words = src_vocab_size
    args.out_n_words = trg_vocab_size
    log.info('inp_n_words: {} out_n_words: {}'.format(args.inp_n_words,
                                                      args.out_n_words))

    # define model
    model = make_model(args.inp_n_words, args.out_n_words)

    # FIXME: need some good condition for multi distributed computing
    if True:
        model.cuda()
        model = DDP(model, device_ids=[args.gpu])
    exit()

    # define model
    criterion = LabelSmoothing(size=args.out_n_words,
                               padding_idx=0,
                               smoothing=0.0)
    # FIXME: need some good condition for multi distributed computing
    if True:
        criterion.cuda()

    # define optimizer
    optimizer = NoamOpt(model_size=model.module.src_embed[0].d_model,
                        factor=1,
                        warmup=400,
                        optimizer=torch.optim.Adam(model.parameters(),
                                                   lr=0,
                                                   betas=(0.9, 0.98),
                                                   eps=1e-9))

    # initial best loss
    best_val_loss = np.inf

    # initialize visdom graph
    vis_train = Visdom()
    vis_valid = Visdom()

    train_loss_list = []
    valid_loss_list = []
    for epoch in range(args.epochs):
        train_losses = do_train(
            (rebatch(pad_id, b) for b in train_iter), model,
            SimpleLossCompute(model.module.generator, criterion,
                              opt=optimizer), epoch, (SRC, TRG))
        valid_loss = do_valid((rebatch(pad_id, b) for b in valid_iter), model,
                              SimpleLossCompute(model.module.generator,
                                                criterion,
                                                opt=optimizer), epoch)

        if args.gpu == 0:
            if valid_loss >= best_val_loss:
                log.info('Try again. Current best is still {:.4f}'.format(
                    best_val_loss))
            else:
                log.info('New record. from {:.4f} to {:.4f}'.format(
                    best_val_loss, valid_loss))
                best_val_loss = valid_loss
                do_save(model, optimizer, epoch, best_val_loss)
    train_loss_list = np.array(train_loss_list)
    valid_loss_list = np.array(valid_loss_list)

    # draw visdom graph
    vis_train.line(Y=train_loss_list, X=np.arange(len(train_loss_list)) * 50)
    vis_valid.line(Y=valid_loss_list, X=np.arange(len(valid_loss_list)))
示例#6
0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

model.to(device)
model.train()
model.zero_grad()
criterion = torch.nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
optimizer = NoamOpt(model_size=args.d_emb,
                    factor=2,
                    warmup=4000,
                    optimizer=torch.optim.Adam(model.parameters(),
                                               lr=0,
                                               betas=(0.9, 0.98),
                                               eps=1e-9))

epoch = args.epoch
total_loss = 0
iteration = 0
for i in range(epoch):
    epoch_iterator = tqdm(data_loader, desc=f'epoch: {i}, loss: {0:.6f}')
    for src, tgt, y in epoch_iterator:
        iteration += 1
        src = torch.LongTensor(src).to(device)
        tgt = torch.LongTensor(tgt).to(device)
        y = torch.LongTensor(y).to(device)
示例#7
0
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    parser = argparse.ArgumentParser()
    parser.add_argument('-epoch', type=int, default=10)
    parser.add_argument('-b', '--batch_size', type=int, default=2048)
    parser.add_argument('-d_model', type=int, default=512)
    parser.add_argument('-d_inner_hid', type=int, default=2048)
    parser.add_argument('-d_k', type=int, default=64)
    parser.add_argument('-d_v', type=int, default=64)
    parser.add_argument('-n_head', type=int, default=8)
    parser.add_argument('-n_layers', type=int, default=6)
    parser.add_argument('-lr_mul', type=float, default=2.0)
    parser.add_argument('-dropout', type=float, default=0.1)
    parser.add_argument('-output_dir', type=str, default=None)
    parser.add_argument('-warmup', '--n_warmup_steps', type=int, default=4000)

    opt = parser.parse_args()

    english = Field(sequential=True,
                    use_vocab=True,
                    tokenize=tokenize_eng,
                    lower=True,
                    pad_token='<blank>',
                    init_token='<s>',
                    eos_token='</s>')

    german = Field(sequential=True,
                   use_vocab=True,
                   tokenize=tokenize_ger,
                   lower=True,
                   pad_token='<blank>',
                   init_token='<s>',
                   eos_token='</s>')

    fields = {'English': ('eng', english), 'German': ('ger', german)}
    train_data, test_data = TabularDataset.splits(path='',
                                                  train='train.json',
                                                  test='test.json',
                                                  format='json',
                                                  fields=fields)

    english.build_vocab(train_data, max_size=1000, min_freq=1)
    print('[Info] Get source language vocabulary size:', len(english.vocab))

    german.build_vocab(train_data, max_size=1000, min_freq=1)
    print('[Info] Get target language vocabulary size:', len(german.vocab))

    batch_size = opt.batch_size
    # data = pickle.load(open(opt.data_file, 'rb'))

    opt.src_pad_idx = english.vocab.stoi['<blank>']
    opt.trg_pad_idx = german.vocab.stoi['<blank>']

    opt.src_vocab_size = len(english.vocab)
    opt.trg_vocab_size = len(german.vocab)

    devices = [0, 1, 2, 3]
    pad_idx = opt.trg_vocab_size
    model = make_model(len(english.vocab), len(german.vocab), N=6)
    model.cuda()
    criterion = LabelSmoothing(size=len(german.vocab),
                               padding_idx=pad_idx,
                               smoothing=0.1)
    criterion.cuda()
    BATCH_SIZE = 12000
    train_iter = MyIterator(train_data,
                            batch_size=BATCH_SIZE,
                            device=0,
                            repeat=False,
                            sort_key=lambda x: (len(x.eng), len(x.ger)),
                            batch_size_fn=batch_size_fn,
                            train=True)
    valid_iter = MyIterator(test_data,
                            batch_size=BATCH_SIZE,
                            device=0,
                            repeat=False,
                            sort_key=lambda x: (len(x.eng), len(x.ger)),
                            batch_size_fn=batch_size_fn,
                            train=False)
    model_par = nn.DataParallel(model, device_ids=devices)

    model_opt = NoamOpt(
        model.src_embed[0].d_model, 1, 2000,
        torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98),
                         eps=1e-9))
    for epoch in range(10):
        model_par.train()
        run_epoch((rebatch(pad_idx, b) for b in train_iter), model_par,
                  MultiGPULossCompute(model.generator,
                                      criterion,
                                      devices=devices,
                                      opt=model_opt))
        model_par.eval()
        loss = run_epoch((rebatch(pad_idx, b) for b in valid_iter), model_par,
                         MultiGPULossCompute(model.generator,
                                             criterion,
                                             devices=devices,
                                             opt=None))
        print(loss)

    for i, batch in enumerate(valid_iter):
        src = batch.src.transpose(0, 1)[:1]
        src_mask = (src != english.vocab.stoi["<blank>"]).unsqueeze(-2)
        out = greedy_decode(model,
                            src,
                            src_mask,
                            max_len=60,
                            start_symbol=german.vocab.stoi["<s>"])
        print("Translation:", end="\t")
        for i in range(1, out.size(1)):
            sym = german.vocab.itos[out[0, i]]
            if sym == "</s>": break
            print(sym, end=" ")
        print()
        print("Target:", end="\t")
        for i in range(1, batch.trg.size(0)):
            sym = german.vocab.itos[batch.trg.data[i, 0]]
            if sym == "</s>": break
            print(sym, end=" ")
        print()
        break
def train(arg_parser):
    logs_path = os.path.join(arg_parser.log_dir, arg_parser.subdir)
    if not os.path.isdir(logs_path):
        os.makedirs(logs_path)
    file_name_epoch_indep = get_model_name(arg_parser)
    recombination = arg_parser.recombination_method

    vocab = Vocab(f'bert-{arg_parser.BERT}-uncased')
    model_type = TSP if arg_parser.TSP_BSP else BSP
    model = model_type(input_vocab=vocab,
                       target_vocab=vocab,
                       d_model=arg_parser.d_model,
                       d_int=arg_parser.d_int,
                       d_k=arg_parser.d_k,
                       h=arg_parser.h,
                       n_layers=arg_parser.n_layers,
                       dropout_rate=arg_parser.dropout,
                       max_len_pe=arg_parser.max_len_pe,
                       bert_name=arg_parser.BERT)

    file_path = os.path.join(
        arg_parser.models_path,
        f"{file_name_epoch_indep}_epoch_{arg_parser.epoch_to_load}.pt")
    if arg_parser.train_load:
        train_dataset = get_dataset_finish_by(arg_parser.data_folder, 'train',
                                              f"600_entity_recomb.tsv")
        test_dataset = get_dataset_finish_by(arg_parser.data_folder, 'dev',
                                             f"100_entity_recomb.tsv")
        load_model(file_path=file_path, model=model)
        #load_model(file_path=os.path.join('models_to_keep', 'BSP_d_model256_layers4_recombentity+nesting+concat2_extrastrain1800_extrasdev300_epoch_75.pt'), model=model)
        print('loaded model')
    else:
        train_dataset = get_dataset_finish_by(
            arg_parser.data_folder, 'train',
            f"{600 + arg_parser.extras_train}_{recombination}_recomb.tsv")
        test_dataset = get_dataset_finish_by(
            arg_parser.data_folder, 'dev',
            f"{100 + arg_parser.extras_dev}_{recombination}_recomb.tsv")
    model.train()

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    if arg_parser.optimizer:
        optimizer = NoamOpt(model_size=arg_parser.d_model, factor=1, warmup=arg_parser.warmups_steps, \
                            optimizer=torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=arg_parser.lr)
    model.device = device
    summary_writer = SummaryWriter(
        log_dir=logs_path) if arg_parser.log else None

    n_train = len(train_dataset)
    n_test = len(test_dataset)

    for epoch in range(arg_parser.epochs):
        running_loss = 0.0
        last_log_time = time.time()

        # Training
        train_loss = 0.0
        for batch_idx, batch_examples in enumerate(
                data_iterator(train_dataset,
                              batch_size=arg_parser.batch_size,
                              shuffle=arg_parser.shuffle)):
            if ((batch_idx % 100) == 0) and batch_idx > 1:
                print(
                    "epoch {} | batch {} | mean running loss {:.2f} | {:.2f} batch/s"
                    .format(epoch, batch_idx, running_loss / 100,
                            100 / (time.time() - last_log_time)))
                last_log_time = time.time()
                running_loss = 0.0

            sources, targets = batch_examples[0], batch_examples[1]
            example_losses = -model(sources, targets)  # (batch_size,)
            batch_loss = example_losses.sum()
            loss = batch_loss / arg_parser.batch_size

            # clip gradient
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       arg_parser.clip_grad)

            if arg_parser.optimizer:
                loss.backward()
                optimizer.step()
                optimizer.optimizer.zero_grad()
            else:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            # add loss
            running_loss += loss.item()
            train_loss += loss.item()

        print("Epoch train loss : {}".format(
            math.sqrt(train_loss /
                      math.ceil(n_train / arg_parser.batch_size))))

        if summary_writer is not None:
            summary_writer.add_scalar(
                "train/loss",
                train_loss / math.ceil(n_train / arg_parser.batch_size),
                global_step=epoch)
        if (epoch % arg_parser.save_every
                == 0) and arg_parser.log and epoch > 0:
            if arg_parser.train_load:
                save_model(
                    arg_parser.models_path,
                    f"{file_name_epoch_indep}_epoch_{epoch + arg_parser.epoch_to_load}.pt",
                    model, device)
            else:
                save_model(
                    arg_parser.models_path,
                    f"{file_name_epoch_indep}_epoch_{epoch + arg_parser.epoch_to_load}.pt",
                    model, device)
        ## TEST
        test_loss = 0.0

        for batch_idx, batch_examples in enumerate(
                data_iterator(test_dataset,
                              batch_size=arg_parser.batch_size,
                              shuffle=arg_parser.shuffle)):
            with torch.no_grad():
                sources, targets = batch_examples[0], batch_examples[1]
                example_losses = -model(sources, targets)  # (batch_size,)
                batch_loss = example_losses.sum()
                loss = batch_loss / arg_parser.batch_size

                test_loss += loss.item()

        if summary_writer is not None:
            summary_writer.add_scalar(
                "test/loss",
                test_loss / math.ceil(n_test / arg_parser.batch_size),
                global_step=epoch)
        print("TEST loss | epoch {} | {:.2f}".format(
            epoch, test_loss / math.ceil(n_test / arg_parser.batch_size)))

    return None