コード例 #1
0
ファイル: translate.py プロジェクト: Kaixin-Wu/myTransformer
def greedy_test(args):
    """ Test function """

    # load vocabulary
    vocab = torch.load(args.vocab)

    # build model
    translator = Transformer(args, vocab)
    translator.eval()

    # load parameters
    translator.load_state_dict(torch.load(args.decode_model_path))
    if args.cuda:
        translator = translator.cuda()

    test_data = read_corpus(args.decode_from_file, source="src")
    # ['<BOS>', '<PAD>', 'PAD', '<PAD>', '<PAD>']
    pred_data = len(test_data) * [[
        constants.PAD_WORD if i else constants.BOS_WORD
        for i in range(args.decode_max_steps)
    ]]

    output_file = codecs.open(args.decode_output_file, "w", encoding="utf-8")
    for test, pred in zip(test_data, pred_data):
        pred_output = [constants.PAD_WORD] * args.decode_max_steps
        test_var = to_input_variable([test], vocab.src, cuda=args.cuda)

        # only need one time
        enc_output = translator.encode(test_var[0], test_var[1])
        for i in range(args.decode_max_steps):
            pred_var = to_input_variable([pred[:i + 1]],
                                         vocab.tgt,
                                         cuda=args.cuda)

            scores = translator.translate(enc_output, test_var[0], pred_var)

            _, argmax_idxs = torch.max(scores, dim=-1)
            one_step_idx = argmax_idxs[-1].item()

            pred_output[i] = vocab.tgt.id2word[one_step_idx]
            if (one_step_idx
                    == constants.EOS) or (i == args.decode_max_steps - 1):
                print("[Source] %s" % " ".join(test))
                print("[Predict] %s" % " ".join(pred_output[:i]))
                print()

                output_file.write(" ".join(pred_output[:i]) + "\n")
                output_file.flush()
                break
            pred[i + 1] = vocab.tgt.id2word[one_step_idx]

    output_file.close()
コード例 #2
0
ファイル: train.py プロジェクト: Kaixin-Wu/myTransformer
def init_training(args):
    """ Initialize training process """

    # load vocabulary
    vocab = torch.load(args.vocab)

    # build model
    transformer = Transformer(args, vocab)

    # if finetune
    if args.finetune:
        print("[Finetune] %s" % args.finetune_model_path)
        transformer.load_state_dict(torch.load(args.finetune_model_path))

    # vocab_mask for masking padding
    vocab_mask = torch.ones(len(vocab.tgt))
    vocab_mask[vocab.tgt[constants.PAD_WORD]] = 0

    # loss object
    cross_entropy_loss = nn.CrossEntropyLoss(weight=vocab_mask,
                                             size_average=False)

    if args.cuda:
        transformer = transformer.cuda()
        cross_entropy_loss = cross_entropy_loss.cuda()

    if args.optimizer == "Warmup_Adam":
        optimizer = ScheduledOptim(
            torch.optim.Adam(transformer.get_trainable_parameters(),
                             betas=(0.9, 0.98),
                             eps=1e-09), args.d_model, args.n_warmup_steps)

    if args.optimizer == "Adam":
        optimizer = torch.optim.Adam(
            params=transformer.get_trainable_parameters(),
            lr=args.lr,
            betas=(0.9, 0.98),
            eps=1e-8)

    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(
            params=transformer.get_trainable_parameters(), lr=args.lr)

    # multi gpus
    if torch.cuda.device_count() > 1:
        print("[Multi GPU] using", torch.cuda.device_count(), "GPUs\n")
        transformer = nn.DataParallel(transformer)

    return vocab, transformer, optimizer, cross_entropy_loss
コード例 #3
0
ファイル: transform.py プロジェクト: zxsted/torch_light
    def __init__(self, model_source, cuda=False, beam_size=3):
        self.torch = torch.cuda if cuda else torch
        self.cuda = cuda
        self.beam_size = beam_size

        if self.cuda:
            model_source = torch.load(model_source)
        else:
            model_source = torch.load(
                model_source, map_location=lambda storage, loc: storage)
        self.src_dict = model_source["src_dict"]
        self.tgt_dict = model_source["tgt_dict"]
        self.src_idx2word = {v: k for k, v in model_source["tgt_dict"].items()}
        self.args = args = model_source["settings"]
        model = Transformer(args)
        model.load_state_dict(model_source['model'])

        if self.cuda: model = model.cuda()
        else: model = model.cpu()
        self.model = model.eval()
コード例 #4
0
def main(gpu_id=None):
    dataset = Dataset(transform=transform, n_datas=10000)
    pad_vec = np.zeros(len(dataset.human_vocab))
    pad_vec[dataset.human_vocab['<pad>']] = 1
    dataloader = torch.utils.data.DataLoader(dataset=dataset,
                                             batch_size=6,
                                             shuffle=True,
                                             num_workers=6,
                                             collate_fn=partial(
                                                 collate_fn, pad_vec))

    model = Transformer(n_head=2)
    if gpu_id is not None:
        print('use gpu')
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id
        n_gpus = torch.cuda.device_count()
        # print('use %d gpu [%s]' % (n_gpus, gpu_id))
        model = model.cuda()
        # model = torch.nn.DataParallel(model, device_ids=[i for i in range(n_gpus)])
    # loss_fn = torch.nn.CrossEntropyLoss()
    loss_fn = torch.nn.MSELoss()

    optimizer = torch.optim.Adam(model.parameters())

    model = sl.load_model('./checkpoint', -1, model)
    optimizer = sl.load_optimizer('./checkpoint', -1, optimizer)

    try:
        trained_epoch = sl.find_last_checkpoint('./checkpoint')
        print('train form epoch %d' % (trained_epoch + 1))
    except Exception as e:
        print('train from the very begining, {}'.format(e))
        trained_epoch = -1
    for epoch in range(trained_epoch + 1, 20):
        train(model,
              loss_fn,
              optimizer,
              dataloader,
              epoch,
              use_gpu=True if gpu_id is not None else False)
コード例 #5
0
ファイル: translate.py プロジェクト: Kaixin-Wu/myTransformer
def test(args):
    """ Decode with beam search """

    # load vocabulary
    vocab = torch.load(args.vocab)

    # build model
    translator = Transformer(args, vocab)
    translator.eval()

    # load parameters
    translator.load_state_dict(torch.load(args.decode_model_path))
    if args.cuda:
        translator = translator.cuda()

    test_data = read_corpus(args.decode_from_file, source="src")
    output_file = codecs.open(args.decode_output_file, "w", encoding="utf-8")
    for test in test_data:
        test_seq, test_pos = to_input_variable([test],
                                               vocab.src,
                                               cuda=args.cuda)
        test_seq_beam = test_seq.expand(args.decode_beam_size,
                                        test_seq.size(1))

        enc_output = translator.encode(test_seq, test_pos)
        enc_output_beam = enc_output.expand(args.decode_beam_size,
                                            enc_output.size(1),
                                            enc_output.size(2))

        beam = Beam_Search_V2(beam_size=args.decode_beam_size,
                              tgt_vocab=vocab.tgt,
                              length_alpha=args.decode_alpha)
        for i in range(args.decode_max_steps):

            # the first time for beam search
            if i == 0:
                # <BOS>
                pred_var = to_input_variable(beam.candidates[:1],
                                             vocab.tgt,
                                             cuda=args.cuda)
                scores = translator.translate(enc_output, test_seq, pred_var)
            else:
                pred_var = to_input_variable(beam.candidates,
                                             vocab.tgt,
                                             cuda=args.cuda)
                scores = translator.translate(enc_output_beam, test_seq_beam,
                                              pred_var)

            log_softmax_scores = F.log_softmax(scores, dim=-1)
            log_softmax_scores = log_softmax_scores.view(
                pred_var[0].size(0), -1, log_softmax_scores.size(-1))
            log_softmax_scores = log_softmax_scores[:, -1, :]

            is_done = beam.advance(log_softmax_scores)
            beam.update_status()

            if is_done:
                break

        print("[Source] %s" % " ".join(test))
        print("[Predict] %s" % beam.get_best_candidate())
        print()

        output_file.write(beam.get_best_candidate() + "\n")
        output_file.flush()

    output_file.close()
コード例 #6
0
# build the model
if not args.universal:
    model = Transformer(SRC, TRG, args)
else:
    model = UniversalTransformer(SRC, TRG, args)

# logger.info(str(model))
if args.load_from is not None:
    with torch.cuda.device(args.gpu):
        model.load_state_dict(torch.load(args.models_dir + '/' + args.load_from + '.pt',
        map_location=lambda storage, loc: storage.cuda()))  # load the pretrained models.


# use cuda
if args.gpu > -1:
    model.cuda(args.gpu)

# additional information
args.__dict__.update({'model_name': model_name, 'hp_str': hp_str,  'logger': logger})

# show the arg:
arg_str = "args:\n"
for w in sorted(args.__dict__.keys()):
    if (w is not "U") and (w is not "V") and (w is not "Freq"):
        arg_str += "{}:\t{}\n".format(w, args.__dict__[w])
logger.info(arg_str)

if args.tensorboard and (not args.debug):
    from tensorboardX import SummaryWriter
    writer = SummaryWriter('{}/{}'.format(args.runs_dir, args.prefix + args.hp_str))
コード例 #7
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '-data',
        type=str,
        default='./data/data.pt',
        help=
        'Path to the source data. The default is ./data/data.pt, which is the output of preprocessing.'
    )
    parser.add_argument('-epoch', default=10000)
    parser.add_argument('-log_step', default=5)
    parser.add_argument('-save_model_epoch', default=1)
    parser.add_argument('-save_model_path', default='./saved_model/')
    args = parser.parse_args()

    dataset = torch.load(args.data)

    batch_size = 4
    src_vocab = dataset['dict']['src']
    tgt_vocab = dataset['dict']['tgt']
    print("\n\nBatch Size = %d" % batch_size)
    print("Source Vocab Size = %d" % len(src_vocab))
    print("Target Vocab Size = %d" % len(tgt_vocab))

    print("\nLoading Training Data ... ")
    training_batches = get_loader(src=dataset['train']['src'],
                                  tgt=dataset['train']['tgt'],
                                  src_vocabs=dataset['dict']['src'],
                                  tgt_vocabs=dataset['dict']['tgt'],
                                  batch_size=batch_size,
                                  use_cuda=True,
                                  shuffle=True)

    # print("\nLoading Validation Data ... ")
    # validation_data = get_loader(
    #     src=dataset['valid']['src'],
    #     tgt=dataset['valid']['tgt'],
    #     src_vocabs=dataset['dict']['src'],
    #     tgt_vocabs=dataset['dict']['tgt'],
    #     batch_size=batch_size,
    #     use_cuda=False,
    #     shuffle=False
    # )

    # For python 2
    transformer_config = [
        6, 512, 512, 8, batch_size,
        len(src_vocab),
        len(tgt_vocab), 100, 0.1, True
    ]

    # For python 3
    # transformer_config = {
    #     'N': 6,
    #     'd_model': int(512),
    #     'd_ff': 512,
    #     'H': 8,
    #     'batch_size': batch_size,
    #     'src_vocab_size': int(len(src_vocab)),
    #     'tgt_vocab_size': int(len(tgt_vocab)),
    #     'max_seq': 100,
    #     'dropout': 0.1,
    #     'use_cuda': True
    # }

    transformer = Transformer(transformer_config)
    if torch.cuda.is_available():
        print("CUDA enabled.")
        transformer.cuda()

    optimizer = optim.Adam(
        transformer.parameters(),
        lr=0.001,
        # betas=(0.9, 0.98),
        # eps=1e-09
    )

    criterion = nn.CrossEntropyLoss()

    # Prepare a txt file to print training log
    if not os.path.exists(args.save_model_path):
        print(
            "\nCreated a directory (%s) for saving model since it does not exist.\n"
            % args.save_model_path)
        os.makedirs(args.save_model_path)

    f = open('%s/train_log.txt' % args.save_model_path, 'w')

    # Train the model
    for e in range(args.epoch):
        for i, batch in enumerate(
                tqdm(training_batches,
                     mininterval=2,
                     desc='  Training  ',
                     leave=False)):
            # print ("BATCH")
            # print(batch[0][0])
            # exit()
            sources = to_var(batch[0])
            targets = to_var(batch[1])
            src_seq_len = targets.size()[1]
            tgt_seq_len = targets.size()[1]

            if torch.cuda.is_available():
                sources = sources.cuda()
                targets = targets.cuda()

            optimizer.zero_grad()
            outputs = transformer(sources, targets)

            # print("\n\n\n########### OUTPUT ###########")
            # print(len(outputs))
            # print(outputs.max(1)[1].data.tolist() )
            # exit()
            #
            # print("\n\n\n########### TARGET ###########")
            # print(len(targets))
            # print(targets)

            # print(" \n\n TARGETS %d " %i)
            # print(targets)
            # print(targets.contiguous().view(-1).long())
            # exit()

            targets = targets.contiguous().view(-1).long()
            loss = criterion(outputs, targets)

            # backprop
            loss.backward()

            # optimize params
            optimizer.step()

            # Print log info to both console and file
            if i % args.log_step == 0:
                print(
                    "\n\n\n\n#################################################################################"
                )
                log = (
                    'Epoch [%d/%d], Step [%d/%d], Loss: %.4f, Perplexity: %5.4f\n'
                    % (e, args.epoch, i, len(training_batches), loss.data[0],
                       np.exp(loss.data[0])))
                print(log)
                f.write("{}".format(log))

                # Print the first sentence of the batch (The first sentence of the batch)
                src_indices = sources.data.tolist(
                )[0][:src_seq_len]  # Variable -> Tensor -> List
                src_sentence = convert2text(src_indices,
                                            src_vocab)  # Get sentence

                pred_indices = outputs.max(
                    1)[1].data.tolist()  # Variable -> Tensor -> List
                pred_indices = [
                    i[0] for i in pred_indices[:tgt_seq_len]
                ]  # Get data of index until the max_seq_length of target (i.e. first sentence of the batch).
                pred_sentence = convert2text(pred_indices,
                                             tgt_vocab)  # Get sentence

                tgt_indices = targets.data.tolist(
                )[:tgt_seq_len]  # Variable -> Tensor -> List
                tgt_sentence = convert2text(tgt_indices,
                                            tgt_vocab)  # Get sentence

                original = ("ORIGINAL:  {}\n".format(src_sentence))
                predicted = ("PREDICTED: {}\n".format(pred_sentence))
                truth = ("TRUTH:     {}\n\n".format(tgt_sentence))
                print(original)
                print(predicted)
                print(truth)
                f.write("{}".format(original))
                f.write("{}".format(predicted))
                f.write("{}".format(truth))

        # Save the models
        if (e) % args.save_model_epoch == 0:
            torch.save(
                transformer.state_dict(),
                os.path.join(args.save_model_path,
                             'transformer-%d-%d.pkl' % (e + 1, i + 1)))
コード例 #8
0
    if (args.share_encoder) and (args.load_from is None):
        model.encoder = copy.deepcopy(teacher_model.encoder)
        for params in model.encoder.parameters():
            params.requires_grad = True

# use cuda
if args.gpu > -1:
    model.cuda(args.gpu)
    if align_table is not None:
        align_table = torch.LongTensor(align_table).cuda(args.gpu)
        align_table = Variable(align_table)
        model.alignment = align_table

    if args.teacher is not None:
        teacher_model.cuda(args.gpu)

# additional information
args.__dict__.update({
    'model_name': model_name,
    'hp_str': hp_str,
    'logger': logger
})

# ----------------------------------------------------------------------------------------------------------------- #
if args.mode == 'train':
    logger.info('starting training')
    train_model(args, model, train_real, dev_real, teacher_model)

elif args.mode == 'test':
    logger.info(
コード例 #9
0
def main(tokenizer, src_tok_file, tgt_tok_file, train_file, val_file,
         test_file, num_epochs, batch_size, d_model, nhead, num_encoder_layers,
         num_decoder_layers, dim_feedforward, dropout, learning_rate,
         data_path, checkpoint_file, do_train):
    logging.info('Using tokenizer: {}'.format(tokenizer))

    src_tokenizer = TokenizerWrapper(tokenizer, BLANK_WORD, SEP_TOKEN,
                                     CLS_TOKEN, PAD_TOKEN, MASK_TOKEN)
    src_tokenizer.train(src_tok_file, 20000, SPECIAL_TOKENS)

    tgt_tokenizer = TokenizerWrapper(tokenizer, BLANK_WORD, SEP_TOKEN,
                                     CLS_TOKEN, PAD_TOKEN, MASK_TOKEN)
    tgt_tokenizer.train(tgt_tok_file, 20000, SPECIAL_TOKENS)

    SRC = ttdata.Field(tokenize=src_tokenizer.tokenize, pad_token=BLANK_WORD)
    TGT = ttdata.Field(tokenize=tgt_tokenizer.tokenize,
                       init_token=BOS_WORD,
                       eos_token=EOS_WORD,
                       pad_token=BLANK_WORD)

    logging.info('Loading training data...')
    train_ds, val_ds, test_ds = ttdata.TabularDataset.splits(
        path=data_path,
        format='tsv',
        train=train_file,
        validation=val_file,
        test=test_file,
        fields=[('src', SRC), ('tgt', TGT)])

    test_src_sentence = val_ds[0].src
    test_tgt_sentence = val_ds[0].tgt

    MIN_FREQ = 2
    SRC.build_vocab(train_ds.src, min_freq=MIN_FREQ)
    TGT.build_vocab(train_ds.tgt, min_freq=MIN_FREQ)

    logging.info(f'''SRC vocab size: {len(SRC.vocab)}''')
    logging.info(f'''TGT vocab size: {len(TGT.vocab)}''')

    train_iter = ttdata.BucketIterator(train_ds,
                                       batch_size=batch_size,
                                       repeat=False,
                                       sort_key=lambda x: len(x.src))
    val_iter = ttdata.BucketIterator(val_ds,
                                     batch_size=1,
                                     repeat=False,
                                     sort_key=lambda x: len(x.src))
    test_iter = ttdata.BucketIterator(test_ds,
                                      batch_size=1,
                                      repeat=False,
                                      sort_key=lambda x: len(x.src))

    source_vocab_length = len(SRC.vocab)
    target_vocab_length = len(TGT.vocab)

    model = Transformer(d_model=d_model,
                        nhead=nhead,
                        num_encoder_layers=num_encoder_layers,
                        num_decoder_layers=num_decoder_layers,
                        dim_feedforward=dim_feedforward,
                        dropout=dropout,
                        source_vocab_length=source_vocab_length,
                        target_vocab_length=target_vocab_length)
    optim = torch.optim.Adam(model.parameters(),
                             lr=learning_rate,
                             betas=(0.9, 0.98),
                             eps=1e-9)
    model = model.cuda()

    if do_train:
        train_losses, valid_losses = train(train_iter, val_iter, model, optim,
                                           num_epochs, batch_size,
                                           test_src_sentence,
                                           test_tgt_sentence, SRC, TGT,
                                           src_tokenizer, tgt_tokenizer,
                                           checkpoint_file)
    else:
        logging.info('Skipped training.')

    # Load best model and score test set
    logging.info('Loading best model.')
    model.load_state_dict(torch.load(checkpoint_file))
    model.eval()
    logging.info('Scoring the test set...')
    score_start = time.time()
    test_bleu, test_chrf = score(test_iter, model, tgt_tokenizer, SRC, TGT)
    score_time = time.time() - score_start
    logging.info(f'''Scoring complete in {score_time/60:.3f} minutes.''')
    logging.info(f'''BLEU : {test_bleu}''')
    logging.info(f'''CHRF : {test_chrf}''')
コード例 #10
0
    # 输入和输出的最大长度
    src_len = 5  # enc_input max sequence length
    tgt_len = 6  # dec_input(=dec_output) max sequence length

    # 将数据转为id序列
    enc_inputs, dec_inputs, dec_outputs = make_data(sentences)

    loader = DataLoader(MyDataSet(enc_inputs, dec_inputs, dec_outputs),
                        batch_size=2,
                        shuffle=True)

    model = Transformer()

    # 指定多gpu运行
    if torch.cuda.is_available():
        model.cuda()

    if torch.cuda.device_count() > 1:
        args.n_gpu = torch.cuda.device_count()
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        # 就这一行
        model = nn.DataParallel(model)

    criterion = nn.CrossEntropyLoss(ignore_index=0)
    optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.99)

    for epoch in range(30):
        # 训练三十轮
        for enc_inputs, dec_inputs, dec_outputs in loader:
            '''
            enc_inputs: [batch_size, src_len]
コード例 #11
0
                                depth=1,
                                num_classes=2,
                                char_vocab_size=len(c2idx),
                                char_embed_dim=50)

transformer_parameters = sum(p.numel() for p in Transformer_model.parameters()
                             if p.requires_grad)
rnn_parameters = sum(p.numel() for p in RNNseq_model.parameters()
                     if p.requires_grad)
total_parameters = transformer_parameters + rnn_parameters
print(f'Number of parameters: {total_parameters}')

# Move the model to the GPU if available
if using_GPU:
    RNNseq_model = RNNseq_model.cuda()
    Transformer_model = Transformer_model.cuda()

# Set up criterion for calculating loss
weight_tensor = torch.Tensor([1.0, 2.0]).cuda()
loss_criterion = nn.NLLLoss(weight=weight_tensor)

rnn_optimizer = optim.Adam(RNNseq_model.parameters(), lr=0.005)
trans_optimizer = optim.Adam(Transformer_model.parameters(), lr=0.0001)

rnn_scheduler = optim.lr_scheduler.MultiStepLR(rnn_optimizer,
                                               milestones=[2, 5],
                                               gamma=0.3)
trans_scheduler = optim.lr_scheduler.MultiStepLR(trans_optimizer,
                                                 milestones=[2, 5],
                                                 gamma=0.3)
コード例 #12
0
ファイル: train.py プロジェクト: wangleiai/transformer
import numpy as np
import torch.nn.functional as F
import os
from tensorboardX import SummaryWriter
import mask
from performance import Performance
from my_optim import ScheduledOptim

SRC, TRG, train_iter, test_iter = preparedData(params.data_path,
                                               params.batch_size)
src_pad = SRC.vocab.stoi['<pad>']
trg_pad = TRG.vocab.stoi['<pad>']
model = Transformer(len(SRC.vocab), len(TRG.vocab), params.d_model,
                    params.n_layers, params.heads, params.dropout)
if params.is_cuda:
    model = model.cuda()

# print(model)
print('trg_vocal_len: ', len(TRG.vocab))
print('src_vocab_len: ', len(SRC.vocab))

vocab_to_json(TRG.vocab, params.word_json_file, params.trg_lang)
vocab_to_json(SRC.vocab, params.word_json_file, params.src_lang)
print("write data to json finished !")

# optimizer = torch.optim.Adam(model.parameters(), lr=params.lr, betas=(0.9, 0.98), eps=1e-9)
optimizer = ScheduledOptim(
    torch.optim.Adam(model.parameters(),
                     lr=params.lr,
                     betas=(0.9, 0.98),
                     eps=1e-09), params.d_model, params.n_warmup_steps)
コード例 #13
0
def main() -> None:
    """Entry point.
    """
    print("Start!!!")
    sys.stdout.flush()
    if args.run_mode == "train":
        train_data = MultiAlignedDataMultiFiles(config_data.train_data_params,
                                                device=device)
        #train_data = tx.data.MultiAlignedData(config_data.train_data_params, device=device)
        print("will data_iterator")
        data_iterator = tx.data.DataIterator({"train": train_data})
        print("data_iterator done")

        # Create model and optimizer
        model = Transformer(config_model, config_data, train_data.vocab('src'))
        model.to(device)
        print("device:", device)
        print("vocab src1:", train_data.vocab('src').id_to_token_map_py)
        print("vocab src2:", train_data.vocab('src').token_to_id_map_py)

        model = ModelWrapper(model, config_model.beam_width)
        if torch.cuda.device_count() > 1:
            #model = nn.DataParallel(model.cuda(), device_ids=[0, 1]).to(device)
            #model = MyDataParallel(model.cuda(), device_ids=[0, 1]).to(device)
            model = MyDataParallel(model.cuda()).to(device)

        lr_config = config_model.lr_config
        if lr_config["learning_rate_schedule"] == "static":
            init_lr = lr_config["static_lr"]
            scheduler_lambda = lambda x: 1.0
        else:
            init_lr = lr_config["lr_constant"]
            scheduler_lambda = functools.partial(
                get_lr_multiplier, warmup_steps=lr_config["warmup_steps"])
        optim = torch.optim.Adam(model.parameters(),
                                 lr=init_lr,
                                 betas=(0.9, 0.997),
                                 eps=1e-9)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optim, scheduler_lambda)

        output_dir = Path(args.output_dir)
        if not output_dir.exists():
            output_dir.mkdir()

        def _save_epoch(epoch):

            checkpoint_name = f"checkpoint{epoch}.pt"
            print(f"saveing model...{checkpoint_name}")
            torch.save(model.state_dict(), output_dir / checkpoint_name)

        def _train_epoch(epoch):
            data_iterator.switch_to_dataset('train')
            model.train()
            #model.module.train()
            #print("after model.module.train")
            sys.stdout.flush()
            step = 0
            num_steps = len(data_iterator)
            loss_stats = []
            for batch in data_iterator:
                #print("batch:", batch)
                #batch = batch.to(device)
                return_dict = model(batch)
                #return_dict = model.module.forward(batch)
                loss = return_dict['loss']
                #print("loss:", loss)
                loss = loss.mean()
                #print("loss:", loss)
                #print("loss.item():", loss.item())
                loss_stats.append(loss.item())

                optim.zero_grad()
                loss.backward()
                optim.step()
                scheduler.step()

                config_data.display = 1
                if step % config_data.display == 0:
                    avr_loss = sum(loss_stats) / len(loss_stats)
                    ppl = utils.get_perplexity(avr_loss)
                    print(
                        f"epoch={epoch}, step={step}/{num_steps}, loss={avr_loss:.4f}, ppl={ppl:.4f}, lr={scheduler.get_lr()[0]}"
                    )
                    sys.stdout.flush()
                step += 1

        print("will train")
        for i in range(config_data.num_epochs):
            print("epoch i:", i)
            sys.stdout.flush()
            _train_epoch(i)
            _save_epoch(i)

    elif args.run_mode == "test":
        test_data = tx.data.MultiAlignedData(config_data.test_data_params,
                                             device=device)
        data_iterator = tx.data.DataIterator({"test": test_data})
        print("test_data vocab src1 before load:",
              test_data.vocab('src').id_to_token_map_py)

        # Create model and optimizer
        model = Transformer(config_model, config_data, test_data.vocab('src'))

        model = ModelWrapper(model, config_model.beam_width)
        #print("state_dict:", model.state_dict())
        model_loaded = torch.load(args.load_checkpoint)
        #print("model_loaded state_dict:", model_loaded)
        model_loaded = rm_begin_str_in_keys("module.", model_loaded)
        #print("model_loaded2 state_dict:", model_loaded)

        model.load_state_dict(model_loaded)
        #model.load_state_dict(torch.load(args.load_checkpoint))
        model.to(device)

        data_iterator.switch_to_dataset('test')
        model.eval()
        print("will predict !!!")
        sys.stdout.flush()

        fo = open(args.pred_output_file, "w")
        print("test_data vocab src1:",
              test_data.vocab('src').id_to_token_map_py)
        print("test_data vocab src2:",
              test_data.vocab('src').token_to_id_map_py)
        with torch.no_grad():
            for batch in data_iterator:
                print("batch:", batch)
                return_dict = model.predict(batch)
                preds = return_dict['preds'].cpu()
                print("preds:", preds)
                pred_words = tx.data.map_ids_to_strs(preds,
                                                     test_data.vocab('src'))
                #src_words = tx.data.map_ids_to_strs(batch['src_text'], test_data.vocab('src'))
                src_words = [" ".join(sw) for sw in batch['src_text']]
                for swords, words in zip(src_words, pred_words):
                    print(str(swords) + "\t" + str(words))
                    fo.write(str(words) + "\n")
                #print(" ".join(batch.src_text) + "\t" + pred_words)
                #print(batch.src_text, pred_words)
                #fo.write(str(pred_words) + "\n")
                fo.flush()
        fo.close()

    else:
        raise ValueError(f"Unknown mode: {args.run_mode}")