def forward(self, data_dict): src = data_dict[DK_SRC_WID].to(device()) src_oov = data_dict[DK_SRC_OOV_WID].to(device()) src = self.src_embed(src, src_oov) src_mask = data_dict[DK_SRC_WID_MASK].to(device()) encoder_op = self.encoder(src, mask=src_mask) encoder_op = self.dropout(encoder_op) if hasattr(self.aggr, "require_mask"): aggr_op = self.aggr(encoder_op, mask=src_mask) else: aggr_op = self.aggr(encoder_op) word_probs = self.word_generator(aggr_op) return word_probs
def forward(self, x, target): assert x.size(1) == self.size x = x.to(device()) target = target.to(device()) true_dist = x.data.clone() true_dist.fill_(self.smoothing / (self.size - 2)) indices = target.data.unsqueeze(1) true_dist.scatter_(1, indices, self.confidence) if self.padding_idx is not None: true_dist[:, self.padding_idx] = 0 mask = torch.nonzero(target.data == self.padding_idx) if mask.shape[0] > 0: true_dist.index_fill_(0, mask.squeeze(), 0.0) return self.criterion(x, true_dist)
def forward(self, data_dict): tgt = data_dict[DK_TGT_GEN_WID].to(device()) src_mask = data_dict[DK_SRC_WID_MASK].to(device()) target_len = tgt.shape[1] encoder_op, encoder_hidden = self.encode(data_dict) decoder_hidden = self.prep_enc_hidden_for_dec(encoder_hidden) use_teacher_forcing = True if random.random( ) < self.params.s2s_teacher_forcing_ratio else False g_probs, attns, c_probs = self.decode( encoder_op, decoder_hidden, src_mask, target_len, tgt=tgt if use_teacher_forcing else None) return g_probs, attns, c_probs
def forward(self, embedded, hidden=None, lens=None, mask=None): if mask is not None: lens = mask.sum(dim=-1).squeeze() self.rnn.flatten_parameters() if lens is None: outputs, hidden = self.rnn(embedded, hidden) rv_lens = None else: packed = nn.utils.rnn.pack_padded_sequence(embedded, lens, batch_first=True) outputs, hidden = self.rnn(packed, hidden) outputs, rv_lens = nn.utils.rnn.pad_packed_sequence( outputs, batch_first=True) if self.output_resize_layer is not None: outputs = self.output_resize_layer(outputs) if self.return_aggr_vector_only: if rv_lens is not None: rv_lens = rv_lens.to(device()) rv_lens = rv_lens - 1 rv = torch.gather( outputs, 1, rv_lens.view(-1, 1).unsqueeze(2).repeat( 1, 1, outputs.size(-1))) else: rv = outputs return rv elif self.return_output_vector_only: return outputs else: return outputs, hidden
def make_std_mask(tgt, pad): if pad is not None: tgt_mask = (tgt != pad).unsqueeze(-2) else: tgt_mask = torch.ones(tgt.size()).type(torch.ByteTensor).unsqueeze(-2) tgt_mask = tgt_mask & subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data) return tgt_mask.to(device())
def __getitem__(self, word): with torch.no_grad(): if word in self._model_w2i: word_idx = self._model_w2i[word] tsr = torch.ones(1,1,1).fill_(word_idx).type(torch.LongTensor).to(device()) embedding = self._model(tsr).squeeze().cpu().detach().numpy() else: embedding = self._original_w2v[word] return embedding
def load_state_dict(self, state_dict): self._step = state_dict["_step"] self.warmup = state_dict["warmup"] self.factor = state_dict["factor"] self.model_size = state_dict["model_size"] self._rate = state_dict["_rate"] self.optimizer.load_state_dict(state_dict["opt_state_dict"]) for state in self.optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device())
def mean_of_w2v(word_seg_list, w2v, to_torch_tensor=False): word_seg_list = [w for w in word_seg_list if w in w2v] if len(word_seg_list) == 0: raise ValueError("Input is empty") word_vecs = [w2v[w].reshape(1, -1) for w in word_seg_list] word_vec = np.concatenate(word_vecs, axis=0) word_vec = np.mean(word_vec, axis=0, keepdims=False) if to_torch_tensor: rv = torch.from_numpy(word_vec).type(torch.FloatTensor).to(device()) else: rv = word_vec return rv
def encode(self, data_dict): src = data_dict[DK_SRC_WID].to(device()) src_mask = data_dict[DK_SRC_WID_MASK] encoder_hidden = None encoder_cell = None src_lens = torch.sum(src_mask.squeeze(1), dim=1) if self.params.s2s_encoder_type.lower() == "lstm": encoder_hidden = (encoder_hidden, encoder_cell) src = self.src_embed(src) encoder_op, encoder_hidden = self.encoder(src, encoder_hidden, src_lens) return encoder_op, encoder_hidden
def update(self, probs, next_vals, next_wids, next_words, dec_hs, ctx=None): assert len(next_wids) == len(self.curr_candidates) next_candidates = [] for i, tup in enumerate(self.curr_candidates): score = tup[1] prev_prob_list = [t for t in tup[2]] prev_words = [t for t in tup[3]] decoder_hidden = dec_hs[i] context = ctx[i] if ctx is not None else None preds = next_wids[i] vals = next_vals[i] pred_words = next_words[i] prev_prob_list.append(probs) for bi in range(len(preds)): wi = preds[bi] val = vals[bi] word = pred_words[bi] div_penalty = 0.0 if i > 0: div_penalty = self.gamma * (bi + 1) new_score = score + val - div_penalty new_tgt = torch.ones(1, 1).long().fill_(wi).to(device()) new_words = [w for w in prev_words] new_words.append(word) if wi == self.eos_idx: if self.len_norm > 0: length_penalty = (self.len_norm + new_tgt.shape[1]) / ( self.len_norm + 1) new_score /= length_penalty**self.len_norm else: new_score = new_score / new_tgt.shape[ 1] if new_tgt.shape[1] > 0 else new_score ppl = 0 # TODO: add perplexity later self.completed_insts.append( (new_tgt, new_score, ppl, new_words)) else: next_candidates.append( (new_tgt, new_score, prev_prob_list, new_words, decoder_hidden, context)) next_candidates = sorted(next_candidates, key=lambda t: t[1], reverse=True) next_candidates = next_candidates[:self.beam_width] self.curr_candidates = next_candidates self.done = len(self.curr_candidates) == 0
def decode(self, enc_hs, dec_hidden, src_mask, target_length, tgt=None): sos = torch.ones(enc_hs.shape[0], 1).fill_(self.params.sos_idx).long().to(device()) gen_probs = [] cpy_probs = [] dec_attns = [] dec_input = sos context = self.make_init_att(dec_hidden) precomp = None for di in range(target_length): dec_output, dec_attn, cpy_prob, dec_hidden, context, precomp = self.decode_step( dec_input, dec_hidden, enc_hs, src_mask, context, precomp) if tgt is not None: dec_input = tgt[:, di].unsqueeze(1) else: _, next_wi = dec_output.topk(1) dec_input = next_wi.squeeze(2).detach().long().to(device()) gen_probs.append(dec_output) cpy_probs.append(cpy_prob) dec_attns.append(dec_attn) return torch.cat(gen_probs, dim=1), torch.cat(dec_attns, dim=1), torch.cat(cpy_probs, dim=1)
def load_state_dict(self, state_dict): self.curr_lr = state_dict["curr_lr"] self.shrink_factor = state_dict["shrink_factor"] self.past_scores_considered = state_dict["past_scores_considered"] self.verbose = state_dict["verbose"] self.min_lr = state_dict["min_lr"] self.past_scores_list = state_dict["past_scores_list"] self.score_method = state_dict["score_method"] self.max_fail_limit = state_dict["max_fail_limit"] self.curr_fail_count = state_dict["curr_fail_count"] self.optimizer.load_state_dict(state_dict["opt_sd"]) for state in self.optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device())
def __init__(self, enc_h, i2w, idx_in_batch, src_seg_list, beam_width=4, sos_idx=2, eos_idx=3, ctx=None, gamma=0.0, len_norm=0.0): self.idx_in_batch = idx_in_batch self.beam_width = beam_width self.gamma = gamma self.len_norm = len_norm self.eos_idx = eos_idx self.src_seg_list = src_seg_list sos = torch.ones(1, 1).fill_(sos_idx).long().to(device()) self.curr_candidates = [(sos, 0.0, [], [i2w[sos_idx]], enc_h, ctx)] self.completed_insts = [] self.done = False
def dv_seq2seq_beam_decode_batch(model, batch, start_idx, i2w, max_len, gamma=0.0, oov_idx=1, beam_width=4, eos_idx=3, len_norm=1.0, topk=1): batch_size = batch[DK_SRC_WID].shape[0] model = model.to(device()) encoder_op, encoder_hidden = model.encode(batch) src_mask = batch[DK_SRC_WID_MASK].to(device()) encoder_hidden = model.prep_enc_hidden_for_dec(encoder_hidden) context = model.make_init_att(encoder_hidden) batch_results = [ S2SBeamSearchResult(idx_in_batch=bi, i2w=i2w, src_seg_list=batch[DK_SRC_SEG_LISTS][bi], enc_h=encoder_hidden[:, bi, :].unsqueeze(0), beam_width=beam_width, sos_idx=start_idx, eos_idx=eos_idx, ctx=context[bi, :].unsqueeze(0), gamma=gamma, len_norm=len_norm) for bi in range(batch_size) ] final_ans = [] for i in range(max_len): curr_actives = [b for b in batch_results if not b.done] if len(curr_actives) == 0: break b_tgt_list = [b.get_curr_tgt() for b in curr_actives] b_tgt = torch.cat(b_tgt_list, dim=0) b_hidden_list = [b.get_curr_dec_hidden() for b in curr_actives] b_hidden = torch.cat(b_hidden_list, dim=1).to(device()) b_ctx_list = [b.get_curr_context() for b in curr_actives] b_context = torch.cat(b_ctx_list, dim=0).to(device()) b_cand_size_list = [b.get_curr_candidate_size() for b in curr_actives] b_src_seg_list = [b.get_curr_src_seg_list() for b in curr_actives] enc_op = torch.cat([ encoder_op[b.idx_in_batch, :, :].unsqueeze(0).repeat( b.get_curr_candidate_size(), 1, 1) for b in curr_actives ], dim=0) s_mask = torch.cat([ src_mask[b.idx_in_batch, :, :].unsqueeze(0).repeat( b.get_curr_candidate_size(), 1, 1) for b in curr_actives ], dim=0) gen_wid_probs, cpy_wid_probs, cpy_gate_probs, b_hidden, b_context, _ = model.decode_step( b_tgt, b_hidden, enc_op, s_mask, b_context, None) gen_wid_probs, cpy_wid_probs = get_gen_cpy_log_probs( gen_wid_probs, cpy_wid_probs, cpy_gate_probs) comb_prob = torch.cat([gen_wid_probs, cpy_wid_probs], dim=2) beam_i = 0 for bi, size in enumerate(b_cand_size_list): g_probs = comb_prob[beam_i:beam_i + size, :].view( size, -1, comb_prob.size(-1)) hiddens = b_hidden[:, beam_i:beam_i + size, :].view( size, -1, b_hidden.size(-1)) ctxs = b_context[beam_i:beam_i + size, :].view( size, -1, b_context.size(-1)) vt, it = g_probs.topk(beam_width) next_vals, next_wids, next_words, dec_hs, ctx = [], [], [], [], [] for ci in range(size): vals, wis, words = [], [], [] for idx in range(beam_width): vals.append(vt[ci, 0, idx].item()) wi = it[ci, 0, idx].item() if wi in i2w: # generate word = i2w[wi] else: # copy c_wi = wi - len(i2w) if c_wi < len(b_src_seg_list[bi][ci]): word = b_src_seg_list[bi][ci][c_wi] else: word = i2w[oov_idx] wi = oov_idx wis.append(wi) words.append(word) next_vals.append(vals) next_wids.append(wis) next_words.append(words) dec_hs = [ hiddens[j, :].unsqueeze(1) for j in range(hiddens.shape[0]) ] ctx = [ctxs[j, :].unsqueeze(0) for j in range(ctxs.shape[0])] curr_actives[bi].update(g_probs, next_vals, next_wids, next_words, dec_hs, ctx) beam_i += size for b in batch_results: final_ans.append(b.collect_results(topk=topk)) return final_ans
def run_dv_seq2seq_epoch(model, loader, criterion_gen, criterion_cpy, curr_epoch=0, max_grad_norm=5.0, optimizer=None, desc="Train", pad_idx=0, model_name="dv_seq2seq", logs_dir=""): start = time.time() total_tokens = 0 total_loss = 0 total_correct = 0 for batch in tqdm(loader, mininterval=2, desc=desc, leave=False, ascii=True): g_wid_probs, c_wid_probs, c_gate_probs = model(batch) gen_targets = batch[DK_TGT_GEN_WID].to(device()) cpy_targets = batch[DK_TGT_CPY_WID].to(device()) cpy_truth_gates = batch[DK_TGT_CPY_GATE].to(device()) n_tokens = batch[DK_TGT_N_TOKENS].item() g_log_wid_probs, c_log_wid_probs = get_gen_cpy_log_probs( g_wid_probs, c_wid_probs, c_gate_probs) c_log_wid_probs = c_log_wid_probs * ( cpy_truth_gates.unsqueeze(2).expand_as(c_log_wid_probs)) g_log_wid_probs = g_log_wid_probs * ( (1 - cpy_truth_gates).unsqueeze(2).expand_as(g_log_wid_probs)) g_log_wid_probs = g_log_wid_probs.view(-1, g_log_wid_probs.size(-1)) c_log_wid_probs = c_log_wid_probs.view(-1, c_log_wid_probs.size(-1)) g_loss = criterion_gen(g_log_wid_probs, gen_targets.contiguous().view(-1)) c_loss = criterion_cpy(c_log_wid_probs, cpy_targets.contiguous().view(-1)) loss = g_loss + c_loss # compute acc tgt = copy.deepcopy(gen_targets.view(-1, 1).squeeze(1)) c_sw = cpy_truth_gates.view(-1, 1).squeeze(1) c_tg = cpy_targets.view(-1, 1).squeeze(1) g_preds_i = copy.deepcopy(g_log_wid_probs.max(1)[1]) c_preds_i = c_log_wid_probs.max(1)[1] g_preds_v = g_log_wid_probs.max(1)[0] for i in range(g_preds_i.shape[0]): if g_preds_v[i] == 0: g_preds_i[i] = c_preds_i[i] for i in range(tgt.shape[0]): if c_sw[i] == 1: tgt[i] = c_tg[i] n_correct = g_preds_i.data.eq(tgt.data) n_correct = n_correct.masked_select(tgt.ne(pad_idx).data).sum() total_loss += loss.item() total_correct += n_correct.item() total_tokens += n_tokens if optimizer is not None: optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_( filter(lambda p: p.requires_grad, model.parameters()), max_grad_norm) optimizer.step() loss_report = total_loss / total_tokens acc = total_correct / total_tokens elapsed = time.time() - start info = desc + " epoch %d loss %f, acc %f ppl %f elapsed time %f" % ( curr_epoch, loss_report, acc, math.exp(loss_report), elapsed) print(info) write_line_to_file(info, logs_dir + model_name + "_train_info.txt") return loss_report, acc
def train_dv_seq2seq(params, model, train_loader, criterion_gen, criterion_cpy, optimizer, completed_epochs=0, eval_loader=None, best_eval_result=0, best_eval_epoch=0, past_eval_results=[], checkpoint=True): model = model.to(device()) criterion_gen = criterion_gen.to(device()) criterion_cpy = criterion_cpy.to(device()) for epoch in range(params.epochs): report_epoch = epoch + completed_epochs + 1 model.train() train_loss, _ = run_dv_seq2seq_epoch( model, train_loader, criterion_gen, criterion_cpy, curr_epoch=report_epoch, optimizer=optimizer, max_grad_norm=params.max_gradient_norm, desc="Train", pad_idx=params.pad_idx, model_name=params.model_name, logs_dir=params.logs_dir) if params.lr_decay_with_train_perf and hasattr(optimizer, "update_learning_rate"): optimizer.update_learning_rate(train_loss, "max") if eval_loader is not None: model.eval() with torch.no_grad(): if report_epoch >= params.full_eval_start_epoch and \ report_epoch % params.full_eval_every_epoch == 0: eval_score = eval_dv_seq2seq(model, eval_loader, params) if eval_score > best_eval_result: best_eval_result = eval_score best_eval_epoch = report_epoch print("Model best checkpoint with score {}".format( eval_score)) fn = params.saved_models_dir + params.model_name + "_best.pt" if checkpoint: model_checkpoint(fn, report_epoch, model, optimizer, params, past_eval_results, best_eval_result, best_eval_epoch) info = "Best {} so far {} from epoch {}".format( params.eval_metric, best_eval_result, best_eval_epoch) print(info) write_line_to_file( info, params.logs_dir + params.model_name + "_train_info.txt") if hasattr(optimizer, "update_learning_rate" ) and not params.lr_decay_with_train_perf: optimizer.update_learning_rate(eval_score) past_eval_results.append(eval_score) if len(past_eval_results ) > params.past_eval_scores_considered: past_eval_results = past_eval_results[1:] fn = params.saved_models_dir + params.model_name + "_latest.pt" if checkpoint: model_checkpoint(fn, report_epoch, model, optimizer, params, past_eval_results, best_eval_result, best_eval_epoch) print("") return best_eval_result, best_eval_epoch
def make_init_att(self, context): batch_size = context.size(1) h_size = (batch_size, 1, self.params.s2s_encoder_hidden_size * self.params.s2s_encoder_rnn_dir) return context.data.new(*h_size).zero_().float().to(device())
def get_curr_context(self): if len(self.curr_candidates) == 0: return None return torch.cat( [tup[5] for tup in self.curr_candidates if tup[5] is not None], dim=0).float().to(device())
def get_curr_dec_hidden(self): if len(self.curr_candidates) == 0: return None return torch.cat([tup[4] for tup in self.curr_candidates], dim=1).float().to(device())
def get_curr_tgt(self): if len(self.curr_candidates) == 0: return None return torch.cat([tup[0] for tup in self.curr_candidates], dim=0).long().to(device())
def train_rwg(params, model, train_loader, criterion, optimizer, completed_epochs=0, eval_loader=None, best_eval_result=0, best_eval_epoch=0, past_eval_results=[], past_train_loss=[]): model = model.to(device()) criterion = criterion.to(device()) for epoch in range(params.epochs): report_epoch = epoch + completed_epochs + 1 model.train() train_loss, _ = run_rwg_epoch(train_loader, model, criterion, optimizer, model_name=params.model_name, report_acc=False, max_grad_norm=params.max_gradient_norm, pad_idx=params.pad_idx, curr_epoch=report_epoch, logs_dir=params.logs_dir) past_train_loss.append(train_loss) if len(past_train_loss) > 2: past_train_loss = past_train_loss[1:] if params.lr_decay_with_train_perf: if params.lr_decay_with_train_loss_diff and \ len(past_train_loss) == 2 and \ past_train_loss[1] <= past_train_loss[0] and \ past_train_loss[0] - past_train_loss[1] < params.train_loss_diff_threshold: print("updating lr by train loss diff, threshold: {}".format( params.train_loss_diff_threshold)) optimizer.shrink_learning_rate() past_train_loss = [] elif hasattr(optimizer, "update_learning_rate"): optimizer.update_learning_rate(train_loss, "max") fn = params.saved_models_dir + params.model_name + "_latest.pt" model_checkpoint(fn, report_epoch, model, optimizer, params, past_eval_results, best_eval_result, best_eval_epoch) if eval_loader is not None: model.eval() with torch.no_grad(): if report_epoch >= params.full_eval_start_epoch and \ report_epoch % params.full_eval_every_epoch == 0: eval_loss, eval_score = run_rwg_epoch( eval_loader, model, criterion, None, pad_idx=params.pad_idx, model_name=params.model_name, report_acc=True, curr_epoch=report_epoch, logs_dir=None, desc="Eval") if eval_score > best_eval_result: best_eval_result = eval_score best_eval_epoch = report_epoch print("Model best checkpoint with score {}".format( eval_score)) fn = params.saved_models_dir + params.model_name + "_best.pt" model_checkpoint(fn, report_epoch, model, optimizer, params, past_eval_results, best_eval_result, best_eval_epoch) info = "Best {} so far {} from epoch {}".format( params.eval_metric, best_eval_result, best_eval_epoch) print(info) write_line_to_file( info, params.logs_dir + params.model_name + "_train_info.txt") if hasattr(optimizer, "update_learning_rate" ) and not params.lr_decay_with_train_perf: optimizer.update_learning_rate(eval_score, "min") past_eval_results.append(eval_score) if len(past_eval_results ) > params.past_eval_scores_considered: past_eval_results = past_eval_results[1:] print("")
def run_rwg_epoch(data_iter, model, criterion, optimizer, model_name="rwg", desc="Train", curr_epoch=0, pad_idx=0, logs_dir=None, max_grad_norm=5.0, report_acc=False): start = time.time() total_tokens = 0 total_loss = 0 total_correct_k = 0 total_correct_2k = 0 total_correct_10 = 0 total_correct_25 = 0 total_correct_50 = 0 total_correct_100 = 0 total_correct_500 = 0 total_acc_tokens = 0 for batch in tqdm(data_iter, mininterval=2, desc=desc, leave=False, ascii=True): probs = model.forward(batch) gen_targets = batch[DK_TGT_GEN_WID] gen_targets = label_tsr_to_one_hot_tsr(gen_targets, probs.size(-1)) gen_targets = gen_targets.to(device()) n_tokens = batch[DK_TGT_N_TOKENS].item() probs = probs.view(-1, probs.size(-1)) loss = criterion(probs, gen_targets.contiguous()) total_loss += loss.item() total_tokens += n_tokens if optimizer is not None: optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_( filter(lambda p: p.requires_grad, model.parameters()), max_grad_norm) optimizer.step() if torch.isnan(loss).any(): assert False, "nan detected after step()" if report_acc: gen_targets = batch[DK_TGT_GEN_WID] for bi in range(batch[DK_BATCH_SIZE]): k = batch[DK_WI_N_WORDS][bi] prob = probs[bi, :].squeeze() _, top_k_ids = prob.topk(500) pred_k_ids = top_k_ids.tolist() pred_tk_ids = set(pred_k_ids[:k]) pred_2k_ids = set(pred_k_ids[:int(2 * k)]) pred_10_ids = set(pred_k_ids[:10]) pred_25_ids = set(pred_k_ids[:25]) pred_50_ids = set(pred_k_ids[:50]) pred_100_ids = set(pred_k_ids[:100]) pred_500_ids = set(pred_k_ids) truth_ids = gen_targets[bi, :].squeeze() for truth_id in truth_ids.tolist(): if truth_id == pad_idx: continue if truth_id in pred_tk_ids: total_correct_k += 1 if truth_id in pred_2k_ids: total_correct_2k += 1 if truth_id in pred_10_ids: total_correct_10 += 1 if truth_id in pred_25_ids: total_correct_25 += 1 if truth_id in pred_50_ids: total_correct_50 += 1 if truth_id in pred_100_ids: total_correct_100 += 1 if truth_id in pred_500_ids: total_correct_500 += 1 total_acc_tokens += 1 elapsed = time.time() - start if report_acc: info = desc + " epoch %d loss %f top_k acc %f top_2k acc %f top_10 acc %f top_25 acc %f top_50 acc %f top_100 " \ "acc %f top_500 acc % f ppl %f elapsed time %f" % ( curr_epoch, total_loss / total_tokens, total_correct_k / total_acc_tokens, total_correct_2k / total_acc_tokens, total_correct_10 / total_acc_tokens, total_correct_25 / total_acc_tokens, total_correct_50 / total_acc_tokens, total_correct_100 / total_acc_tokens, total_correct_500 / total_acc_tokens, math.exp(total_loss / total_tokens), elapsed) else: info = desc + " epoch %d loss %f ppl %f elapsed time %f" % ( curr_epoch, total_loss / total_tokens, math.exp(total_loss / total_tokens), elapsed) print(info) if logs_dir is not None: write_line_to_file(info, logs_dir + model_name + "_train_info.txt") rv_loss = total_loss / total_tokens rv_perf = total_correct_100 / total_acc_tokens if total_acc_tokens > 0 else 0 return rv_loss, rv_perf