Esempio n. 1
0
    def __init__(self):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(self.vocab, config.train_data_path,
                               config.batch_size, single_pass=False, mode='train')
        time.sleep(10)

        train_dir = os.path.join(config.log_root, 'train_%d' % (int(time.time())))
        if not os.path.exists(train_dir):
            os.mkdir(train_dir)

        self.model_dir = os.path.join(train_dir, 'models')
        if not os.path.exists(self.model_dir):
            os.mkdir(self.model_dir)

        self.summary_writer = tf.summary.FileWriter(train_dir)
Esempio n. 2
0
    def __init__(self, model_path):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(self.vocab,
                               config.eval_data_path,
                               mode='eval',
                               batch_size=config.batch_size,
                               single_pass=True)
        time.sleep(15)
        model_name = os.path.basename(model_path)

        eval_dir = os.path.join(config.log_root, 'eval_%s' % (model_name))
        if not os.path.exists(eval_dir):
            os.mkdir(eval_dir)
        self.summary_writer = tf.summary.FileWriter(eval_dir)

        self.model = Model(model_path, is_eval=True)
Esempio n. 3
0
    def __init__(self, model_file_path):

        model_name = os.path.basename(model_file_path)
        self._test_dir = os.path.join(config.log_root,
                                      'decode_%s' % (model_name))
        self._rouge_ref_dir = os.path.join(self._test_dir, 'rouge_ref')
        self._rouge_dec_dir = os.path.join(self._test_dir, 'rouge_dec')
        for p in [self._test_dir, self._rouge_ref_dir, self._rouge_dec_dir]:
            if not os.path.exists(p):
                os.mkdir(p)

        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(data_path=config.decode_data_path,
                               vocab=self.vocab,
                               mode='decode',
                               batch_size=config.beam_size,
                               single_pass=True)
        time.sleep(15)

        self.model = Model(model_file_path, is_eval=True)
Esempio n. 4
0
    def batcher(self):
        if not hasattr(self, '_batcher'):
            _batcher = Batcher(self.config)
            setattr(self, '_batcher', _batcher)

        return self._batcher
Esempio n. 5
0
class Evaluate(object):
    def __init__(self, model_path):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(self.vocab,
                               config.eval_data_path,
                               mode='eval',
                               batch_size=config.batch_size,
                               single_pass=True)
        time.sleep(15)
        model_name = os.path.basename(model_path)

        eval_dir = os.path.join(config.log_root, 'eval_%s' % (model_name))
        if not os.path.exists(eval_dir):
            os.mkdir(eval_dir)
        self.summary_writer = tf.summary.FileWriter(eval_dir)

        self.model = Model(model_path, is_eval=True)

    def eval_one_batch(self, batch):
        enc_batch, enc_lens, enc_pos, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, c_t, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_lens, dec_pos, dec_padding_mask, max_dec_len, tgt_batch = \
            get_output_from_batch(batch, use_cuda)

        enc_out, enc_fea, enc_h = self.model.encoder(enc_batch, enc_lens)
        s_t = self.model.reduce_state(enc_h)

        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t = dec_batch[:, di]  # Teacher forcing
            final_dist, s_t, c_t, attn_dist, p_gen, next_coverage = self.model.decoder(
                y_t, s_t, enc_out, enc_fea, enc_padding_mask, c_t, extra_zeros,
                enc_batch_extend_vocab, coverage, di)
            tgt = tgt_batch[:, di]
            gold_probs = torch.gather(final_dist, 1,
                                      tgt.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_step_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_step_losses / dec_lens
        loss = torch.mean(batch_avg_loss)

        return loss.item()

    def run(self):
        start = time.time()
        running_avg_loss, iter = 0, 0
        batch = self.batcher.next_batch()
        print_interval = 100
        while batch is not None:
            loss = self.eval_one_batch(batch)
            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss,
                                                     self.summary_writer, iter)
            iter += 1

            if iter % print_interval == 0:
                self.summary_writer.flush()
                print('step: %d, second: %.2f , loss: %f' %
                      (iter, time.time() - start, running_avg_loss))
                start = time.time()
            batch = self.batcher.next_batch()

        return running_avg_loss
Esempio n. 6
0
class Train(object):
    def __init__(self):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(self.vocab,
                               config.train_data_path,
                               config.batch_size,
                               single_pass=True,
                               mode='train')
        time.sleep(10)

        train_dir = os.path.join(config.log_root,
                                 'train_%d' % (int(time.time())))
        if not os.path.exists(train_dir):
            os.mkdir(train_dir)

        self.model_dir = os.path.join(train_dir, 'models')
        if not os.path.exists(self.model_dir):
            os.mkdir(self.model_dir)

        self.summary_writer = tf.summary.FileWriter(train_dir)

    def save_model(self, running_avg_loss, iter):
        state = {
            'iter': iter,
            'encoders_state_dict': self.model.encoders.state_dict(),
            'encoders_att_state_dict': self.model.encoders_att.state_dict(),
            'decoder_state_dict': self.model.decoder.state_dict(),
            'reduce_state_dict': self.model.reduce_state.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_loss
        }
        model_save_path = os.path.join(
            self.model_dir, 'model_%d_%d' % (iter, int(time.time())))
        torch.save(state, model_save_path)

    def setup_train(self, model_path=None):
        self.model = Model(model_path, is_tran=config.tran)
        initial_lr = config.lr_coverage if config.is_coverage else config.lr

        params = list(self.model.encoders.parameters()) + list(self.model.decoder.parameters()) + \
                 list(self.model.reduce_state.parameters())
        total_params = sum([param[0].nelement() for param in params])
        print('The Number of params of model: %.3f million' %
              (total_params / 1e6))  # million
        self.optimizer = optim.Adagrad(
            params,
            lr=initial_lr,
            initial_accumulator_value=config.adagrad_init_acc)

        start_iter, start_loss = 0, 0

        if model_path is not None:
            state = torch.load(model_path,
                               map_location=lambda storage, location: storage)
            start_iter = state['iter']
            start_loss = state['current_loss']

            if not config.is_coverage:
                self.optimizer.load_state_dict(state['optimizer'])
                if use_cuda:
                    for state in self.optimizer.state.values():
                        for k, v in state.items():
                            if torch.is_tensor(v):
                                state[k] = v.cuda()

        return start_iter, start_loss

    def train_one_batch(self, batch):
        enc_batch, enc_lens, enc_pos, enc_padding_mask, enc_batch_extend_vocab, \
        extra_zeros, c_t, coverage = get_input_from_batch(batch, use_cuda)
        dec_batch, dec_lens, dec_pos, dec_padding_mask, max_dec_len, tgt_batch = \
            get_output_from_batch(batch, use_cuda)

        self.optimizer.zero_grad()
        enc_outs = []
        enc_feas = []
        enc_hs = []
        if not config.tran:
            for i, encoder in enumerate(self.model.encoders):
                enc_out, enc_fea, enc_h = encoder(enc_batch[i], enc_lens[i])
                enc_outs.append(enc_out)
                enc_feas.append(enc_fea)
                enc_hs.append(enc_h)

        # else:
        #     enc_out, enc_fea, enc_h = self.model.encoder(enc_batch, enc_pos)

        s_t = self.model.reduce_state(enc_h)

        step_losses, cove_losses = [], []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t = dec_batch[:, di]  # Teacher forcing

            # modify the original frame for two encoders.
            final_dist_0, s_t_0, c_t_0, attn_dist_0, p_gen, next_coverage = \
                self.model.decoder(y_t, s_t, enc_outs[0], enc_feas[0], enc_padding_mask[0], c_t,
                                   extra_zeros, enc_batch_extend_vocab[0], coverage, di)

            final_dist_1, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = \
                self.model.decoder(y_t, s_t, enc_outs[1], enc_feas[1], enc_padding_mask[1], c_t,
                                   extra_zeros, enc_batch_extend_vocab[1], coverage, di)

            y_t_emb = self.model.decoder.tgt_word_emb(y_t)
            encoders_att = self.model.encoders_att(enc_hs, y_t_emb)
            final_dist = torch.stack((final_dist_0, final_dist_1), dim=1)
            final_dist = torch.bmm(encoders_att, final_dist).squeeze()

            encoders_att_ = encoders_att.transpose(0,
                                                   1).contiguous()  # 1 x b x 2
            h = s_t_0[0] * encoders_att_[:, :, :1] + s_t_1[
                0] * encoders_att_[:, :, 1:]
            c = s_t_0[1] * encoders_att_[:, :, :1] + s_t_1[
                1] * encoders_att_[:, :, 1:]
            s_t = (h, c)

            tgt = tgt_batch[:, di]
            step_mask = dec_padding_mask[:, di]
            gold_probs = torch.gather(final_dist, 1,
                                      tgt.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                cove_losses.append(step_coverage_loss * step_mask)
                coverage = next_coverage

            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses / dec_lens
        loss = torch.mean(batch_avg_loss)

        loss.backward()

        clip_grad_norm_(self.model.encoders.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(),
                        config.max_grad_norm)

        self.optimizer.step()

        if config.is_coverage:
            cove_losses = torch.sum(torch.stack(cove_losses, 1), 1)
            batch_cove_loss = cove_losses / dec_lens
            batch_cove_loss = torch.mean(batch_cove_loss)
            return loss.item(), batch_cove_loss.item()

        return loss.item(), 0.

    def run(self, n_iters, model_path=None):
        iter, running_avg_loss = self.setup_train(model_path)
        start = time.time()
        interval = 100

        while iter < n_iters:
            batch = self.batcher.next_batch()
            loss, cove_loss = self.train_one_batch(batch)

            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss,
                                                     self.summary_writer, iter)
            iter += 1

            if iter % interval == 0:
                self.summary_writer.flush()
                print('step: %d, second: %.2f , loss: %f, cover_loss: %f' %
                      (iter, time.time() - start, loss, cove_loss))
                start = time.time()
            if iter % 5000 == 0:
                self.save_model(running_avg_loss, iter)
Esempio n. 7
0
class BeamSearch(object):
    def __init__(self, model_file_path):

        model_name = os.path.basename(model_file_path)
        self._test_dir = os.path.join(config.log_root,
                                      'decode_%s' % (model_name))
        self._rouge_ref_dir = os.path.join(self._test_dir, 'rouge_ref')
        self._rouge_dec_dir = os.path.join(self._test_dir, 'rouge_dec')
        for p in [self._test_dir, self._rouge_ref_dir, self._rouge_dec_dir]:
            if not os.path.exists(p):
                os.mkdir(p)

        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(data_path=config.decode_data_path,
                               vocab=self.vocab,
                               mode='decode',
                               batch_size=config.beam_size,
                               single_pass=True)
        time.sleep(15)

        self.model = Model(model_file_path, is_eval=True)

    def sort_beams(self, beams):
        return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True)

    def beam_search(self, batch):
        # single example repeated across the batch
        enc_batch, enc_lens, enc_pos, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, c_t, coverage = \
            get_input_from_batch(batch, use_cuda)

        enc_out, enc_fea, enc_h = self.model.encoder(enc_batch, enc_lens)
        s_t = self.model.reduce_state(enc_h)

        dec_h, dec_c = s_t  # b x hidden_dim
        dec_h = dec_h.squeeze()
        dec_c = dec_c.squeeze()

        # decoder batch preparation, it has beam_size example initially everything is repeated
        beams = [
            Beam(tokens=[self.vocab.word2id(config.BOS_TOKEN)],
                 log_probs=[0.0],
                 state=(dec_h[0], dec_c[0]),
                 context=c_t[0],
                 coverage=(coverage[0] if config.is_coverage else None))
            for _ in range(config.beam_size)
        ]

        steps = 0
        results = []
        while steps < config.max_dec_steps and len(results) < config.beam_size:
            latest_tokens = [h.latest_token for h in beams]
            latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(config.UNK_TOKEN) \
                             for t in latest_tokens]
            y_t = Variable(torch.LongTensor(latest_tokens))
            if use_cuda:
                y_t = y_t.cuda()
            all_state_h = [h.state[0] for h in beams]
            all_state_c = [h.state[1] for h in beams]
            all_context = [h.context for h in beams]

            s_t = (torch.stack(all_state_h,
                               0).unsqueeze(0), torch.stack(all_state_c,
                                                            0).unsqueeze(0))
            c_t = torch.stack(all_context, 0)

            coverage_t = None
            if config.is_coverage:
                all_coverage = [h.coverage for h in beams]
                coverage_t = torch.stack(all_coverage, 0)

            final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(
                y_t, s_t, enc_out, enc_fea, enc_padding_mask, c_t, extra_zeros,
                enc_batch_extend_vocab, coverage_t, steps)
            log_probs = torch.log(final_dist)
            topk_log_probs, topk_ids = torch.topk(log_probs,
                                                  config.beam_size * 2)

            dec_h, dec_c = s_t
            dec_h = dec_h.squeeze()
            dec_c = dec_c.squeeze()

            all_beams = []
            # On the first step, we only had one original hypothesis (the initial hypothesis). On subsequent steps, all original hypotheses are distinct.
            num_orig_beams = 1 if steps == 0 else len(beams)
            for i in range(num_orig_beams):
                h = beams[i]
                state_i = (dec_h[i], dec_c[i])
                context_i = c_t[i]
                coverage_i = (coverage[i] if config.is_coverage else None)

                for j in range(config.beam_size *
                               2):  # for each of the top 2*beam_size hyps:
                    new_beam = h.extend(token=topk_ids[i, j].item(),
                                        log_prob=topk_log_probs[i, j].item(),
                                        state=state_i,
                                        context=context_i,
                                        coverage=coverage_i)
                    all_beams.append(new_beam)

            beams = []
            for h in self.sort_beams(all_beams):
                if h.latest_token == self.vocab.word2id(config.EOS_TOKEN):
                    if steps >= config.min_dec_steps:
                        results.append(h)
                else:
                    beams.append(h)
                if len(beams) == config.beam_size or len(
                        results) == config.beam_size:
                    break

            steps += 1

        if len(results) == 0:
            results = beams

        beams_sorted = self.sort_beams(results)

        return beams_sorted[0]

    def run(self):

        counter = 0
        start = time.time()
        batch = self.batcher.next_batch()
        while batch is not None:
            # Run beam search to get best Hypothesis
            best_summary = self.beam_search(batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_summary.tokens[1:]]
            decoded_words = utils.outputids2words(
                output_ids, self.vocab,
                (batch.art_oovs[0] if config.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(dataset.EOS_TOKEN)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words
            # notice: "original_abstract_sents": 'original' means its datetype is bytes-like.
            original_abstract_sents = batch.original_abstracts_sents[0]

            write_for_rouge(original_abstract_sents, decoded_words, counter,
                            self._rouge_ref_dir, self._rouge_dec_dir)
            counter += 1
            if counter % 1000 == 0:
                print('%d example in %d sec' % (counter, time.time() - start))
                start = time.time()

            batch = self.batcher.next_batch()

        print("Decoder has finished reading dataset for single_pass.")
        print("Now starting ROUGE eval...")
        results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_log(results_dict, self._test_dir)
Esempio n. 8
0
class Train(object):
    def __init__(self):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(self.vocab,
                               config.train_data_path,
                               config.batch_size,
                               single_pass=False,
                               mode='train')
        time.sleep(10)

        train_dir = os.path.join(config.log_root,
                                 'train_%d' % (int(time.time())))
        if not os.path.exists(train_dir):
            os.mkdir(train_dir)

        self.model_dir = os.path.join(train_dir, 'models')
        if not os.path.exists(self.model_dir):
            os.mkdir(self.model_dir)

        self.summary_writer = tf.summary.FileWriter(train_dir)

    def save_model(self, running_avg_loss, iter, name='model_1'):
        state = {
            'iter': iter,
            'encoder_state_dict': self.model.encoder.state_dict(),
            'decoder_state_dict': self.model.decoder.state_dict(),
            'reduce_state_dict': self.model.reduce_state.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_loss
        }
        model_save_path = os.path.join(self.model_dir,
                                       name)  # % (iter, int(time.time())))
        torch.save(state, model_save_path)

    def setup_train(self, model_path=None):
        self.model = Model(model_path, is_tran=config.tran)
        initial_lr = config.lr_coverage if config.is_coverage else config.lr

        params = list(self.model.encoder.parameters()) + list(self.model.decoder.parameters()) + \
                 list(self.model.reduce_state.parameters())
        total_params = sum([param[0].nelement() for param in params])
        print('The Number of params of model: %.3f million' %
              (total_params / 1e6))  # million
        self.optimizer = optim.Adagrad(
            params,
            lr=initial_lr,
            initial_accumulator_value=config.adagrad_init_acc)

        start_iter, start_loss = 0, 0

        if model_path is not None:
            state = torch.load(model_path,
                               map_location=lambda storage, location: storage)
            start_iter = state['iter']
            start_loss = state['current_loss']

            if not config.is_coverage:
                self.optimizer.load_state_dict(state['optimizer'])
                if use_cuda:
                    for state in self.optimizer.state.values():
                        for k, v in state.items():
                            if torch.is_tensor(v):
                                state[k] = v.cuda()

        return start_iter, start_loss

    def train_one_batch(self, batch):
        enc_batch, enc_lens, enc_pos, enc_padding_mask, enc_batch_extend_vocab, \
        extra_zeros, c_t, coverage = get_input_from_batch(batch, use_cuda)
        dec_batch, dec_lens, dec_pos, dec_padding_mask, max_dec_len, tgt_batch = \
            get_output_from_batch(batch, use_cuda)

        self.optimizer.zero_grad()

        if not config.tran:
            enc_out, enc_fea, enc_h = self.model.encoder(enc_batch, enc_lens)
        else:
            enc_out, enc_fea, enc_h = self.model.encoder(enc_batch, enc_pos)

        s_t = self.model.reduce_state(enc_h)

        step_losses, cove_losses = [], []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t = dec_batch[:, di]  # Teacher forcing
            final_dist, s_t, c_t, attn_dist, p_gen, next_coverage = \
                self.model.decoder(y_t, s_t, enc_out, enc_fea, enc_padding_mask, c_t,
                                   extra_zeros, enc_batch_extend_vocab, coverage, di)
            tgt = tgt_batch[:, di]
            step_mask = dec_padding_mask[:, di]
            gold_probs = torch.gather(final_dist, 1,
                                      tgt.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                cove_losses.append(step_coverage_loss * step_mask)
                coverage = next_coverage

            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses / dec_lens
        loss = torch.mean(batch_avg_loss)

        loss.backward()

        clip_grad_norm_(self.model.encoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(),
                        config.max_grad_norm)

        self.optimizer.step()

        if config.is_coverage:
            cove_losses = torch.sum(torch.stack(cove_losses, 1), 1)
            batch_cove_loss = cove_losses / dec_lens
            batch_cove_loss = torch.mean(batch_cove_loss)
            return loss.item(), batch_cove_loss.item()

        return loss.item(), 0.

    def run(self, n_iters, model_path=None):
        iter, running_avg_loss = self.setup_train(model_path)
        start = time.time()
        interval = 100
        prev_eval_loss = float("inf")
        while (time.time() - start) / 3600 <= 11.0:  #iter < n_iters:
            batch = self.batcher.next_batch()
            loss, cove_loss = self.train_one_batch(batch)

            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss,
                                                     self.summary_writer, iter)
            iter += 1

            if iter % interval == 0:
                self.summary_writer.flush()
                print('step: %d, second: %.2f , loss: %f, cover_loss: %f' %
                      (iter, time.time() - start, loss, cove_loss))
                start = time.time()
            if iter % 20000 == 0:
                self.save_model(running_avg_loss, iter, 'model_temp')
                eval_loss = Evaluate(os.path.join(self.model_dir,
                                                  'model_temp')).run()
                if eval_loss < prev_eval_loss:
                    print(
                        f"eval loss for iteration: {iter} is {eval_loss}, previous best eval loss = {prev_eval_loss}, saving checkpoint..."
                    )
                    prev_eval_loss = eval_loss
                    self.save_model(running_avg_loss, iter)
                else:
                    print(
                        f"eval loss for iteration: {iter}, previous best eval loss = {prev_eval_loss}, no improvement, skipping..."
                    )
Esempio n. 9
0
class Train(object):
    def __init__(self):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(self.vocab, config.train_data_path,
                               config.batch_size, single_pass=False, mode='train')
        time.sleep(10)

        train_dir = os.path.join(config.log_root, 'train_%d' % (int(time.time())))
        if not os.path.exists(train_dir):
            os.mkdir(train_dir)

        self.model_dir = os.path.join(train_dir, 'models')
        if not os.path.exists(self.model_dir):
            os.mkdir(self.model_dir)

        self.summary_writer = tf.summary.FileWriter(train_dir)

    def save_model(self, running_avg_loss, iter):
        model_state_dict = self.model.state_dict()

        state = {
            'iter': iter,
            'current_loss': running_avg_loss,
            'optimizer': self.optimizer._optimizer.state_dict(),
            "model": model_state_dict
        }
        model_save_path = os.path.join(self.model_dir, 'model_%d_%d' % (iter, int(time.time())))
        torch.save(state, model_save_path)

    def setup_train(self, model_path):

        device = torch.device('cuda' if use_cuda else 'cpu')

        self.model = Model(
            config.vocab_size,
            config.vocab_size,
            config.max_enc_steps,
            config.max_dec_steps,
            d_k=config.d_k,
            d_v=config.d_v,
            d_model=config.d_model,
            d_word_vec=config.emb_dim,
            d_inner=config.d_inner_hid,
            n_layers=config.n_layers,
            n_head=config.n_head,
            dropout=config.dropout).to(device)

        self.optimizer = ScheduledOptim(
            optim.Adam(
                filter(lambda x: x.requires_grad, self.model.parameters()),
                betas=(0.9, 0.98), eps=1e-09),
            config.d_model, config.n_warmup_steps)


        params = list(self.model.encoder.parameters()) + list(self.model.decoder.parameters())
        total_params = sum([param[0].nelement() for param in params])
        print('The Number of params of model: %.3f million' % (total_params / 1e6))  # million

        start_iter, start_loss = 0, 0

        if model_path is not None:
            state = torch.load(model_path, map_location=lambda storage, location: storage)
            start_iter = state['iter']
            start_loss = state['current_loss']

            if not config.is_coverage:
                self.optimizer._optimizer.load_state_dict(state['optimizer'])
                if use_cuda:
                    for state in self.optimizer._optimizer.state.values():
                        for k, v in state.items():
                            if torch.is_tensor(v):
                                state[k] = v.cuda()

        return start_iter, start_loss

    def train_one_batch(self, batch):
        enc_batch, enc_lens, enc_pos, enc_padding_mask, enc_batch_extend_vocab, \
        extra_zeros, c_t, coverage = get_input_from_batch(batch, use_cuda, transformer=True)
        dec_batch, dec_lens, dec_pos, dec_padding_mask, max_dec_len, tgt_batch = \
            get_output_from_batch(batch, use_cuda, transformer=True)

        self.optimizer.zero_grad()

        pred = self.model(enc_batch, enc_pos, dec_batch, dec_pos)
        gold_probs = torch.gather(pred, -1, tgt_batch.unsqueeze(-1)).squeeze()
        batch_loss = -torch.log(gold_probs + config.eps)
        batch_loss = batch_loss * dec_padding_mask

        sum_losses = torch.sum(batch_loss, 1)
        batch_avg_loss = sum_losses / dec_lens
        loss = torch.mean(batch_avg_loss)

        loss.backward()

        # update parameters
        self.optimizer.step_and_update_lr()

        return loss.item(), 0.

    def run(self, n_iters, model_path=None):
        iter, running_avg_loss = self.setup_train(model_path)
        start = time.time()
        interval = 100

        while iter < n_iters:
            batch = self.batcher.next_batch()
            loss, cove_loss = self.train_one_batch(batch)

            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss, self.summary_writer, iter)
            iter += 1

            if iter % interval == 0:
                self.summary_writer.flush()
                print(
                    'step: %d, second: %.2f , loss: %f, cover_loss: %f' % (iter, time.time() - start, loss, cove_loss))
                start = time.time()
            if iter % 5000 == 0:
                self.save_model(running_avg_loss, iter)