예제 #1
0
 def forward(self, data_dict):
     src = data_dict[DK_SRC_WID].to(device())
     src_oov = data_dict[DK_SRC_OOV_WID].to(device())
     src = self.src_embed(src, src_oov)
     src_mask = data_dict[DK_SRC_WID_MASK].to(device())
     encoder_op = self.encoder(src, mask=src_mask)
     encoder_op = self.dropout(encoder_op)
     if hasattr(self.aggr, "require_mask"):
         aggr_op = self.aggr(encoder_op, mask=src_mask)
     else:
         aggr_op = self.aggr(encoder_op)
     word_probs = self.word_generator(aggr_op)
     return word_probs
예제 #2
0
 def forward(self, x, target):
     assert x.size(1) == self.size
     x = x.to(device())
     target = target.to(device())
     true_dist = x.data.clone()
     true_dist.fill_(self.smoothing / (self.size - 2))
     indices = target.data.unsqueeze(1)
     true_dist.scatter_(1, indices, self.confidence)
     if self.padding_idx is not None:
         true_dist[:, self.padding_idx] = 0
         mask = torch.nonzero(target.data == self.padding_idx)
         if mask.shape[0] > 0: true_dist.index_fill_(0, mask.squeeze(), 0.0)
     return self.criterion(x, true_dist)
예제 #3
0
 def forward(self, data_dict):
     tgt = data_dict[DK_TGT_GEN_WID].to(device())
     src_mask = data_dict[DK_SRC_WID_MASK].to(device())
     target_len = tgt.shape[1]
     encoder_op, encoder_hidden = self.encode(data_dict)
     decoder_hidden = self.prep_enc_hidden_for_dec(encoder_hidden)
     use_teacher_forcing = True if random.random(
     ) < self.params.s2s_teacher_forcing_ratio else False
     g_probs, attns, c_probs = self.decode(
         encoder_op,
         decoder_hidden,
         src_mask,
         target_len,
         tgt=tgt if use_teacher_forcing else None)
     return g_probs, attns, c_probs
예제 #4
0
    def forward(self, embedded, hidden=None, lens=None, mask=None):
        if mask is not None:
            lens = mask.sum(dim=-1).squeeze()
        self.rnn.flatten_parameters()
        if lens is None:
            outputs, hidden = self.rnn(embedded, hidden)
            rv_lens = None
        else:
            packed = nn.utils.rnn.pack_padded_sequence(embedded,
                                                       lens,
                                                       batch_first=True)
            outputs, hidden = self.rnn(packed, hidden)
            outputs, rv_lens = nn.utils.rnn.pad_packed_sequence(
                outputs, batch_first=True)

        if self.output_resize_layer is not None:
            outputs = self.output_resize_layer(outputs)

        if self.return_aggr_vector_only:
            if rv_lens is not None:
                rv_lens = rv_lens.to(device())
                rv_lens = rv_lens - 1
                rv = torch.gather(
                    outputs, 1,
                    rv_lens.view(-1, 1).unsqueeze(2).repeat(
                        1, 1, outputs.size(-1)))
            else:
                rv = outputs
            return rv
        elif self.return_output_vector_only:
            return outputs
        else:
            return outputs, hidden
예제 #5
0
def make_std_mask(tgt, pad):
    if pad is not None:
        tgt_mask = (tgt != pad).unsqueeze(-2)
    else:
        tgt_mask = torch.ones(tgt.size()).type(torch.ByteTensor).unsqueeze(-2)
    tgt_mask = tgt_mask & subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data)
    return tgt_mask.to(device())
예제 #6
0
 def __getitem__(self, word):
     with torch.no_grad():
         if word in self._model_w2i:
             word_idx = self._model_w2i[word]
             tsr = torch.ones(1,1,1).fill_(word_idx).type(torch.LongTensor).to(device())
             embedding = self._model(tsr).squeeze().cpu().detach().numpy()
         else:
             embedding = self._original_w2v[word]
     return embedding
예제 #7
0
 def load_state_dict(self, state_dict):
     self._step = state_dict["_step"]
     self.warmup = state_dict["warmup"]
     self.factor = state_dict["factor"]
     self.model_size = state_dict["model_size"]
     self._rate = state_dict["_rate"]
     self.optimizer.load_state_dict(state_dict["opt_state_dict"])
     for state in self.optimizer.state.values():
         for k, v in state.items():
             if isinstance(v, torch.Tensor):
                 state[k] = v.to(device())
예제 #8
0
def mean_of_w2v(word_seg_list, w2v, to_torch_tensor=False):
    word_seg_list = [w for w in word_seg_list if w in w2v]
    if len(word_seg_list) == 0:
        raise ValueError("Input is empty")
    word_vecs = [w2v[w].reshape(1, -1) for w in word_seg_list]
    word_vec = np.concatenate(word_vecs, axis=0)
    word_vec = np.mean(word_vec, axis=0, keepdims=False)
    if to_torch_tensor:
        rv = torch.from_numpy(word_vec).type(torch.FloatTensor).to(device())
    else:
        rv = word_vec
    return rv
예제 #9
0
 def encode(self, data_dict):
     src = data_dict[DK_SRC_WID].to(device())
     src_mask = data_dict[DK_SRC_WID_MASK]
     encoder_hidden = None
     encoder_cell = None
     src_lens = torch.sum(src_mask.squeeze(1), dim=1)
     if self.params.s2s_encoder_type.lower() == "lstm":
         encoder_hidden = (encoder_hidden, encoder_cell)
     src = self.src_embed(src)
     encoder_op, encoder_hidden = self.encoder(src, encoder_hidden,
                                               src_lens)
     return encoder_op, encoder_hidden
예제 #10
0
 def update(self,
            probs,
            next_vals,
            next_wids,
            next_words,
            dec_hs,
            ctx=None):
     assert len(next_wids) == len(self.curr_candidates)
     next_candidates = []
     for i, tup in enumerate(self.curr_candidates):
         score = tup[1]
         prev_prob_list = [t for t in tup[2]]
         prev_words = [t for t in tup[3]]
         decoder_hidden = dec_hs[i]
         context = ctx[i] if ctx is not None else None
         preds = next_wids[i]
         vals = next_vals[i]
         pred_words = next_words[i]
         prev_prob_list.append(probs)
         for bi in range(len(preds)):
             wi = preds[bi]
             val = vals[bi]
             word = pred_words[bi]
             div_penalty = 0.0
             if i > 0: div_penalty = self.gamma * (bi + 1)
             new_score = score + val - div_penalty
             new_tgt = torch.ones(1, 1).long().fill_(wi).to(device())
             new_words = [w for w in prev_words]
             new_words.append(word)
             if wi == self.eos_idx:
                 if self.len_norm > 0:
                     length_penalty = (self.len_norm + new_tgt.shape[1]) / (
                         self.len_norm + 1)
                     new_score /= length_penalty**self.len_norm
                 else:
                     new_score = new_score / new_tgt.shape[
                         1] if new_tgt.shape[1] > 0 else new_score
                 ppl = 0  # TODO: add perplexity later
                 self.completed_insts.append(
                     (new_tgt, new_score, ppl, new_words))
             else:
                 next_candidates.append(
                     (new_tgt, new_score, prev_prob_list, new_words,
                      decoder_hidden, context))
     next_candidates = sorted(next_candidates,
                              key=lambda t: t[1],
                              reverse=True)
     next_candidates = next_candidates[:self.beam_width]
     self.curr_candidates = next_candidates
     self.done = len(self.curr_candidates) == 0
예제 #11
0
 def decode(self, enc_hs, dec_hidden, src_mask, target_length, tgt=None):
     sos = torch.ones(enc_hs.shape[0],
                      1).fill_(self.params.sos_idx).long().to(device())
     gen_probs = []
     cpy_probs = []
     dec_attns = []
     dec_input = sos
     context = self.make_init_att(dec_hidden)
     precomp = None
     for di in range(target_length):
         dec_output, dec_attn, cpy_prob, dec_hidden, context, precomp = self.decode_step(
             dec_input, dec_hidden, enc_hs, src_mask, context, precomp)
         if tgt is not None:
             dec_input = tgt[:, di].unsqueeze(1)
         else:
             _, next_wi = dec_output.topk(1)
             dec_input = next_wi.squeeze(2).detach().long().to(device())
         gen_probs.append(dec_output)
         cpy_probs.append(cpy_prob)
         dec_attns.append(dec_attn)
     return torch.cat(gen_probs,
                      dim=1), torch.cat(dec_attns,
                                        dim=1), torch.cat(cpy_probs, dim=1)
예제 #12
0
 def load_state_dict(self, state_dict):
     self.curr_lr = state_dict["curr_lr"]
     self.shrink_factor = state_dict["shrink_factor"]
     self.past_scores_considered = state_dict["past_scores_considered"]
     self.verbose = state_dict["verbose"]
     self.min_lr = state_dict["min_lr"]
     self.past_scores_list = state_dict["past_scores_list"]
     self.score_method = state_dict["score_method"]
     self.max_fail_limit = state_dict["max_fail_limit"]
     self.curr_fail_count = state_dict["curr_fail_count"]
     self.optimizer.load_state_dict(state_dict["opt_sd"])
     for state in self.optimizer.state.values():
         for k, v in state.items():
             if isinstance(v, torch.Tensor):
                 state[k] = v.to(device())
예제 #13
0
 def __init__(self,
              enc_h,
              i2w,
              idx_in_batch,
              src_seg_list,
              beam_width=4,
              sos_idx=2,
              eos_idx=3,
              ctx=None,
              gamma=0.0,
              len_norm=0.0):
     self.idx_in_batch = idx_in_batch
     self.beam_width = beam_width
     self.gamma = gamma
     self.len_norm = len_norm
     self.eos_idx = eos_idx
     self.src_seg_list = src_seg_list
     sos = torch.ones(1, 1).fill_(sos_idx).long().to(device())
     self.curr_candidates = [(sos, 0.0, [], [i2w[sos_idx]], enc_h, ctx)]
     self.completed_insts = []
     self.done = False
예제 #14
0
def dv_seq2seq_beam_decode_batch(model,
                                 batch,
                                 start_idx,
                                 i2w,
                                 max_len,
                                 gamma=0.0,
                                 oov_idx=1,
                                 beam_width=4,
                                 eos_idx=3,
                                 len_norm=1.0,
                                 topk=1):
    batch_size = batch[DK_SRC_WID].shape[0]
    model = model.to(device())
    encoder_op, encoder_hidden = model.encode(batch)
    src_mask = batch[DK_SRC_WID_MASK].to(device())
    encoder_hidden = model.prep_enc_hidden_for_dec(encoder_hidden)
    context = model.make_init_att(encoder_hidden)
    batch_results = [
        S2SBeamSearchResult(idx_in_batch=bi,
                            i2w=i2w,
                            src_seg_list=batch[DK_SRC_SEG_LISTS][bi],
                            enc_h=encoder_hidden[:, bi, :].unsqueeze(0),
                            beam_width=beam_width,
                            sos_idx=start_idx,
                            eos_idx=eos_idx,
                            ctx=context[bi, :].unsqueeze(0),
                            gamma=gamma,
                            len_norm=len_norm) for bi in range(batch_size)
    ]
    final_ans = []
    for i in range(max_len):
        curr_actives = [b for b in batch_results if not b.done]
        if len(curr_actives) == 0: break
        b_tgt_list = [b.get_curr_tgt() for b in curr_actives]
        b_tgt = torch.cat(b_tgt_list, dim=0)
        b_hidden_list = [b.get_curr_dec_hidden() for b in curr_actives]
        b_hidden = torch.cat(b_hidden_list, dim=1).to(device())
        b_ctx_list = [b.get_curr_context() for b in curr_actives]
        b_context = torch.cat(b_ctx_list, dim=0).to(device())
        b_cand_size_list = [b.get_curr_candidate_size() for b in curr_actives]
        b_src_seg_list = [b.get_curr_src_seg_list() for b in curr_actives]
        enc_op = torch.cat([
            encoder_op[b.idx_in_batch, :, :].unsqueeze(0).repeat(
                b.get_curr_candidate_size(), 1, 1) for b in curr_actives
        ],
                           dim=0)
        s_mask = torch.cat([
            src_mask[b.idx_in_batch, :, :].unsqueeze(0).repeat(
                b.get_curr_candidate_size(), 1, 1) for b in curr_actives
        ],
                           dim=0)
        gen_wid_probs, cpy_wid_probs, cpy_gate_probs, b_hidden, b_context, _ = model.decode_step(
            b_tgt, b_hidden, enc_op, s_mask, b_context, None)
        gen_wid_probs, cpy_wid_probs = get_gen_cpy_log_probs(
            gen_wid_probs, cpy_wid_probs, cpy_gate_probs)
        comb_prob = torch.cat([gen_wid_probs, cpy_wid_probs], dim=2)
        beam_i = 0
        for bi, size in enumerate(b_cand_size_list):
            g_probs = comb_prob[beam_i:beam_i + size, :].view(
                size, -1, comb_prob.size(-1))
            hiddens = b_hidden[:, beam_i:beam_i + size, :].view(
                size, -1, b_hidden.size(-1))
            ctxs = b_context[beam_i:beam_i + size, :].view(
                size, -1, b_context.size(-1))
            vt, it = g_probs.topk(beam_width)
            next_vals, next_wids, next_words, dec_hs, ctx = [], [], [], [], []
            for ci in range(size):
                vals, wis, words = [], [], []
                for idx in range(beam_width):
                    vals.append(vt[ci, 0, idx].item())
                    wi = it[ci, 0, idx].item()
                    if wi in i2w:  # generate
                        word = i2w[wi]
                    else:  # copy
                        c_wi = wi - len(i2w)
                        if c_wi < len(b_src_seg_list[bi][ci]):
                            word = b_src_seg_list[bi][ci][c_wi]
                        else:
                            word = i2w[oov_idx]
                        wi = oov_idx
                    wis.append(wi)
                    words.append(word)
                next_vals.append(vals)
                next_wids.append(wis)
                next_words.append(words)
            dec_hs = [
                hiddens[j, :].unsqueeze(1) for j in range(hiddens.shape[0])
            ]
            ctx = [ctxs[j, :].unsqueeze(0) for j in range(ctxs.shape[0])]
            curr_actives[bi].update(g_probs, next_vals, next_wids, next_words,
                                    dec_hs, ctx)
            beam_i += size
    for b in batch_results:
        final_ans.append(b.collect_results(topk=topk))
    return final_ans
예제 #15
0
def run_dv_seq2seq_epoch(model,
                         loader,
                         criterion_gen,
                         criterion_cpy,
                         curr_epoch=0,
                         max_grad_norm=5.0,
                         optimizer=None,
                         desc="Train",
                         pad_idx=0,
                         model_name="dv_seq2seq",
                         logs_dir=""):
    start = time.time()
    total_tokens = 0
    total_loss = 0
    total_correct = 0
    for batch in tqdm(loader,
                      mininterval=2,
                      desc=desc,
                      leave=False,
                      ascii=True):
        g_wid_probs, c_wid_probs, c_gate_probs = model(batch)
        gen_targets = batch[DK_TGT_GEN_WID].to(device())
        cpy_targets = batch[DK_TGT_CPY_WID].to(device())
        cpy_truth_gates = batch[DK_TGT_CPY_GATE].to(device())
        n_tokens = batch[DK_TGT_N_TOKENS].item()
        g_log_wid_probs, c_log_wid_probs = get_gen_cpy_log_probs(
            g_wid_probs, c_wid_probs, c_gate_probs)
        c_log_wid_probs = c_log_wid_probs * (
            cpy_truth_gates.unsqueeze(2).expand_as(c_log_wid_probs))
        g_log_wid_probs = g_log_wid_probs * (
            (1 - cpy_truth_gates).unsqueeze(2).expand_as(g_log_wid_probs))
        g_log_wid_probs = g_log_wid_probs.view(-1, g_log_wid_probs.size(-1))
        c_log_wid_probs = c_log_wid_probs.view(-1, c_log_wid_probs.size(-1))
        g_loss = criterion_gen(g_log_wid_probs,
                               gen_targets.contiguous().view(-1))
        c_loss = criterion_cpy(c_log_wid_probs,
                               cpy_targets.contiguous().view(-1))
        loss = g_loss + c_loss
        # compute acc
        tgt = copy.deepcopy(gen_targets.view(-1, 1).squeeze(1))
        c_sw = cpy_truth_gates.view(-1, 1).squeeze(1)
        c_tg = cpy_targets.view(-1, 1).squeeze(1)
        g_preds_i = copy.deepcopy(g_log_wid_probs.max(1)[1])
        c_preds_i = c_log_wid_probs.max(1)[1]
        g_preds_v = g_log_wid_probs.max(1)[0]
        for i in range(g_preds_i.shape[0]):
            if g_preds_v[i] == 0: g_preds_i[i] = c_preds_i[i]
        for i in range(tgt.shape[0]):
            if c_sw[i] == 1: tgt[i] = c_tg[i]
        n_correct = g_preds_i.data.eq(tgt.data)
        n_correct = n_correct.masked_select(tgt.ne(pad_idx).data).sum()
        total_loss += loss.item()
        total_correct += n_correct.item()
        total_tokens += n_tokens
        if optimizer is not None:
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(
                filter(lambda p: p.requires_grad, model.parameters()),
                max_grad_norm)
            optimizer.step()
    loss_report = total_loss / total_tokens
    acc = total_correct / total_tokens
    elapsed = time.time() - start
    info = desc + " epoch %d loss %f, acc %f ppl %f elapsed time %f" % (
        curr_epoch, loss_report, acc, math.exp(loss_report), elapsed)
    print(info)
    write_line_to_file(info, logs_dir + model_name + "_train_info.txt")
    return loss_report, acc
예제 #16
0
def train_dv_seq2seq(params,
                     model,
                     train_loader,
                     criterion_gen,
                     criterion_cpy,
                     optimizer,
                     completed_epochs=0,
                     eval_loader=None,
                     best_eval_result=0,
                     best_eval_epoch=0,
                     past_eval_results=[],
                     checkpoint=True):
    model = model.to(device())
    criterion_gen = criterion_gen.to(device())
    criterion_cpy = criterion_cpy.to(device())
    for epoch in range(params.epochs):
        report_epoch = epoch + completed_epochs + 1
        model.train()
        train_loss, _ = run_dv_seq2seq_epoch(
            model,
            train_loader,
            criterion_gen,
            criterion_cpy,
            curr_epoch=report_epoch,
            optimizer=optimizer,
            max_grad_norm=params.max_gradient_norm,
            desc="Train",
            pad_idx=params.pad_idx,
            model_name=params.model_name,
            logs_dir=params.logs_dir)
        if params.lr_decay_with_train_perf and hasattr(optimizer,
                                                       "update_learning_rate"):
            optimizer.update_learning_rate(train_loss, "max")
        if eval_loader is not None:
            model.eval()
            with torch.no_grad():
                if report_epoch >= params.full_eval_start_epoch and \
                   report_epoch % params.full_eval_every_epoch == 0:
                    eval_score = eval_dv_seq2seq(model, eval_loader, params)
                    if eval_score > best_eval_result:
                        best_eval_result = eval_score
                        best_eval_epoch = report_epoch
                        print("Model best checkpoint with score {}".format(
                            eval_score))
                        fn = params.saved_models_dir + params.model_name + "_best.pt"
                        if checkpoint:
                            model_checkpoint(fn, report_epoch, model,
                                             optimizer, params,
                                             past_eval_results,
                                             best_eval_result, best_eval_epoch)
                    info = "Best {} so far {} from epoch {}".format(
                        params.eval_metric, best_eval_result, best_eval_epoch)
                    print(info)
                    write_line_to_file(
                        info, params.logs_dir + params.model_name +
                        "_train_info.txt")
                    if hasattr(optimizer, "update_learning_rate"
                               ) and not params.lr_decay_with_train_perf:
                        optimizer.update_learning_rate(eval_score)
                    past_eval_results.append(eval_score)
                    if len(past_eval_results
                           ) > params.past_eval_scores_considered:
                        past_eval_results = past_eval_results[1:]
        fn = params.saved_models_dir + params.model_name + "_latest.pt"
        if checkpoint:
            model_checkpoint(fn, report_epoch, model, optimizer, params,
                             past_eval_results, best_eval_result,
                             best_eval_epoch)
        print("")
    return best_eval_result, best_eval_epoch
예제 #17
0
 def make_init_att(self, context):
     batch_size = context.size(1)
     h_size = (batch_size, 1, self.params.s2s_encoder_hidden_size *
               self.params.s2s_encoder_rnn_dir)
     return context.data.new(*h_size).zero_().float().to(device())
예제 #18
0
 def get_curr_context(self):
     if len(self.curr_candidates) == 0: return None
     return torch.cat(
         [tup[5] for tup in self.curr_candidates if tup[5] is not None],
         dim=0).float().to(device())
예제 #19
0
 def get_curr_dec_hidden(self):
     if len(self.curr_candidates) == 0: return None
     return torch.cat([tup[4] for tup in self.curr_candidates],
                      dim=1).float().to(device())
예제 #20
0
 def get_curr_tgt(self):
     if len(self.curr_candidates) == 0: return None
     return torch.cat([tup[0] for tup in self.curr_candidates],
                      dim=0).long().to(device())
예제 #21
0
def train_rwg(params,
              model,
              train_loader,
              criterion,
              optimizer,
              completed_epochs=0,
              eval_loader=None,
              best_eval_result=0,
              best_eval_epoch=0,
              past_eval_results=[],
              past_train_loss=[]):
    model = model.to(device())
    criterion = criterion.to(device())
    for epoch in range(params.epochs):
        report_epoch = epoch + completed_epochs + 1
        model.train()
        train_loss, _ = run_rwg_epoch(train_loader,
                                      model,
                                      criterion,
                                      optimizer,
                                      model_name=params.model_name,
                                      report_acc=False,
                                      max_grad_norm=params.max_gradient_norm,
                                      pad_idx=params.pad_idx,
                                      curr_epoch=report_epoch,
                                      logs_dir=params.logs_dir)

        past_train_loss.append(train_loss)
        if len(past_train_loss) > 2: past_train_loss = past_train_loss[1:]
        if params.lr_decay_with_train_perf:
            if params.lr_decay_with_train_loss_diff and \
                len(past_train_loss) == 2 and \
                past_train_loss[1] <= past_train_loss[0] and \
                past_train_loss[0] - past_train_loss[1] < params.train_loss_diff_threshold:
                print("updating lr by train loss diff, threshold: {}".format(
                    params.train_loss_diff_threshold))
                optimizer.shrink_learning_rate()
                past_train_loss = []
            elif hasattr(optimizer, "update_learning_rate"):
                optimizer.update_learning_rate(train_loss, "max")

        fn = params.saved_models_dir + params.model_name + "_latest.pt"
        model_checkpoint(fn, report_epoch, model, optimizer, params,
                         past_eval_results, best_eval_result, best_eval_epoch)

        if eval_loader is not None:
            model.eval()
            with torch.no_grad():
                if report_epoch >= params.full_eval_start_epoch and \
                   report_epoch % params.full_eval_every_epoch == 0:
                    eval_loss, eval_score = run_rwg_epoch(
                        eval_loader,
                        model,
                        criterion,
                        None,
                        pad_idx=params.pad_idx,
                        model_name=params.model_name,
                        report_acc=True,
                        curr_epoch=report_epoch,
                        logs_dir=None,
                        desc="Eval")
                    if eval_score > best_eval_result:
                        best_eval_result = eval_score
                        best_eval_epoch = report_epoch
                        print("Model best checkpoint with score {}".format(
                            eval_score))
                        fn = params.saved_models_dir + params.model_name + "_best.pt"
                        model_checkpoint(fn, report_epoch, model, optimizer,
                                         params, past_eval_results,
                                         best_eval_result, best_eval_epoch)
                    info = "Best {} so far {} from epoch {}".format(
                        params.eval_metric, best_eval_result, best_eval_epoch)
                    print(info)
                    write_line_to_file(
                        info, params.logs_dir + params.model_name +
                        "_train_info.txt")
                    if hasattr(optimizer, "update_learning_rate"
                               ) and not params.lr_decay_with_train_perf:
                        optimizer.update_learning_rate(eval_score, "min")
                    past_eval_results.append(eval_score)
                    if len(past_eval_results
                           ) > params.past_eval_scores_considered:
                        past_eval_results = past_eval_results[1:]

        print("")
예제 #22
0
def run_rwg_epoch(data_iter,
                  model,
                  criterion,
                  optimizer,
                  model_name="rwg",
                  desc="Train",
                  curr_epoch=0,
                  pad_idx=0,
                  logs_dir=None,
                  max_grad_norm=5.0,
                  report_acc=False):
    start = time.time()
    total_tokens = 0
    total_loss = 0
    total_correct_k = 0
    total_correct_2k = 0
    total_correct_10 = 0
    total_correct_25 = 0
    total_correct_50 = 0
    total_correct_100 = 0
    total_correct_500 = 0
    total_acc_tokens = 0
    for batch in tqdm(data_iter,
                      mininterval=2,
                      desc=desc,
                      leave=False,
                      ascii=True):
        probs = model.forward(batch)
        gen_targets = batch[DK_TGT_GEN_WID]
        gen_targets = label_tsr_to_one_hot_tsr(gen_targets, probs.size(-1))
        gen_targets = gen_targets.to(device())
        n_tokens = batch[DK_TGT_N_TOKENS].item()
        probs = probs.view(-1, probs.size(-1))
        loss = criterion(probs, gen_targets.contiguous())
        total_loss += loss.item()
        total_tokens += n_tokens
        if optimizer is not None:
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(
                filter(lambda p: p.requires_grad, model.parameters()),
                max_grad_norm)
            optimizer.step()
            if torch.isnan(loss).any():
                assert False, "nan detected after step()"

        if report_acc:
            gen_targets = batch[DK_TGT_GEN_WID]
            for bi in range(batch[DK_BATCH_SIZE]):
                k = batch[DK_WI_N_WORDS][bi]
                prob = probs[bi, :].squeeze()
                _, top_k_ids = prob.topk(500)
                pred_k_ids = top_k_ids.tolist()
                pred_tk_ids = set(pred_k_ids[:k])
                pred_2k_ids = set(pred_k_ids[:int(2 * k)])
                pred_10_ids = set(pred_k_ids[:10])
                pred_25_ids = set(pred_k_ids[:25])
                pred_50_ids = set(pred_k_ids[:50])
                pred_100_ids = set(pred_k_ids[:100])
                pred_500_ids = set(pred_k_ids)
                truth_ids = gen_targets[bi, :].squeeze()
                for truth_id in truth_ids.tolist():
                    if truth_id == pad_idx: continue
                    if truth_id in pred_tk_ids:
                        total_correct_k += 1
                    if truth_id in pred_2k_ids:
                        total_correct_2k += 1
                    if truth_id in pred_10_ids:
                        total_correct_10 += 1
                    if truth_id in pred_25_ids:
                        total_correct_25 += 1
                    if truth_id in pred_50_ids:
                        total_correct_50 += 1
                    if truth_id in pred_100_ids:
                        total_correct_100 += 1
                    if truth_id in pred_500_ids:
                        total_correct_500 += 1
                    total_acc_tokens += 1

    elapsed = time.time() - start
    if report_acc:
        info = desc + " epoch %d loss %f top_k acc %f top_2k acc %f top_10 acc %f top_25 acc %f top_50 acc %f top_100 " \
                      "acc %f top_500 acc % f ppl %f elapsed time %f" % (
                curr_epoch, total_loss / total_tokens,
                total_correct_k / total_acc_tokens,
                total_correct_2k / total_acc_tokens,
                total_correct_10 / total_acc_tokens,
                total_correct_25 / total_acc_tokens,
                total_correct_50 / total_acc_tokens,
                total_correct_100 / total_acc_tokens,
                total_correct_500 / total_acc_tokens,
                math.exp(total_loss / total_tokens),
                elapsed)
    else:
        info = desc + " epoch %d loss %f ppl %f elapsed time %f" % (
            curr_epoch, total_loss / total_tokens,
            math.exp(total_loss / total_tokens), elapsed)
    print(info)
    if logs_dir is not None:
        write_line_to_file(info, logs_dir + model_name + "_train_info.txt")
    rv_loss = total_loss / total_tokens
    rv_perf = total_correct_100 / total_acc_tokens if total_acc_tokens > 0 else 0
    return rv_loss, rv_perf