def main():
    dataset = Dataset(transform=transform, n_datas=10000,
                      seed=None)  #生成10000个数据,确保字符都出现
    model = Transformer(n_head=2)
    try:
        trained_epoch = sl.find_last_checkpoint('./checkpoint')
        print('load model %d' % (trained_epoch))
    except Exception as e:
        print('no trained model found, {}'.format(e))
        return
    model = sl.load_model('./checkpoint', -1, model)
    model.eval()

    x, y, extra = dataset.__getitem__(0)  #值使用y的第0个特征向量,即<pad>对应的onehot向量
    # print(x.shape, y.shape)
    # pred = model(torch.from_numpy(x).unsqueeze(0), torch.from_numpy(y).unsqueeze(0)).squeeze()
    pred = translate(model, x,
                     y[0])  #日期格式转换时,对于输入序列,我们全部知道;但是对于输出序列,只有开头的<pad>的已知的
    # print(pred.shape)
    pred = np.argmax(pred.detach().numpy(), axis=1)[1:]
    # print(extra['machine_readable'])
    pred = [dataset.inv_machine_vocab[p] for p in pred]
    pred_str = ''.join(pred)
    human_readable = extra['human_readable']
    machine_readable = extra['machine_readable']
    print('[%s] --> [%s], answer: [%s]' %
          (human_readable, pred_str, list(machine_readable)))

    dec_scores = model.decoder.scores_for_paint
    # print(dec_scores.shape)
    paint_score(dec_scores[0], human_readable, pred)  #[0]是去batch中的第0个
Beispiel #2
0
def generate(
    x: str,
    beam_width: int,
    device: torch.device,
    max_seq_len: int,
    model: Transformer,
    tokenizer: Tokenizer
) -> str:
    model.eval()
    seq = torch.LongTensor([tokenizer.bos_id]).to(device)
    x = torch.LongTensor([tokenizer.encode(x, max_len=-1)]).to(device)

    accum_prob = torch.zeros(beam_width).to(device)

    for _ in range(max_seq_len):
        pred_y = model.predict(x, seq)

        top_k_in_all_beams = []
        for out_beams in range(seq.size(0)):
            top_k_prob_in_beam, top_k_index_in_beam = \
                pred_y[out_beams, -1].topk(
                    k=beam_width,
                    dim=-1
                )
            for in_beam in range(beam_width):

                prob = accum_prob[out_beams] -\
                    top_k_prob_in_beam[in_beam].log()
                prob = prob.unsqueeze(0)

                temp_seq = torch.cat([
                    seq[out_beams],
                    top_k_index_in_beam[in_beam].unsqueeze(0)
                ], dim=-1).unsqueeze(0)

                top_k_in_all_beams.append({
                    'prob': prob,
                    'seq': temp_seq
                })

        _, top_k_index_in_all_beams = torch.cat([
            beam['prob'] for beam in top_k_in_all_beams
        ]).topk(k=beam_width, dim=0)

        seq = torch.cat([
            top_k_in_all_beams[index]['seq']
            for index in top_k_index_in_all_beams
        ], dim=0)

        accum_prob = torch.cat([
            top_k_in_all_beams[index]['prob']
            for index in top_k_index_in_all_beams
        ], dim=0)

        if x.size(0) != seq.size(0):
            x = x.repeat(seq.size(0) // x.size(0), 1)

    for i in tokenizer.batch_decode(seq.tolist()):
        print(i)
Beispiel #3
0
def greedy_test(args):
    """ Test function """

    # load vocabulary
    vocab = torch.load(args.vocab)

    # build model
    translator = Transformer(args, vocab)
    translator.eval()

    # load parameters
    translator.load_state_dict(torch.load(args.decode_model_path))
    if args.cuda:
        translator = translator.cuda()

    test_data = read_corpus(args.decode_from_file, source="src")
    # ['<BOS>', '<PAD>', 'PAD', '<PAD>', '<PAD>']
    pred_data = len(test_data) * [[
        constants.PAD_WORD if i else constants.BOS_WORD
        for i in range(args.decode_max_steps)
    ]]

    output_file = codecs.open(args.decode_output_file, "w", encoding="utf-8")
    for test, pred in zip(test_data, pred_data):
        pred_output = [constants.PAD_WORD] * args.decode_max_steps
        test_var = to_input_variable([test], vocab.src, cuda=args.cuda)

        # only need one time
        enc_output = translator.encode(test_var[0], test_var[1])
        for i in range(args.decode_max_steps):
            pred_var = to_input_variable([pred[:i + 1]],
                                         vocab.tgt,
                                         cuda=args.cuda)

            scores = translator.translate(enc_output, test_var[0], pred_var)

            _, argmax_idxs = torch.max(scores, dim=-1)
            one_step_idx = argmax_idxs[-1].item()

            pred_output[i] = vocab.tgt.id2word[one_step_idx]
            if (one_step_idx
                    == constants.EOS) or (i == args.decode_max_steps - 1):
                print("[Source] %s" % " ".join(test))
                print("[Predict] %s" % " ".join(pred_output[:i]))
                print()

                output_file.write(" ".join(pred_output[:i]) + "\n")
                output_file.flush()
                break
            pred[i + 1] = vocab.tgt.id2word[one_step_idx]

    output_file.close()
def gen_soft_labels(c):
    c.setdefault(hebbian=False, distributed=False)
    net = Transformer(c)
    net, step = c.init_model(net, step='max', train=False)

    print('generating soft labels...')
    data_gen_tr = SequentialIterator(c, 1, 'train')
    net.eval()
    with torch.no_grad():
        i = 0
        for batch in tqdm(data_gen_tr):
            x = to_torch(batch, c.device).t()
            inputs, labels = x[:-1], x[1:]
            probs, _ = net(inputs, labels)

            values, indices = torch.topk(probs, c.topk, dim=1)

            indices_ = indices.cpu().numpy()
            values_ = values.cpu().numpy()
            labels_ = labels.cpu().numpy()

            if probs.size(0) != inputs.size(0):
                indices_ = indices_[-inputs.size(0):, :]
                values_ = values_[-inputs.size(0):, :]

            if i == 0:
                all_soft_indices = indices_
                all_soft_values = values_
            else:
                all_soft_indices = np.concatenate((all_soft_indices, indices_),
                                                  axis=0)
                all_soft_values = np.concatenate((all_soft_values, values_),
                                                 axis=0)

            i += 1
    all_soft_indices = np.concatenate(
        (all_soft_indices[0:1, :], all_soft_indices), axis=0)
    all_soft_values = np.concatenate(
        (all_soft_values[0:1, :], all_soft_values), axis=0)

    np.save(Cache / 'wikitext-103' / 'train_soft_labels.npy', all_soft_indices)
    np.save(Cache / 'wikitext-103' / 'train_soft_probs.npy', all_soft_values)
    print('Saved %s' % (Cache / 'wikitext-103' / 'train_soft_labels.npy'))
    print('Saved %s' % (Cache / 'wikitext-103' / 'train_soft_probs.npy'))

    cnt = 0.
    for k in range(len(data_gen_tr.tokens)):
        if data_gen_tr.tokens[k] in all_soft_indices[k]:
            cnt += 1
    print('%s%% of the tokens are predicted within the top %s logits' %
          (100 * cnt / len(data_gen_tr.tokens), c.topk))
Beispiel #5
0
    def __init__(self, model_source, rewrite_len=30, beam_size=4, debug=False):
        self.beam_size = beam_size
        self.rewrite_len = rewrite_len
        self.debug = debug

        model_source = torch.load(model_source,
                                  map_location=lambda storage, loc: storage)
        self.dict = model_source["word2idx"]
        self.idx2word = {v: k for k, v in model_source["word2idx"].items()}
        self.args = args = model_source["settings"]
        torch.manual_seed(args.seed)
        model = Transformer(args)
        model.load_state_dict(model_source['model'])
        self.model = model.eval()
Beispiel #6
0
def test(hp):
    # Loading hyper params
    load_hparams(hp, hp.ckpt)

    logging.info("# Prepare test batches")
    test_batches, num_test_batches, num_test_samples = get_batch(
        hp.test1,
        hp.test1,
        100000,
        100000,
        hp.vocab,
        hp.test_batch_size,
        shuffle=False)
    iter = tf.data.Iterator.from_structure(test_batches.output_types,
                                           test_batches.output_shapes)
    xs, ys = iter.get_next()

    test_init_op = iter.make_initializer(test_batches)

    logging.info("# Load model")
    model = Transformer(hp)

    logging.info("# Session")
    with tf.Session() as sess:
        ckpt_ = tf.train.latest_checkpoint(hp.ckpt)
        ckpt = ckpt_ if ckpt_ else hp.ckpt
        saver = tf.train.Saver()

        saver.restore(sess, ckpt)

        y_hat, mean_loss = model.eval(sess, test_init_op, xs, ys,
                                      num_test_batches)

        logging.info("# get hypotheses")
        hypotheses = get_hypotheses(num_test_samples, y_hat, model.idx2token)

        logging.info("# write results")
        model_output = os.path.split(ckpt)[-1]
        if not os.path.exists(hp.testdir):
            os.makedirs(hp.testdir)
        translation = os.path.join(hp.testdir, model_output)
        with open(translation, 'w', encoding="utf-8") as fout:
            fout.write("\n".join(hypotheses))

        logging.info("# calc bleu score and append it to translation")
        calc_bleu_nltk(hp.test2, translation)
Beispiel #7
0
    def __init__(self, model_source, cuda=False, beam_size=3):
        self.torch = torch.cuda if cuda else torch
        self.cuda = cuda
        self.beam_size = beam_size

        if self.cuda:
            model_source = torch.load(model_source)
        else:
            model_source = torch.load(
                model_source, map_location=lambda storage, loc: storage)
        self.src_dict = model_source["src_dict"]
        self.tgt_dict = model_source["tgt_dict"]
        self.src_idx2word = {v: k for k, v in model_source["tgt_dict"].items()}
        self.args = args = model_source["settings"]
        model = Transformer(args)
        model.load_state_dict(model_source['model'])

        if self.cuda: model = model.cuda()
        else: model = model.cpu()
        self.model = model.eval()
Beispiel #8
0
                                                             hp.batch_size,
                                                             shuffle=False)

# create a iterator of the correct shape and type
iter = tf.data.Iterator.from_structure(train_batches.output_types,
                                       train_batches.output_shapes)
xs, ys = iter.get_next()
# print('x0, x1 =', xs[0].shape, x[1].shape )

train_init_op = iter.make_initializer(train_batches)
eval_init_op = iter.make_initializer(eval_batches)

logging.info("# Load model")
m = Transformer(hp)
loss, train_op, global_step, train_summaries = m.train(xs, ys)
y_hat, eval_summaries, eval_loss = m.eval(xs, ys)
# y_hat = m.infer(xs, ys)

logging.info("# Session")
saver = tf.train.Saver(max_to_keep=hp.num_epochs)
with tf.Session() as sess:
    ckpt = tf.train.latest_checkpoint(hp.logdir)
    if ckpt is None:
        logging.info("Initializing from scratch")
        sess.run(tf.global_variables_initializer())
        save_variable_specs(os.path.join(hp.logdir, "specs"))
    else:
        saver.restore(sess, ckpt)

    summary_writer = tf.summary.FileWriter(hp.logdir, sess.graph)
Beispiel #9
0
class Trainer:
    def __init__(self, args, train_loader, test_loader, tokenizer_src, tokenizer_tgt):
        self.args = args
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.src_vocab_size = tokenizer_src.vocab_size
        self.tgt_vocab_size = tokenizer_tgt.vocab_size
        self.pad_id = tokenizer_src.pad_token_id # pad_token_id in tokenizer_tgt.vocab should be the same with this.
        self.device = 'cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu'

        self.model = Transformer(src_vocab_size = self.src_vocab_size,
                                 tgt_vocab_size = self.tgt_vocab_size,
                                 seq_len        = args.max_seq_len,
                                 d_model        = args.hidden,
                                 n_layers       = args.n_layers,
                                 n_heads        = args.n_attn_heads,
                                 p_drop         = args.dropout,
                                 d_ff           = args.ffn_hidden,
                                 pad_id         = self.pad_id)
        if args.multi_gpu:
            self.model = nn.DataParallel(self.model)
        self.model.to(self.device)

        self.optimizer = ScheduledOptim(optim.Adam(self.model.parameters(), betas=(0.9, 0.98), eps=1e-9),
                                        init_lr=2.0, d_model=args.hidden)
        self.criterion = nn.CrossEntropyLoss(ignore_index=self.pad_id)

    def train(self, epoch):
        losses = 0
        n_batches, n_samples = len(self.train_loader), len(self.train_loader.dataset)
        
        self.model.train()
        for i, batch in enumerate(self.train_loader):
            encoder_inputs, decoder_inputs, decoder_outputs = map(lambda x: x.to(self.device), batch)
            # |encoder_inputs| : (batch_size, seq_len), |decoder_inputs| : (batch_size, seq_len-1), |decoder_outputs| : (batch_size, seq_len-1)

            outputs, encoder_attns, decoder_attns, enc_dec_attns = self.model(encoder_inputs, decoder_inputs)
            # |outputs| : (batch_size, seq_len-1, tgt_vocab_size)
            # |encoder_attns| : [(batch_size, n_heads, seq_len, seq_len)] * n_layers
            # |decoder_attns| : [(batch_size, n_heads, seq_len-1, seq_len-1)] * n_layers
            # |enc_dec_attns| : [(batch_size, n_heads, seq_len-1, seq_len)] * n_layers
            
            loss = self.criterion(outputs.view(-1, self.tgt_vocab_size), decoder_outputs.view(-1))
            losses += loss.item()
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.update_learning_rate()
            self.optimizer.step()

            if i % (n_batches//5) == 0 and i != 0:
                print('Iteration {} ({}/{})\tLoss: {:.4f}\tlr: {:.4f}'.format(i, i, n_batches, losses/i, self.optimizer.get_current_lr))
        
        print('Train Epoch: {}\t>\tLoss: {:.4f}'.format(epoch, losses/n_batches))
            
    def validate(self, epoch):
        losses = 0
        n_batches, n_samples = len(self.test_loader), len(self.test_loader.dataset)
        
        self.model.eval()
        with torch.no_grad():
            for i, batch in enumerate(self.test_loader):
                encoder_inputs, decoder_inputs, decoder_outputs = map(lambda x: x.to(self.device), batch)
                # |encoder_inputs| : (batch_size, seq_len), |decoder_inputs| : (batch_size, seq_len-1), |decoder_outputs| : (batch_size, seq_len-1)

                outputs, encoder_attns, decoder_attns, enc_dec_attns = self.model(encoder_inputs, decoder_inputs)
                # |outputs| : (batch_size, seq_len-1, tgt_vocab_size)
                # |encoder_attns| : [(batch_size, n_heads, seq_len, seq_len)] * n_layers
                # |decoder_attns| : [(batch_size, n_heads, seq_len-1, seq_len-1)] * n_layers
                # |enc_dec_attns| : [(batch_size, n_heads, seq_len-1, seq_len)] * n_layers
                
                loss = self.criterion(outputs.view(-1, self.tgt_vocab_size), decoder_outputs.view(-1))
                losses += loss.item()

        print('Valid Epoch: {}\t>\tLoss: {:.4f}'.format(epoch, losses/n_batches))

    def save(self, epoch, model_prefix='model', root='.model'):
        path = Path(root) / (model_prefix + '.ep%d' % epoch)
        if not path.parent.exists():
            path.parent.mkdir()
        
        torch.save(self.model, path)
Beispiel #10
0
def do_train(args):
    if args.use_cuda:
        trainer_count = fluid.dygraph.parallel.Env().nranks
        place = fluid.CUDAPlace(fluid.dygraph.parallel.Env(
        ).dev_id) if trainer_count > 1 else fluid.CUDAPlace(0)
    else:
        trainer_count = 1
        place = fluid.CPUPlace()

    # define the data generator
    processor = reader.DataProcessor(
        fpattern=args.training_file,
        src_vocab_fpath=args.src_vocab_fpath,
        trg_vocab_fpath=args.trg_vocab_fpath,
        token_delimiter=args.token_delimiter,
        use_token_batch=args.use_token_batch,
        batch_size=args.batch_size,
        device_count=trainer_count,
        pool_size=args.pool_size,
        sort_type=args.sort_type,
        shuffle=args.shuffle,
        shuffle_batch=args.shuffle_batch,
        start_mark=args.special_token[0],
        end_mark=args.special_token[1],
        unk_mark=args.special_token[2],
        max_length=args.max_length,
        n_head=args.n_head)
    batch_generator = processor.data_generator(phase="train")
    if args.validation_file:
        val_processor = reader.DataProcessor(
            fpattern=args.validation_file,
            src_vocab_fpath=args.src_vocab_fpath,
            trg_vocab_fpath=args.trg_vocab_fpath,
            token_delimiter=args.token_delimiter,
            use_token_batch=args.use_token_batch,
            batch_size=args.batch_size,
            device_count=trainer_count,
            pool_size=args.pool_size,
            sort_type=args.sort_type,
            shuffle=False,
            shuffle_batch=False,
            start_mark=args.special_token[0],
            end_mark=args.special_token[1],
            unk_mark=args.special_token[2],
            max_length=args.max_length,
            n_head=args.n_head)
        val_batch_generator = val_processor.data_generator(phase="train")
    if trainer_count > 1:  # for multi-process gpu training
        batch_generator = fluid.contrib.reader.distributed_batch_reader(
            batch_generator)
    args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
        args.unk_idx = processor.get_vocab_summary()

    with fluid.dygraph.guard(place):
        # set seed for CE
        random_seed = eval(str(args.random_seed))
        if random_seed is not None:
            fluid.default_main_program().random_seed = random_seed
            fluid.default_startup_program().random_seed = random_seed

        # define data loader
        train_loader = fluid.io.DataLoader.from_generator(capacity=10)
        train_loader.set_batch_generator(batch_generator, places=place)
        if args.validation_file:
            val_loader = fluid.io.DataLoader.from_generator(capacity=10)
            val_loader.set_batch_generator(val_batch_generator, places=place)

        # define model
        transformer = Transformer(
            args.src_vocab_size, args.trg_vocab_size, args.max_length + 1,
            args.n_layer, args.n_head, args.d_key, args.d_value, args.d_model,
            args.d_inner_hid, args.prepostprocess_dropout,
            args.attention_dropout, args.relu_dropout, args.preprocess_cmd,
            args.postprocess_cmd, args.weight_sharing, args.bos_idx,
            args.eos_idx)

        # define loss
        criterion = CrossEntropyCriterion(args.label_smooth_eps)

        # define optimizer
        optimizer = fluid.optimizer.Adam(
            learning_rate=NoamDecay(args.d_model, args.warmup_steps,
                                    args.learning_rate),
            beta1=args.beta1,
            beta2=args.beta2,
            epsilon=float(args.eps),
            parameter_list=transformer.parameters())

        ## init from some checkpoint, to resume the previous training
        if args.init_from_checkpoint:
            model_dict, opt_dict = fluid.load_dygraph(
                os.path.join(args.init_from_checkpoint, "transformer"))
            transformer.load_dict(model_dict)
            optimizer.set_dict(opt_dict)
        ## init from some pretrain models, to better solve the current task
        if args.init_from_pretrain_model:
            model_dict, _ = fluid.load_dygraph(
                os.path.join(args.init_from_pretrain_model, "transformer"))
            transformer.load_dict(model_dict)

        if trainer_count > 1:
            strategy = fluid.dygraph.parallel.prepare_context()
            transformer = fluid.dygraph.parallel.DataParallel(transformer,
                                                              strategy)

        # the best cross-entropy value with label smoothing
        loss_normalizer = -(
            (1. - args.label_smooth_eps) * np.log(
                (1. - args.label_smooth_eps)) + args.label_smooth_eps *
            np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20))

        ce_time = []
        ce_ppl = []
        step_idx = 0

        # train loop
        for pass_id in range(args.epoch):
            epoch_start = time.time()

            batch_id = 0
            batch_start = time.time()
            interval_word_num = 0.0
            for input_data in train_loader():
                if args.max_iter and step_idx == args.max_iter:  #NOTE: used for benchmark
                    return
                batch_reader_end = time.time()

                (src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
                 trg_slf_attn_bias, trg_src_attn_bias, lbl_word,
                 lbl_weight) = input_data

                logits = transformer(src_word, src_pos, src_slf_attn_bias,
                                     trg_word, trg_pos, trg_slf_attn_bias,
                                     trg_src_attn_bias)

                sum_cost, avg_cost, token_num = criterion(logits, lbl_word,
                                                          lbl_weight)

                if trainer_count > 1:
                    avg_cost = transformer.scale_loss(avg_cost)
                    avg_cost.backward()
                    transformer.apply_collective_grads()
                else:
                    avg_cost.backward()

                optimizer.minimize(avg_cost)
                transformer.clear_gradients()

                interval_word_num += np.prod(src_word.shape)
                if step_idx % args.print_step == 0:
                    total_avg_cost = avg_cost.numpy() * trainer_count

                    if step_idx == 0:
                        logger.info(
                            "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                            "normalized loss: %f, ppl: %f" %
                            (step_idx, pass_id, batch_id, total_avg_cost,
                             total_avg_cost - loss_normalizer,
                             np.exp([min(total_avg_cost, 100)])))
                    else:
                        train_avg_batch_cost = args.print_step / (
                            time.time() - batch_start)
                        word_speed = interval_word_num / (
                            time.time() - batch_start)
                        logger.info(
                            "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                            "normalized loss: %f, ppl: %f, avg_speed: %.2f step/s, "
                            "words speed: %0.2f words/s" %
                            (step_idx, pass_id, batch_id, total_avg_cost,
                             total_avg_cost - loss_normalizer,
                             np.exp([min(total_avg_cost, 100)]),
                             train_avg_batch_cost, word_speed))
                    batch_start = time.time()
                    interval_word_num = 0.0

                if step_idx % args.save_step == 0 and step_idx != 0:
                    # validation
                    if args.validation_file:
                        transformer.eval()
                        total_sum_cost = 0
                        total_token_num = 0
                        for input_data in val_loader():
                            (src_word, src_pos, src_slf_attn_bias, trg_word,
                             trg_pos, trg_slf_attn_bias, trg_src_attn_bias,
                             lbl_word, lbl_weight) = input_data
                            logits = transformer(
                                src_word, src_pos, src_slf_attn_bias, trg_word,
                                trg_pos, trg_slf_attn_bias, trg_src_attn_bias)
                            sum_cost, avg_cost, token_num = criterion(
                                logits, lbl_word, lbl_weight)
                            total_sum_cost += sum_cost.numpy()
                            total_token_num += token_num.numpy()
                            total_avg_cost = total_sum_cost / total_token_num
                        logger.info("validation, step_idx: %d, avg loss: %f, "
                                    "normalized loss: %f, ppl: %f" %
                                    (step_idx, total_avg_cost,
                                     total_avg_cost - loss_normalizer,
                                     np.exp([min(total_avg_cost, 100)])))
                        transformer.train()

                    if args.save_model and (
                            trainer_count == 1 or
                            fluid.dygraph.parallel.Env().dev_id == 0):
                        model_dir = os.path.join(args.save_model,
                                                 "step_" + str(step_idx))
                        if not os.path.exists(model_dir):
                            os.makedirs(model_dir)
                        fluid.save_dygraph(
                            transformer.state_dict(),
                            os.path.join(model_dir, "transformer"))
                        fluid.save_dygraph(
                            optimizer.state_dict(),
                            os.path.join(model_dir, "transformer"))

                batch_id += 1
                step_idx += 1

            train_epoch_cost = time.time() - epoch_start
            ce_time.append(train_epoch_cost)
            logger.info("train epoch: %d, epoch_cost: %.5f s" %
                        (pass_id, train_epoch_cost))

        if args.save_model:
            model_dir = os.path.join(args.save_model, "step_final")
            if not os.path.exists(model_dir):
                os.makedirs(model_dir)
            fluid.save_dygraph(transformer.state_dict(),
                               os.path.join(model_dir, "transformer"))
            fluid.save_dygraph(optimizer.state_dict(),
                               os.path.join(model_dir, "transformer"))

        if args.enable_ce:
            _ppl = 0
            _time = 0
            try:
                _time = ce_time[-1]
                _ppl = ce_ppl[-1]
            except:
                print("ce info error")
            print("kpis\ttrain_duration_card%s\t%s" % (trainer_count, _time))
            print("kpis\ttrain_ppl_card%s\t%f" % (trainer_count, _ppl))
Beispiel #11
0
    # tmp_text = "我 是 何 杰"
    # tmp_pieces = sp.EncodeAsPieces(tmp_text)
    # tmp_batches, num_tmp_batches, num_tmp_samples = get_batch_single(" ".join(tmp_pieces) + "\n",
    #                                                                     hp.vocab, 1, shuffle=False)
    # tmp_iter = tf.data.Iterator.from_structure(tmp_batches.output_types, tmp_batches.output_shapes)
    # tmp_x, tmp_y = tmp_iter.get_next()
    # print(f"xs is {tmp_x}")
    # print(f"ys is {tmp_y}")
    xs = tf.placeholder(dtype=tf.int32,
                        shape=(None, None)), tf.placeholder(dtype=tf.int32,
                                                            shape=(None, ))
    ys = tf.placeholder(dtype=tf.int32, shape=(None, None)), tf.placeholder(dtype=tf.int32, shape=(None, None)), \
         tf.placeholder(dtype=tf.int32, shape=(None,))

    m = Transformer(hp)
    y_hat, _ = m.eval(xs, ys)
    saver = tf.train.Saver()

    saver.restore(sess, ckpt)

    while True:
        raw_text = input("Model prompt: ")
        if raw_text == 'EOF':
            break
        pieces = sp.EncodeAsPieces(raw_text)
        (x0, x1), (y0, y1, y2) = get_batch_single(" ".join(pieces) + "\n",
                                                  hp.vocab,
                                                  1,
                                                  shuffle=False)
        import numpy as np
        x = [[x0], [x1]]
Beispiel #12
0
    1000,
    1000,
    hp.vocab,
    hp.test_paraphrased,
    hp.test_batch_size,
    shuffle=False,
    paraphrase_type=hp.paraphrase_type)
iter = tf.data.Iterator.from_structure(test_batches.output_types,
                                       test_batches.output_shapes)
xs, ys, x_paraphrased_dict, _ = iter.get_next()

test_init_op = iter.make_initializer(test_batches)

logging.info("# Load model")
m = Transformer(hp)
y_hat, _ = m.eval(xs, ys, x_paraphrased_dict)

logging.info("# Session")
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
    ckpt_ = tf.train.latest_checkpoint(hp.logdir)
    ckpt = hp.logdir if ckpt_ is None else ckpt_  # None: ckpt is a file. otherwise dir.
    saver = tf.train.Saver()

    saver.restore(sess, ckpt)

    sess.run(test_init_op)

    logging.info("# get hypotheses")
    hypotheses = get_hypotheses(num_test_batches, num_test_samples, sess,
Beispiel #13
0
    paraphrase_type=hp.paraphrase_type)

# create a iterator of the correct shape and type
iter = tf.data.Iterator.from_structure(train_batches.output_types,
                                       train_batches.output_shapes)
xs, ys, x_paraphrased_dict, synonym_label = iter.get_next()

train_init_op = iter.make_initializer(train_batches)
eval_init_op = iter.make_initializer(eval_batches)

logging.info("# Load model")
m = Transformer(hp)
loss, train_op, global_step, train_summaries = m.train(xs, ys,
                                                       x_paraphrased_dict,
                                                       synonym_label)
y_hat, eval_summaries = m.eval(xs, ys, x_paraphrased_dict)
# y_hat = m.infer(xs, ys)

logging.info("# Session")
saver = tf.train.Saver(max_to_keep=hp.num_epochs)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
    ckpt = tf.train.latest_checkpoint(hp.logdir)
    from_scrtch = True
    if ckpt is None:
        sess.run(tf.global_variables_initializer())
        save_variable_specs(os.path.join(hp.logdir, "specs"))
    else:
        saver.restore(sess, ckpt)
if hp.en_pretreat_path is not None:
    _, en_wordEmbedding = getWordEmbedding(hp, en_words, hp.en_w2v_model_path)

de_wordEmbedding = None
if hp.de_pretreat_path is not None:
    _, de_wordEmbedding = getWordEmbedding(hp, de_words, hp.de_w2v_model_path)

m = Transformer(hp, en_wordEmbedding, de_wordEmbedding)

X1 = tf.placeholder(tf.int32, shape=(None, hp.sequence_length), name="X_en")
X2 = tf.placeholder(tf.int32, shape=(None, hp.sequence_length), name="X_de")
y1 = tf.placeholder(tf.int32, shape=[None], name="y_en")
y2 = tf.placeholder(tf.int32, shape=[None], name="y_de")
loss, train_op, global_step, predictions, shuffle_y, en_memory, de_memory = m.train(
    X1, X2, y1, y2)
eval_loss, eval_prediction, sentence = m.eval(X1, X2, y1, y2)

saver = tf.train.Saver(max_to_keep=hp.epoch)

with tf.Session() as sess:
    ckpt = tf.train.latest_checkpoint(hp.ckpt_path)
    if ckpt is None:
        sess.run(tf.global_variables_initializer())
    else:
        saver.restore(sess, ckpt)

    for epoch in range(hp.epoch):
        all_prediction = np.array([])
        all_label = np.array([])
        all_loss = 0.0
        print('--------Epoch {}--------'.format(epoch))
Beispiel #15
0
biggest_score = 0.13

for epoch in range(num_epochs):
    k += 1
    k = k % 10
    print(f"[Epoch {(epoch + 1)} / {num_epochs}]")

    if save_model and prev_score > biggest_score:
        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        }
        save_checkpoint(checkpoint)
        biggest_score = prev_score

    model.eval()
    generated_sentence = create_sentence(model,
                                         sentence,
                                         input_text,
                                         output_text,
                                         device,
                                         max_length=50)

    print(f"Generated sentence example: \n {generated_sentence}")
    model.train()

    losses = []
    means = []

    temp_boi = big_iterator.copy()  # I stopped caring after this point
    del temp_boi[k]
def do_predict(args):
    if args.use_cuda:
        place = fluid.CUDAPlace(0)
    else:
        place = fluid.CPUPlace()

    # define the data generator
    '''
    # old reader
    processor = reader.DataProcessor(fpattern=args.predict_file,
                                     src_vocab_fpath=args.src_vocab_fpath,
                                     trg_vocab_fpath=args.trg_vocab_fpath,
                                     token_delimiter=args.token_delimiter,
                                     use_token_batch=False,
                                     batch_size=args.batch_size,
                                     device_count=1,
                                     pool_size=args.pool_size,
                                     sort_type=reader.SortType.NONE,
                                     shuffle=False,
                                     shuffle_batch=False,
                                     start_mark=args.special_token[0],
                                     end_mark=args.special_token[1],
                                     unk_mark=args.special_token[2],
                                     max_length=args.max_length,
                                     n_head=args.n_head)
    '''
    processor = reader.DataProcessor(fpattern=args.predict_file,
                                     src_vocab_fpath=args.src_vocab_fpath,
                                     trg_vocab_fpath=args.trg_vocab_fpath,
                                     token_delimiter=args.token_delimiter,
                                     use_token_batch=False,
                                     batch_size=args.batch_size,
                                     device_count=1,
                                     pool_size=args.pool_size,
                                     sort_type=reader.SortType.NONE,
                                     shuffle=False,
                                     shuffle_batch=False,
                                     only_src=args.only_src,
                                     start_mark=args.special_token[0],
                                     end_mark=args.special_token[1],
                                     unk_mark=args.special_token[2],
                                     max_length=args.max_length,
                                     n_head=args.n_head,
                                     stream=args.stream,
                                     src_bpe_dict=args.src_bpe_dict)
    batch_generator = processor.data_generator(phase="predict", place=place)
    args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
        args.unk_idx = processor.get_vocab_summary()
    trg_idx2word = reader.DataProcessor.load_dict(
        dict_path=args.trg_vocab_fpath, reverse=True)

    args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
        args.unk_idx = processor.get_vocab_summary()

    with fluid.dygraph.guard(place):
        # define data loader
        test_loader = fluid.io.DataLoader.from_generator(capacity=10)
        test_loader.set_batch_generator(batch_generator, places=place)

        # define model
        transformer = Transformer(
            args.src_vocab_size, args.trg_vocab_size, args.max_length + 1,
            args.n_layer, args.n_head, args.d_key, args.d_value, args.d_model,
            args.d_inner_hid, args.prepostprocess_dropout,
            args.attention_dropout, args.relu_dropout, args.preprocess_cmd,
            args.postprocess_cmd, args.weight_sharing, args.bos_idx,
            args.eos_idx)

        # load the trained model
        assert args.init_from_params, (
            "Please set init_from_params to load the infer model.")
        model_dict, _ = fluid.load_dygraph(
            os.path.join(args.init_from_params, "transformer"))
        # to avoid a longer length than training, reset the size of position
        # encoding to max_length
        model_dict["encoder.pos_encoder.weight"] = position_encoding_init(
            args.max_length + 1, args.d_model)
        model_dict["decoder.pos_encoder.weight"] = position_encoding_init(
            args.max_length + 1, args.d_model)
        transformer.load_dict(model_dict)

        # set evaluate mode
        transformer.eval()

        f = open(args.output_file, "wb")

        detok = MosesDetokenizer(lang='en')
        detc = MosesDetruecaser()

        for input_data in test_loader():
            if args.stream:
                (src_word, src_pos, src_slf_attn_bias, trg_word,
                 trg_src_attn_bias, real_read) = input_data
            else:
                (src_word, src_pos, src_slf_attn_bias, trg_word,
                 trg_src_attn_bias) = input_data

            finished_seq, finished_scores = transformer.beam_search(
                src_word,
                src_pos,
                src_slf_attn_bias,
                trg_word,
                trg_src_attn_bias,
                bos_id=args.bos_idx,
                eos_id=args.eos_idx,
                beam_size=args.beam_size,
                max_len=args.max_out_len,
                waitk=args.waitk,
                stream=args.stream)
            finished_seq = finished_seq.numpy()
            finished_scores = finished_scores.numpy()
            for idx, ins in enumerate(finished_seq):
                for beam_idx, beam in enumerate(ins):
                    if beam_idx >= args.n_best: break
                    id_list = post_process_seq(beam, args.bos_idx,
                                               args.eos_idx)
                    word_list = [trg_idx2word[id] for id in id_list]

                    if args.stream:
                        if args.waitk > 0:
                            # for wait-k models, wait k words in the beginning
                            word_list = [b''] * (args.waitk - 1) + word_list
                        else:
                            # for full sentence model, wait until the end
                            word_list = [b''] * (len(real_read[idx].numpy()) -
                                                 1) + word_list

                        final_output = []
                        real_output = []
                        _read = real_read[idx].numpy()
                        sent = ''
                        bpe_flag = False

                        for j in range(max(len(_read), len(word_list))):
                            # append number of reads at step j
                            r = _read[j] if j < len(_read) else 0
                            if r > 0:
                                final_output += [b''] * (r - 1)

                            # append number of writes at step j
                            w = word_list[j] if j < len(word_list) else b''
                            w = w.decode('utf-8')
                            real_output.append(w)

                            # if bpe_flag:
                            #     _sent = ('%s@@ %s'%(sent, w)).strip()
                            # else:
                            #     _sent = ('%s %s'%(sent, w)).strip()

                            _sent = ' '.join(real_output)

                            if len(_sent) > 0:

                                _sent += ' a'
                                _sent = ' '.join(_sent.split())

                                # if _sent.endswith('@@ a'):
                                #     bpe_flag = True
                                # else:
                                #     bpe_flag = False

                                _sent = _sent.replace('@@ ', '')
                                _sent = detok.detokenize(_sent.split())
                                _sent = detc.detruecase(_sent)
                                _sent = ' '.join(_sent)
                                _sent = _sent[:-1].strip()

                            incre = _sent[len(sent):]
                            #print('_sent0:', _sent)
                            sent = _sent
                            #print('sent:', sent)

                            if r > 0:
                                # if there is read, append a word to write
                                # final_output.append(w)
                                final_output.append(str.encode(incre))
                            else:
                                # if there is no read, append word to the final write
                                if j >= len(word_list):
                                    break
                                # final_output[-1] += b' '+w
                                final_output[-1] += str.encode(incre)

                            #print(final_output)
                            #print('incre:', incre)
                            #print('_sent1:', _sent)
                            # f.write(bytes('part:'+_sent+'\n'))

                        sequence = b"\n".join(final_output) + b" \n"
                        f.write(sequence)
                        # embed()
                    else:
                        sequence = b" ".join(word_list) + b"\n"
                        f.write(sequence)
                    f.flush()
def main() -> None:
    """Entry point.
    """
    print("Start!!!")
    sys.stdout.flush()
    if args.run_mode == "train":
        train_data = MultiAlignedDataMultiFiles(config_data.train_data_params,
                                                device=device)
        #train_data = tx.data.MultiAlignedData(config_data.train_data_params, device=device)
        print("will data_iterator")
        data_iterator = tx.data.DataIterator({"train": train_data})
        print("data_iterator done")

        # Create model and optimizer
        model = Transformer(config_model, config_data, train_data.vocab('src'))
        model.to(device)
        print("device:", device)
        print("vocab src1:", train_data.vocab('src').id_to_token_map_py)
        print("vocab src2:", train_data.vocab('src').token_to_id_map_py)

        model = ModelWrapper(model, config_model.beam_width)
        if torch.cuda.device_count() > 1:
            #model = nn.DataParallel(model.cuda(), device_ids=[0, 1]).to(device)
            #model = MyDataParallel(model.cuda(), device_ids=[0, 1]).to(device)
            model = MyDataParallel(model.cuda()).to(device)

        lr_config = config_model.lr_config
        if lr_config["learning_rate_schedule"] == "static":
            init_lr = lr_config["static_lr"]
            scheduler_lambda = lambda x: 1.0
        else:
            init_lr = lr_config["lr_constant"]
            scheduler_lambda = functools.partial(
                get_lr_multiplier, warmup_steps=lr_config["warmup_steps"])
        optim = torch.optim.Adam(model.parameters(),
                                 lr=init_lr,
                                 betas=(0.9, 0.997),
                                 eps=1e-9)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optim, scheduler_lambda)

        output_dir = Path(args.output_dir)
        if not output_dir.exists():
            output_dir.mkdir()

        def _save_epoch(epoch):

            checkpoint_name = f"checkpoint{epoch}.pt"
            print(f"saveing model...{checkpoint_name}")
            torch.save(model.state_dict(), output_dir / checkpoint_name)

        def _train_epoch(epoch):
            data_iterator.switch_to_dataset('train')
            model.train()
            #model.module.train()
            #print("after model.module.train")
            sys.stdout.flush()
            step = 0
            num_steps = len(data_iterator)
            loss_stats = []
            for batch in data_iterator:
                #print("batch:", batch)
                #batch = batch.to(device)
                return_dict = model(batch)
                #return_dict = model.module.forward(batch)
                loss = return_dict['loss']
                #print("loss:", loss)
                loss = loss.mean()
                #print("loss:", loss)
                #print("loss.item():", loss.item())
                loss_stats.append(loss.item())

                optim.zero_grad()
                loss.backward()
                optim.step()
                scheduler.step()

                config_data.display = 1
                if step % config_data.display == 0:
                    avr_loss = sum(loss_stats) / len(loss_stats)
                    ppl = utils.get_perplexity(avr_loss)
                    print(
                        f"epoch={epoch}, step={step}/{num_steps}, loss={avr_loss:.4f}, ppl={ppl:.4f}, lr={scheduler.get_lr()[0]}"
                    )
                    sys.stdout.flush()
                step += 1

        print("will train")
        for i in range(config_data.num_epochs):
            print("epoch i:", i)
            sys.stdout.flush()
            _train_epoch(i)
            _save_epoch(i)

    elif args.run_mode == "test":
        test_data = tx.data.MultiAlignedData(config_data.test_data_params,
                                             device=device)
        data_iterator = tx.data.DataIterator({"test": test_data})
        print("test_data vocab src1 before load:",
              test_data.vocab('src').id_to_token_map_py)

        # Create model and optimizer
        model = Transformer(config_model, config_data, test_data.vocab('src'))

        model = ModelWrapper(model, config_model.beam_width)
        #print("state_dict:", model.state_dict())
        model_loaded = torch.load(args.load_checkpoint)
        #print("model_loaded state_dict:", model_loaded)
        model_loaded = rm_begin_str_in_keys("module.", model_loaded)
        #print("model_loaded2 state_dict:", model_loaded)

        model.load_state_dict(model_loaded)
        #model.load_state_dict(torch.load(args.load_checkpoint))
        model.to(device)

        data_iterator.switch_to_dataset('test')
        model.eval()
        print("will predict !!!")
        sys.stdout.flush()

        fo = open(args.pred_output_file, "w")
        print("test_data vocab src1:",
              test_data.vocab('src').id_to_token_map_py)
        print("test_data vocab src2:",
              test_data.vocab('src').token_to_id_map_py)
        with torch.no_grad():
            for batch in data_iterator:
                print("batch:", batch)
                return_dict = model.predict(batch)
                preds = return_dict['preds'].cpu()
                print("preds:", preds)
                pred_words = tx.data.map_ids_to_strs(preds,
                                                     test_data.vocab('src'))
                #src_words = tx.data.map_ids_to_strs(batch['src_text'], test_data.vocab('src'))
                src_words = [" ".join(sw) for sw in batch['src_text']]
                for swords, words in zip(src_words, pred_words):
                    print(str(swords) + "\t" + str(words))
                    fo.write(str(words) + "\n")
                #print(" ".join(batch.src_text) + "\t" + pred_words)
                #print(batch.src_text, pred_words)
                #fo.write(str(pred_words) + "\n")
                fo.flush()
        fo.close()

    else:
        raise ValueError(f"Unknown mode: {args.run_mode}")
Beispiel #18
0
def main():
    parser = argparse.ArgumentParser(description='Commonsense Dataset Dev')

    # Experiment params
    parser.add_argument('--mode', type=str, help='train or test mode', required=True, choices=['train', 'test'])
    parser.add_argument('--expt_dir', type=str, help='root directory to save model & summaries')
    parser.add_argument('--expt_name', type=str, help='expt_dir/expt_name: organize experiments')
    parser.add_argument('--run_name', type=str, help='expt_dir/expt_name/run_name: organize training runs')
    parser.add_argument('--test_file', type=str, default='test',
                        help='The file containing test data to evaluate in test mode.')

    # Model params
    parser.add_argument('--model', type=str, help='transformer model (e.g. roberta-base)', required=True)
    parser.add_argument('--num_layers', type=int,
                        help='Number of hidden layers in transformers (default number if not provided)', default=-1)
    parser.add_argument('--seq_len', type=int, help='tokenized input sequence length', default=256)
    parser.add_argument('--num_cls', type=int, help='model number of classes', default=2)
    parser.add_argument('--ckpt', type=str, help='path to model checkpoint .pth file')

    # Data params
    parser.add_argument('--pred_file', type=str, help='address of prediction csv file, for "test" mode',
                        default='results.csv')
    parser.add_argument('--dataset', type=str, default='com2sense')
    # Training params
    parser.add_argument('--lr', type=float, help='learning rate', default=1e-5)
    parser.add_argument('--epochs', type=int, help='number of epochs', default=100)
    parser.add_argument('--batch_size', type=int, help='batch size', default=8)
    parser.add_argument('--acc_step', type=int, help='gradient accumulation steps', default=1)
    parser.add_argument('--log_interval', type=int, help='interval size for logging training summaries', default=100)
    parser.add_argument('--save_interval', type=int, help='save model after `n` weight update steps', default=30000)
    parser.add_argument('--val_size', type=int, help='validation set size for evaluating metrics, '
                                                     'and it need to be even to get pairwise accuracy', default=2048)

    # GPU params
    parser.add_argument('--gpu_ids', type=str, help='GPU IDs (0,1,2,..) seperated by comma', default='0')
    parser.add_argument('-data_parallel',
                        help='Whether to use nn.dataparallel (currently available for BERT-based models)',
                        action='store_true')
    parser.add_argument('--use_amp', type=str2bool, help='Automatic-Mixed Precision (T/F)', default='T')
    parser.add_argument('-cpu', help='use cpu only (for test)', action='store_true')

    # Misc params
    parser.add_argument('--num_workers', type=int, help='number of worker threads for Dataloader', default=1)

    # Parse Args
    args = parser.parse_args()

    # Dataset list
    dataset_names = csv2list(args.dataset)
    print()

    # Multi-GPU
    device_ids = csv2list(args.gpu_ids, int)
    print('Selected GPUs: {}'.format(device_ids))

    # Device for loading dataset (batches)
    device = torch.device(device_ids[0])
    if args.cpu:
        device = torch.device('cpu')

    # Text-to-Text
    text2text = ('t5' in args.model)
    uniqa = ('unified' in args.model)

    assert not (text2text and args.use_amp == 'T'), 'use_amp should be F when using T5-based models.'
    # Train params
    n_epochs = args.epochs
    batch_size = args.batch_size
    lr = args.lr
    accumulation_steps = args.acc_step
    # Todo: Verify the grad-accum code (loss avging seems slightly incorrect)

    # Train
    if args.mode == 'train':
        # Ensure CUDA available for training
        assert torch.cuda.is_available(), 'No CUDA device for training!'

        # Setup train log directory
        log_dir = os.path.join(args.expt_dir, args.expt_name, args.run_name)

        if not os.path.exists(log_dir):
            os.makedirs(log_dir)

        # TensorBoard summaries setup  -->  /expt_dir/expt_name/run_name/
        writer = SummaryWriter(log_dir)

        # Train log file
        log_file = setup_logger(parser, log_dir)

        print('Training Log Directory: {}\n'.format(log_dir))

        # Dataset & Dataloader
        dataset = BaseDataset('train', tokenizer=args.model, max_seq_len=args.seq_len, text2text=text2text, uniqa=uniqa)
        train_datasets = ConcatDataset([dataset])

        dataset = BaseDataset('dev', tokenizer=args.model, max_seq_len=args.seq_len, text2text=text2text, uniqa=uniqa)
        val_datasets = ConcatDataset([dataset])

        train_loader = DataLoader(train_datasets, batch_size, shuffle=True, drop_last=True,
                                  num_workers=args.num_workers)
        val_loader = DataLoader(val_datasets, batch_size, shuffle=True, drop_last=True, num_workers=args.num_workers)

        # In multi-dataset setups, also track dataset-specific loaders for validation metrics
        val_dataloaders = []
        if len(dataset_names) > 1:
            for val_dset in val_datasets.datasets:
                loader = DataLoader(val_dset, batch_size, shuffle=True, drop_last=True, num_workers=args.num_workers)

                val_dataloaders.append(loader)

        # Tokenizer
        tokenizer = dataset.get_tokenizer()

        # Split sizes
        train_size = train_datasets.__len__()
        val_size = val_datasets.__len__()
        log_msg = 'Train: {} \nValidation: {}\n\n'.format(train_size, val_size)

        # Min of the total & subset size
        val_used_size = min(val_size, args.val_size)
        log_msg += 'Validation Accuracy is computed using {} samples. See --val_size\n'.format(val_used_size)

        log_msg += 'No. of Classes: {}\n'.format(args.num_cls)
        print_log(log_msg, log_file)

        # Build Model
        model = Transformer(args.model, args.num_cls, text2text, device_ids, num_layers=args.num_layers)
        if args.data_parallel and not args.ckpt:
            model = nn.DataParallel(model, device_ids=device_ids)
            device = torch.device(f'cuda:{model.device_ids[0]}')

        if not text2text:
            model.to(device)

        model.train()

        # Loss & Optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr)
        optimizer.zero_grad()

        scaler = GradScaler(enabled=args.use_amp)

        # Step & Epoch
        start_epoch = 1
        curr_step = 1
        best_val_acc = 0.0

        # Load model checkpoint file (if specified)
        if args.ckpt:
            checkpoint = torch.load(args.ckpt, map_location=device)

            # Load model & optimizer
            model.load_state_dict(checkpoint['model_state_dict'])
            if args.data_parallel:
                model = nn.DataParallel(model, device_ids=device_ids)
                device = torch.device(f'cuda:{model.device_ids[0]}')
            model.to(device)

            curr_step = checkpoint['curr_step']
            start_epoch = checkpoint['epoch']
            prev_loss = checkpoint['loss']

            log_msg = 'Resuming Training...\n'
            log_msg += 'Model successfully loaded from {}\n'.format(args.ckpt)
            log_msg += 'Training loss: {:2f} (from ckpt)\n'.format(prev_loss)

            print_log(log_msg, log_file)

        steps_per_epoch = len(train_loader)
        start_time = time()

        for epoch in range(start_epoch, start_epoch + n_epochs):
            for batch in tqdm(train_loader):
                # Load batch to device
                batch = {k: v.to(device) for k, v in batch.items()}

                with autocast(args.use_amp):
                    if text2text:
                        # Forward + Loss
                        output = model(batch)
                        loss = output[0]

                    else:
                        # Forward Pass
                        label_logits = model(batch)
                        label_gt = batch['label']

                        # Compute Loss
                        loss = criterion(label_logits, label_gt)

                if args.data_parallel:
                    loss = loss.mean()
                # Backward Pass
                loss /= accumulation_steps
                scaler.scale(loss).backward()

                if curr_step % accumulation_steps == 0:
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()

                # Print Results - Loss value & Validation Accuracy
                if curr_step % args.log_interval == 0:
                    # Validation set accuracy
                    if val_datasets:
                        val_metrics = compute_eval_metrics(model, val_loader, device, val_used_size, tokenizer,
                                                           text2text, parallel=args.data_parallel)

                        # Reset the mode to training
                        model.train()

                        log_msg = 'Validation Accuracy: {:.2f} %  || Validation Loss: {:.4f}'.format(
                            val_metrics['accuracy'], val_metrics['loss'])

                        print_log(log_msg, log_file)

                        # Add summaries to TensorBoard
                        writer.add_scalar('Val/Loss', val_metrics['loss'], curr_step)
                        writer.add_scalar('Val/Accuracy', val_metrics['accuracy'], curr_step)

                    # Add summaries to TensorBoard
                    writer.add_scalar('Train/Loss', loss.item(), curr_step)

                    # Compute elapsed & remaining time for training to complete
                    time_elapsed = (time() - start_time) / 3600

                    log_msg = 'Epoch [{}/{}], Step [{}/{}], Loss: {:.4f} | time elapsed: {:.2f}h |'.format(
                        epoch, n_epochs, curr_step, steps_per_epoch, loss.item(), time_elapsed)

                    print_log(log_msg, log_file)

                # Save the model
                if curr_step % args.save_interval == 0:
                    path = os.path.join(log_dir, 'model_' + str(curr_step) + '.pth')

                    state_dict = {'model_state_dict': model.state_dict(),
                                  'curr_step': curr_step, 'loss': loss.item(),
                                  'epoch': epoch, 'val_accuracy': best_val_acc}

                    torch.save(state_dict, path)

                    log_msg = 'Saving the model at the {} step to directory:{}'.format(curr_step, log_dir)
                    print_log(log_msg, log_file)

                curr_step += 1

            # Validation accuracy on the entire set
            if val_datasets:
                log_msg = '-------------------------------------------------------------------------\n'
                val_metrics = compute_eval_metrics(model, val_loader, device, val_size, tokenizer, text2text,
                                                   parallel=args.data_parallel)

                log_msg += '\nAfter {} epoch:\n'.format(epoch)
                log_msg += 'Validation Accuracy: {:.2f} %  || Validation Loss: {:.4f}\n'.format(
                    val_metrics['accuracy'], val_metrics['loss'])

                # For Multi-Dataset setup:
                if len(dataset_names) > 1:
                    # compute validation set metrics on each dataset independently
                    for loader in val_dataloaders:
                        metrics = compute_eval_metrics(model, loader, device, val_size, tokenizer, text2text,
                                                       parallel=args.data_parallel)

                        log_msg += '\n --> {}\n'.format(loader.dataset.get_classname())
                        log_msg += 'Validation Accuracy: {:.2f} %  || Validation Loss: {:.4f}\n'.format(
                            metrics['accuracy'], metrics['loss'])

                # Save best model after every epoch
                if val_metrics["accuracy"] > best_val_acc:
                    best_val_acc = val_metrics["accuracy"]

                    step = '{:.1f}k'.format(curr_step / 1000) if curr_step > 1000 else '{}'.format(curr_step)
                    filename = 'ep_{}_stp_{}_acc_{:.4f}_{}.pth'.format(
                        epoch, step, best_val_acc, args.model.replace('-', '_').replace('/', '_'))

                    path = os.path.join(log_dir, filename)
                    if args.data_parallel:
                        model_state_dict = model.module.state_dict()
                    else:
                        model_state_dict = model.state_dict()
                    state_dict = {'model_state_dict': model_state_dict,
                                  'curr_step': curr_step, 'loss': loss.item(),
                                  'epoch': epoch, 'val_accuracy': best_val_acc}

                    torch.save(state_dict, path)

                    log_msg += "\n** Best Performing Model: {:.2f} ** \nSaving weights at {}\n".format(best_val_acc,
                                                                                                       path)

                log_msg += '-------------------------------------------------------------------------\n\n'
                print_log(log_msg, log_file)

                # Reset the mode to training
                model.train()

        writer.close()
        log_file.close()

    elif args.mode == 'test':

        # Dataloader
        dataset = BaseDataset(args.test_file, tokenizer=args.model, max_seq_len=args.seq_len, text2text=text2text,
                              uniqa=uniqa)

        loader = DataLoader(dataset, batch_size, num_workers=args.num_workers)

        tokenizer = dataset.get_tokenizer()

        model = Transformer(args.model, args.num_cls, text2text, num_layers=args.num_layers)
        model.eval()
        model.to(device)

        # Load model weights
        if args.ckpt:
            checkpoint = torch.load(args.ckpt, map_location=device)
            model.load_state_dict(checkpoint['model_state_dict'])
        data_len = dataset.__len__()
        print('Total Samples: {}'.format(data_len))

        is_pairwise = 'com2sense' in dataset_names

        # Inference
        metrics = compute_eval_metrics(model, loader, device, data_len, tokenizer, text2text, is_pairwise=is_pairwise,
                                       is_test=True, parallel=args.data_parallel)

        df = pd.DataFrame(metrics['meta'])
        df.to_csv(args.pred_file)

        print(f'Results for model {args.model}')
        print(f'Results evaluated on file {args.test_file}')
        print('Sentence Accuracy: {:.4f}'.format(metrics['accuracy']))
        if is_pairwise:
            print('Pairwise Accuracy: {:.4f}'.format(metrics['pair_acc']))
Beispiel #19
0
def train(hp):
    save_hparams(hp, hp.checkpoints_dir)
    # Data generator
    logging.info("Prepare Train/Eval batches...")
    train_batches, num_train_batches, num_train_samples = get_batch(
        hp.train1,
        hp.train2,
        hp.maxlen1,
        hp.maxlen2,
        hp.vocab,
        hp.batch_size,
        shuffle=True)
    eval_batches, num_eval_batches, num_eval_samples = get_batch(hp.eval1,
                                                                 hp.eval2,
                                                                 10000,
                                                                 10000,
                                                                 hp.vocab,
                                                                 hp.batch_size,
                                                                 shuffle=False)

    # Batch iterator
    iter = tf.data.Iterator.from_structure(train_batches.output_types,
                                           train_batches.output_shapes)
    xs, ys = iter.get_next()

    train_init_op = iter.make_initializer(train_batches)
    eval_init_op = iter.make_initializer(eval_batches)

    # Build model
    logging.info("Build model...")
    model = Transformer(hp)
    logging.info("Model is built!")

    # Session
    logging.info("Session initialize")
    saver = tf.train.Saver(max_to_keep=5)

    with tf.Session() as sess:
        # Check & Load latest version model checkpoint
        ckpt = tf.train.latest_checkpoint(hp.checkpoints_dir)
        if ckpt is None:
            logging.info("Initializing from scratch")
            sess.run(tf.global_variables_initializer())
            save_variable_specs(os.path.join(hp.checkpoints_dir, "specs"))
        else:
            saver.restore(sess, ckpt)

        summary_writer = tf.summary.FileWriter(hp.checkpoints_dir, sess.graph)

        sess.run(train_init_op)
        total_steps = hp.num_epochs * num_train_batches
        _gs = sess.run(model.global_step)

        k = 5
        min_dev_loss = 0
        stop_alpha = 20.0
        eval_losses = []
        # Start training
        for i in tqdm(range(_gs, total_steps + 1)):
            _input_x, _decoder_input, _target = sess.run([xs[0], ys[0], ys[1]])
            _, _gs, _summary = sess.run(
                [model.train_op, model.global_step, model.summaries],
                feed_dict={
                    model.input_x: _input_x,
                    model.decoder_input: _decoder_input,
                    model.target: _target,
                    model.is_training: True
                })
            epoch = math.ceil(_gs / num_train_batches)
            summary_writer.add_summary(_summary, _gs)

            # Evaluation
            if _gs and _gs % num_train_batches == 0:
                logging.info("Epoch {} is done".format(epoch))
                _loss = sess.run(model.loss,
                                 feed_dict={
                                     model.input_x: _input_x,
                                     model.decoder_input: _decoder_input,
                                     model.target: _target,
                                     model.is_training: False
                                 })

                # evaluation
                y_hat, mean_loss = model.eval(sess, eval_init_op, xs, ys,
                                              num_eval_batches)

                # id to token
                logging.info("# Get hypotheses")
                hypotheses = get_hypotheses(num_eval_samples, y_hat,
                                            model.idx2token)

                # save translation results
                if not os.path.exists(hp.evaldir):
                    os.makedirs(hp.evaldir)
                logging.info("# Write results")
                model_output = "translation_E{:02d}L{:.2f}EL{:.2f}".format(
                    epoch, _loss, mean_loss)
                translation = os.path.join(hp.evaldir, model_output)
                with open(translation, 'w', encoding="utf-8") as fout:
                    fout.write("\n".join(hypotheses))
                logging.info(
                    "# Calculate bleu score and append it to translation")

                # bleu
                calc_bleu_nltk(hp.eval2, translation)

                # save model
                logging.info("# Save models")
                ckpt_name = os.path.join(hp.checkpoints_dir, model_output)
                saver.save(sess, ckpt_name, global_step=_gs)
                logging.info(
                    "After training of {} epochs, {} has been saved.".format(
                        epoch, ckpt_name))

                # claculate early stop
                if len(eval_losses) == 0:
                    min_dev_loss = mean_loss
                eval_losses.append(mean_loss)
                gl, p_k, pq_alpha = calculate_earlystop_baseline(
                    mean_loss, min_dev_loss, eval_losses, k)
                min_dev_loss = mean_loss if mean_loss < min_dev_loss else min_dev_loss
                eval_losses = eval_losses[-k:]
                logging.info(
                    "GL(t): {:.4f}, P_k: {:.4f}, PQ_alpha: {:.4f}".format(
                        gl, p_k, pq_alpha))
                if gl > stop_alpha:
                    logging.info(
                        "No optimization for a long time, auto-stopping...")
                    break

                # change data iterator back to train iterator
                sess.run(train_init_op)

        summary_writer.close()

    logging.info("Done")
Beispiel #20
0
def main(args, hparams):

    # prepare data
    testset = TextMelLoader(hparams.test_files, hparams, shuffle=False)
    collate_fn = TextMelCollate(hparams.n_frames_per_step)
    test_loader = DataLoader(
        testset,
        num_workers=1,
        shuffle=False,
        batch_size=1,
        pin_memory=False,
        collate_fn=collate_fn,
    )

    # prepare model
    model = Transformer(hparams).cuda("cuda:0")
    checkpoint_restore = load_avg_checkpoint(args.checkpoint_path)
    model.load_state_dict(checkpoint_restore)
    model.eval()
    print("# total parameters:", sum(p.numel() for p in model.parameters()))

    # infer
    duration_add = 0
    with torch.no_grad():
        for i, batch in tqdm(enumerate(test_loader)):
            x, y = parse_batch(batch)

            # the start time
            start = time.perf_counter()
            (
                mel_output,
                mel_output_postnet,
                _,
                enc_attn_list,
                dec_attn_list,
                dec_enc_attn_list,
            ) = model.inference(x)

            # the end time
            duration = time.perf_counter() - start
            duration_add += duration

            # denormalize the feats and save the mels and attention plots
            mel_predict = mel_output_postnet[0]
            mel_denorm = denormalize_feats(mel_predict, hparams.dump)
            mel_path = os.path.join(args.output_infer,
                                    "{:0>3d}".format(i) + ".pt")
            torch.save(mel_denorm, mel_path)

            plot_data(
                (
                    mel_output.detach().cpu().numpy()[0],
                    mel_output_postnet.detach().cpu().numpy()[0],
                    mel_denorm.numpy(),
                ),
                i,
                args.output_infer,
            )

            plot_attn(
                enc_attn_list,
                dec_attn_list,
                dec_enc_attn_list,
                i,
                args.output_infer,
            )

        duration_avg = duration_add / (i + 1)
        print("The average inference time is: %f" % duration_avg)
Beispiel #21
0
def test(args):
    """ Decode with beam search """

    # load vocabulary
    vocab = torch.load(args.vocab)

    # build model
    translator = Transformer(args, vocab)
    translator.eval()

    # load parameters
    translator.load_state_dict(torch.load(args.decode_model_path))
    if args.cuda:
        translator = translator.cuda()

    test_data = read_corpus(args.decode_from_file, source="src")
    output_file = codecs.open(args.decode_output_file, "w", encoding="utf-8")
    for test in test_data:
        test_seq, test_pos = to_input_variable([test],
                                               vocab.src,
                                               cuda=args.cuda)
        test_seq_beam = test_seq.expand(args.decode_beam_size,
                                        test_seq.size(1))

        enc_output = translator.encode(test_seq, test_pos)
        enc_output_beam = enc_output.expand(args.decode_beam_size,
                                            enc_output.size(1),
                                            enc_output.size(2))

        beam = Beam_Search_V2(beam_size=args.decode_beam_size,
                              tgt_vocab=vocab.tgt,
                              length_alpha=args.decode_alpha)
        for i in range(args.decode_max_steps):

            # the first time for beam search
            if i == 0:
                # <BOS>
                pred_var = to_input_variable(beam.candidates[:1],
                                             vocab.tgt,
                                             cuda=args.cuda)
                scores = translator.translate(enc_output, test_seq, pred_var)
            else:
                pred_var = to_input_variable(beam.candidates,
                                             vocab.tgt,
                                             cuda=args.cuda)
                scores = translator.translate(enc_output_beam, test_seq_beam,
                                              pred_var)

            log_softmax_scores = F.log_softmax(scores, dim=-1)
            log_softmax_scores = log_softmax_scores.view(
                pred_var[0].size(0), -1, log_softmax_scores.size(-1))
            log_softmax_scores = log_softmax_scores[:, -1, :]

            is_done = beam.advance(log_softmax_scores)
            beam.update_status()

            if is_done:
                break

        print("[Source] %s" % " ".join(test))
        print("[Predict] %s" % beam.get_best_candidate())
        print()

        output_file.write(beam.get_best_candidate() + "\n")
        output_file.flush()

    output_file.close()
Beispiel #22
0
def main(args):
    src, tgt = load_data(args.path)

    src_vocab = Vocab(init_token='<sos>',
                      eos_token='<eos>',
                      pad_token='<pad>',
                      unk_token='<unk>')
    src_vocab.load(os.path.join(args.path, 'vocab.en'))
    tgt_vocab = Vocab(init_token='<sos>',
                      eos_token='<eos>',
                      pad_token='<pad>',
                      unk_token='<unk>')
    tgt_vocab.load(os.path.join(args.path, 'vocab.de'))

    vsize_src = len(src_vocab)
    vsize_tar = len(tgt_vocab)
    net = Transformer(vsize_src, vsize_tar)

    if not args.test:

        train_loader = get_loader(src['train'],
                                  tgt['train'],
                                  src_vocab,
                                  tgt_vocab,
                                  batch_size=args.batch_size,
                                  shuffle=True)
        valid_loader = get_loader(src['valid'],
                                  tgt['valid'],
                                  src_vocab,
                                  tgt_vocab,
                                  batch_size=args.batch_size)

        net.to(device)
        optimizer = optim.Adam(net.parameters(), lr=args.lr)

        best_valid_loss = 10.0
        for epoch in range(args.epochs):
            print("Epoch {0}".format(epoch))
            net.train()
            train_loss = run_epoch(net, train_loader, optimizer)
            print("train loss: {0}".format(train_loss))
            net.eval()
            valid_loss = run_epoch(net, valid_loader, None)
            print("valid loss: {0}".format(valid_loss))
            torch.save(net, 'data/ckpt/last_model')
            if valid_loss < best_valid_loss:
                best_valid_loss = valid_loss
                torch.save(net, 'data/ckpt/best_model')
    else:
        # test
        net = torch.load('data/ckpt/best_model')
        net.to(device)
        net.eval()

        test_loader = get_loader(src['test'],
                                 tgt['test'],
                                 src_vocab,
                                 tgt_vocab,
                                 batch_size=args.batch_size)

        pred = []
        iter_cnt = 0
        for src_batch, tgt_batch in test_loader:
            source, src_mask = make_tensor(src_batch)
            source = source.to(device)
            src_mask = src_mask.to(device)
            res = net.decode(source, src_mask)
            pred_batch = res.tolist()
            # every sentences in pred_batch should start with <sos> token (index: 0) and end with <eos> token (index: 1).
            # every <pad> token (index: 2) should be located after <eos> token (index: 1).
            # example of pred_batch:
            # [[0, 5, 6, 7, 1],
            #  [0, 4, 9, 1, 2],
            #  [0, 6, 1, 2, 2]]
            pred += seq2sen(pred_batch, tgt_vocab)
            iter_cnt += 1
            #print(pred_batch)

        with open('data/results/pred.txt', 'w') as f:
            for line in pred:
                f.write('{}\n'.format(line))

        os.system(
            'bash scripts/bleu.sh data/results/pred.txt data/multi30k/test.de.atok'
        )
                                                             hp.vocab,
                                                             hp.batch_size,
                                                             shuffle=False)

# create a iterator of the correct shape and type
iter = tf.data.Iterator.from_structure(train_batches.output_types, train_batches.output_shapes)
xs, ys = iter.get_next()

logging.info('# init data')
train_init_op = iter.make_initializer(train_batches)
eval_init_op = iter.make_initializer(eval_batches)

logging.info("# Load model")
m = Transformer(hp)
loss, train_op, global_step, train_summaries = m.train_multi_gpu(xs, ys)
y_hat, eval_summaries, sent2, pred = m.eval(xs, ys)

logging.info("# Session")
saver = tf.train.Saver(max_to_keep=hp.num_epochs)
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
    ckpt = tf.train.latest_checkpoint(hp.logdir)
    if ckpt is None:
        logging.info("Initializing from scratch")
        sess.run(tf.global_variables_initializer())
        save_variable_specs(os.path.join(hp.logdir, "specs"))
    else:
        saver.restore(sess, ckpt)

    summary_writer = tf.summary.FileWriter(hp.logdir, sess.graph)

    sess.run(train_init_op)
Beispiel #24
0
eval_batches, num_eval_batches, num_eval_samples = get_batch(hp.eval1,
                                                             hp.eval2,
                                                             100000,
                                                             100000,
                                                             hp.vocab,
                                                             hp.batch_size,
                                                             shuffle=False)
iter = tf.data.Iterator.from_structure(eval_batches.output_types,
                                       eval_batches.output_shapes)
xs, ys = iter.get_next()
decoder_inputs, y, y_seqlen, sents2 = ys
eval_init_op = iter.make_initializer(eval_batches)

logging.info("# Load model")
m = Transformer(hp)
y_mask = m.y_masks(y)
y_hat, eval_summaries = m.eval(xs, ys, y_mask)
saver = tf.train.Saver()

with tf.Session() as sess:
    ckpt = tf.train.latest_checkpoint(hp.logdir)
    saver.restore(sess, ckpt)
    summary_writer = tf.summary.FileWriter(hp.logdir, sess.graph)
    sess.run(eval_init_op)
    _y_hat, _y = sess.run([y_hat, y])
    print(_y_hat)
    print(_y)
    print(acc(_y_hat, _y))
    #hypotheses = get_hypotheses(1, 128, sess, y_hat, m.idx2token)

#print(hypotheses)
Beispiel #25
0
def main(args):

    #create a writer
    writer = SummaryWriter('loss_plot_' + args.mode, comment='test')
    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # Image preprocessing, normalization for the pretrained resnet
    transform = T.Compose([
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    # Load vocabulary wrapper
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    val_length = len(os.listdir(args.image_dir_val))

    # Build data loader
    data_loader = get_loader(args.image_dir,
                             args.caption_path,
                             vocab,
                             transform,
                             args.batch_size,
                             shuffle=True,
                             num_workers=args.num_workers)

    data_loader_val = get_loader(args.image_dir_val,
                                 args.caption_path_val,
                                 vocab,
                                 transform,
                                 args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers)

    # Build the model
    # if no-attention model is chosen:
    if args.model_type == 'no_attention':
        encoder = Encoder(args.embed_size).to(device)
        decoder = Decoder(args.embed_size, args.hidden_size, len(vocab),
                          args.num_layers).to(device)
        criterion = nn.CrossEntropyLoss()

    # if attention model is chosen:
    elif args.model_type == 'attention':
        encoder = EncoderAtt(encoded_image_size=9).to(device)
        decoder = DecoderAtt(vocab, args.encoder_dim, args.hidden_size,
                             args.attention_dim, args.embed_size,
                             args.dropout_ratio, args.alpha_c).to(device)

    # if transformer model is chosen:
    elif args.model_type == 'transformer':
        model = Transformer(len(vocab), args.embed_size,
                            args.transformer_layers, 8,
                            args.dropout_ratio).to(device)

        encoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, model.encoder.parameters()),
                                             lr=args.learning_rate_enc)
        decoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, model.decoder.parameters()),
                                             lr=args.learning_rate_dec)
        criterion = nn.CrossEntropyLoss(ignore_index=vocab.word2idx['<pad>'])

    else:
        print('Select model_type attention or no_attention')

    # if model is not transformer: additional step in encoder is needed: freeze lower layers of resnet if args.fine_tune == True
    if args.model_type != 'transformer':
        decoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder.parameters()),
                                             lr=args.learning_rate_dec)
        encoder.fine_tune(args.fine_tune)
        encoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, encoder.parameters()),
                                             lr=args.learning_rate_enc)

    # initialize lists to store results:
    loss_train = []
    loss_val = []
    loss_val_epoch = []
    loss_train_epoch = []

    bleu_res_list = []
    cider_res_list = []
    rouge_res_list = []

    results = {}

    # calculate total steps fot train and validation
    total_step = len(data_loader)
    total_step_val = len(data_loader_val)

    #For each epoch
    for epoch in tqdm(range(args.num_epochs)):

        loss_val_iter = []
        loss_train_iter = []

        # set model to train mode
        if args.model_type != 'transformer':
            encoder.train()
            decoder.train()
        else:
            model.train()

        # for each entry in data_loader
        for i, (images, captions, lengths) in tqdm(enumerate(data_loader)):
            # load images and captions to device
            images = images.to(device)
            captions = captions.to(device)
            # Forward, backward and optimize

            # forward and backward path is different dependent of model type:
            if args.model_type == 'no_attention':
                # get features from encoder
                features = encoder(images)
                # pad targergets to a length
                targets = pack_padded_sequence(captions,
                                               lengths,
                                               batch_first=True)[0]
                # get output from decoder
                outputs = decoder(features, captions, lengths)
                # calculate loss
                loss = criterion(outputs, targets)

                # optimizer and backward step
                decoder_optimizer.zero_grad()
                decoder_optimizer.zero_grad()
                loss.backward()
                decoder_optimizer.step()
                encoder_optimizer.step()

            elif args.model_type == 'attention':

                # get features from encoder
                features = encoder(images)

                # get targets - starting from 2 word in captions
                #(the model not sequantial, so targets are predicted in parallel- no need to predict first word in captions)

                targets = captions[:, 1:]
                # decode length = length-1 for each caption
                decode_lengths = [length - 1 for length in lengths]
                #flatten targets
                targets = targets.reshape(targets.shape[0] * targets.shape[1])

                sampled_caption = []

                # get scores and alphas from decoder
                scores, alphas = decoder(features, captions, decode_lengths)

                scores = scores.view(-1, scores.shape[-1])

                #predicted = prediction with maximum score
                _, predicted = torch.max(scores, dim=1)

                # calculate loss
                loss = decoder.loss(scores, targets, alphas)

                # optimizer and backward step
                decoder_optimizer.zero_grad()
                decoder_optimizer.zero_grad()
                loss.backward()
                decoder_optimizer.step()
                encoder_optimizer.step()

            elif args.model_type == 'transformer':

                # input is captions without last word
                trg_input = captions[:, :-1]
                # create mask
                trg_mask = create_masks(trg_input)

                # get scores from model
                scores = model(images, trg_input, trg_mask)
                scores = scores.view(-1, scores.shape[-1])

                # get targets - starting from 2 word in captions
                targets = captions[:, 1:]

                #predicted = prediction with maximum score
                _, predicted = torch.max(scores, dim=1)

                # calculate loss
                loss = criterion(
                    scores,
                    targets.reshape(targets.shape[0] * targets.shape[1]))

                #forward and backward path
                decoder_optimizer.zero_grad()
                decoder_optimizer.zero_grad()
                loss.backward()
                decoder_optimizer.step()
                encoder_optimizer.step()

            else:
                print('Select model_type attention or no_attention')

            # append results to loss lists and writer
            loss_train_iter.append(loss.item())
            loss_train.append(loss.item())
            writer.add_scalar('Loss/train/iterations', loss.item(), i + 1)

            # Print log info
            if i % args.log_step == 0:
                print(
                    'Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
                    .format(epoch, args.num_epochs, i, total_step, loss.item(),
                            np.exp(loss.item())))

        print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'.
              format(epoch, args.num_epochs, i, total_step, loss.item(),
                     np.exp(loss.item())))

        #append mean of last 10 batches as approximate epoch loss
        loss_train_epoch.append(np.mean(loss_train_iter[-10:]))

        writer.add_scalar('Loss/train/epoch', np.mean(loss_train_iter[-10:]),
                          epoch + 1)

        #save model
        if args.model_type != 'transformer':
            torch.save(
                decoder.state_dict(),
                os.path.join(
                    args.model_path,
                    'decoder_' + args.mode + '_{}.ckpt'.format(epoch + 1)))
            torch.save(
                encoder.state_dict(),
                os.path.join(
                    args.model_path,
                    'decoder_' + args.mode + '_{}.ckpt'.format(epoch + 1)))

        else:
            torch.save(
                model.state_dict(),
                os.path.join(
                    args.model_path,
                    'model_' + args.mode + '_{}.ckpt'.format(epoch + 1)))
        np.save(
            os.path.join(args.predict_json,
                         'loss_train_temp_' + args.mode + '.npy'), loss_train)

        #validate model:
        # set model to eval mode:
        if args.model_type != 'transformer':
            encoder.eval()
            decoder.eval()
        else:
            model.eval()
        total_step = len(data_loader_val)

        # set no_grad mode:
        with torch.no_grad():
            # for each entry in data_loader
            for i, (images, captions,
                    lengths) in tqdm(enumerate(data_loader_val)):
                targets = pack_padded_sequence(captions,
                                               lengths,
                                               batch_first=True)[0]
                images = images.to(device)
                captions = captions.to(device)

                # forward and backward path is different dependent of model type:
                if args.model_type == 'no_attention':
                    features = encoder(images)
                    outputs = decoder(features, captions, lengths)
                    loss = criterion(outputs, targets)

                elif args.model_type == 'attention':

                    features = encoder(images)
                    sampled_caption = []
                    targets = captions[:, 1:]
                    decode_lengths = [length - 1 for length in lengths]
                    targets = targets.reshape(targets.shape[0] *
                                              targets.shape[1])

                    scores, alphas = decoder(features, captions,
                                             decode_lengths)

                    _, predicted = torch.max(scores, dim=1)

                    scores = scores.view(-1, scores.shape[-1])

                    sampled_caption = []

                    loss = decoder.loss(scores, targets, alphas)

                elif args.model_type == 'transformer':

                    trg_input = captions[:, :-1]
                    trg_mask = create_masks(trg_input)
                    scores = model(images, trg_input, trg_mask)
                    scores = scores.view(-1, scores.shape[-1])
                    targets = captions[:, 1:]

                    _, predicted = torch.max(scores, dim=1)

                    loss = criterion(
                        scores,
                        targets.reshape(targets.shape[0] * targets.shape[1]))

                #display results
                if i % args.log_step == 0:
                    print(
                        'Epoch [{}/{}], Step [{}/{}], Validation Loss: {:.4f}, Validation Perplexity: {:5.4f}'
                        .format(epoch, args.num_epochs, i, total_step_val,
                                loss.item(), np.exp(loss.item())))

                # append results to loss lists and writer
                loss_val.append(loss.item())
                loss_val_iter.append(loss.item())

                writer.add_scalar('Loss/validation/iterations', loss.item(),
                                  i + 1)

        np.save(
            os.path.join(args.predict_json, 'loss_val_' + args.mode + '.npy'),
            loss_val)

        print(
            'Epoch [{}/{}], Step [{}/{}], Validation Loss: {:.4f}, Validation Perplexity: {:5.4f}'
            .format(epoch, args.num_epochs, i, total_step_val, loss.item(),
                    np.exp(loss.item())))

        # results: epoch validation loss

        loss_val_epoch.append(np.mean(loss_val_iter))
        writer.add_scalar('Loss/validation/epoch', np.mean(loss_val_epoch),
                          epoch + 1)

        #predict captions:
        filenames = os.listdir(args.image_dir_val)

        predicted = {}

        for file in tqdm(filenames):
            if file == '.DS_Store':
                continue
            # Prepare an image
            image = load_image(os.path.join(args.image_dir_val, file),
                               transform)
            image_tensor = image.to(device)

            # Generate caption starting with <start> word

            # procedure is different for each model type
            if args.model_type == 'attention':

                features = encoder(image_tensor)
                sampled_ids, _ = decoder.sample(features)
                sampled_ids = sampled_ids[0].cpu().numpy()
                #start sampled_caption with <start>
                sampled_caption = ['<start>']

            elif args.model_type == 'no_attention':
                features = encoder(image_tensor)
                sampled_ids = decoder.sample(features)
                sampled_ids = sampled_ids[0].cpu().numpy()
                sampled_caption = ['<start>']

            elif args.model_type == 'transformer':

                e_outputs = model.encoder(image_tensor)
                max_seq_length = 20
                sampled_ids = torch.zeros(max_seq_length, dtype=torch.long)
                sampled_ids[0] = torch.LongTensor([[vocab.word2idx['<start>']]
                                                   ]).to(device)

                for i in range(1, max_seq_length):

                    trg_mask = np.triu(np.ones((1, i, i)), k=1).astype('uint8')
                    trg_mask = Variable(
                        torch.from_numpy(trg_mask) == 0).to(device)

                    out = model.decoder(sampled_ids[:i].unsqueeze(0),
                                        e_outputs, trg_mask)

                    out = model.out(out)
                    out = F.softmax(out, dim=-1)
                    val, ix = out[:, -1].data.topk(1)
                    sampled_ids[i] = ix[0][0]

                sampled_ids = sampled_ids.cpu().numpy()
                sampled_caption = []

            # Convert word_ids to words
            for word_id in sampled_ids:
                word = vocab.idx2word[word_id]
                sampled_caption.append(word)
                # break at <end> of the sentence
                if word == '<end>':
                    break
            sentence = ' '.join(sampled_caption)

            predicted[file] = sentence

        # save predictions to json file:
        json.dump(
            predicted,
            open(
                os.path.join(
                    args.predict_json,
                    'predicted_' + args.mode + '_' + str(epoch) + '.json'),
                'w'))

        #validate model
        with open(args.caption_path_val, 'r') as file:
            captions = json.load(file)

        res = {}
        for r in predicted:
            res[r] = [predicted[r].strip('<start> ').strip(' <end>')]

        images = captions['images']
        caps = captions['annotations']
        gts = {}
        for image in images:
            image_id = image['id']
            file_name = image['file_name']
            list_cap = []
            for cap in caps:
                if cap['image_id'] == image_id:
                    list_cap.append(cap['caption'])
            gts[file_name] = list_cap

        #calculate BLUE, CIDER and ROUGE metrics from real and resulting captions
        bleu_res = bleu(gts, res)
        cider_res = cider(gts, res)
        rouge_res = rouge(gts, res)

        # append resuls to result lists
        bleu_res_list.append(bleu_res)
        cider_res_list.append(cider_res)
        rouge_res_list.append(rouge_res)

        # write results to writer
        writer.add_scalar('BLEU1/validation/epoch', bleu_res[0], epoch + 1)
        writer.add_scalar('BLEU2/validation/epoch', bleu_res[1], epoch + 1)
        writer.add_scalar('BLEU3/validation/epoch', bleu_res[2], epoch + 1)
        writer.add_scalar('BLEU4/validation/epoch', bleu_res[3], epoch + 1)
        writer.add_scalar('CIDEr/validation/epoch', cider_res, epoch + 1)
        writer.add_scalar('ROUGE/validation/epoch', rouge_res, epoch + 1)

    results['bleu'] = bleu_res_list
    results['cider'] = cider_res_list
    results['rouge'] = rouge_res_list

    json.dump(
        results,
        open(os.path.join(args.predict_json, 'results_' + args.mode + '.json'),
             'w'))
    np.save(
        os.path.join(args.predict_json, 'loss_train_' + args.mode + '.npy'),
        loss_train)
    np.save(os.path.join(args.predict_json, 'loss_val_' + args.mode + '.npy'),
            loss_val)
Beispiel #26
0
def main(tokenizer, src_tok_file, tgt_tok_file, train_file, val_file,
         test_file, num_epochs, batch_size, d_model, nhead, num_encoder_layers,
         num_decoder_layers, dim_feedforward, dropout, learning_rate,
         data_path, checkpoint_file, do_train):
    logging.info('Using tokenizer: {}'.format(tokenizer))

    src_tokenizer = TokenizerWrapper(tokenizer, BLANK_WORD, SEP_TOKEN,
                                     CLS_TOKEN, PAD_TOKEN, MASK_TOKEN)
    src_tokenizer.train(src_tok_file, 20000, SPECIAL_TOKENS)

    tgt_tokenizer = TokenizerWrapper(tokenizer, BLANK_WORD, SEP_TOKEN,
                                     CLS_TOKEN, PAD_TOKEN, MASK_TOKEN)
    tgt_tokenizer.train(tgt_tok_file, 20000, SPECIAL_TOKENS)

    SRC = ttdata.Field(tokenize=src_tokenizer.tokenize, pad_token=BLANK_WORD)
    TGT = ttdata.Field(tokenize=tgt_tokenizer.tokenize,
                       init_token=BOS_WORD,
                       eos_token=EOS_WORD,
                       pad_token=BLANK_WORD)

    logging.info('Loading training data...')
    train_ds, val_ds, test_ds = ttdata.TabularDataset.splits(
        path=data_path,
        format='tsv',
        train=train_file,
        validation=val_file,
        test=test_file,
        fields=[('src', SRC), ('tgt', TGT)])

    test_src_sentence = val_ds[0].src
    test_tgt_sentence = val_ds[0].tgt

    MIN_FREQ = 2
    SRC.build_vocab(train_ds.src, min_freq=MIN_FREQ)
    TGT.build_vocab(train_ds.tgt, min_freq=MIN_FREQ)

    logging.info(f'''SRC vocab size: {len(SRC.vocab)}''')
    logging.info(f'''TGT vocab size: {len(TGT.vocab)}''')

    train_iter = ttdata.BucketIterator(train_ds,
                                       batch_size=batch_size,
                                       repeat=False,
                                       sort_key=lambda x: len(x.src))
    val_iter = ttdata.BucketIterator(val_ds,
                                     batch_size=1,
                                     repeat=False,
                                     sort_key=lambda x: len(x.src))
    test_iter = ttdata.BucketIterator(test_ds,
                                      batch_size=1,
                                      repeat=False,
                                      sort_key=lambda x: len(x.src))

    source_vocab_length = len(SRC.vocab)
    target_vocab_length = len(TGT.vocab)

    model = Transformer(d_model=d_model,
                        nhead=nhead,
                        num_encoder_layers=num_encoder_layers,
                        num_decoder_layers=num_decoder_layers,
                        dim_feedforward=dim_feedforward,
                        dropout=dropout,
                        source_vocab_length=source_vocab_length,
                        target_vocab_length=target_vocab_length)
    optim = torch.optim.Adam(model.parameters(),
                             lr=learning_rate,
                             betas=(0.9, 0.98),
                             eps=1e-9)
    model = model.cuda()

    if do_train:
        train_losses, valid_losses = train(train_iter, val_iter, model, optim,
                                           num_epochs, batch_size,
                                           test_src_sentence,
                                           test_tgt_sentence, SRC, TGT,
                                           src_tokenizer, tgt_tokenizer,
                                           checkpoint_file)
    else:
        logging.info('Skipped training.')

    # Load best model and score test set
    logging.info('Loading best model.')
    model.load_state_dict(torch.load(checkpoint_file))
    model.eval()
    logging.info('Scoring the test set...')
    score_start = time.time()
    test_bleu, test_chrf = score(test_iter, model, tgt_tokenizer, SRC, TGT)
    score_time = time.time() - score_start
    logging.info(f'''Scoring complete in {score_time/60:.3f} minutes.''')
    logging.info(f'''BLEU : {test_bleu}''')
    logging.info(f'''CHRF : {test_chrf}''')
Beispiel #27
0
    hp.test_source,
    hp.test_target,
    100000,
    100000,
    hp.vocab,
    hp.test_batch_size,
    shuffle=False)
iter = tf.data.Iterator.from_structure(test_batches.output_types,
                                       test_batches.output_shapes)
xs, ys = iter.get_next()

test_init_op = iter.make_initializer(test_batches)

logging.info("# Load model")
m = Transformer(hp)
y_hat, _, refs = m.eval(xs, ys)

logging.info("# Session")
with tf.Session() as sess:
    ckpt = tf.train.latest_checkpoint(hp.modeldir)
    saver = tf.train.Saver()

    saver.restore(sess, ckpt)

    sess.run(test_init_op)

    logging.info("# get hypotheses")
    hypotheses, refs_result = get_hypotheses(num_test_batches,
                                             num_test_samples, sess, y_hat,
                                             refs, m.idx2token)
Beispiel #28
0
def do_predict(args):
    if args.use_cuda:
        place = fluid.CUDAPlace(0)
    else:
        place = fluid.CPUPlace()

    # define the data generator
    processor = reader.DataProcessor(fpattern=args.predict_file,
                                     src_vocab_fpath=args.src_vocab_fpath,
                                     trg_vocab_fpath=args.trg_vocab_fpath,
                                     token_delimiter=args.token_delimiter,
                                     use_token_batch=False,
                                     batch_size=args.batch_size,
                                     device_count=1,
                                     pool_size=args.pool_size,
                                     sort_type=reader.SortType.NONE,
                                     shuffle=False,
                                     shuffle_batch=False,
                                     start_mark=args.special_token[0],
                                     end_mark=args.special_token[1],
                                     unk_mark=args.special_token[2],
                                     max_length=args.max_length,
                                     n_head=args.n_head)
    batch_generator = processor.data_generator(phase="predict", place=place)
    args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
        args.unk_idx = processor.get_vocab_summary()
    trg_idx2word = reader.DataProcessor.load_dict(
        dict_path=args.trg_vocab_fpath, reverse=True)

    args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
        args.unk_idx = processor.get_vocab_summary()

    with fluid.dygraph.guard(place):
        # define data loader
        test_loader = fluid.io.DataLoader.from_generator(capacity=10)
        test_loader.set_batch_generator(batch_generator, places=place)

        # define model
        transformer = Transformer(
            args.src_vocab_size, args.trg_vocab_size, args.max_length + 1,
            args.n_layer, args.n_head, args.d_key, args.d_value, args.d_model,
            args.d_inner_hid, args.prepostprocess_dropout,
            args.attention_dropout, args.relu_dropout, args.preprocess_cmd,
            args.postprocess_cmd, args.weight_sharing, args.bos_idx,
            args.eos_idx)

        # load the trained model
        assert args.init_from_params, (
            "Please set init_from_params to load the infer model.")
        model_dict, _ = fluid.load_dygraph(
            os.path.join(args.init_from_params, "transformer"))
        # to avoid a longer length than training, reset the size of position
        # encoding to max_length
        model_dict["encoder.pos_encoder.weight"] = position_encoding_init(
            args.max_length + 1, args.d_model)
        model_dict["decoder.pos_encoder.weight"] = position_encoding_init(
            args.max_length + 1, args.d_model)
        transformer.load_dict(model_dict)

        # set evaluate mode
        transformer.eval()

        f = open(args.output_file, "wb")
        for input_data in test_loader():
            (src_word, src_pos, src_slf_attn_bias, trg_word,
             trg_src_attn_bias) = input_data
            finished_seq, finished_scores = transformer.beam_search(
                src_word,
                src_pos,
                src_slf_attn_bias,
                trg_word,
                trg_src_attn_bias,
                bos_id=args.bos_idx,
                eos_id=args.eos_idx,
                beam_size=args.beam_size,
                max_len=args.max_out_len)
            finished_seq = finished_seq.numpy()
            finished_scores = finished_scores.numpy()
            for ins in finished_seq:
                for beam_idx, beam in enumerate(ins):
                    if beam_idx >= args.n_best: break
                    id_list = post_process_seq(beam, args.bos_idx,
                                               args.eos_idx)
                    word_list = [trg_idx2word[id] for id in id_list]
                    sequence = b" ".join(word_list) + b"\n"
                    f.write(sequence)
Beispiel #29
0
                                                             hp.vocab,
                                                             hp.batch_size,
                                                             shuffle=False)

# create a iterator of the correct shape and type
iter = tf.data.Iterator.from_structure(train_batches.output_types,
                                       train_batches.output_shapes)
xs, ys = iter.get_next()

train_init_op = iter.make_initializer(train_batches)
eval_init_op = iter.make_initializer(eval_batches)

logging.info("# Load model")
m = Transformer(hp)
loss, train_op, global_step, train_summaries = m.train(xs, ys)
y_hat, eval_summaries = m.eval(xs, ys)
# y_hat = m.infer(xs, ys)

logging.info("# Session")
saver = tf.train.Saver(max_to_keep=hp.num_epochs)
with tf.Session() as sess:
    ckpt = tf.train.latest_checkpoint(hp.logdir)
    if ckpt is None:
        logging.info("Initializing from scratch")
        sess.run(tf.global_variables_initializer())
        save_variable_specs(os.path.join(hp.logdir, "specs"))
    else:
        saver.restore(sess, ckpt)

    summary_writer = tf.summary.FileWriter(hp.logdir, sess.graph)
Beispiel #30
0
def main(args):

    # 0. initial setting

    # set environmet
    cudnn.benchmark = True

    if not os.path.isdir('./ckpt'):
        os.mkdir('./ckpt')
    if not os.path.isdir('./results'):
        os.mkdir('./results')
    if not os.path.isdir(os.path.join('./ckpt', args.name)):
        os.mkdir(os.path.join('./ckpt', args.name))
    if not os.path.isdir(os.path.join('./results', args.name)):
        os.mkdir(os.path.join('./results', args.name))
    if not os.path.isdir(os.path.join('./results', args.name, "log")):
        os.mkdir(os.path.join('./results', args.name, "log"))

    # set logger
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(message)s')
    handler = logging.FileHandler("results/{}/log/{}.log".format(
        args.name, time.strftime('%c', time.localtime(time.time()))))
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    logger.addHandler(logging.StreamHandler())
    args.logger = logger

    # set cuda
    if torch.cuda.is_available():
        args.logger.info("running on cuda")
        args.device = torch.device("cuda")
        args.use_cuda = True
    else:
        args.logger.info("running on cpu")
        args.device = torch.device("cpu")
        args.use_cuda = False

    args.logger.info("[{}] starts".format(args.name))

    # 1. load data

    args.logger.info("loading data...")
    src, tgt = load_data(args.path)

    src_vocab = Vocab(init_token='<sos>',
                      eos_token='<eos>',
                      pad_token='<pad>',
                      unk_token='<unk>')
    src_vocab.load(os.path.join(args.path, 'vocab.en'))
    tgt_vocab = Vocab(init_token='<sos>',
                      eos_token='<eos>',
                      pad_token='<pad>',
                      unk_token='<unk>')
    tgt_vocab.load(os.path.join(args.path, 'vocab.de'))

    # 2. setup

    args.logger.info("setting up...")

    sos_idx = 0
    eos_idx = 1
    pad_idx = 2
    max_length = 50

    src_vocab_size = len(src_vocab)
    tgt_vocab_size = len(tgt_vocab)

    # transformer config
    d_e = 512  # embedding size
    d_q = 64  # query size (= key, value size)
    d_h = 2048  # hidden layer size in feed forward network
    num_heads = 8
    num_layers = 6  # number of encoder/decoder layers in encoder/decoder

    args.sos_idx = sos_idx
    args.eos_idx = eos_idx
    args.pad_idx = pad_idx
    args.max_length = max_length
    args.src_vocab_size = src_vocab_size
    args.tgt_vocab_size = tgt_vocab_size
    args.d_e = d_e
    args.d_q = d_q
    args.d_h = d_h
    args.num_heads = num_heads
    args.num_layers = num_layers

    model = Transformer(args)
    model.to(args.device)
    loss_fn = nn.CrossEntropyLoss(ignore_index=pad_idx)
    optimizer = optim.Adam(model.parameters(), lr=1e-5)

    if args.load:
        model.load_state_dict(load(args, args.ckpt))

    # 3. train / test

    if not args.test:
        # train
        args.logger.info("starting training")
        acc_val_meter = AverageMeter(name="Acc-Val (%)",
                                     save_all=True,
                                     save_dir=os.path.join(
                                         'results', args.name))
        train_loss_meter = AverageMeter(name="Loss",
                                        save_all=True,
                                        save_dir=os.path.join(
                                            'results', args.name))
        train_loader = get_loader(src['train'],
                                  tgt['train'],
                                  src_vocab,
                                  tgt_vocab,
                                  batch_size=args.batch_size,
                                  shuffle=True)
        valid_loader = get_loader(src['valid'],
                                  tgt['valid'],
                                  src_vocab,
                                  tgt_vocab,
                                  batch_size=args.batch_size)

        for epoch in range(1, 1 + args.epochs):
            spent_time = time.time()
            model.train()
            train_loss_tmp_meter = AverageMeter()
            for src_batch, tgt_batch in tqdm(train_loader):
                # src_batch: (batch x source_length), tgt_batch: (batch x target_length)
                optimizer.zero_grad()
                src_batch, tgt_batch = torch.LongTensor(src_batch).to(
                    args.device), torch.LongTensor(tgt_batch).to(args.device)
                batch = src_batch.shape[0]
                # split target batch into input and output
                tgt_batch_i = tgt_batch[:, :-1]
                tgt_batch_o = tgt_batch[:, 1:]

                pred = model(src_batch.to(args.device),
                             tgt_batch_i.to(args.device))
                loss = loss_fn(pred.contiguous().view(-1, tgt_vocab_size),
                               tgt_batch_o.contiguous().view(-1))
                loss.backward()
                optimizer.step()

                train_loss_tmp_meter.update(loss / batch, weight=batch)

            train_loss_meter.update(train_loss_tmp_meter.avg)
            spent_time = time.time() - spent_time
            args.logger.info(
                "[{}] train loss: {:.3f} took {:.1f} seconds".format(
                    epoch, train_loss_tmp_meter.avg, spent_time))

            # validation
            model.eval()
            acc_val_tmp_meter = AverageMeter()
            spent_time = time.time()

            for src_batch, tgt_batch in tqdm(valid_loader):
                src_batch, tgt_batch = torch.LongTensor(
                    src_batch), torch.LongTensor(tgt_batch)
                tgt_batch_i = tgt_batch[:, :-1]
                tgt_batch_o = tgt_batch[:, 1:]

                with torch.no_grad():
                    pred = model(src_batch.to(args.device),
                                 tgt_batch_i.to(args.device))

                corrects, total = val_check(
                    pred.max(dim=-1)[1].cpu(), tgt_batch_o)
                acc_val_tmp_meter.update(100 * corrects / total, total)

            spent_time = time.time() - spent_time
            args.logger.info(
                "[{}] validation accuracy: {:.1f} %, took {} seconds".format(
                    epoch, acc_val_tmp_meter.avg, spent_time))
            acc_val_meter.update(acc_val_tmp_meter.avg)

            if epoch % args.save_period == 0:
                save(args, "epoch_{}".format(epoch), model.state_dict())
                acc_val_meter.save()
                train_loss_meter.save()
    else:
        # test
        args.logger.info("starting test")
        test_loader = get_loader(src['test'],
                                 tgt['test'],
                                 src_vocab,
                                 tgt_vocab,
                                 batch_size=args.batch_size)
        pred_list = []
        model.eval()

        for src_batch, tgt_batch in test_loader:
            #src_batch: (batch x source_length)
            src_batch = torch.Tensor(src_batch).long().to(args.device)
            batch = src_batch.shape[0]
            pred_batch = torch.zeros(batch, 1).long().to(args.device)
            pred_mask = torch.zeros(batch, 1).bool().to(
                args.device)  # mask whether each sentece ended up

            with torch.no_grad():
                for _ in range(args.max_length):
                    pred = model(
                        src_batch,
                        pred_batch)  # (batch x length x tgt_vocab_size)
                    pred[:, :, pad_idx] = -1  # ignore <pad>
                    pred = pred.max(dim=-1)[1][:, -1].unsqueeze(
                        -1)  # next word prediction: (batch x 1)
                    pred = pred.masked_fill(
                        pred_mask,
                        2).long()  # fill out <pad> for ended sentences
                    pred_mask = torch.gt(pred.eq(1) + pred.eq(2), 0)
                    pred_batch = torch.cat([pred_batch, pred], dim=1)
                    if torch.prod(pred_mask) == 1:
                        break

            pred_batch = torch.cat([
                pred_batch,
                torch.ones(batch, 1).long().to(args.device) + pred_mask.long()
            ],
                                   dim=1)  # close all sentences
            pred_list += seq2sen(pred_batch.cpu().numpy().tolist(), tgt_vocab)

        with open('results/pred.txt', 'w', encoding='utf-8') as f:
            for line in pred_list:
                f.write('{}\n'.format(line))

        os.system(
            'bash scripts/bleu.sh results/pred.txt multi30k/test.de.atok')