Example #1
0
def run(set_path, entry_path, _set):
    entryset = parser.run_parser(set_path)  # parse
    entryset = order.run(entryset, 'en')  # order english part
    nmt = NMT(entryset, _set)  # translate to german
    nmt.preprocess()

    p.dump(entryset, open(entry_path, 'w'))
Example #2
0
def run(entry_path, set_path, en_path, de_path, _set):
    entryset = p.load(open(entry_path, 'rb'))
    if de_path != '':
        nmt = NMT(entryset, _set) # translate to german
        entryset = nmt.postprocess()
        entryset = order.run(entryset, 'de') # order german

    # referring expressions
    entryset = reg.run(entryset, 'en')

    # lexicalization
    template = TemplateExtraction()
    entryset = template(entryset, 'en')
    template.close()

    lexsize, templates, templates_de, entities, references = stats.run(entryset)

    # run xml generator
    parser.run_generator(entryset=entryset, input_dir=set_path, output_dir=en_path, lng='en')
    if de_path != '':
        parser.run_generator(entryset=entryset, input_dir=set_path, output_dir=de_path, lng='de')

    # extract and generate templates based on sentence segmentation
    # en_temp = template.run(entryset)
    # json.dump(en_temp, open(os.path.join(en_path, 'templates.json'), 'w'), indent=4, separators=(',', ': '))
    #
    # de_temp = template.run(entryset, 'de')
    # json.dump(de_temp, open(os.path.join(de_path, 'templates.json'), 'w'), indent=4, separators=(',', ': '))
    return lexsize, templates, templates_de, entities, references
Example #3
0
def beam_search(model: NMT, test_data_src: List[List[str]], beam_size: int, max_decoding_time_step: int) -> List[List[Hypothesis]]:
    was_training = model.training
    model.eval()

    hypotheses = []
    with torch.no_grad():
        for src_sent in tqdm(test_data_src, desc='Decoding', file=sys.stdout):
            example_hyps = model.beam_search(src_sent, beam_size=beam_size, max_decoding_time_step=max_decoding_time_step)

            hypotheses.append(example_hyps)

    if was_training: model.train(was_training)

    return hypotheses
Example #4
0
def decode(args: Dict[str, str]):
    """
    performs decoding on a test set, and save the best-scoring decoding results.
    If the target gold-standard sentences are given, the function also computes
    corpus-level BLEU score.
    """
    test_src_dir = os.path.join(args.test_dir, args.input_col.lower())
    test_tgt_dir = os.path.join(args.test_dir, args.output_col.lower())

    print(f"load test source sentences from [{test_src_dir}]", file=sys.stderr)
    test_data_src = read_corpus(test_src_dir, source='src')
    if test_tgt_dir:
        print(f"load test target sentences from [{test_tgt_dir}]", file=sys.stderr)
        test_data_tgt = read_corpus(test_tgt_dir, source='tgt')

    model_path = os.path.join(args.model_dir, 'model.bin')
    print(f"load model from {model_path}", file=sys.stderr)
    model = NMT.load(model_path)

    if args.cuda:
        model = model.to(torch.device("cuda:0"))

    hypotheses = beam_search(model, test_data_src,
                             beam_size=int(args.beam_size),
                             max_decoding_time_step=int(args.max_decoding_time_step))

    top_hypotheses = [hyps[0] for hyps in hypotheses]
    bleu_score = compute_corpus_level_bleu_score(test_data_tgt, top_hypotheses)
    print(f'Corpus BLEU: {bleu_score}', file=sys.stderr)

    output_path = os.path.join(args.eval_dir, 'decode.txt')
    with open(output_path, 'w') as f:
        f.write(str(bleu_score))
def main(argv):

    if FLAGS.reparse_vocab:
        logging.info('Loading data...')
        en_full, ch_full = read_txt(FLAGS.data_file_path)

    logging.info('Building/Loading vocab...')
    vocab_hub = VocabHub()
    if FLAGS.reparse_vocab:
        vocab_hub.build(en_full, ch_full)
        vocab_hub.save()
    else:
        vocab_hub.load()

    ch_num_vocab = len(vocab_hub.ch.word_dict)
    en_num_vocab = len(vocab_hub.en.word_dict)
    enc_vocab_size = en_num_vocab if FLAGS.is_en_to_ch else ch_num_vocab
    dec_vocab_size = ch_num_vocab if FLAGS.is_en_to_ch else en_num_vocab
    enc_type = 'English' if FLAGS.is_en_to_ch else 'Chinese'
    dec_type = 'Chinese' if FLAGS.is_en_to_ch else 'English'
    logging.info('({})Encoder vocab size: {}'.format(enc_type, enc_vocab_size))
    logging.info('({})Decoder vocab size: {}'.format(dec_type, dec_vocab_size))

    logging.info('Preprocessing datasets...')
    dataset_generator = VocabDataset()
    if FLAGS.reparse_vocab:
        dataset_generator.build(en_full, ch_full, vocab_hub)
        dataset_generator.save()
    else:
        dataset_generator.load()

    logging.info('Performing train/test/val split...')
    if FLAGS.is_en_to_ch:
        enc_processed, dec_processed = (dataset_generator.en_processed,
                                        dataset_generator.ch_processed)
    else:
        enc_processed, dec_processed = (dataset_generator.ch_processed,
                                        dataset_generator.en_processed)
    enc_processed = [s[1:-1] for s in enc_processed]
    enc_dec_ds = list(zip(enc_processed, dec_processed))
    train_ds, test_ds = train_test_split(enc_dec_ds,
                                         test_size=0.003,
                                         random_state=FLAGS.seed)
    enc_train_ds, dec_train_ds = split_enc_dec_ds(train_ds)
    enc_test_ds, dec_test_ds = split_enc_dec_ds(test_ds)
    logging.info('Number of train obs: {}'.format(len(enc_train_ds)))
    logging.info('Number of test obs: {}'.format(len(enc_test_ds)))

    logging.info('Training neural network model...')
    nmt_model = NMT(FLAGS, enc_vocab_size, dec_vocab_size)
    keras_trainer = Trainer(FLAGS, nmt_model)
    keras_trainer.train([enc_train_ds, dec_train_ds],
                        [enc_test_ds, dec_test_ds])
Example #6
0
def nmt():

    if request.method == 'POST':
        text = request.form['sent']

        EMBED_SIZE = 256

        HIDDEN_SIZE = 512

        DROPOUT_RATE = 0.2

        BATCH_SIZE = 256

        NUM_TRAIN_STEPS = 10

        VOCAB = Vocab.load('VOCAB_FILE')

        vocab_inp_size = len(VOCAB.src) + 1
        vocab_tar_size = len(VOCAB.tgt) + 1

        model = NMT(vocab_inp_size, vocab_tar_size, EMBED_SIZE, HIDDEN_SIZE,
                    BATCH_SIZE)
        sample_hidden = model.encoder.initialize_hidden_state()
        sample_output, sample_hidden = model.encoder(
            tf.random.uniform((BATCH_SIZE, 1)), sample_hidden)
        sample_decoder_output, _, _ = model.decoder(
            tf.random.uniform((BATCH_SIZE, 1)), sample_hidden, sample_output)
        model.load_weights('es_en')

        pred = decode_sentence(model, text, VOCAB)

        return render_template('home.html', result=pred)
Example #7
0
def dual(args):
    vocabs = {}
    opts = {}
    state_dicts = {}
    train_srcs = {}
    lms = {}

    # load model params & training data
    for i in range(2):
        model_id = (['A', 'B'])[i]
        print('loading pieces, part {:s}'.format(model_id))

        print('  load model{:s}     from [{:s}]'.format(model_id, args.nmt[i]), file=sys.stderr)
        params = torch.load(args.nmt[i], map_location=lambda storage, loc: storage)  # load model onto CPU
        vocabs[model_id] = params['vocab']
        opts[model_id] = params['args']
        state_dicts[model_id] = params['state_dict']

        print('  load train_src{:s} from [{:s}]'.format(model_id, args.src[i]), file=sys.stderr)
        train_srcs[model_id] = read_corpus(args.src[i], source='src')

        print('  load lm{:s}        from [{:s}]'.format(model_id, args.lm[i]), file=sys.stderr)
        lms[model_id] = LMProb(args.lm[i], args.dict[i])

    models = {}
    optimizers = {}

    for m in ['A', 'B']:
        # build model
        opts[m].cuda = args.cuda

        models[m] = NMT(opts[m], vocabs[m])
        models[m].load_state_dict(state_dicts[m])
        models[m].train()

        if args.cuda:
            models[m] = models[m].cuda()

        random.shuffle(train_srcs[m])

        # optimizer
        # optimizers[m] = torch.optim.Adam(models[m].parameters())
        optimizers[m] = torch.optim.SGD(models[m].parameters(), lr=1e-3, momentum=0.9)

    # loss function
    loss_nll = torch.nn.NLLLoss()
    loss_ce = torch.nn.CrossEntropyLoss()

    epoch = 0
    start = args.start_iter

    while True:
        epoch += 1
        print('\nstart of epoch {:d}'.format(epoch))

        data = {}
        data['A'] = iter(train_srcs['A'])
        data['B'] = iter(train_srcs['B'])

        start += (epoch - 1) * len(train_srcs['A']) + 1

        for t in range(start, start + len(train_srcs['A'])):
            show_log = False
            if t % args.log_every == 0:
                show_log = True

            if show_log:
                print('\nstep', t)

            for m in ['A', 'B']:
                lm_probs = []

                NLL_losses = []
                CE_losses = []

                modelA = models[m]
                modelB = models[change(m)]
                lmB = lms[change(m)]
                optimizerA = optimizers[m]
                optimizerB = optimizers[change(m)]
                vocabB = vocabs[change(m)]
                s = next(data[m])

                if show_log:
                    print('\n{:s} -> {:s}'.format(m, change(m)))
                    print('[s]', ' '.join(s))

                hyps = modelA.beam(s, beam_size=5)

                for ids, smid, dist in hyps:
                    if show_log:
                        print('[smid]', ' '.join(smid))

                    var_ids = Variable(torch.LongTensor(ids[1:]), requires_grad=False)
                    NLL_losses.append(loss_nll(dist, var_ids).cpu())

                    lm_probs.append(lmB.get_prob(smid))

                    src_sent_var = to_input_variable([smid], vocabB.src, cuda=args.cuda)
                    tgt_sent_var = to_input_variable([['<s>'] + s + ['</s>']], vocabB.tgt, cuda=args.cuda)
                    src_sent_len = [len(smid)]

                    score = modelB(src_sent_var, src_sent_len, tgt_sent_var[:-1]).squeeze(1)

                    CE_losses.append(loss_ce(score, tgt_sent_var[1:].view(-1)).cpu())

                # losses on target language
                fw_losses = torch.cat(NLL_losses)

                # losses on reconstruction
                bw_losses = torch.cat(CE_losses)

                # r1, language model reward
                r1s = Variable(torch.FloatTensor(lm_probs), requires_grad=False)
                r1s = (r1s - torch.mean(r1s)) / torch.std(r1s)

                # r2, communication reward
                r2s = Variable(bw_losses.data, requires_grad=False)
                r2s = (torch.mean(r2s) - r2s) / torch.std(r2s)

                # rk = alpha * r1 + (1 - alpha) * r2
                rks = r1s * args.alpha + r2s * (1 - args.alpha)

                # averaging loss over samples
                A_loss = torch.mean(fw_losses * rks)
                B_loss = torch.mean(bw_losses * (1 - args.alpha))

                if show_log:
                    for r1, r2, rk, fw_loss, bw_loss in zip(r1s.data.numpy(), r2s.data.numpy(), rks.data.numpy(), fw_losses.data.numpy(), bw_losses.data.numpy()):
                        print('r1={:7.4f}\t r2={:7.4f}\t rk={:7.4f}\t fw_loss={:7.4f}\t bw_loss={:7.4f}'.format(r1, r2, rk, fw_loss, bw_loss))
                    print('A loss = {:.7f} \t B loss = {:.7f}'.format(A_loss.data.numpy().item(), B_loss.data.numpy().item()))

                optimizerA.zero_grad()
                optimizerB.zero_grad()

                A_loss.backward()
                B_loss.backward()

                optimizerA.step()
                optimizerB.step()

            if t % args.save_n_iter == 0:
                print('\nsaving model')
                models['A'].save('{}.iter{}.bin'.format(args.model[0], t))
                models['B'].save('{}.iter{}.bin'.format(args.model[1], t))
Example #8
0
def dual(args):
    vocabs = {}
    opts = {}
    state_dicts = {}
    train_srcs = {}
    lms = {}

    # load model params & training data
    for i in range(2):
        model_id = (['A', 'B'])[i]
        print('loading pieces, part {:s}'.format(model_id))

        print('  load model{:s}     from [{:s}]'.format(model_id, args.nmt[i]),
              file=sys.stderr)
        params = torch.load(
            args.nmt[i],
            map_location=lambda storage, loc: storage)  # load model onto CPU
        vocabs[model_id] = params['vocab']
        opts[model_id] = params['args']
        state_dicts[model_id] = params['state_dict']

        print('  load train_src{:s} from [{:s}]'.format(model_id, args.src[i]),
              file=sys.stderr)
        train_srcs[model_id] = read_corpus(args.src[i], source='src')

        print('  load lm{:s}        from [{:s}]'.format(model_id, args.lm[i]),
              file=sys.stderr)
        lms[model_id] = LMProb(args.lm[i], args.dict[i])

    models = {}
    optimizers = {}

    for m in ['A', 'B']:
        # build model
        opts[m].cuda = args.cuda

        models[m] = NMT(opts[m], vocabs[m])
        models[m].load_state_dict(state_dicts[m])
        models[m].train()

        if args.cuda:
            models[m] = models[m].cuda()

        random.shuffle(train_srcs[m])

        # optimizer
        # optimizers[m] = torch.optim.Adam(models[m].parameters())
        optimizers[m] = torch.optim.SGD(models[m].parameters(),
                                        lr=1e-3,
                                        momentum=0.9)

    # loss function
    loss_nll = torch.nn.NLLLoss()
    loss_ce = torch.nn.CrossEntropyLoss()
    f_lossA = open(args.model[0] + ".losses", "w")
    f_lossB = open(args.model[1] + ".losses", "w")

    epoch = 0
    start = args.start_iter

    while True:
        epoch += 1
        print('\nstart of epoch {:d}'.format(epoch))

        data = {}
        data['A'] = iter(train_srcs['A'])
        data['B'] = iter(train_srcs['B'])

        start += (epoch - 1) * len(train_srcs['A']) + 1

        for t in range(start, start + len(train_srcs['A'])):
            show_log = False
            if t % args.log_every == 0:
                show_log = True

            if show_log:
                print('\nstep', t)

            for m in ['A', 'B']:
                lm_probsA = []
                lm_probsB = []

                NLL_lossesA = []
                NLL_lossesB = []

                modelA = models[m]
                modelB = models[change(m)]
                lmA = lms[m]
                lmB = lms[change(m)]
                optimizerA = optimizers[m]
                optimizerB = optimizers[change(m)]
                vocabA = vocabs[m]
                vocabB = vocabs[change(m)]
                s = next(data[m])

                if show_log:
                    print('\n{:s} -> {:s}'.format(m, change(m)))
                    print('[s]', ' '.join(s))

                hyps = modelA.beam(s, beam_size=5)

                src_sents_var = to_input_variable([s],
                                                  modelA.vocab.src,
                                                  cuda=args.cuda,
                                                  is_test=True)
                src_encoding, _ = modelA.encode(src_sents_var, [len(s)])
                src_encoding = src_encoding.squeeze(1)
                src_encoding = torch.mean(src_encoding, dim=0)

                tb_encodings = []

                for ids, smid, dist in hyps:
                    if show_log:
                        print('[smid]', ' '.join(smid))

                    var_ids = torch.LongTensor(ids[1:]).detach()
                    NLL_lossesB.append(
                        loss_nll(dist, var_ids).unsqueeze(0).cpu())
                    lm_probsB.append(lmB.get_prob(smid))

                    idback, sback, distback = modelB.beam(smid, beam_size=1)[0]
                    var_idback = torch.LongTensor(idback[1:]).detach()
                    NLL_lossesA.append(
                        loss_nll(distback, var_idback).unsqueeze(0).cpu())
                    lm_probsA.append(lmA.get_prob(sback))

                    tb_sents_var = to_input_variable([sback],
                                                     modelA.vocab.src,
                                                     cuda=args.cuda,
                                                     is_test=True)
                    tb_encoding, _ = modelA.encode(tb_sents_var, [len(sback)])
                    tb_encoding = tb_encoding.squeeze(1)
                    tb_encoding = torch.mean(tb_encoding, dim=0, keepdim=True)
                    tb_encodings.append(tb_encoding)

                # losses on target language
                fw_losses = torch.cat(NLL_lossesB)

                # losses on reconstruction
                bw_losses = torch.cat(NLL_lossesA)

                # r1, language model reward
                r1s = torch.FloatTensor(lm_probsB).detach()
                r1s = (r1s - torch.mean(r1s)) / torch.std(r1s)

                # r2, communication reward
                r2s = torch.FloatTensor(lm_probsA).detach()
                r2s = (r2s - torch.mean(r2s)) / torch.std(r2s)

                tb_encodings = torch.cat(tb_encodings).detach()
                cossim = torch.matmul(tb_encodings, src_encoding)
                cossim = 1 - torch.nn.Sigmoid()(torch.mean(cossim)).item()

                # rab = alpha * cossim + (1 - alpha) * r1
                # rba = beta  * cossim + (1 - beta ) * r2
                rkab = cossim * args.alpha + r1s * (1 - args.alpha)
                rkba = cossim * args.beta + r2s * (1 - args.beta)

                # averaging loss over samples
                A_loss = torch.mean(fw_losses * rkab)
                B_loss = torch.mean(bw_losses * rkba)

                if show_log:
                    for r1, r2, rab, rba, fw_loss, bw_loss in zip(
                            r1s.data.numpy(), r2s.data.numpy(),
                            rkab.data.numpy(), rkba.data.numpy(),
                            fw_losses.data.numpy(), bw_losses.data.numpy()):
                        print(
                            'r1={:7.4f}\t r2={:7.4f}\t rab={:7.4f}\t rba={:7.4f}\t fw_loss={:7.4f}\t bw_loss={:7.4f}'
                            .format(r1, r2, rab, rba, fw_loss, bw_loss))
                    print('A loss = {:.7f} \t B loss = {:.7f}'.format(
                        A_loss.data.numpy().item(),
                        B_loss.data.numpy().item()))
                    f_lossA.write(
                        str(t) +
                        ' {:.7f}\n'.format(A_loss.data.numpy().item()))
                    f_lossB.write(
                        str(t) +
                        ' {:.7f}\n'.format(B_loss.data.numpy().item()))

                optimizerA.zero_grad()
                optimizerB.zero_grad()

                A_loss.backward()
                B_loss.backward()

                optimizerA.step()
                optimizerB.step()

            if t % args.save_n_iter == 0:
                print('\nsaving model')
                models['A'].save('{}.iter{}.bin'.format(args.model[0], t))
                models['B'].save('{}.iter{}.bin'.format(args.model[1], t))
    f_lossA.close()
    f_lossB.close()
Example #9
0
def dual(args):
    vocabs = {}
    opts = {}
    state_dicts = {}
    train_srcs = {}
    train_tgt = {}
    lm_scores = {}
    dev_data = {}

    # load model params & training data
    for i in range(2):
        model_id = (['A', 'B'])[i]
        print('loading pieces, part {:s}'.format(model_id))
        print('  load model{:s}     from [{:s}]'.format(model_id, args.nmt[i]))
        params = torch.load(
            args.nmt[i],
            map_location=lambda storage, loc: storage)  # load model onto CPU
        vocabs[model_id] = params['vocab']
        print('==' * 10)
        print(vocabs[model_id])
        opts[model_id] = params['args']
        state_dicts[model_id] = params['state_dict']
        print('done')

    for i in range(2):
        model_id = (['A', 'B'])[i]
        print('  load train_src{:s} from [{:s}]'.format(model_id, args.src[i]))
        train_srcs[model_id], lm_scores[model_id] = read_corpus_for_dsl(
            args.src[i], source='src')
        train_tgt[model_id], _ = read_corpus_for_dsl(args.src[(i + 1) % 2],
                                                     source='tgt')

    dev_data_src1 = read_corpus(args.val[0], source='src')
    dev_data_tgt1 = read_corpus(args.val[1], source='tgt')
    dev_data['A'] = list(zip(dev_data_src1, dev_data_tgt1))
    dev_data_src2 = read_corpus(args.val[1], source='src')
    dev_data_tgt2 = read_corpus(args.val[0], source='tgt')
    dev_data['B'] = list(zip(dev_data_src2, dev_data_tgt2))

    models = {}
    optimizers = {}
    nll_loss = {}
    cross_entropy_loss = {}

    for m in ['A', 'B']:
        # build model
        opts[m].cuda = args.cuda

        models[m] = NMT(opts[m], vocabs[m])
        models[m].load_state_dict(state_dicts[m])
        models[m].train()

        if args.cuda:
            if m == 'A':
                models[m] = models[m].cuda()
            else:
                models[m] = models[m].cuda()

        optimizers[m] = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                               models[m].parameters()),
                                        lr=args.lr)
    for m in ['A', 'B']:
        vocab_mask = torch.ones(len(vocabs[m].tgt))
        vocab_mask[vocabs[m].tgt['<pad>']] = 0
        nll_loss[m] = torch.nn.NLLLoss(weight=vocab_mask, size_average=False)
        cross_entropy_loss[m] = torch.nn.CrossEntropyLoss(weight=vocab_mask,
                                                          reduce=False,
                                                          size_average=False)
        models[m].eval()
        if args.cuda:
            nll_loss[m] = nll_loss[m].cuda()
            cross_entropy_loss[m] = cross_entropy_loss[m].cuda()
    epoch = 0

    train_data = list(
        zip(train_srcs['A'], train_tgt['A'], lm_scores['A'], lm_scores['B']))
    cum_lossA = cum_lossB = 0
    att_loss = 0
    ce_lossA_log = 0
    ce_lossB_log = 0
    t = 0
    hist_valid_scores = {}
    hist_valid_scores['A'] = []
    hist_valid_scores['B'] = []

    patience = {}
    patience['A'] = patience['B'] = 0
    decay = {}
    decay['A'] = 0
    decay['B'] = 0
    while True:
        epoch += 1
        print('\nstart of epoch {:d}'.format(epoch))

        data = {}
        data['A'] = data_iter_for_dual(train_data,
                                       batch_size=args.batch_size,
                                       shuffle=False)

        for batchA in data['A']:
            src_sentsA, tgt_sentsA, src_scoresA, src_scoresB = batchA[
                0], batchA[1], batchA[2], batchA[3]
            tgt_sents_forA = [['<s>'] + sent + ['</s>'] for sent in tgt_sentsA]

            src_sents_varA, masksA = to_input_variable(src_sentsA,
                                                       vocabs['A'].src,
                                                       cuda=args.cuda)
            tgt_sents_varA, _ = to_input_variable(tgt_sents_forA,
                                                  vocabs['A'].tgt,
                                                  cuda=args.cuda)
            src_scores_varA = Variable(torch.FloatTensor(src_scoresA),
                                       requires_grad=False)

            src_sents_len_A = [len(s) for s in src_sentsA]
            # print(src_sents_varA, src_sents_len_A, tgt_sents_varA[:-1], masksA)
            scoresA, feature_A, att_sim_A = models['A'](src_sents_varA,
                                                        src_sents_len_A,
                                                        tgt_sents_varA[:-1],
                                                        masksA)

            ce_lossA = cross_entropy_loss['A'](scoresA.view(
                -1, scoresA.size(2)), tgt_sents_varA[1:].view(-1)).cpu()

            batch_data = (src_sentsA, tgt_sentsA, src_scoresA, src_scoresB)
            src_sentsA, tgt_sentsA, src_scoresA, src_scoresB = get_new_batch(
                batch_data)
            tgt_sents_forB = [['<s>'] + sent + ['</s>'] for sent in src_sentsA]

            src_sents_varB, masksB = to_input_variable(tgt_sentsA,
                                                       vocabs['B'].src,
                                                       cuda=args.cuda)
            tgt_sents_varB, _ = to_input_variable(tgt_sents_forB,
                                                  vocabs['B'].tgt,
                                                  cuda=args.cuda)
            src_scores_varB = Variable(torch.FloatTensor(src_scoresB),
                                       requires_grad=False)

            src_sents_len = [len(s) for s in tgt_sentsA]
            scoresB, feature_B, att_sim_B = models['B'](src_sents_varB,
                                                        src_sents_len,
                                                        tgt_sents_varB[:-1],
                                                        masksB)

            ce_lossB = cross_entropy_loss['B'](scoresB.view(
                -1, scoresB.size(2)), tgt_sents_varB[1:].view(-1)).cpu()

            optimizerA = optimizers['A']
            optimizerB = optimizers['B']

            optimizerA.zero_grad()
            optimizerB.zero_grad()
            # print (ce_lossA.size(), src_scores_varA.size(), tgt_sents_varA[1:].size(0))
            ce_lossA = ce_lossA.view(tgt_sents_varA[1:].size(0),
                                     tgt_sents_varA[1:].size(1)).mean(0)
            ce_lossB = ce_lossB.view(tgt_sents_varB[1:].size(0),
                                     tgt_sents_varB[1:].size(1)).mean(0)

            att_sim_A = torch.cat(att_sim_A, 1)

            masksA = masksA.transpose(1, 0).unsqueeze(1)
            masksA = masksA.expand(masksA.size(0), att_sim_A.size(1),
                                   masksA.size(2))
            assert att_sim_A.size() == masksA.size(), '{} {}'.format(
                att_sim_A.size(), masksA.size())
            att_sim_B = torch.cat(att_sim_B, 1)
            masksB = masksB.transpose(1, 0).unsqueeze(1)
            masksB = masksB.expand(masksB.size(0), att_sim_B.size(1),
                                   masksB.size(2))
            assert att_sim_B.size() == masksB.size(), '{} {}'.format(
                att_sim_B.size(), masksB.size())
            att_sim_B = att_sim_B.transpose(2, 1)
            loss_att_A = loss_att(att_sim_A, att_sim_B, masksB.transpose(1, 0),
                                  src_sents_len)
            loss_att_B = loss_att(att_sim_A.transpose(2, 1),
                                  att_sim_B.transpose(2, 1), masksB,
                                  src_sents_len_A)

            dual_loss = (src_scores_varA - ce_lossA - src_scores_varB +
                         ce_lossB)**2
            att_loss_ = (loss_att_A + loss_att_B)

            lossA = ce_lossA + args.beta1 * dual_loss + args.beta3 * att_loss_
            lossB = ce_lossB + args.beta2 * dual_loss + args.beta4 * att_loss_

            lossA = torch.mean(lossA)
            lossB = torch.mean(lossB)

            cum_lossA += lossA.data[0]
            cum_lossB += lossB.data[0]

            ce_lossA_log += torch.mean(ce_lossA).data[0]
            ce_lossB_log += torch.mean(ce_lossB).data[0]
            att_loss += (torch.mean(loss_att_A).data[0] +
                         torch.mean(loss_att_B).data[0])

            optimizerA.zero_grad()
            lossA.backward(retain_graph=True)
            grad_normA = torch.nn.utils.clip_grad_norm(
                models['A'].parameters(), args.clip_grad)
            optimizerA.step()
            optimizerB.zero_grad()
            lossB.backward()
            grad_normB = torch.nn.utils.clip_grad_norm(
                models['B'].parameters(), args.clip_grad)
            optimizerB.step()
            if t % args.log_n_iter == 0 and t != 0:
                print(
                    'epoch %d, avg. loss A %.3f, avg. word loss A %.3f, avg, loss B %.3f, avg. word loss B %.3f, avg att loss %.3f'
                    % (epoch, cum_lossA / args.log_n_iter, ce_lossA_log /
                       args.log_n_iter, cum_lossB / args.log_n_iter,
                       ce_lossB_log / args.log_n_iter,
                       att_loss / args.log_n_iter))
                cum_lossA = 0
                cum_lossB = 0
                att_loss = 0
                ce_lossA_log = 0
                ce_lossB_log = 0
            if t % args.val_n_iter == 0 and t != 0:
                print('Validation begins ...')
                for i, model_id in enumerate(['A', 'B']):
                    models[model_id].eval()

                    tmp_dev_data = dev_data[model_id]
                    dev_hyps = decode(models[model_id], tmp_dev_data)
                    dev_hyps = [hyps[0] for hyps in dev_hyps]
                    valid_metric = get_bleu([tgt for src, tgt in tmp_dev_data],
                                            dev_hyps, 'test')
                    models[model_id].train()
                    hist_scores = hist_valid_scores[model_id]
                    print('Model_id {} Sentence bleu : {}'.format(
                        model_id, valid_metric))

                    is_better = len(
                        hist_scores) == 0 or valid_metric > max(hist_scores)
                    hist_scores.append(valid_metric)

                    if not is_better:
                        patience[model_id] += 1
                        print('hit patience %d' % patience[model_id])
                        if patience[model_id] > 0:
                            if abs(optimizers[model_id].param_groups[0]
                                   ['lr']) < 1e-8:
                                exit(0)
                            if decay[model_id] < 1:
                                lr = optimizers[model_id].param_groups[0][
                                    'lr'] * 0.5
                                print('Decay learning rate to %f' % lr)
                                optimizers[model_id].param_groups[0]['lr'] = lr
                                patience[model_id] = 0
                                decay[model_id] += 1
                            else:
                                for param in models[model_id].parameters():
                                    if param.size()[0] == 50000 or param.size(
                                    )[0] == 27202:
                                        param.requires_grad = False

                                lr = optimizers[model_id].param_groups[0][
                                    'lr'] * 0.95
                                print('Decay learning rate to %f' % lr)
                                optimizers[model_id].param_groups[0]['lr'] = lr
                                decay[model_id] += 1

                    else:
                        patience[model_id] = 0
                        if model_id == 'A':
                            np.save('{}.iter{}'.format(args.model[i], t),
                                    att_sim_A[0].cpu().data.numpy())
                        if model_id == 'B':
                            np.save('{}.iter{}'.format(args.model[i], t),
                                    att_sim_B[0].cpu().data.numpy())
                        models[model_id].save('{}.iter{}.bin'.format(
                            args.model[i], t))

            t += 1
Example #10
0
def init():
    global model
    model_dir = Model.get_model_path('arxiv-nmt-pipeline')
    model_path = os.path.join(model_dir, 'model.bin')
    model = NMT.load(model_path)
    model.eval()
        dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)

        print("deleting files not required...")
        del src_sents, tgt_sents, src_pad, tgt_pad

        print("beginning training...")

        train(dataset, EMBED_SIZE, HIDDEN_SIZE, DROPOUT_RATE, BATCH_SIZE,
              NUM_TRAIN_STEPS, BUFFER_SIZE, steps_per_epoch, vocab_inp_size,
              vocab_tar_size, VOCAB)
        print("training complete!")

    if args['decode']:

        print('reading vocabulary file: %s' % args['--vocab'])
        VOCAB = Vocab.load(args['--vocab'])

        vocab_inp_size = len(VOCAB.src) + 1
        vocab_tar_size = len(VOCAB.tgt) + 1

        print('restoring pre-trained model')
        model = NMT(vocab_inp_size, vocab_tar_size, EMBED_SIZE, HIDDEN_SIZE,
                    BATCH_SIZE)
        sample_hidden = model.encoder.initialize_hidden_state()
        sample_output, sample_hidden = model.encoder(
            tf.random.uniform((BATCH_SIZE, 1)), sample_hidden)
        sample_decoder_output, _, _ = model.decoder(
            tf.random.uniform((BATCH_SIZE, 1)), sample_hidden, sample_output)
        model.load_weights(args['--model-file'])
        print("beginning decoding...")
        decode(model, args['--sent-file'], VOCAB, 0)
Example #12
0
def train_mle(args: Dict):
    train_data_src = read_corpus(args['--train-src'], source='src')
    train_data_tgt = read_corpus(args['--train-tgt'], source='tgt')

    dev_data_src = read_corpus(args['--dev-src'], source='src')
    dev_data_tgt = read_corpus(args['--dev-tgt'], source='tgt')

    train_data = list(zip(train_data_src, train_data_tgt))
    dev_data = list(zip(dev_data_src, dev_data_tgt))

    train_batch_size = int(args['--batch-size'])
    clip_grad = float(args['--clip-grad'])
    valid_niter = int(args['--valid-niter'])
    log_every = int(args['--log-every'])
    notify_slack_every = int(args['--notify-slack-every'])
    model_save_path = args['--save-to']

    vocab = Vocab.load(args['--vocab'])

    model = NMT(embed_size=int(args['--embed-size']),
                hidden_size=int(args['--hidden-size']),
                dropout_rate=float(args['--dropout']),
                input_feed=args['--input-feed'],
                label_smoothing=float(args['--label-smoothing']),
                vocab=vocab)
    model.train()

    uniform_init = float(args['--uniform-init'])
    if np.abs(uniform_init) > 0.:
        print('uniformly initialize parameters [-%f, +%f]' %
              (uniform_init, uniform_init),
              file=sys.stderr)
        for p in model.parameters():
            p.data.uniform_(-uniform_init, uniform_init)

    vocab_mask = torch.ones(len(vocab.tgt))
    vocab_mask[vocab.tgt['<pad>']] = 0

    device = torch.device("cuda:0" if args['--cuda'] else "cpu")
    print('use device: %s' % device, file=sys.stderr)

    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=float(args['--lr']))

    num_trial = 0
    train_iter = patience = cum_loss = report_loss = cum_tgt_words = report_tgt_words = 0
    cum_examples = report_examples = epoch = valid_num = 0
    hist_valid_scores = []
    train_time = begin_time = time.time()

    log_data = {'args': args}  # log用, あとで学習の収束とか見るよう

    _info = f"""
        begin Maximum Likelihood training
        ・学習:{len(train_data)}ペア
        ・テスト:{len(dev_data)}ペア, {valid_niter}iter毎
        ・バッチサイズ:{train_batch_size}
        ・1epoch = {len(train_data)}ペア = {int(len(train_data)/train_batch_size)}iter
        ・max epoch:{args['--max-epoch']}
    """
    print(_info)
    print(_info, file=sys.stderr)

    _notify_slack_if_need(f"""
    {_info}
    {args}
    """, args)

    while True:
        epoch += 1

        for src_sents, tgt_sents in batch_iter(train_data,
                                               batch_size=train_batch_size,
                                               shuffle=True):
            train_iter += 1

            optimizer.zero_grad()

            batch_size = len(src_sents)

            # (batch_size)
            example_losses = -model(src_sents, tgt_sents)
            batch_loss = example_losses.sum()
            loss = batch_loss / batch_size

            loss.backward()

            # clip gradient
            grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(),
                                                      clip_grad)

            optimizer.step()

            batch_losses_val = batch_loss.item()
            report_loss += batch_losses_val
            cum_loss += batch_losses_val

            tgt_words_num_to_predict = sum(
                len(s[1:]) for s in tgt_sents)  # omitting leading `<s>`
            report_tgt_words += tgt_words_num_to_predict
            cum_tgt_words += tgt_words_num_to_predict
            report_examples += batch_size
            cum_examples += batch_size

            if train_iter % log_every == 0 or train_iter % notify_slack_every == 0:
                _report = 'epoch %d, iter %d, avg. loss %.2f, avg. ppl %.2f ' \
                          'cum. examples %d, speed %.2f words/sec, time elapsed %.2f sec' % (epoch, train_iter,
                                                                                             report_loss / report_examples,
                                                                                             math.exp(
                                                                                                 report_loss / report_tgt_words),
                                                                                             cum_examples,
                                                                                             report_tgt_words / (
                                                                                                     time.time() - train_time),
                                                                                             time.time() - begin_time)
                print(_report, file=sys.stderr)

                _list_dict_update(
                    log_data, {
                        'epoch': epoch,
                        'train_iter': train_iter,
                        'loss': report_loss / report_examples,
                        'ppl': math.exp(report_loss / report_tgt_words),
                        'examples': cum_examples,
                        'speed': report_tgt_words / (time.time() - train_time),
                        'elapsed': time.time() - begin_time
                    }, 'train')

                train_time = time.time()
                report_loss = report_tgt_words = report_examples = 0.

                if train_iter % notify_slack_every == 0:
                    _notify_slack_if_need(_report, args)

            # perform validation
            if train_iter % valid_niter == 0:
                print(
                    'epoch %d, iter %d, cum. loss %.2f, cum. ppl %.2f cum. examples %d'
                    % (epoch, train_iter, cum_loss / cum_examples,
                       np.exp(cum_loss / cum_tgt_words), cum_examples),
                    file=sys.stderr)

                cum_loss = cum_examples = cum_tgt_words = 0.
                valid_num += 1

                print('begin validation ...', file=sys.stderr)

                # compute dev. ppl and bleu
                dev_ppl, dev_loss = evaluate_ppl(
                    model, dev_data,
                    batch_size=128)  # dev batch size can be a bit larger
                valid_metric, eval_info = evaluate_valid_metric(
                    model, dev_data, dev_ppl, args)

                _report = 'validation: iter %d, dev. ppl %f, dev. %s %f , time elapsed %.2f sec' % (
                    train_iter, dev_ppl, args['--valid-metric'], valid_metric,
                    eval_info['elapsed'])
                print(_report, file=sys.stderr)
                _notify_slack_if_need(_report, args)

                if 'dev_data' in log_data:
                    log_data['dev_data'] = dev_data[:int(
                        args['--dev-decode-limit'])]

                _list_dict_update(log_data, {
                    'epoch': epoch,
                    'train_iter': train_iter,
                    'loss': dev_loss,
                    'ppl': dev_ppl,
                    args['--valid-metric']: valid_metric,
                    **eval_info,
                },
                                  'validation',
                                  is_save=True)

                is_better = len(hist_valid_scores
                                ) == 0 or valid_metric > max(hist_valid_scores)
                hist_valid_scores.append(valid_metric)

                if is_better:
                    patience = 0
                    print('save currently the best model to [%s]' %
                          model_save_path,
                          file=sys.stderr)
                    model.save(model_save_path)

                    # also save the optimizers' state
                    torch.save(optimizer.state_dict(),
                               model_save_path + '.optim')
                elif patience < int(args['--patience']):
                    patience += 1
                    print('hit patience %d' % patience, file=sys.stderr)

                    if patience == int(args['--patience']):
                        num_trial += 1
                        print('hit #%d trial' % num_trial, file=sys.stderr)
                        if num_trial == int(args['--max-num-trial']):
                            _report = 'early stop!'
                            _notify_slack_if_need(_report, args)
                            print(_report, file=sys.stderr)
                            exit(0)

                        # decay lr, and restore from previously best checkpoint
                        lr = optimizer.param_groups[0]['lr'] * float(
                            args['--lr-decay'])
                        print(
                            'load previously best model and decay learning rate to %f'
                            % lr,
                            file=sys.stderr)

                        # load model
                        params = torch.load(
                            model_save_path,
                            map_location=lambda storage, loc: storage)
                        model.load_state_dict(params['state_dict'])
                        model = model.to(device)

                        print('restore parameters of the optimizers',
                              file=sys.stderr)
                        optimizer.load_state_dict(
                            torch.load(model_save_path + '.optim'))

                        # set new lr
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr

                        # reset patience
                        patience = 0

                if epoch == int(args['--max-epoch']):
                    _report = 'reached maximum number of epochs!'
                    _notify_slack_if_need(_report, args)
                    print(_report, file=sys.stderr)
                    exit(0)
Example #13
0
def train_raml(args: Dict):
    train_data_src = read_corpus(args['--train-src'], source='src')
    train_data_tgt = read_corpus(args['--train-tgt'], source='tgt')

    dev_data_src = read_corpus(args['--dev-src'], source='src')
    dev_data_tgt = read_corpus(args['--dev-tgt'], source='tgt')

    train_data = list(zip(train_data_src, train_data_tgt))
    dev_data = list(zip(dev_data_src, dev_data_tgt))

    train_batch_size = int(args['--batch-size'])
    clip_grad = float(args['--clip-grad'])
    valid_niter = int(args['--valid-niter'])
    log_every = int(args['--log-every'])
    notify_slack_every = int(args['--notify-slack-every'])
    model_save_path = args['--save-to']

    vocab = Vocab.load(args['--vocab'])

    model = NMT(embed_size=int(args['--embed-size']),
                hidden_size=int(args['--hidden-size']),
                dropout_rate=float(args['--dropout']),
                input_feed=args['--input-feed'],
                label_smoothing=float(args['--label-smoothing']),
                vocab=vocab)
    model.train()

    # NOTE: RAML
    tau = float(args['--raml-temp'])
    raml_sample_mode = args['--raml-sample-mode']
    raml_sample_size = int(args['--raml-sample-size'])

    uniform_init = float(args['--uniform-init'])
    if np.abs(uniform_init) > 0.:
        print('uniformly initialize parameters [-%f, +%f]' %
              (uniform_init, uniform_init),
              file=sys.stderr)
        for p in model.parameters():
            p.data.uniform_(-uniform_init, uniform_init)

    vocab_mask = torch.ones(len(vocab.tgt))
    vocab_mask[vocab.tgt['<pad>']] = 0

    device = torch.device("cuda:0" if args['--cuda'] else "cpu")
    print('use device: %s' % device, file=sys.stderr)

    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=float(args['--lr']))

    num_trial = 0
    train_iter = patience = cum_loss = report_loss = cum_tgt_words = report_tgt_words = 0
    cum_examples = report_examples = epoch = valid_num = 0
    hist_valid_scores = []
    train_time = begin_time = time.time()

    # NOTE: RAML
    report_weighted_loss = cum_weighted_loss = 0

    # NOTE: RAML サンプリングの読み込み or 生成
    if raml_sample_mode == 'pre_sample':
        # dict of (src, [tgt: (sent, prob)])
        print('read in raml training data...', file=sys.stderr, end='')
        begin_time = time.time()
        raml_samples = read_raml_train_data(args['--raml-sample-file'],
                                            temp=tau)
        print('done[%d s].' % (time.time() - begin_time))
    else:
        raise Exception(f'sampling:{raml_sample_mode} は、まだ未実装です')

    log_data = {'args': args}  # log用, あとで学習の収束とか見るよう

    _info = f"""
        begin RAML training
        ・学習:{len(train_data)}ペア
        ・テスト:{len(dev_data)}ペア, {valid_niter}iter毎
        ・バッチサイズ:{train_batch_size}
        ・1epoch = {len(train_data)}ペア = {int(len(train_data)/train_batch_size)}iter
        ・max epoch:{args['--max-epoch']}
    """
    print(_info)
    print(_info, file=sys.stderr)

    _notify_slack_if_need(f"""
    {_info}
    {args}
    """, args)

    while True:
        epoch += 1

        for src_sents, tgt_sents in batch_iter(train_data,
                                               batch_size=train_batch_size,
                                               shuffle=True):
            train_iter += 1

            # NOTE: RAML
            # src_sents 内 sent に紐づくサンプリングを取得 → 学習データとする
            raml_src_sents = []
            raml_tgt_sents = []
            raml_tgt_weights = []
            if raml_sample_mode == 'pre_sample':
                for src_sent in src_sents:
                    sent = ' '.join(src_sent)
                    tgt_samples_all = raml_samples[sent]
                    # random choice from candidate samples
                    if raml_sample_size >= len(tgt_samples_all):
                        tgt_samples = tgt_samples_all
                    else:
                        tgt_samples_id = np.random.choice(
                            range(1, len(tgt_samples_all)),
                            size=raml_sample_size - 1,
                            replace=False)
                        # [ground truth y*] + samples
                        tgt_samples = [tgt_samples_all[0]] + [
                            tgt_samples_all[i] for i in tgt_samples_id
                        ]

                    raml_src_sents.extend([src_sent] * len(tgt_samples))
                    raml_tgt_sents.extend([['<s>'] + sent.split(' ') +
                                           ['</s>']
                                           for sent, weight in tgt_samples])
                    raml_tgt_weights.extend(
                        [weight for sent, weight in tgt_samples])
            else:
                raise Exception(f'sampling:{raml_sample_mode} は、まだ未実装です')

            optimizer.zero_grad()

            # NOTE: RAML
            weights_var = torch.tensor(raml_tgt_weights,
                                       dtype=torch.float,
                                       device=device)
            batch_size = len(raml_src_sents)

            # (batch_size)
            unweighted_loss = -model(raml_src_sents, raml_tgt_sents)
            batch_loss = weighted_loss = (unweighted_loss * weights_var).sum()
            loss = batch_loss / batch_size

            loss.backward()

            # clip gradient
            grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(),
                                                      clip_grad)

            optimizer.step()

            # NOTE: RAML
            weighted_loss_val = weighted_loss.item()
            batch_losses_val = unweighted_loss.sum().item()

            report_weighted_loss += weighted_loss_val
            cum_weighted_loss += weighted_loss_val
            report_loss += batch_losses_val
            cum_loss += batch_losses_val

            tgt_words_num_to_predict = sum(
                len(s[1:]) for s in tgt_sents)  # omitting leading `<s>`
            report_tgt_words += tgt_words_num_to_predict
            cum_tgt_words += tgt_words_num_to_predict
            report_examples += batch_size
            cum_examples += batch_size

            if train_iter % log_every == 0 or train_iter % notify_slack_every == 0:
                _report = 'epoch %d, iter %d, avg. loss %.2f, avg. ppl %.2f ' \
                          'cum. examples %d, speed %.2f words/sec, time elapsed %.2f sec' % (epoch, train_iter,
                                                                                             report_weighted_loss / report_examples,
                                                                                             math.exp(
                                                                                                 report_loss / report_tgt_words),
                                                                                             cum_examples,
                                                                                             report_tgt_words / (
                                                                                                     time.time() - train_time),
                                                                                             time.time() - begin_time)
                print(_report, file=sys.stderr)

                _list_dict_update(
                    log_data, {
                        'epoch': epoch,
                        'train_iter': train_iter,
                        'loss': report_loss / report_examples,
                        'ppl': math.exp(report_loss / report_tgt_words),
                        'examples': cum_examples,
                        'speed': report_tgt_words / (time.time() - train_time),
                        'elapsed': time.time() - begin_time
                    }, 'train')

                train_time = time.time()
                report_loss = report_weighted_loss = report_tgt_words = report_examples = 0.

                if train_iter % notify_slack_every == 0:
                    _notify_slack_if_need(_report, args)

            # perform validation
            if train_iter % valid_niter == 0:
                print(
                    'epoch %d, iter %d, cum. loss %.2f, cum. ppl %.2f cum. examples %d'
                    % (epoch, train_iter, cum_weighted_loss / cum_examples,
                       np.exp(cum_loss / cum_tgt_words), cum_examples),
                    file=sys.stderr)

                cum_loss = cum_weighted_loss = cum_examples = cum_tgt_words = 0.
                valid_num += 1

                print('begin validation ...', file=sys.stderr)

                # compute dev. ppl and bleu
                dev_ppl, dev_loss = evaluate_ppl(
                    model, dev_data,
                    batch_size=128)  # dev batch size can be a bit larger
                valid_metric, eval_info = evaluate_valid_metric(
                    model, dev_data, dev_ppl, args)

                _report = 'validation: iter %d, dev. ppl %f, dev. %s %f , time elapsed %.2f sec' % (
                    train_iter, dev_ppl, args['--valid-metric'], valid_metric,
                    eval_info['elapsed'])
                print(_report, file=sys.stderr)
                _notify_slack_if_need(_report, args)

                if 'dev_data' in log_data:
                    log_data['dev_data'] = dev_data[:int(
                        args['--dev-decode-limit'])]

                _list_dict_update(log_data, {
                    'epoch': epoch,
                    'train_iter': train_iter,
                    'loss': dev_loss,
                    'ppl': dev_ppl,
                    args['--valid-metric']: valid_metric,
                    **eval_info,
                },
                                  'validation',
                                  is_save=True)

                is_better = len(hist_valid_scores
                                ) == 0 or valid_metric > max(hist_valid_scores)
                hist_valid_scores.append(valid_metric)

                if is_better:
                    patience = 0
                    print('save currently the best model to [%s]' %
                          model_save_path,
                          file=sys.stderr)
                    model.save(model_save_path)

                    # also save the optimizers' state
                    torch.save(optimizer.state_dict(),
                               model_save_path + '.optim')
                elif patience < int(args['--patience']):
                    patience += 1
                    print('hit patience %d' % patience, file=sys.stderr)

                    if patience == int(args['--patience']):
                        num_trial += 1
                        print('hit #%d trial' % num_trial, file=sys.stderr)
                        if num_trial == int(args['--max-num-trial']):
                            _report = 'early stop!'
                            _notify_slack_if_need(_report, args)
                            print(_report, file=sys.stderr)
                            exit(0)

                        # decay lr, and restore from previously best checkpoint
                        lr = optimizer.param_groups[0]['lr'] * float(
                            args['--lr-decay'])
                        print(
                            'load previously best model and decay learning rate to %f'
                            % lr,
                            file=sys.stderr)

                        # load model
                        params = torch.load(
                            model_save_path,
                            map_location=lambda storage, loc: storage)
                        model.load_state_dict(params['state_dict'])
                        model = model.to(device)

                        print('restore parameters of the optimizers',
                              file=sys.stderr)
                        optimizer.load_state_dict(
                            torch.load(model_save_path + '.optim'))

                        # set new lr
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr

                        # reset patience
                        patience = 0

                if epoch == int(args['--max-epoch']):
                    _report = 'reached maximum number of epochs!'
                    _notify_slack_if_need(_report, args)
                    print(_report, file=sys.stderr)
                    exit(0)