Пример #1
0
    def setup_train(self):
        logging.info('Create model')

        self.model = TransformerSummarizer(**self.m_args).to(self.device)
        if config.multi_gpu:
            self.model = torch.nn.DataParallel(self.model, device_ids=[0,1,2,3])
        # self.model = get_cuda(self.model)
        self.m_args['embedding_weights'] = None
        optim_param = self.model.module.learnable_parameters() if isinstance(self.model, torch.nn.DataParallel) else self.model.learnable_parameters()

        self.optim = Optim(config.learning_rate, config.max_grad_norm,
                             lr_decay=config.learning_rate_decay, start_decay_at=config.start_decay_at)
        self.optim.set_parameters(optim_param)
        #self.optim = Adam(optim_param, lr=config.learning_rate, amsgrad=True, betas=[0.9, 0.98], eps=1e-9)
        start_iter = 0

        if self.args.load_model is not None:
            load_model_path = os.path.join(config.save_model_path, self.args.load_model)
            checkpoint = torch.load(load_model_path)
            start_iter = checkpoint["iter"]
            self.model.module.load_state_dict(checkpoint["model_dict"])
            self.optim = checkpoint['trainer_dict']
            logging.info("Loaded model at " + load_model_path)

        self.print_args()

        return start_iter
Пример #2
0
 def setup_valid(self):
     if self.args.cuda:
         checkpoint = torch.load(
             os.path.join(eval_config.save_model_path,
                          self.args.load_model))
     else:
         checkpoint = torch.load(os.path.join(eval_config.save_model_path,
                                              self.args.load_model),
                                 map_location=lambda storage, loc: storage)
     if self.args.fine_tune:
         embeddings = checkpoint["model_dict"][
             "positional_encoding.embedding.weight"]
         self.m_args["embedding_weights"] = embeddings
     #self.m_args['embedding_weights'] = self.embeddings
     self.model = TransformerSummarizer(**self.m_args).to(self.device)
     self.model.load_state_dict(checkpoint["model_dict"])
Пример #3
0
    def setup_train(self):
        logging.info('Create model')

        self.model = TransformerSummarizer(**self.m_args).to(self.device)
        if config.multi_gpu:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=[0, 1, 2, 3])
        # self.model = get_cuda(self.model)
        self.m_args['embedding_weights'] = None
        optim_param = self.model.module.learnable_parameters() if isinstance(
            self.model,
            torch.nn.DataParallel) else self.model.learnable_parameters()
        self.optimizer = Adam(optim_param,
                              lr=config.learning_rate,
                              amsgrad=True,
                              betas=[0.9, 0.98],
                              eps=1e-9)
        start_iter = 0

        if self.args.load_model is not None:
            load_model_path = os.path.join(config.save_model_path,
                                           self.args.load_model)
            checkpoint = torch.load(load_model_path)
            start_iter = checkpoint["iter"]
            self.model.module.load_state_dict(checkpoint["model_dict"])
            self.optimizer.load_state_dict(checkpoint["trainer_dict"])
            logging.info("Loaded model at " + load_model_path)

        if self.args.new_lr is not None:
            self.optimizer = Adam(self.model.module.learnable_parameters(),
                                  lr=self.args.new_lr)
        '''
        model_dict = self.model.module.state_dict()
        for k, v in model_dict.items():
            print(k)
        '''

        return start_iter
Пример #4
0
class Train(object):
    def __init__(self, args):
        logging.info('Loading dataset')
        #self.loader = CNNDailyMail(config.dataset, data_type, ['src', 'trg'], config.bpe_model_filename)
        self.loader = DataLoader(config.dataset, train_type, ['src', 'trg'], config.bpe_model_filename)
        self.eval_loader = DataLoader(config.dataset, eval_type, ['src', 'trg'], config.bpe_model_filename)
        logging.info('Dataset has been loaded.Total size: %s, maxlen: %s', self.loader.data_length, self.loader.max_len)

        self.device = torch.device("cuda" if args.cuda and torch.cuda.is_available() else "cpu")
        if config.multi_gpu:
            self.n_gpu = torch.cuda.device_count()
        else:
            self.n_gpu = 0
            os.environ["CUDA_VISIBLE_DEVICES"] = "3"
        self.writer = SummaryWriter(config.log + config.prefix)

        if args.pretrain_emb:
            self.embeddings = torch.from_numpy(np.load(config.emb_filename)).float()
            logging.info('Use vocabulary and embedding sizes from embedding dump.')
            self.vocab_size, self.emb_size = self.embeddings.shape
        else:
            self.embeddings = None
            self.vocab_size, self.emb_size = config.vocab_size, config.emb_size

        self.args = args
        self.m_args = {'max_seq_len': 638, 'vocab_size': self.vocab_size,
                       'enc_n_layers': config.enc_n_layers, 'dec_n_layers': config.dec_n_layers, 'global_layers': config.global_layers,
                       'batch_size': config.train_bs, 'emb_size': self.emb_size, 'dim_m': config.model_dim,
                       'n_heads': config.n_heads, 'dim_i': config.inner_dim, 'dropout': config.dropout,
                       'embedding_weights': self.embeddings, 'lexical_switch': args.lexical,
                       'emb_share_prj': args.emb_share_prj, 'global_encoding':config.global_encoding,
                       'encoding_gate': config.encoding_gate, 'stack': config.stack}
        #if config.multi_gpu:
        #    self.m_args['batch_size'] *= n_gpu
        self.error_file = open(config.error_filename, "w")

    def setup_train(self):
        logging.info('Create model')

        self.model = TransformerSummarizer(**self.m_args).to(self.device)
        if config.multi_gpu:
            self.model = torch.nn.DataParallel(self.model, device_ids=[0,1,2,3])
        # self.model = get_cuda(self.model)
        self.m_args['embedding_weights'] = None
        optim_param = self.model.module.learnable_parameters() if isinstance(self.model, torch.nn.DataParallel) else self.model.learnable_parameters()

        self.optim = Optim(config.learning_rate, config.max_grad_norm,
                             lr_decay=config.learning_rate_decay, start_decay_at=config.start_decay_at)
        self.optim.set_parameters(optim_param)
        #self.optim = Adam(optim_param, lr=config.learning_rate, amsgrad=True, betas=[0.9, 0.98], eps=1e-9)
        start_iter = 0

        if self.args.load_model is not None:
            load_model_path = os.path.join(config.save_model_path, self.args.load_model)
            checkpoint = torch.load(load_model_path)
            start_iter = checkpoint["iter"]
            self.model.module.load_state_dict(checkpoint["model_dict"])
            self.optim = checkpoint['trainer_dict']
            logging.info("Loaded model at " + load_model_path)

        self.print_args()

        return start_iter

    def print_args(self):
        print("encoder layer num: ", config.enc_n_layers)
        print("decoder layer num: ", config.dec_n_layers)
        print("emb_size: ", config.emb_size)
        print("model_dim: ", config.model_dim)
        print("inner_dim: ", config.inner_dim)
        print("mle_weight: ", config.mle_weight)


    def save_model(self, iter):
        save_path = config.dump_filename + "/%07d.tar" % iter
        torch.save({
            "iter": iter + 1,
            "model_dict": self.model.module.state_dict() if config.multi_gpu else self.model.state_dict(),
            "trainer_dict": self.optim
        }, save_path)
        with open(config.args_filename, 'w') as f:
            f.write(json.dumps(self.m_args))
            f.write("\nlearning rate: " + str(config.learning_rate)+"\n")
            f.write("iters_ml/rl: " + str(config.mle_epoch) + "  " + str(config.rl_epoch)+"\n")
            f.write("dump_model_path: " + str(config.dump_filename)+"\n")
            f.write("topic_flag: "+str(self.args.topic)+"\n")
            f.write("args: "+str(self.args)+"\n")
        logging.info('Model has been saved at %s', save_path)

    def train_RL(self, src, trg, src_lens, optim, topic_seq=None):
        batch_size = src.shape[0]
        mle_weight = config.mle_weight
        if self.args.topic:
            output = self.model.forward(src, trg, src_lens, topic_seq=topic_seq, eval_flag=True)
        else:
            output = self.model.forward(src, trg, src_lens, topic_seq=None, eval_flag=True)
        probs = torch.cat((self.model.module.initial_probs.to(src.device).repeat(batch_size, 1, 1), output[:, :-1, :]), dim=1)

        #print(probs)
        greedy_seq = probs.argmax(-1)
        mle_loss = self.model.module.label_smooth(probs.view(-1, self.model.module.vocab_size), trg.view(-1)).cuda()

        # multinomial sampling
        probs = F.softmax(probs, dim=-1)
        multi_dist = Categorical(probs)
        sample_seq = multi_dist.sample()  # perform multinomial sampling
        index = sample_seq.view(batch_size, -1, 1)
        sample_prob = torch.gather(probs, 2, index) #batch, seq_len, 1

                                                                         #If multinomial based sampling, compute log probabilites of sampled words
        non_zero = (sample_prob == self.loader.eos_idx)
        mask = np.zeros_like(non_zero.cpu())
        for i in range(non_zero.shape[0]):
            index = torch.nonzero(non_zero[i])
            #print(index)
            if index.shape[0] == 0:
                mask[i] = 1
            else:
                mask[i][:index[0]] = 1

        mask = torch.FloatTensor(mask).cuda()
        lens = torch.sum(mask, dim=1) + 1# Length of sampled sentence
        RL_logs = torch.sum(mask * sample_prob, dim=1) / lens.cuda()

        #compute normalizied log probability of a sentence

        sample_seq = self.loader.decode(sample_seq)
        generated = self.loader.decode(greedy_seq)
        original = self.loader.decode(trg)
        sample_reward = self.reward_function(sample_seq, original)
        baseline_reward = self.reward_function(generated, original)
        sample_reward = torch.FloatTensor(sample_reward).cuda()
        baseline_reward = torch.FloatTensor(baseline_reward).cuda()

        rl_loss = -(sample_reward - baseline_reward) * RL_logs
        rl_loss = torch.mean(rl_loss)

        batch_reward = torch.mean(sample_reward).item()
        # ------------------------------------------------------------------------------------
        if config.multi_gpu:
            mle_loss = mle_loss.mean()
            rl_loss = rl_loss.mean()
        #print("mle loss, rl loss, reward:", mle_loss.item(), rl_loss.item(), batch_reward)
        self.model.zero_grad()
        (mle_weight * mle_loss + (1-mle_weight) * rl_loss).backward()
        optim.step()

        return mle_loss.item(), generated, original, batch_reward


    def reward_function(self, decoded_sents, original_sents):
        #print(decoded_sents)
        rouge = Rouge()
        try:
            scores = rouge.get_scores(decoded_sents, original_sents)
        except Exception:
            print("Rouge failed for multi sentence evaluation.. Finding exact pair")
            scores = []
            for i in range(len(decoded_sents)):
                try:
                    score = rouge.get_scores(decoded_sents[i], original_sents[i])
                except Exception:
                    #print("Error occured at:")
                    #print("decoded_sents:", decoded_sents[i])
                    #print("original_sents:", original_sents[i])
                    score = [{"rouge-l": {"f": 0.0}}]
                scores.append(score[0])
        rouge_l_f1 = [score["rouge-l"]["f"] for score in scores]
        return rouge_l_f1

    def eval_model(self):
        decoded_sents, ref_sents, article_sents = [], [], []
        self.model.module.eval()
        for i in range(0, self.eval_loader.data_length, config.eval_bs):
            start = i
            end = min(i + config.eval_bs, self.eval_loader.data_length)
            batch = self.eval_loader.eval_next_batch(start, end, self.device)
            lengths, indices = torch.sort(batch.src_length, dim=0, descending=True)
            src = torch.index_select(batch.src, dim=0, index=indices)
            trg = torch.index_select(batch.trg, dim=0, index=indices)
            topic_seq = None
            seq = self.model.module.evaluate(src, lengths, max_seq_len=config.max_seq_len)
            try:
                generated = self.loader.decode(seq)
                original = self.loader.decode(trg)
                #original = batch.trg_text
                article = self.loader.decode(src)

                decoded_sents.extend(generated)
                ref_sents.extend(original)
                article_sents.extend(article)
            except:
                print("failed at batch %d", i)
        scores = self.reward_function(decoded_sents, ref_sents)
        score = np.array(scores).mean()
        return score

    def trainIters(self):
        logging.info('Start training')
        start_iter = self.setup_train()
        if config.schedule:
            scheduler = L.CosineAnnealingLR(self.optim.optimizer, T_max=config.mle_epoch)

        batches = self.loader.data_length // config.train_bs

        rl_iters = config.rl_epoch * batches
        logging.info("%d steps for per epoch, there is %d epoches.", batches, config.rl_epoch)
        tmp_epoch = 1

        for i in range(start_iter, start_iter+rl_iters):
            train_batch = self.loader.next_batch(config.train_bs, self.device)
            lengths, indices = torch.sort(train_batch.src_length, dim=0, descending=True)
            src = torch.index_select(train_batch.src, dim=0, index=indices)
            trg = torch.index_select(train_batch.trg, dim=0, index=indices)

            loss, generated, original, reward = self.train_RL(src, trg, self.optim)
            #generated = self.loader.decode(seq)
            #original = self.loader.decode(train_batch.trg)

            scores = self.reward_function(generated, original)
            scores = np.array(scores).mean()

            if i % config.train_interval == 0:
                logging.info('Iteration %d; Loss: %f; rouge: %.4f', i, loss, scores)
                self.writer.add_scalar('Loss', loss, i)
                self.save_model(i)

            if (i-start_iter) % batches == 0:
                logging.info("Epoch %d finished!", tmp_epoch)
                self.optim.updateLearningRate(score=0, epoch=tmp_epoch)
                if config.schedule:
                    scheduler.step()
                    logging.info("Decaying learning rate to %g" % scheduler.get_lr()[0])
                tmp_epoch = tmp_epoch + 1

            logging.info("rl training finished!")
Пример #5
0
#
# parser.add_argument('--inp', metavar='I', type=str, default='sample', help='name sample part of dataset')
# parser.add_argument('--out', metavar='O', type=str, default='./dataset/generated.txt', help='output file')
# parser.add_argument('--prefix', metavar='P', type=str, default='simple-summ', help='model prefix')
# parser.add_argument('--dataset', metavar='D', type=str, default='./dataset', help='dataset folder')
# parser.add_argument('--limit', metavar='L', type=int, default=15, help='generation limit')
#
# args = parser.parse_args()

dataset = 'data/news_commentary/mono/'
bpe_model_filename = 'model_dumps/sp_bpe.model'
model_filename = 'model_dumps/simple-summ/simple-summ.model'
model_args_filename = 'model_dumps/simple-summ/simple-summ.args'
# emb_filename = os.path.join('./models_dumps', args.prefix, 'embedding.npy')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
loader = DataLoader(dataset, ['sample'], ['src'], bpe_model_filename)
args_m = pickle.load(open(model_args_filename, 'rb'))
model = TransformerSummarizer(**args_m)
model.load_state_dict(torch.load(model_filename))
model.to(device)
model.eval()

with torch.no_grad():
    summ = []
    for batch in loader.sequential('sample', device):
        seq = model.sample(batch, 50)
        summ += loader.decode(seq)
    with open("sample_output", 'w', encoding="utf8") as f:
        f.write('\n'.join(summ))
Пример #6
0
class Train(object):
    def __init__(self, args):
        logging.info('Loading dataset')
        #self.loader = CNNDailyMail(config.dataset, data_type, ['src', 'trg'], config.bpe_model_filename)
        self.loader = DataLoader(config.dataset, train_type, ['src', 'trg'],
                                 config.bpe_model_filename)
        self.eval_loader = DataLoader(config.dataset, eval_type,
                                      ['src', 'trg'],
                                      config.bpe_model_filename)
        logging.info('Dataset has been loaded.Total size: %s, maxlen: %s',
                     self.loader.data_length, self.loader.max_len)

        self.device = torch.device(
            "cuda" if args.cuda and torch.cuda.is_available() else "cpu")
        if config.multi_gpu:
            self.n_gpu = torch.cuda.device_count()
        else:
            self.n_gpu = 0
            os.environ["CUDA_VISIBLE_DEVICES"] = "3"
        self.writer = SummaryWriter(config.dump_filename)

        if args.pretrain_emb:
            self.embeddings = torch.from_numpy(np.load(
                config.emb_filename)).float()
            logging.info(
                'Use vocabulary and embedding sizes from embedding dump.')
            self.vocab_size, self.emb_size = self.embeddings.shape
        else:
            self.embeddings = None
            self.vocab_size, self.emb_size = config.vocab_size, config.emb_size

        self.args = args
        self.m_args = {
            'max_seq_len': self.loader.max_len,
            'vocab_size': self.vocab_size,
            'enc_n_layers': config.enc_n_layers,
            'dec_n_layers': config.dec_n_layers,
            'global_layers': config.global_layers,
            'batch_size': config.train_bs,
            'emb_size': self.emb_size,
            'dim_m': config.model_dim,
            'n_heads': config.n_heads,
            'dim_i': config.inner_dim,
            'dropout': config.dropout,
            'embedding_weights': self.embeddings,
            'lexical_switch': args.lexical,
            'emb_share_prj': args.emb_share_prj,
            'global_encoding': config.global_encoding,
            'encoding_gate': config.encoding_gate,
            'stack': config.stack,
            'topic': config.topic,
            'inception': config.inception,
            'gtu': config.gtu
        }
        #if config.multi_gpu:
        #    self.m_args['batch_size'] *= n_gpu
        self.error_file = open(config.error_filename, "w")

    def setup_train(self):
        logging.info('Create model')
        if self.args.load_model is not None:
            with open(os.path.join(config.save_model_path, "model.args"),
                      "r") as f:
                self.m_args = json.loads(f.readline())
                self.m_args['embedding_weights'] = self.embeddings
        self.model = TransformerSummarizer(**self.m_args).to(self.device)
        if config.multi_gpu:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=[0, 1, 2, 3])
        # self.model = get_cuda(self.model)
        self.m_args['embedding_weights'] = None
        optim_param = self.model.module.learnable_parameters() if isinstance(
            self.model,
            torch.nn.DataParallel) else self.model.learnable_parameters()

        self.optim = Optim(config.learning_rate,
                           config.max_grad_norm,
                           lr_decay=config.learning_rate_decay,
                           start_decay_at=config.start_decay_at)
        self.optim.set_parameters(optim_param)
        #self.optim = Adam(optim_param, lr=config.learning_rate, amsgrad=True, betas=[0.9, 0.98], eps=1e-9)
        start_iter = 0

        if self.args.load_model is not None:
            load_model_path = os.path.join(config.save_model_path,
                                           self.args.load_model)
            checkpoint = torch.load(load_model_path)
            start_iter = checkpoint["iter"]
            self.model.module.load_state_dict(checkpoint["model_dict"])
            self.optim = checkpoint['trainer_dict']
            logging.info("Loaded model at " + load_model_path)

        self.print_args()

        return start_iter

    def print_args(self):
        print("encoder layer num: ", config.enc_n_layers)
        print("decoder layer num: ", config.dec_n_layers)
        print("emb_size: ", config.emb_size)
        print("model_dim: ", config.model_dim)
        print("inner_dim: ", config.inner_dim)
        print("mle_weight: ", config.mle_weight)

    def save_model(self, iter):
        save_path = config.dump_filename + "/%07d.tar" % iter
        torch.save(
            {
                "iter":
                iter + 1,
                "model_dict":
                self.model.module.state_dict()
                if config.multi_gpu else self.model.state_dict(),
                "trainer_dict":
                self.optim
            }, save_path)
        with open(config.args_filename, 'w') as f:
            f.write(json.dumps(self.m_args))
            f.write("\nlearning rate: " + str(config.learning_rate) + "\n")
            f.write("iters_ml/rl: " + str(config.mle_epoch) + "  " +
                    str(config.rl_epoch) + "\n")
            f.write("dump_model_path: " + str(config.dump_filename) + "\n")
            #f.write("topic_flag: "+str(self.args.topic)+"\n")
            f.write("args: " + str(self.args) + "\n")
        logging.info('Model has been saved at %s', save_path)

    def reward_function(self, decoded_sents, original_sents):
        #print(decoded_sents)
        rouge = Rouge()
        try:
            scores = rouge.get_scores(decoded_sents, original_sents)
        except Exception:
            print(
                "Rouge failed for multi sentence evaluation.. Finding exact pair"
            )
            scores = []
            for i in range(len(decoded_sents)):
                try:
                    score = rouge.get_scores(decoded_sents[i],
                                             original_sents[i])
                except Exception:
                    #print("Error occured at:")
                    #print("decoded_sents:", decoded_sents[i])
                    #print("original_sents:", original_sents[i])
                    score = [{"rouge-1": {"f": 0.0}}]
                scores.append(score[0])
        rouge_l_f1 = [score["rouge-1"]["f"] for score in scores]
        return rouge_l_f1

    def eval_model(self):
        decoded_sents, ref_sents, article_sents = [], [], []
        self.model.module.eval()
        for i in range(0, self.eval_loader.data_length, config.eval_bs):
            start = i
            end = min(i + config.eval_bs, self.eval_loader.data_length)
            batch = self.eval_loader.eval_next_batch(start, end, self.device)
            lengths, indices = torch.sort(batch.src_length,
                                          dim=0,
                                          descending=True)
            src = torch.index_select(batch.src, dim=0, index=indices)
            trg = torch.index_select(batch.trg, dim=0, index=indices)
            seq, loss = self.model.module.evaluate(
                src,
                lengths,
                max_seq_len=config.max_seq_len,
                beam_size=None,
                topic_seq=None,
                trg=trg)
            try:
                generated = self.loader.decode(seq)
                original = self.loader.decode(trg)
                #original = batch.trg_text
                article = self.loader.decode(src)

                decoded_sents.extend(generated)
                ref_sents.extend(original)
                article_sents.extend(article)
            except:
                print("failed at batch %d", i)
        scores = self.reward_function(decoded_sents, ref_sents)
        score = np.array(scores).mean()
        return score, loss

    def trainIters(self):
        logging.info('Start training')
        start_iter = self.setup_train()
        count = loss_total = rouge_total = 0
        record_rouge = 0.0
        if config.schedule:
            scheduler = L.CosineAnnealingLR(self.optim.optimizer,
                                            T_max=config.mle_epoch)

        batches = self.loader.data_length // config.train_bs

        mle_iter = config.mle_epoch * batches
        logging.info("%d steps for per epoch, there is %d epoches.", batches,
                     config.mle_epoch)
        tmp_epoch = (start_iter // batches) + 1
        for i in range(start_iter, mle_iter):
            self.model.train()
            self.model.zero_grad()
            train_batch = self.loader.next_batch(config.train_bs, self.device)
            lengths, indices = torch.sort(train_batch.src_length,
                                          dim=0,
                                          descending=True)
            src = torch.index_select(train_batch.src, dim=0, index=indices)
            trg = torch.index_select(train_batch.trg, dim=0, index=indices)

            #self.optim.zero_grad()
            if config.topic:
                #print(train_batch.topic)
                loss, seq = self.model.forward(src, trg, lengths,
                                               train_batch.topic)
            else:
                loss, seq = self.model.forward(src, trg, lengths)
            if self.n_gpu > 1:
                loss = loss.mean()
            loss.backward()
            self.optim.step()

            generated = self.loader.decode(seq)
            original = self.loader.decode(trg)
            scores = self.reward_function(generated, original)
            scores = np.array(scores).mean()

            loss_total += loss
            rouge_total += scores
            count += 1
            if i % config.train_interval == 0:
                loss_avg = loss_total / count
                rouge_avg = rouge_total / count
                logging.info(
                    'Iteration %d; Loss: %f; rouge: %.4f; loss avg: %.4f; rouge avg: %.4f',
                    i, loss, scores, loss_avg, rouge_avg)
                self.writer.add_scalar('Loss', loss, i)
                loss_total = rouge_total = count = 0

            if i % config.train_sample_interval == 0 and i > 10000:
                score, loss = self.eval_model()
                self.writer.add_scalar('Test', loss, i)
                self.writer.add_scalar('TestRouge', score, i)
                logging.info("%s loss: %f", eval_type, loss)
                if score > record_rouge:
                    logging.info("%s score: %f", eval_type, score)
                    self.save_model(i)
                    record_rouge = score
                elif i % config.save_interval == 0:
                    self.save_model(i)

            if i % batches == 0 and i > 1:
                logging.info("Epoch %d finished!", tmp_epoch)
                self.optim.updateLearningRate(score=0, epoch=tmp_epoch)
                if config.schedule:
                    scheduler.step()
                    logging.info("Decaying learning rate to %g" %
                                 scheduler.get_lr()[0])
                tmp_epoch = tmp_epoch + 1

        logging.info("mle training finished!")
Пример #7
0
                                sort_within_batch=True)
    val_iter = BucketIterator(val_data,
                              BATCH_SIZE,
                              sort_key=lambda x: len(x.text),
                              sort_within_batch=True)
    test_iter = BucketIterator(test_data,
                               BATCH_SIZE,
                               sort_key=lambda x: len(x.text),
                               sort_within_batch=True)

    if not args.evaluate_only:

        ff = FastText("en")
        embeddings = ff.get_vecs_by_tokens(SRC.vocab.itos)

        model = TransformerSummarizer(ATTENTION_HEADS, N_LAYERS, N_LAYERS, DIM_FEEDFORWARD, \
                                        SEQ_LEN, VOCAB_SIZE, PAD_IDX, src_list, embeddings=embeddings).to(device)

        num_batches = math.ceil(len(train_data) / BATCH_SIZE)
        val_batches = math.ceil(len(val_data) / BATCH_SIZE)

        parameters = filter(lambda p: p.requires_grad, model.parameters())
        optimizer = optim.Adam(parameters)
        criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

        print("Training Started")

        for epoch in range(N_EPOCHS):
            start_time = time.time()

            train_loss = train(model, train_iter, num_batches, optimizer,
                               criterion)
Пример #8
0
class Train(object):
    def __init__(self, args):
        logging.info('Loading dataset')
        #self.loader = CNNDailyMail(config.dataset, data_type, ['src', 'trg'], config.bpe_model_filename)
        self.loader = DataLoader(config.dataset, data_type, ['src', 'trg'],
                                 config.bpe_model_filename)
        logging.info('Dataset has been loaded.Total size: %s, maxlen: %s',
                     self.loader.data_length, self.loader.max_len)

        self.device = torch.device(
            "cuda" if args.cuda and torch.cuda.is_available() else "cpu")
        if config.multi_gpu:
            self.n_gpu = torch.cuda.device_count()
        else:
            self.n_gpu = 0
            os.environ["CUDA_VISIBLE_DEVICES"] = "3"
        self.writer = SummaryWriter(config.log + config.prefix)

        if args.pretrain_emb:
            self.embeddings = torch.from_numpy(np.load(
                config.emb_filename)).float()
            logging.info(
                'Use vocabulary and embedding sizes from embedding dump.')
            self.vocab_size, self.emb_size = self.embeddings.shape
        else:
            self.embeddings = None
            self.vocab_size, self.emb_size = config.vocab_size, config.emb_size

        self.args = args
        self.m_args = {
            'max_seq_len': self.loader.max_len,
            'vocab_size': self.vocab_size,
            'n_layers': config.n_layers,
            'batch_size': config.train_bs,
            'emb_size': self.emb_size,
            'dim_m': config.model_dim,
            'n_heads': config.n_heads,
            'dim_i': config.inner_dim,
            'dropout': config.dropout,
            'embedding_weights': self.embeddings,
            'lexical_switch': args.lexical
        }
        #if config.multi_gpu:
        #    self.m_args['batch_size'] *= n_gpu
        self.error_file = open(config.error_filename, "w")

    def setup_train(self):
        logging.info('Create model')

        self.model = TransformerSummarizer(**self.m_args).to(self.device)
        if config.multi_gpu:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=[0, 1, 2, 3])
        # self.model = get_cuda(self.model)
        self.m_args['embedding_weights'] = None
        optim_param = self.model.module.learnable_parameters() if isinstance(
            self.model,
            torch.nn.DataParallel) else self.model.learnable_parameters()
        self.optimizer = Adam(optim_param,
                              lr=config.learning_rate,
                              amsgrad=True,
                              betas=[0.9, 0.98],
                              eps=1e-9)
        start_iter = 0

        if self.args.load_model is not None:
            load_model_path = os.path.join(config.save_model_path,
                                           self.args.load_model)
            checkpoint = torch.load(load_model_path)
            start_iter = checkpoint["iter"]
            self.model.module.load_state_dict(checkpoint["model_dict"])
            self.optimizer.load_state_dict(checkpoint["trainer_dict"])
            logging.info("Loaded model at " + load_model_path)

        if self.args.new_lr is not None:
            self.optimizer = Adam(self.model.module.learnable_parameters(),
                                  lr=self.args.new_lr)
        '''
        model_dict = self.model.module.state_dict()
        for k, v in model_dict.items():
            print(k)
        '''

        return start_iter

    def train_RL(self, batch, optim, mle_weight):

        batch_size = batch.src.shape[0]
        source_seq, target_seq = batch.src, batch.trg
        extra_zeros, enc_batch_extend_vocab, trg_extend = batch.extra_zeros, batch.enc_batch_extend_vocab, batch.trg_extend_vocab

        if self.args.topic:
            greedy_seq, out_oov, probs, out_probs, mle_loss = self.model.forward(
                source_seq,
                target_seq,
                extra_zeros,
                enc_batch_extend_vocab,
                trg_extend,
                topic_seq=batch.topic)
        else:
            greedy_seq, out_oov, probs, out_probs, mle_loss = self.model.forward(
                source_seq, target_seq, extra_zeros, enc_batch_extend_vocab,
                trg_extend)

        # multinomial sampling
        probs = F.softmax(probs, dim=-1)
        multi_dist = Categorical(probs)
        sample_seq = multi_dist.sample()  # perform multinomial sampling
        index = sample_seq.view(batch_size, -1, 1)
        sample_prob = torch.gather(probs, 2, index)  #batch, seq_len, 1

        non_zero = (sample_prob == self.loader.eos_idx)
        mask = np.zeros_like(non_zero.cpu())
        for i in range(non_zero.shape[0]):
            index = torch.nonzero(non_zero[i])
            #print(index)
            if index.shape[0] == 0:
                mask[i] = 1
            else:
                mask[i][:index[0]] = 1

        mask = torch.FloatTensor(mask).cuda()
        lens = torch.sum(mask, dim=1) + 1  # Length of sampled sentence
        RL_logs = torch.sum(mask * sample_prob, dim=1) / lens.cuda()

        #compute normalizied log probability of a sentence
        #print(sample_seq)
        sample_seq = self.loader.decode_oov(greedy_seq,
                                            source_oovs=batch.source_oov,
                                            oov=sample_seq)
        generated = self.loader.decode_oov(greedy_seq,
                                           source_oovs=batch.source_oov,
                                           oov=out_oov)
        original = batch.trg_text

        sample_reward = self.reward_function(sample_seq, original)
        baseline_reward = self.reward_function(generated, original)
        sample_reward = torch.FloatTensor(sample_reward).cuda()
        baseline_reward = torch.FloatTensor(baseline_reward).cuda()
        #print(RL_logs)
        #print("reward", sample_reward, baseline_reward)
        # if iter%200 == 0:
        #     self.write_to_file(sample_sents, greedy_sents, batch.original_abstracts, sample_reward, baseline_reward, iter)
        # Self-critic policy gradient training (eq 15 in https://arxiv.org/pdf/1705.04304.pdf)
        rl_loss = -(sample_reward - baseline_reward) * RL_logs
        rl_loss = torch.mean(rl_loss)

        batch_reward = torch.mean(sample_reward).item()
        #rl_loss = T.FloatTensor([0]).cuda()
        #batch_reward = 0

        # ------------------------------------------------------------------------------------

        if config.multi_gpu:
            mle_loss = mle_loss.mean()
            rl_loss = rl_loss.mean()
        #print("mle loss, rl loss, reward:", mle_loss.item(), rl_loss.item(), batch_reward)
        optim.zero_grad()
        (mle_weight * mle_loss + (1 - mle_weight) * rl_loss).backward()
        optim.step()

        #mix_loss = mle_weight * mle_loss + (1-mle_weight) * rl_loss
        return mle_loss.item(), generated, original, batch_reward

    def reward_function(self, decoded_sents, original_sents):
        #print(decoded_sents)
        rouge = Rouge()
        try:
            scores = rouge.get_scores(decoded_sents, original_sents)
        except Exception:
            print(
                "Rouge failed for multi sentence evaluation.. Finding exact pair"
            )
            scores = []
            for i in range(len(decoded_sents)):
                try:
                    score = rouge.get_scores(decoded_sents[i],
                                             original_sents[i])
                except Exception:
                    #print("Error occured at:")
                    #print("decoded_sents:", decoded_sents[i])
                    #print("original_sents:", original_sents[i])
                    score = [{"rouge-l": {"f": 0.0}}]
                scores.append(score[0])
        rouge_l_f1 = [score["rouge-l"]["f"] for score in scores]
        return rouge_l_f1

    def save_model(self, iter):
        save_path = config.dump_filename + "/%07d.tar" % iter
        torch.save(
            {
                "iter":
                iter + 1,
                "model_dict":
                self.model.module.state_dict()
                if config.multi_gpu else self.model.state_dict(),
                "trainer_dict":
                self.optimizer.state_dict()
            }, save_path)
        with open(config.args_filename, 'w') as f:
            f.write(json.dumps(self.m_args))
        logging.info('Model has been saved')

    def trainIters(self):
        logging.info('Start training')
        start_iter = self.setup_train()
        count = loss_total = rouge_total = 0
        record_rouge = 0.0
        self.model.train()

        for i in range(start_iter, config.iters):
            train_batch = self.loader.next_batch(config.train_bs, self.device)
            self.optimizer.zero_grad()
            source_seq, target_seq = train_batch.src, train_batch.trg
            extra_zeros, enc_batch_extend_vocab, trg_extend = train_batch.extra_zeros, train_batch.enc_batch_extend_vocab, train_batch.trg_extend_vocab
            if self.args.topic:
                seq, loss, out_oov, out_copy_probs, out_probs = self.model.forward(
                    source_seq,
                    target_seq,
                    extra_zeros,
                    enc_batch_extend_vocab,
                    trg_extend,
                    topic_seq=train_batch.topic)
            else:
                seq, loss, out_oov, out_copy_probs, out_probs = self.model.forward(
                    source_seq, target_seq, extra_zeros,
                    enc_batch_extend_vocab, trg_extend)

            if self.n_gpu > 1:
                loss = loss.mean()
            loss.backward()
            self.optimizer.step()
            #print(loss)
            #generated = self.loader.decode(seq)
            generated = self.loader.decode_oov(
                seq, source_oovs=train_batch.source_oov, oov=out_oov)
            original = train_batch.trg_text
            '''
            print(generated[0])
            print(original[0])
            print(train_batch.src_text[0])
            '''
            scores = self.reward_function(generated, original)
            scores = np.array(scores).mean()

            loss_total += loss
            rouge_total += scores
            count += 1
            if i % config.train_interval == 0:
                loss_avg = loss_total / count
                rouge_avg = rouge_total / count
                logging.info(
                    'Iteration %d; Loss: %f; rouge: %.4f; loss avg: %.4f; rouge avg: %.4f',
                    i, loss, scores, loss_avg, rouge_avg)
                self.writer.add_scalar('Loss', loss, i)
                loss_total = rouge_total = count = 0
            if i % config.save_interval == 0 and i >= 25000:
                if scores > record_rouge:
                    self.save_model(i)
                    record_rouge = scores

        for i in range(config.iters, config.rl_iters):
            train_batch = self.loader.next_batch(config.train_bs, self.device)
            loss, generated, original, reward = self.train_RL(
                train_batch, self.optimizer, config.mle_weight)
            #generated = self.loader.decode(seq)
            #original = self.loader.decode(train_batch.trg)

            scores = self.reward_function(generated, original)
            scores = np.array(scores).mean()

            loss_total += loss
            rouge_total += reward
            count += 1
            if i % config.train_interval == 0:
                loss_avg = loss_total / count
                rouge_avg = rouge_total / count
                logging.info(
                    'Iteration %d; Loss: %f; rouge: %.4f; loss avg: %.4f; reward avg: %.4f',
                    i, loss, scores, loss_avg, rouge_avg)
                self.writer.add_scalar('Loss', loss, i)
                loss_total = rouge_total = count = 0
            if i % config.save_interval == 0 and i >= 25000:
                self.save_model(i)
Пример #9
0
class Evaluate(object):
    def __init__(self, args):
        #self.loader = DataLoader(eval_config.dataset, data_type, ['src', 'trg'], eval_config.bpe_model_filename, eval_config.vocab_size)
        self.loader = EvalBatcher(eval_config.dataset, duc_name, duc_src_type,
                                  eval_config.bpe_model_filename)

        if args.pretrain_emb:
            self.embeddings = torch.from_numpy(
                np.load(eval_config.emb_filename)).float()
            self.vocab_size, self.emb_size = self.embeddings.shape
        else:
            self.embeddings = None
            self.vocab_size, self.emb_size = eval_config.vocab_size, eval_config.emb_size
        with open(eval_config.args_filename, "r") as f:
            self.m_args = json.loads(f.readline())
            #self.m_args['max_seq_len'] = eval_config.max_seq_len

        time.sleep(5)
        self.args = args
        self.device = torch.device(
            "cuda" if args.cuda and torch.cuda.is_available() else "cpu")
        #self.iters = int(self.loader.data_length / self.m_args['batch_size'])

    def setup_valid(self):
        if self.args.cuda:
            checkpoint = torch.load(
                os.path.join(eval_config.save_model_path,
                             self.args.load_model))
        else:
            checkpoint = torch.load(os.path.join(eval_config.save_model_path,
                                                 self.args.load_model),
                                    map_location=lambda storage, loc: storage)
        if self.args.fine_tune:
            embeddings = checkpoint["model_dict"][
                "positional_encoding.embedding.weight"]
            self.m_args["embedding_weights"] = embeddings
        #self.m_args['embedding_weights'] = self.embeddings
        self.model = TransformerSummarizer(**self.m_args).to(self.device)
        self.model.load_state_dict(checkpoint["model_dict"])

    def print_original_predicted(self, decoded_sents, ref_sents, article_sents,
                                 loadfile):
        filename = data_type + "_" + loadfile.split(".")[0] + ".txt"

        with open(os.path.join(eval_config.save_model_path, filename),
                  "w") as f:
            for i in range(len(decoded_sents)):
                f.write("article: " + article_sents[i] + "\n")
                f.write("ref: " + ref_sents[i][0].strip().lower() + "\n")
                f.write("ref: " + ref_sents[i][1].strip().lower() + "\n")
                f.write("ref: " + ref_sents[i][2].strip().lower() + "\n")
                f.write("ref: " + ref_sents[i][3].strip().lower() + "\n")
                f.write("dec: " + decoded_sents[i] + "\n\n")

    def print_for_rouge(self, decoded_sents, ref_sents, corpus="giga"):
        assert len(decoded_sents) == len(ref_sents)
        ref_dir = os.path.join(eval_config.save_model_path, 'reference')
        cand_dir = os.path.join(eval_config.save_model_path, 'candidate')
        if not os.path.exists(ref_dir):
            os.mkdir(ref_dir)
        if not os.path.exists(cand_dir):
            os.mkdir(cand_dir)
        if corpus == "giga":
            for i in range(len(ref_sents)):
                with codecs.open(ref_dir + "/%06d_reference.txt" % i, 'w+',
                                 'utf-8') as f:
                    f.write(ref_sents[i])
                with codecs.open(cand_dir + "/%06d_candidate.txt" % i, 'w+',
                                 'utf-8') as f:
                    f.write(decoded_sents[i])
            r = pyrouge.Rouge155()
            r.model_filename_pattern = '#ID#_reference.txt'
            r.system_filename_pattern = '(\d+)_candidate.txt'
        else:
            for i in range(len(ref_sents)):
                nickname = ['A', 'B', 'C', 'D']
                for task in range(len(ref_sents[0])):
                    ref_file_name = nickname[task] + ".%06d_reference.txt" % i
                    with codecs.open(os.path.join(ref_dir, ref_file_name),
                                     'w+', 'utf-8') as f:
                        f.write(ref_sents[i][task].strip().lower())
                with codecs.open(cand_dir + "/%06d_candidate.txt" % i, 'w+',
                                 'utf-8') as f:
                    f.write(decoded_sents[i])
            r = pyrouge.Rouge155()
            r.model_filename_pattern = '[A-Z].#ID#_reference.txt'
            r.system_filename_pattern = '(\d+)_candidate.txt'

        r.model_dir = ref_dir
        r.system_dir = cand_dir
        logging.getLogger('global').setLevel(logging.WARNING)
        rouge_results = r.convert_and_evaluate()
        scores = r.output_to_dict(rouge_results)
        recall = [
            round(scores["rouge_1_recall"] * 100, 2),
            round(scores["rouge_2_recall"] * 100, 2),
            round(scores["rouge_l_recall"] * 100, 2)
        ]
        precision = [
            round(scores["rouge_1_precision"] * 100, 2),
            round(scores["rouge_2_precision"] * 100, 2),
            round(scores["rouge_l_precision"] * 100, 2)
        ]
        f_score = [
            round(scores["rouge_1_f_score"] * 100, 2),
            round(scores["rouge_2_f_score"] * 100, 2),
            round(scores["rouge_l_f_score"] * 100, 2)
        ]
        print("F_measure: %s Recall: %s Precision: %s\n" %
              (str(f_score), str(recall), str(precision)))

    def cal_rouge(self, decoded_sents, original_sents):
        rouge = Rouge()
        try:
            scores = rouge.get_scores(decoded_sents, original_sents)
        except Exception:
            print(
                "Rouge failed for multi sentence evaluation.. Finding exact pair"
            )
            scores = []
            for i in range(len(decoded_sents)):
                try:
                    score = rouge.get_scores(decoded_sents[i],
                                             original_sents[i])
                except Exception:
                    #print("Error occured at:")
                    #print("decoded_sents:", decoded_sents[i])
                    #print("original_sents:", original_sents[i])
                    score = [{
                        "rouge-1": {
                            "f": 0.0
                        },
                        "rouge-2": {
                            "f": 0.0
                        },
                        "rouge-l": {
                            "f": 0.0
                        }
                    }]
                scores.append(score[0])
        rouge_l_f1 = [[
            score["rouge-1"]["f"], score["rouge-2"]["f"], score["rouge-l"]["f"]
        ] for score in scores]
        return rouge_l_f1

    def evaluate_batch(self, max_seq_len, beam_size=None, corpus_id=-1):
        decoded_sents = []
        ref_sents = []
        article_sents = []
        for i in range(0, self.loader.data_length, eval_config.test_bs):
            start = i
            end = min(i + eval_config.test_bs, self.loader.data_length)
            batch = self.loader.eval_next_batch(start, end, self.device)
            lengths, indices = torch.sort(batch.src_length,
                                          dim=0,
                                          descending=True)
            src = torch.index_select(batch.src, dim=0, index=indices)

            topic_seq = None
            seq, loss = self.model.evaluate(src,
                                            lengths,
                                            max_seq_len=max_seq_len,
                                            beam_size=beam_size,
                                            topic_seq=topic_seq)
            # print("success", i)
            with torch.autograd.no_grad():
                generated = self.loader.decode_eval(seq)
                article = self.loader.decode_eval(src)
                if corpus_id == -1:
                    trg = torch.index_select(batch.trg, dim=0, index=indices)
                    original = self.loader.decode_eval(trg)
                    ref_sents.extend(original)
                else:
                    ref_sents.extend((indices.cpu().numpy() + i).tolist())

                decoded_sents.extend(generated)
                article_sents.extend(article)

        return decoded_sents, ref_sents, article_sents

    def run(self,
            max_seq_len,
            beam_size=None,
            print_sents=False,
            corpus_name="giga"):
        self.setup_valid()

        load_file = self.args.load_model
        if corpus_name == 'duc':
            task_files = [
                'task1_ref0.txt', 'task1_ref1.txt', 'task1_ref2.txt',
                'task1_ref3.txt'
            ]
            self.duc_target = EvalTarget(directory, task_files)
            decoded_sents, ref_id, article_sents = self.evaluate_batch(
                max_seq_len, beam_size=beam_size, corpus_id=1)
            ref_sents = []

            for id in ref_id:
                trg_sents = []
                for i in range(len(task_files)):
                    trg_sents.append(self.duc_target.text[i][id])

                ref_sents.append(trg_sents)
            '''
            rouge = Rouge()
            max_scores = []
            ref_for_print = []
            for mine, labels in zip(decoded_sents, ref_sents):
                scores = []
                for sentence in labels:
                    scores.append(rouge.get_scores(mine, sentence, avg=True)['rouge-1']['f']) #
                maxx = max(scores)
                max_scores.append(maxx)
                ref_for_print.append(scores.index(maxx))
            #print(max_scores)
            all_score = np.array(max_scores).mean(axis=0)
            #print(all_score)
            '''
            if print_sents:
                self.print_for_rouge(decoded_sents, ref_sents, corpus="duc")
                #print(np.array(ref_sents).size)
                self.print_original_predicted(decoded_sents, ref_sents,
                                              article_sents, load_file)

        else:
            decoded_sents, ref_sents, article_sents = self.evaluate_batch(
                max_seq_len, beam_size=beam_size, corpus_id=-1)
            scores = self.cal_rouge(decoded_sents, ref_sents)
            score = np.array(scores).mean(axis=0)
            print("rouge-1: %.4f, rouge-2: %.4f, rouge-l: %.4f" %
                  (score[0], score[1], score[2]))

            if print_sents:
                self.print_for_rouge(decoded_sents, ref_sents)
Пример #10
0
class Evaluate(object):
    def __init__(self, args):
        self.loader = DataLoader(eval_config.dataset, data_type,
                                 ['src', 'trg'],
                                 eval_config.bpe_model_filename,
                                 eval_config.vocab_size)
        #self.loader = EvalBatcher(eval_config.dataset, duc_name, duc_src_type, eval_config.bpe_model_filename)

        if args.pretrain_emb:
            self.embeddings = torch.from_numpy(
                np.load(eval_config.emb_filename)).float()
            self.vocab_size, self.emb_size = self.embeddings.shape
        else:
            self.embeddings = None
            self.vocab_size, self.emb_size = eval_config.vocab_size, eval_config.emb_size
        with open(eval_config.args_filename, "r") as f:
            self.m_args = json.loads(f.readline())
            #self.m_args['max_seq_len'] = eval_config.max_seq_len

        time.sleep(5)
        self.args = args
        self.device = torch.device(
            "cuda" if args.cuda and torch.cuda.is_available() else "cpu")
        #self.iters = int(self.loader.data_length / self.m_args['batch_size'])

    def setup_valid(self):
        if self.args.cuda:
            checkpoint = torch.load(
                os.path.join(eval_config.save_model_path,
                             self.args.load_model))
        else:
            checkpoint = torch.load(os.path.join(eval_config.save_model_path,
                                                 self.args.load_model),
                                    map_location=lambda storage, loc: storage)
        if self.args.fine_tune:
            embeddings = checkpoint["model_dict"][
                "positional_encoding.embedding.weight"]
            self.m_args["embedding_weights"] = embeddings
        #self.m_args['embedding_weights'] = self.embeddings
        self.model = TransformerSummarizer(**self.m_args).to(self.device)
        self.model.load_state_dict(checkpoint["model_dict"])

    def evaluate_batch(self):
        self.setup_valid()
        decoded_sents = []
        ref_sents = []
        article_sents = []
        for i in range(0, self.loader.data_length, eval_config.sample_bs):
            start = i
            end = min(i + eval_config.test_bs, self.loader.data_length)
            batch = self.loader.eval_next_batch(start, end, self.device)
            lengths, indices = torch.sort(batch.src_length,
                                          dim=0,
                                          descending=True)
            src = torch.index_select(batch.src, dim=0, index=indices)
            trg = torch.index_select(batch.trg, dim=0, index=indices)
            topic_seq = None

            loss, seq, att_en, att_rc = self.model.forward(src, trg, lengths)
            # print("success", i)
            with torch.autograd.no_grad():
                generated = self.loader.decode_raw(seq)
                article = self.loader.decode_raw(src)

                ref_sents.extend((indices.cpu().numpy() + i).tolist())

                decoded_sents.extend(generated)
                article_sents.extend(article)
        #print(att_rc.shape)
        #print(att_en.shape)
        #print(generated)
        #return decoded_sents,ref_sents, article_sents
        for i in range(8):
            #plt.figure(i)
            plot_heatmap(article[0], generated[0], att_en[i].cpu().data)

        for i in range(8):
            #plt.figure(i)
            plot_heatmap(article[0], generated[0], att_rc[i].cpu().data)
logging.info('Create model')

m_args = {
    'max_seq_len': loader.max_len,
    'vocab_size': vocab_size,
    'n_layers': args.n_layers,
    'emb_size': emb_size,
    'dim_m': args.model_dim,
    'n_heads': args.n_heads,
    'dim_i': args.inner_dim,
    'dropout': args.dropout,
    'embedding_weights': embeddings
}

model = TransformerSummarizer(**m_args).to(device)

m_args['embedding_weights'] = None

optimizer = Adam(model.learnable_parameters(),
                 lr=args.learning_rate,
                 amsgrad=True,
                 betas=[0.9, 0.98],
                 eps=1e-9)

logging.info('Start training')
for i in range(args.iters):
    try:
        train_batch = loader.next_batch(args.train_bs, 'train', device)
        loss, seq = model.train_step(train_batch, optimizer)