def evaluate(self): logger.info("Evaluating in epoch %i" % self.epoch_cnt) self.embedding.eval() self.matchPyramid.eval() data_loader = DataLoader(self.test_data, batch_size=self.params.batch_size, shuffle=False, collate_fn=collate_fn) pred_list = list() label_list = list() with torch.no_grad(): for data_iter in data_loader: sen1, len1, sen2, len2, label = data_iter sen1_ts, len1_ts, sen2_ts, len2_ts, label_ts = truncate( sen1, len1, sen2, len2, label, self.params.word2idx, max_seq_len=self.params.max_seq_len) sen1_ts, len1_ts, sen2_ts, len2_ts, label_ts = to_cuda( sen1_ts, len1_ts, sen2_ts, len2_ts, label_ts) sen1_embedding = self.embedding(sen1_ts) sen2_embedding = self.embedding(sen2_ts) mp_output = self.matchPyramid(sen1_embedding, sen2_embedding) predictions = mp_output.data.max(1)[1] pred_list.extend(predictions.tolist()) label_list.extend(label_ts.tolist()) acc = accuracy_score(label_list, pred_list) f1 = f1_score(label_list, pred_list) logger.info("ACC score in epoch %i :%.4f" % (self.epoch_cnt, acc)) logger.info("F1 score in epoch %i :%.4f" % (self.epoch_cnt, f1))
def top_N_goal_acc(args, data_loader, model, N): goal_acc = [] total_traj = 0 with torch.no_grad(): for batch in data_loader: batch = global_utils.to_cuda(batch) (obs_traj, gt_traj, obs_traj_rel, gt_traj_rel, loss_mask, seq_start_end) = global_utils.get_trajectories_from_batch(batch) pred_scores, labels = predict_goal_from_batch(args, batch, model) top_n = [torch.argsort(tmp, dim=1)[:, -N:] for tmp in pred_scores] top_n = torch.cat(top_n, dim=0) labels = torch.cat(labels, dim=0) candiate_index = labels != -1 labels = labels.unsqueeze(1).repeat(1, N) out = top_n[candiate_index] == labels[candiate_index] out = out.long() out = out.sum(dim=1) goal_acc.append(len(out[out > 0])) total_traj += len(out) return sum(goal_acc) * 1.0 / total_traj
def get_ade_fde_batch(args, batch, model): model.eval() with torch.no_grad(): batch = global_utils.to_cuda(batch) (obs_traj, gt_traj, obs_traj_rel, gt_traj_rel, loss_mask, seq_start_end) = global_utils.get_trajectories_from_batch(batch) batch_size = obs_traj.size(1) pred_loss_mask = loss_mask[:, args.obs_len:] pred_traj_rel = predict_trajectory_from_batch(args, batch, model) pred_traj = global_utils.relative_to_abs(pred_traj_rel, obs_traj[-1]) ade = displacement_error(pred_traj, gt_traj, pred_loss_mask, mode='raw') fde = final_displacement_error(pred_traj, gt_traj, pred_loss_mask, mode='raw') return ade, fde
def loader_get_accuracy(data_loader): if is_change_perspective: batch_array = [] predicted_traj_array = [] score_array = [] label_array = [] for batch in data_loader: new_batch, m_array = transform_perspective(batch, is_batch=True) new_batch = global_utils.to_cuda(new_batch) (obs_traj, gt_traj, obs_traj_rel, gt_traj_rel, loss_mask, seq_start_end ) = global_utils.get_trajectories_from_batch(new_batch) new_pred_traj_rel = predict_trajectory_from_batch( args, new_batch, model) pred_traj_rel = inverse_transform(m_array, new_pred_traj_rel.detach(), new_batch, is_batch=True) pred_scores, labels = predict_goal_from_batch( args, new_batch, model) batch_array.append(global_utils.to_cuda(batch)) predicted_traj_array.append(pred_traj_rel.cuda()) score_array.append(pred_scores) label_array.append(labels) return check_accuracy_predicted(args, batch_array, predicted_traj_array, score_array, label_array) else: return check_accuracy(args, data_loader, model)
def eval(self): params = self.params self.model.eval() lang_id1 = params.lang2id[params.src_lang] lang_id2 = params.lang2id[params.trg_lang] valid = 0 total = 0 for sent1, len1, sent2, len2, y, _, _ in tqdm( self.dataloader['valid']): sent1, len1 = truncate(sent1, len1, params.max_len, params.eos_index) sent2, len2 = truncate(sent2, len2, params.max_len, params.eos_index) x, lengths, positions, langs = concat_batches(sent1, len1, lang_id1, sent2, len2, lang_id2, params.pad_index, params.eos_index, reset_positions=True) # cuda x, y, lengths, positions, langs = to_cuda(x, y, lengths, positions, langs, gpu=self.gpu) # forward output = self.model(x, lengths, positions, langs) predictions = output.data.max(1)[1] # update statistics valid += predictions.eq(y).sum().item() total += len(len1) # compute accuracy acc = 100.0 * valid / total scores = {} scores['acc'] = acc return scores
def train(self): logger.info("Training in epoch %i" % self.epoch_cnt) self.embedding.train() self.matchPyramid.train() data_loader = DataLoader(self.train_data, batch_size=self.params.batch_size, shuffle=True, collate_fn=collate_fn) for data_iter in data_loader: sen1, len1, sen2, len2, label = data_iter sen1_ts, len1_ts, sen2_ts, len2_ts, label_ts = truncate( sen1, len1, sen2, len2, label, self.params.word2idx, max_seq_len=self.params.max_seq_len) sen1_ts, len1_ts, sen2_ts, len2_ts, label_ts = to_cuda( sen1_ts, len1_ts, sen2_ts, len2_ts, label_ts) sen1_embedding = self.embedding(sen1_ts) sen2_embedding = self.embedding(sen2_ts) mp_output = self.matchPyramid(sen1_embedding, sen2_embedding) loss = F.cross_entropy(mp_output, label_ts) self.optimizer.zero_grad() loss.backward() self.optimizer.step()
def check_accuracy(args, data_loader, model, predict_trajectory=True, limit=False): disp_error = [0.0] f_disp_error = [0.0] goal_acc = [0.0] total_traj = 0 metrics = {} model.eval() with torch.no_grad(): for batch in data_loader: batch = global_utils.to_cuda(batch) (obs_traj, gt_traj, obs_traj_rel, gt_traj_rel, loss_mask, seq_start_end) = global_utils.get_trajectories_from_batch(batch) pred_loss_mask = loss_mask[:, args.obs_len:] if predict_trajectory: pred_traj_rel = predict_trajectory_from_batch( args, batch, model) pred_traj = global_utils.relative_to_abs( pred_traj_rel, obs_traj[-1]) # compute the sum of the average loss over each sequence ade = displacement_error(pred_traj, gt_traj, pred_loss_mask) # copmute the sum of the loss at the end of each sequence fde = final_displacement_error(pred_traj, gt_traj, pred_loss_mask) disp_error.append(ade.item()) f_disp_error.append(fde.item()) pred_scores, labels = predict_goal_from_batch(args, batch, model) labels = torch.cat(labels, dim=0) pred_labels = torch.cat( [torch.argmax(sc, dim=1) for sc in pred_scores], dim=0) candiate_index = labels != -1 goal_acc.append( torch.sum( (pred_labels[candiate_index] == labels[candiate_index] ).long())) total_traj += gt_traj.size(1) if limit and total_traj > args.num_samples_check: break metrics["ade"] = sum(disp_error) / total_traj metrics["fde"] = sum(f_disp_error) / total_traj metrics["goal_acc"] = sum(goal_acc) / total_traj model.train() return metrics
def train(self): params = self.params self.model.train() # training variables losses = [] ns = 0 # number of sentences nw = 0 # number of words t = time.time() lang_id1 = params.lang2id[params.src_lang] lang_id2 = params.lang2id[params.trg_lang] for sent1, len1, sent2, len2, y, _, _ in self.dataloader['train']: self.global_step += 1 sent1, len1 = truncate(sent1, len1, params.max_len, params.eos_index) sent2, len2 = truncate(sent2, len2, params.max_len, params.eos_index) x, lengths, positions, langs = concat_batches(sent1, len1, lang_id1, sent2, len2, lang_id2, params.pad_index, params.eos_index, reset_positions=True) bs = len(len1) # cuda x, y, lengths, positions, langs = to_cuda(x, y, lengths, positions, langs, gpu=self.gpu) # loss output = self.model(x, lengths, positions, langs) loss = self.criterion(output, y) # backward / optimization self.optimizer_e.zero_grad() self.optimizer_p.zero_grad() loss.backward() self.optimizer_e.step() self.optimizer_p.step() losses.append(loss.item()) # log if self.global_step % self.params.report_interval == 0: logger.info("GPU %i - Epoch %i - Global_step %i - Loss: %.4f" % (self.gpu, self.epoch, self.global_step, sum(losses) / len(losses))) nw, t = 0, time.time() losses = [] if self.global_step % params.eval_interval == 0: if self.gpu == 0: logger.info("XLM - Evaluating") with torch.no_grad(): scores = self.eval() if scores['acc'] > self.best_acc: self.best_acc = scores['acc'] torch.save( self.model.module, os.path.join(params.save_model, 'best_acc_model.pkl')) with open( os.path.join(params.save_model, 'best_acc.note'), 'a') as f: f.write(str(self.best_acc) + '\n') with open(os.path.join(params.save_model, 'acc.note'), 'a') as f: f.write(str(scores['acc']) + '\n') logger.info("acc - %i " % scores['acc']) self.model.train()
def run(model, params, dico, data, split, src_lang, trg_lang, gen_type="src2trg", alpha=1., beta=1., gamma=0., uniform=False, iter_mult=1, mask_schedule="constant", constant_k=1, batch_size=8, gpu_id=0): #n_batches = math.ceil(len(srcs) / batch_size) if gen_type == "src2trg": ref_path = params.ref_paths[(src_lang, trg_lang, split)] elif gen_type == "trg2src": ref_path = params.ref_paths[(trg_lang, src_lang, split)] refs = [s.strip() for s in open(ref_path, encoding="utf-8").readlines()] hypothesis = [] #hypothesis_selected_pos = [] for batch_n, batch in enumerate( get_iterator(params, data, split, "de", "en")): (src_x, src_lens), (trg_x, trg_lens) = batch batches, batches_src_lens, batches_trg_lens, total_scores = [], [], [], [] #batches_selected_pos = [] for i_topk_length in range(params.num_topk_lengths): # overwrite source/target lengths according to dataset stats if necessary if params.de2en_lengths != None and params.en2de_lengths != None: src_lens_item = src_lens[0].item() - 2 # remove BOS, EOS trg_lens_item = trg_lens[0].item() - 2 # remove BOS, EOS if gen_type == "src2trg": if len(params.de2en_lengths[src_lens_item].keys() ) < i_topk_length + 1: break data_trg_lens = sorted( params.de2en_lengths[src_lens_item].items(), key=operator.itemgetter(1)) data_trg_lens_item = data_trg_lens[-1 - i_topk_length][0] + 2 # overwrite trg_lens trg_lens = torch.ones_like(trg_lens) * data_trg_lens_item elif gen_type == "trg2src": if len(params.en2de_lengths[trg_lens_item].keys() ) < i_topk_length + 1: break data_src_lens = sorted( params.en2de_lengths[trg_lens_item].items(), key=operator.itemgetter(1)) # take i_topk_length most likely length and add BOS, EOS data_src_lens_item = data_src_lens[-1 - i_topk_length][0] + 2 # overwrite src_lens src_lens = torch.ones_like(src_lens) * data_src_lens_item if gen_type == "src2trg": sent1_input = src_x sent2_input = create_masked_batch(trg_lens, params, dico) dec_len = torch.max(trg_lens).item() - 2 # cut BOS, EOS elif gen_type == "trg2src": sent1_input = create_masked_batch(src_lens, params, dico) sent2_input = trg_x dec_len = torch.max(src_lens).item() - 2 # cut BOS, EOS batch, lengths, positions, langs = concat_batches(sent1_input, src_lens, params.lang2id[src_lang], \ sent2_input, trg_lens, params.lang2id[trg_lang], \ params.pad_index, params.eos_index, \ reset_positions=True, assert_eos=True) # not sure about it if gpu_id >= 0: batch, lengths, positions, langs, src_lens, trg_lens = \ to_cuda(batch, lengths, positions, langs, src_lens, trg_lens) with torch.no_grad(): batch, total_score_argmax_toks = \ _evaluate_batch(model, params, dico, batch, lengths, positions, langs, src_lens, trg_lens, gen_type, alpha, beta, gamma, uniform, dec_len, iter_mult, mask_schedule, constant_k) batches.append(batch.clone()) batches_src_lens.append(src_lens.clone()) batches_trg_lens.append(trg_lens.clone()) total_scores.append(total_score_argmax_toks) #batches_selected_pos.append(selected_pos) best_score_idx = np.array(total_scores).argmax() batch, src_lens, trg_lens = batches[best_score_idx], batches_src_lens[ best_score_idx], batches_trg_lens[best_score_idx] #selected_pos = batches_selected_pos[best_score_idx] #if gen_type == "src2trg": # hypothesis_selected_pos.append([selected_pos, trg_lens.item()-2]) #elif gen_type == "trg2src": # hypothesis_selected_pos.append([selected_pos, src_lens.item()-2]) for batch_idx in range(batch_size): src_len = src_lens[batch_idx].item() tgt_len = trg_lens[batch_idx].item() if gen_type == "src2trg": generated = batch[src_len:src_len + tgt_len, batch_idx] else: generated = batch[:src_len, batch_idx] # extra <eos> eos_pos = (generated == params.eos_index).nonzero() if eos_pos.shape[0] > 2: generated = generated[:(eos_pos[1, 0].item() + 1)] hypothesis.extend(convert_to_text(generated.unsqueeze(1), \ torch.Tensor([generated.shape[0]]).int(), \ dico, params)) print("Ex {0}\nRef: {1}\nHyp: {2}\n".format( batch_n, refs[batch_n].encode("utf-8"), hypothesis[-1].encode("utf-8"))) hyp_path = os.path.join(params.hyp_path, 'decoding.txt') hyp_path_tok = os.path.join(params.hyp_path, 'decoding.tok.txt') #hyp_selected_pos_path = os.path.join(params.hyp_path, "selected_pos.pkl") # export sentences to hypothesis file / restore BPE segmentation with open(hyp_path, 'w', encoding='utf-8') as f: f.write('\n'.join(hypothesis) + '\n') with open(hyp_path_tok, 'w', encoding='utf-8') as f: f.write('\n'.join(hypothesis) + '\n') #with open(hyp_selected_pos_path, 'wb') as f: # pkl.dump(hypothesis_selected_pos, f) restore_segmentation(hyp_path) # evaluate BLEU score bleu = eval_moses_bleu(ref_path, hyp_path) print("BLEU %s-%s; %s %s : %f" % (src_lang, trg_lang, hyp_path, ref_path, bleu)) # write BLEU score result to file result_path = os.path.join(params.hyp_path, "result.txt") with open(result_path, 'w', encoding='utf-8') as f: f.write("BLEU %s-%s; %s %s : %f\n" % (src_lang, trg_lang, hyp_path, ref_path, bleu))
def main(params): # initialize the experiment logger = initialize_exp(params) # generate parser / parse parameters parser = get_parser() params = parser.parse_args() reloaded = torch.load(params.model_path) model_params = AttrDict(reloaded['params']) logger.info("Supported languages: %s" % ", ".join(model_params.lang2id.keys())) # update dictionary parameters for name in ['n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index', 'mask_index']: setattr(params, name, getattr(model_params, name)) # build dictionary / build encoder / build decoder / reload weights dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts']) encoder = TransformerModel(model_params, dico, is_encoder=True, with_output=True).cuda().eval() decoder = TransformerModel(model_params, dico, is_encoder=False, with_output=True).cuda().eval() encoder.load_state_dict(reloaded['encoder']) decoder.load_state_dict(reloaded['decoder']) params.src_id = model_params.lang2id[params.src_lang] params.tgt_id = model_params.lang2id[params.tgt_lang] # float16 if params.fp16: assert torch.backends.cudnn.enabled encoder = network_to_half(encoder) decoder = network_to_half(decoder) input_data = torch.load(params.input) eval_dataset = Dataset(input_data["sentences"], input_data["positions"], params) if params.subset_start is not None: assert params.subset_end eval_dataset.select_data(params.subset_start, params.subset_end) eval_dataset.remove_empty_sentences() eval_dataset.remove_long_sentences(params.max_len) n_batch = 0 out = io.open(params.output_path, "w", encoding="utf-8") inp_dump = io.open(os.path.join(params.dump_path, "input.txt"), "w", encoding="utf-8") logger.info("logging to {}".format(os.path.join(params.dump_path, 'input.txt'))) with open(params.output_path, "w", encoding="utf-8") as out: for batch in eval_dataset.get_iterator(shuffle=False): n_batch += 1 (x1, len1) = batch input_text = convert_to_text(x1, len1, input_data["dico"], params) inp_dump.write("\n".join(input_text)) inp_dump.write("\n") langs1 = x1.clone().fill_(params.src_id) # cuda x1, len1, langs1 = to_cuda(x1, len1, langs1) # encode source sentence enc1 = encoder("fwd", x=x1, lengths=len1, langs=langs1, causal=False) enc1 = enc1.transpose(0, 1) # generate translation - translate / convert to text max_len = int(1.5 * len1.max().item() + 10) if params.beam_size == 1: generated, lengths = decoder.generate(enc1, len1, params.tgt_id, max_len=max_len) else: generated, lengths = decoder.generate_beam( enc1, len1, params.tgt_id, beam_size=params.beam_size, length_penalty=params.length_penalty, early_stopping=params.early_stopping, max_len=max_len) hypotheses_batch = convert_to_text(generated, lengths, input_data["dico"], params) out.write("\n".join(hypotheses_batch)) out.write("\n") if n_batch % 100 == 0: logger.info("{} batches processed".format(n_batch)) out.close() inp_dump.close()
def qg4dataset(self, direction, split="test"): direction = direction.split("-") params = self.params encoder = self.encoder decoder = self.decoder encoder.eval() decoder.eval() dico = self.dico src_lang, trg_lang = direction print("Performing %s-%s-xsumm" % (src_lang, trg_lang)) results = [] trg_lang_id = params.lang2id[trg_lang] vocab_mask = self.vocab_mask[ trg_lang] if params.decode_with_vocab else None for batch in tqdm( self.get_iterator_v2("test", ae_lang=src_lang, q_lang=src_lang, ds_name=params.ds_name)): # (sent_q, len_q), (sent_a, len_a), (sent_e, len_e), _ = batch x, lens, _, src_langs, _, _ = self.concat_qae_batch( batch, src_lang[-2:], use_task_emb=False, is_test=True) x, lens, src_langs = to_cuda(x, lens, src_langs) max_len = params.max_dec_len with torch.no_grad(): encoded = encoder("fwd", x=x, lengths=lens, langs=src_langs, causal=False) encoded = encoded.transpose(0, 1) if params.beam_size == 1: decoded, _ = decoder.generate(encoded, lens, trg_lang_id, max_len=max_len, vocab_mask=vocab_mask) else: decoded, _ = decoder.generate_beam( encoded, lens, trg_lang_id, beam_size=params.beam_size, length_penalty=0.9, early_stopping=False, max_len=max_len, vocab_mask=vocab_mask) for j in range(decoded.size(1)): sent = decoded[:, j] delimiters = (sent == params.eos_index).nonzero().view(-1) assert len(delimiters) >= 1 and delimiters[0].item() == 0 sent = sent[1:] if len( delimiters) == 1 else sent[1:delimiters[1]] trg_tokens = [dico[sent[k].item()] for k in range(len(sent))] trg_words = tokens2words(trg_tokens) if trg_lang == "zh": results.append(" ".join("".join(trg_words))) else: results.append(" ".join(trg_words)) return results
def summ4dataset(self, direction): direction = direction.split("-") params = self.params encoder = self.encoder decoder = self.decoder encoder.eval() decoder.eval() dico = self.dico x_lang, y_lang = direction print("Performing %s-%s-xsumm" % (x_lang, y_lang)) X, Y = [], [] x_lang_id = params.lang2id[x_lang[-2:]] y_lang_id = params.lang2id[y_lang[-2:]] vocab_mask = self.vocab_mask[ y_lang[-2:]] if params.decode_with_vocab else None for batch in tqdm(self.get_iterator("test", x_lang, y_lang)): (sent_x, len_x), (sent_y, len_y), _ = batch lang_x = sent_x.clone().fill_(x_lang_id) # lang_y = sent_y.clone().fill_(y_lang_id) sent_x, len_x, lang_x = to_cuda(sent_x, len_x, lang_x) with torch.no_grad(): encoded = encoder("fwd", x=sent_x, lengths=len_x, langs=lang_x, causal=False) encoded = encoded.transpose(0, 1) if params.beam_size == 1: decoded, _ = decoder.generate(encoded, len_x, y_lang_id, max_len=params.max_dec_len, vocab_mask=vocab_mask) else: decoded, _ = decoder.generate_beam( encoded, len_x, y_lang_id, beam_size=params.beam_size, length_penalty=0.9, early_stopping=False, max_len=params.max_dec_len, vocab_mask=vocab_mask) for j in range(decoded.size(1)): sent = decoded[:, j] delimiters = (sent == params.eos_index).nonzero().view(-1) assert len(delimiters) >= 1 and delimiters[0].item() == 0 sent = sent[1:] if len( delimiters) == 1 else sent[1:delimiters[1]] trg_tokens = [dico[sent[k].item()] for k in range(len(sent))] trg_words = tokens2words(trg_tokens) if y_lang.endswith("zh"): Y.append(" ".join("".join(trg_words))) else: Y.append(" ".join(trg_words)) return Y
def main(args): rng = np.random.RandomState(0) # Make dump path if not os.path.exists(args.dump_path): subprocess.Popen("mkdir -p %s" % args.dump_path, shell=True).wait() else: if os.listdir(args.dump_path): m = "Directory {} is not empty.".format(args.dump_path) raise ValueError(m) if len(args.log_file): write_log = True else: write_log = False # load model parameters model_dir = os.path.dirname(args.load_model) params_path = os.path.join(model_dir, 'params.pkl') with open(params_path, "rb") as f: params = pickle.load(f) # load data parameters and model parameters from checkpoint checkpoint_path = os.path.join(model_dir, 'checkpoint.pth') assert os.path.isfile(checkpoint_path) data = torch.load( checkpoint_path, map_location=lambda storage, loc: storage.cuda(params.local_rank)) for k, v in data["params"].items(): params.__dict__[k] = v dico = Dictionary(data["dico_id2word"], data["dico_word2id"], data["dico_counts"]) # Print score for k, v in data["best_metrics"].items(): print("- {}: {}".format(k, v)) # Fix some of the params we pass to load_data params.debug_train = False params.max_vocab = -1 params.min_count = 0 params.tokens_per_batch = -1 params.max_batch_size = args.batch_size params.batch_size = args.batch_size # load data data = load_data(args.data_path, params) # Print data summary for (src, tgt), dataset in data['para'].items(): datatype = "Para data (%s)" % ( "WITHOUT labels" if dataset.labels is None else "WITH labels") m = '{: <27} - {: >12}:{: >10}'.format(datatype, '%s-%s' % (src, tgt), len(dataset)) print(m) # Fix some of the params we pass to the model builder params.reload_model = args.load_model # build model if params.encoder_only: model = build_model(params, dico) else: encoder, decoder = build_model(params, dico) model = encoder # Predict model = model.module if params.multi_gpu else model model.eval() start = time.time() for (src, tgt), dataset in data['para'].items(): path = os.path.join(args.dump_path, "{}-{}.pred".format(src, tgt)) scores_file = open(path, "w") lang1_id = params.lang2id[src] lang2_id = params.lang2id[tgt] diffs = [] nb_written = 0 for batch in dataset.get_iterator(False, group_by_size=False, n_sentences=-1, return_indices=False): (sent1, len1), (sent2, len2), labels = batch sent1, len1 = truncate(sent1, len1, params.max_len, params.eos_index) sent2, len2 = truncate(sent2, len2, params.max_len, params.eos_index) x, lengths, positions, langs = concat_batches(sent1, len1, lang1_id, sent2, len2, lang2_id, params.pad_index, params.eos_index, reset_positions=True) x, lengths, positions, langs = to_cuda(x, lengths, positions, langs) with torch.no_grad(): # Get sentence pair embedding h = model('fwd', x=x, lengths=lengths, positions=positions, langs=langs, causal=False)[0] CLF_ID1, CLF_ID2 = 8, 9 # very hacky, use embeddings to make weights for the classifier emb = (model.module if params.multi_gpu else model).embeddings.weight pred = F.linear(h, emb[CLF_ID1].unsqueeze(0), emb[CLF_ID2, 0]) pred = torch.sigmoid(pred) pred = pred.view(-1).cpu().numpy().tolist() for p, l1, l2 in zip(pred, len1, len2): if l1.item() == 0 and l2.item() == 0: scores_file.write("0.00000000\n") else: scores_file.write("{:.8f}\n".format(p)) nb_written += len(pred) if nb_written % 1000 == 0: elapsed = int(time.time() - start) lpss = elapsed % 60 lpsm = elapsed // 60 lpsh = lpsm // 60 lpsm = lpsm % 60 msg = "[{:02d}:{:02d}:{:02d} {}-{}]".format( lpsh, lpsm, lpss, src, tgt) msg += " {}/{} ({:.2f}%) sentences processed".format( nb_written, len(dataset), 100 * nb_written / len(dataset)) print(msg) if write_log: with open(args.log_file, "a") as fout: fout.write(msg + "\n") # Try reversing order if TEST_REVERSE: x, lengths, positions, langs = concat_batches( sent2, len2, lang2_id, sent1, len1, lang1_id, params.pad_index, params.eos_index, reset_positions=True) x, lengths, positions, langs = to_cuda(x, lengths, positions, langs) with torch.no_grad(): # Get sentence pair embedding h = model('fwd', x=x, lengths=lengths, positions=positions, langs=langs, causal=False)[0] CLF_ID1, CLF_ID2 = 8, 9 # very hacky, use embeddings to make weights for the classifier emb = (model.module if params.multi_gpu else model).embeddings.weight pred_rev = F.linear(h, emb[CLF_ID1].unsqueeze(0), emb[CLF_ID2, 0]) pred_rev = torch.sigmoid(pred_rev) pred_rev = pred_rev.view(-1).cpu().numpy().tolist() for p, pp in zip(pred, pred_rev): diffs.append(p - pp) if TEST_REVERSE: print( "Average absolute diff between score(l1,l2) and score(l2,l1): {}" .format(np.mean(np.abs(diffs)))) scores_file.close()
def transform_perspective(data, is_batch=False, is_seq=False, is_ped=False): m_array = [] if is_batch: (x_origin, y_origin, x_rel, y_rel, loss_mask, seq_start_end, frame_idx, ped_idx, cluster_label, destinations, datasets) = data new_x_origin = [] new_y_origin = [] new_destinations = [] for i, (start, end) in enumerate(seq_start_end.data): seq_x_origin = x_origin[:, start:end].contiguous() seq_y_origin = y_origin[:, start:end].contiguous() seq_destinations = destinations[i] (new_seq_x_origin, new_seq_y_origin, new_seq_destinations, m) = change_perspective(seq_x_origin, seq_y_origin, seq_destinations, get_m=True) new_x_origin.append(new_seq_x_origin) new_y_origin.append(new_seq_y_origin) new_destinations.append(new_seq_destinations.contiguous()) m_array.append(m) new_x_origin = torch.cat(new_x_origin, dim=1) new_y_origin = torch.cat(new_y_origin, dim=1) new_x_rel = torch.zeros_like(new_x_origin) new_y_rel = torch.zeros_like(new_y_origin) new_x_rel[1:] = new_x_origin[1:] - new_x_origin[:-1] new_y_rel[1:] = new_y_origin[1:] - new_y_origin[:-1] new_y_rel[0] = new_y_origin[0] - new_x_origin[-1] new_batch = (new_x_origin, new_y_origin, new_x_rel, new_y_rel, loss_mask, seq_start_end, frame_idx, ped_idx, cluster_label, new_destinations, datasets) return global_utils.to_cuda(new_batch), m_array elif is_seq: (seq_x_origin, seq_y_origin, seq_x_rel, seq_y_rel, seq_loss_mask, seq_dataset, seq_dest, seq_label, start, end, seq_start_frame, seq_ped_idx) = data (new_seq_x_origin, new_seq_y_origin, new_seq_destinations, m) = change_perspective(seq_x_origin, seq_y_origin, seq_destinations, get_m=True) m_array.append(m) new_seq_x_origin = new_seq_x_origin.cuda() new_seq_y_origin = new_seq_y_origin.cuda() new_seq_destinations = new_seq_destinations.cuda() new_seq_x_rel = torch.zeros_like(new_seq_x_origin) new_seq_y_rel = torch.zeros_like(new_seq_y_origin) new_seq_x_rel[1:] = new_seq_x_origin[1:] - new_seq_x_origin[:-1] new_seq_y_rel[1:] = new_seq_y_origin[1:] - new_seq_y_origin[:-1] new_seq_y_rel[0] = new_seq_y_origin[0] - new_seq_x_origin[-1] return (new_seq_x_origin, new_seq_y_origin, new_seq_x_rel, new_seq_y_rel, seq_loss_mask, seq_dataset, new_seq_destinations, seq_label, start, end, seq_start_frame, seq_ped_idx), m_array else: (ped_x_origin, ped_x_rel, ped_loss_mask, ped_dest) = data new_ped_x_origin, _, new_ped_dest, m = change_perspective(ped_x_origin, None, ped_dest, get_m=True) m_array.append(m) new_ped_x_origin = new_ped_x_origin.cuda() new_ped_dest = new_ped_dest.cuda() new_ped_x_rel = torch.zeros_like(new_ped_x_origin) new_ped_x_rel[1:] = new_ped_x_origin[1:] - new_ped_x_origin[:-1] return (new_ped_x_origin, new_ped_x_rel, ped_loss_mask, new_ped_dest), m_array
def run_test(self): params = self.params result_path = params.test_result_path + '_{}'.format(self.gpu) self.model.eval() lang_id1 = params.lang2id[params.src_lang] lang_id2 = params.lang2id[params.trg_lang] proba_result = [] src_text_list = [] trg_text_list = [] with torch.no_grad(): for sent1, len1, sent2, len2, _, src_text, trg_text in tqdm( self.dataloader['test']): sent1, len1 = truncate(sent1, len1, params.max_len, params.eos_index) sent2, len2 = truncate(sent2, len2, params.max_len, params.eos_index) x, lengths, positions, langs = concat_batches( sent1, len1, lang_id1, sent2, len2, lang_id2, params.pad_index, params.eos_index, reset_positions=True) # cuda x, lengths, positions, langs = to_cuda(x, lengths, positions, langs, gpu=self.gpu) # forward output = self.model(x, lengths, positions, langs) proba = F.softmax(output, 1)[:, 1] proba_result.extend(proba.cpu().numpy()) src_text_list.extend(src_text) trg_text_list.extend(trg_text) assert len(proba_result) == len(src_text_list) assert len(proba_result) == len(trg_text_list) if len(proba_result) > params.flush_frequency: logger.info(" GPU %i - write out score..." % self.gpu) with open(result_path, 'a') as f: for i in range(len(proba_result)): f.write('{}{}{}{}{}'.format( src_text_list[i], params.delimeter, trg_text_list[i], params.delimeter, str(proba_result[i])) + os.linesep) proba_result = [] src_text_list = [] trg_text_list = [] # write out the remainings logger.info(" GPU %i - write out score..." % self.gpu) with open(result_path, 'a') as f: for i in range(len(proba_result)): f.write('{}{}{}{}{}'.format( src_text_list[i], params.delimeter, trg_text_list[i], params.delimeter, str(proba_result[i])) + os.linesep) proba_result = [] src_text_list = [] trg_text_list = []
def fwd(self, x, lengths, ipm=False, cmlm=False, candidates=None, spatials=None, src_enc=None, src_len=None, positions=None, langs=None, cache=None): """ Inputs: `x` LongTensor(slen, bs), containing word indices/ LongTensor(flen, bs, features), containing region features `lengths` LongTensor(bs), containing the length of each sentence/ length of the region per pass `candidates` only when cmlm or ipm is True `ipm` Boolean, if True, input is ipm `cmlm` Boolean, if True, input is cap-img pair `positions` LongTensor(slen, bs), containing word positions """ # check inputs if (cmlm is False): if ipm: inputlen_ipm, bs, _ = x.size() assert candidates is not None spatials = spatials.transpose(0, 1) x = x.transpose(0, 1) # batch size as dimension 0 # generate masks mask_i, attn_mask_i = get_masks_image(x) mask_i, attn_mask_i = to_cuda(mask_i, attn_mask_i) # assume img id is 1 modality = x.new_ones([bs, inputlen_ipm]).long() else: inputlen_mlm, bs = x.size() # generate masks mask_w, attn_mask_w = get_masks_word(inputlen_mlm, lengths) x = x.transpose(0, 1) # batch size as dimension 0 positions = x.new(inputlen_mlm).long() positions = torch.arange(inputlen_mlm, out=positions).unsqueeze(0) # assume cap id is zero modality = x.new_zeros(x.size()).long() modality = to_cuda(modality)[0] else: assert len(x) == 2 inputlen_cap, bs = x[0].size() inputlen_img, _, _ = x[1].size() assert spatials[ 0] is None # sanity check for word spatial, must be None spatials_img = spatials[1].transpose(0, 1) # bs, n_features, 6 assert lengths[0].size()[0] == lengths[1].size()[0] == bs assert lengths[0].max().item() <= inputlen_cap and lengths[1].max( ).item() <= inputlen_img x_cap = x[0].transpose(0, 1) # batch size as dimension 0 x_img = x[1].transpose(0, 1) # batch size as dimension 0 mask_w, attn_mask_w = get_masks_word(inputlen_cap, lengths[0]) mask_i, attn_mask_i = get_masks_image(x_img) mask_w, attn_mask_w, mask_i, attn_mask_i = to_cuda( mask_w, attn_mask_w, mask_i, attn_mask_i) positions = positions.transpose(0, 1) # print(f'x_cap: {x_cap.shape}') # print(f'x_img: {x_img.shape}') # print(f'mask_i: {mask_i.shape}') # print(f'mask_w: {mask_w.shape}') # print(f'positions: {positions.shape}') modality_cap = x_cap.new_zeros(bs, inputlen_cap).long() modality_img = x_img.new_ones(bs, inputlen_img).long() modality_cap, modality_img = to_cuda(modality_cap, modality_img) # do not recompute cached elements if cache is not None: if (cmlm is False): if ipm: _inputlen_ipm = inputlen_ipm - cache['inputlen_ipm'] spatials = spatials[:, -_inputlen_ipm:, :] x = x[:, -_inputlen_ipm:, :] mask_i = mask_w[:, -_inputlen_ipm:] attn_mask_i = attn_mask_w[:, -_inputlen_ipm:] else: _inputlen_mlm = inputlen_mlm - cache['inputlen_mlm'] x = x[:, -_inputlen_mlm:] positions = positions[:, -_inputlen_mlm:] mask_w = mask_w[:, -_inputlen_mlm:] attn_mask_w = attn_mask_w[:, -_inputlen_mlm:] # embedding/ position embeddings if (cmlm is False): if ipm: # image embedding tensor_i = self.img_feature_embeddings(x) tensor_i = tensor_i + self.img_spatial_embeddings(spatials) tensor_i = tensor_i + self.modality_embeddings(modality) tensor_i = self.layer_norm_emb_i(tensor_i) tensor_i = F.dropout(tensor_i, p=self.dropout, training=self.training) tensor_i *= mask_i.unsqueeze(-1).to(tensor_i.dtype) # turn image candidates into language space candidates = self.img_feature_embeddings(candidates) with torch.no_grad(): tensor_w = tensor_i.new_zeros(tensor_i.size()) tensor_w = to_cuda(tensor_w)[0] else: # word embeddings tensor_w = self.word_embeddings(x) tensor_w = tensor_w + self.position_embeddings( positions).expand_as(tensor_w) tensor_w = tensor_w + self.modality_embeddings(modality) tensor_w = self.layer_norm_emb_w(tensor_w) tensor_w = F.dropout(tensor_w, p=self.dropout, training=self.training) tensor_w *= mask_w.unsqueeze(-1).to(tensor_w.dtype) with torch.no_grad(): tensor_i = tensor_w.new_zeros(tensor_w.size()) tensor_i = to_cuda(tensor_i)[0] else: tensor_i = self.img_feature_embeddings(x_img[:, 1:]) assert x_img[:, 0].sum().long() == 0 tensor_i_ave = self.img_embedding( torch.sum(x_img[:, 0], dim=-1).long().unsqueeze(-1)) tensor_i = torch.cat((tensor_i_ave, tensor_i), 1) tensor_i = tensor_i + self.img_spatial_embeddings(spatials_img) tensor_i = tensor_i + self.modality_embeddings(modality_img) tensor_i = self.layer_norm_emb_i(tensor_i) tensor_i = F.dropout(tensor_i, p=self.dropout, training=self.training) tensor_i *= mask_i.unsqueeze(-1).to(tensor_i.dtype) # tranform raw img fetures into same img embedding space candidates = self.img_feature_embeddings(candidates) tensor_w = self.word_embeddings(x_cap) tensor_w = tensor_w + self.position_embeddings( positions).expand_as(tensor_w) tensor_w = tensor_w + self.modality_embeddings(modality_cap) tensor_w = self.layer_norm_emb_w(tensor_w) tensor_w = F.dropout(tensor_w, p=self.dropout, training=self.training) tensor_w *= mask_w.unsqueeze(-1).to(tensor_w.dtype) # transformer layers for i in range(self.n_layers): tensor_i, tensor_w = self.fusion[i](tensor_i, tensor_w, cmlm) if (ipm or cmlm): # self attention for image transformer attn_i = self.attentions_i[i](tensor_i, attn_mask_i, cache=cache) attn_i = F.dropout(attn_i, p=self.dropout, training=self.training) tensor_i = tensor_i + attn_i tensor_i = self.layer_norm1_i[i](tensor_i) # FFN for image transformer tensor_i = tensor_i + self.ffns_i[i](tensor_i) tensor_i = self.layer_norm2_i[i](tensor_i) tensor_i *= mask_i.unsqueeze(-1).to(tensor_i.dtype) if ((not ipm) or cmlm): # self attention for caption/language transformer attn_w = self.attentions_w[i](tensor_w, attn_mask_w, cache=cache) attn_w = F.dropout(attn_w, p=self.dropout, training=self.training) tensor_w = tensor_w + attn_w tensor_w = self.layer_norm1_w[i](tensor_w) # FFN for caption/language transformer tensor_w = tensor_w + self.ffns_w[i](tensor_w) tensor_w = self.layer_norm2_w[i](tensor_w) tensor_w *= mask_w.unsqueeze(-1).to(tensor_w.dtype) # update cache length if cache is not None: if (cmlm is False): if ipm: cache['inputlen_ipm'] += tensor_i.size(1) else: cache['inputlen_mlm'] += tensor_w.size(1) # move back sequence length to dimension 0 tensor_w = tensor_w.transpose(0, 1) tensor_i = tensor_i.transpose(0, 1) if cmlm: return tensor_i, tensor_w, candidates elif ipm: return tensor_i, candidates else: return tensor_w
def enc_dec(self, x, lengths, langs, z, non_mask_deb, bs): if non_mask_deb is not None: x_non_deb = x[:, non_mask_deb] # (seq_len, bs) lengths_non_deb = lengths[non_mask_deb] langs_non_deb = langs[:, non_mask_deb] if bs == 1: x_non_deb = x_non_deb.squeeze(1) langs_non_deb = langs_non_deb.squeeze(1) lengths_non_deb = lengths_non_deb.squeeze(0) x_non_deb, langs_non_deb = arrange_x(x_non_deb, lengths_non_deb, langs_non_deb) z = z[ non_mask_deb] #.squeeze(0) if bs == 1 else z[non_mask_deb] # (bs, seq_len, dim) else: x_non_deb = x + 0 # (seq_len, bs) lengths_non_deb = lengths langs_non_deb = langs max_len = lengths_non_deb.max() if self.denoising_ae: (x2, len2) = (x_non_deb.cpu(), lengths_non_deb.cpu()) #(x1, len1) = (x_non_deb, lengths_non_deb) (x1, len1) = self.pre_trainer.add_noise(x_non_deb.cpu(), lengths_non_deb.cpu()) # target words to predict alen = torch.arange(max_len, dtype=torch.long, device=len2.device) pred_mask = alen[:, None] < len2[ None] - 1 # do not predict anything given the last target word y = x2[1:].masked_select(pred_mask[:-1]) assert len(y) == (len2 - 1).sum().item() # cuda x1, len1, x2, len2, y = to_cuda(x1, len1, x2, len2, y) # encode source sentence langs1 = langs_non_deb[torch.arange(x1.size(0))] enc1 = self.pre_trainer.encoder('fwd', x=x1, lengths=len1, langs=langs1, causal=False) enc1 = enc1.transpose(0, 1) #lambda_coeff = self.pre_trainer.params.lambda_ae lambda_coeff = 1 else: x2, len1, len2 = x_non_deb, lengths_non_deb, lengths_non_deb enc1 = z # target words to predict alen = torch.arange(max_len, dtype=torch.long, device=len2.device) pred_mask = alen[:, None] < len2[ None] - 1 # do not predict anything given the last target word y = x2[1:].masked_select(pred_mask[:-1]) assert len(y) == (len2 - 1).sum().item() # cuda y = y.to(x_non_deb.device) lambda_coeff = 1 dec2 = self.pre_trainer.decoder('fwd', x=x2, lengths=len2, langs=langs_non_deb, causal=True, src_enc=enc1, src_len=len1) word_scores, loss_rec = self.pre_trainer.decoder('predict', tensor=dec2, pred_mask=pred_mask, y=y, get_scores=False) loss_rec = lambda_coeff * loss_rec return loss_rec, word_scores, y
def fgim_algorithm(self, get_loss, end_of_epoch): """ Controllable Unsupervised Text Attribute Transfer via Editing Entangled Latent Representation """ threshold = 0.001 lambda_ = 0.9 max_iter_per_epsilon = 100 w = [2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0] limit_batches = 10 # eval mode #self.deb.eval() #self.model.eval() self.pre_trainer.encoder.eval() self.pre_trainer.decoder.eval() text_z_prime = { KEYS["input"]: [], KEYS["gen"]: [], KEYS["deb"]: [], "origin_labels": [], "pred_label": [], "change": [], "w_i": [] } references = [] hypothesis = [] hypothesis2 = [] n_batches = 0 def get_y_hat(logits): if self.bin_classif: probs = torch.sigmoid(logits) y_hat = probs.round().int() else: probs, y_hat = logits.max(dim=1) return y_hat, probs for batch in tqdm(self.train_data_iter): stats_ = {} (x, lengths, langs), y1, y2, weight_out = batch stats_['n_words'] = lengths.sum().item() flag = True if self.params.train_only_on_negative_examples: #negative_examples = ~(y2.squeeze() < self.params.threshold) negative_examples = y2.squeeze() > self.params.threshold batch, flag = select_with_mask(batch, mask=negative_examples) (x, lengths, langs), y1, y2, weight_out = batch if flag: y = y2 if self.params.version == 3 else y1 x, y, y2, lengths, langs = to_cuda(x, y, y2, lengths, langs) #langs = None batch = [(x, lengths, langs), y1, y2, weight_out] origin_data = self.pre_trainer.encoder('fwd', x=x, lengths=lengths, langs=langs, causal=False) # Define target label if self.bin_classif: #y_prime = self.max_label - y y_prime = self.max_label - (y > self.params.threshold).float() else: #y_prime = self.max_label - (y > self.params.threshold).float() y_prime = self.max_label - y #.float() batch[2 if self.params.version == 3 else 1] = y_prime flag = False for w_i in w: #print("---------- w_i:", w_i) data = to_var(origin_data.clone() ) # (batch_size, seq_length, latent_size) b = True if b: data.requires_grad = True logits, classif_loss = self.model.predict( data, y_prime, weights=self.train_data_iter.weights) #_, stats = get_loss(None, batch, self.params, None, logits = logits, loss = classif_loss, mode="train", epoch = self.epoch) #y_hat = stats["label_pred"] y_hat, _ = get_y_hat(logits) self.model.zero_grad() classif_loss.backward() data = data - w_i * data.grad.data else: logits, classif_loss = self.model.predict( data, y_prime, weights=self.train_data_iter.weights) #_, stats = get_loss(None, batch, self.params, None, logits = logits, loss = classif_loss, mode="train", epoch = self.epoch) #y_hat = stats["label_pred"] y_hat, _ = get_y_hat(logits) it = 0 while True: #if torch.cdist(y_hat, y_prime) < threshold : #if ((y_hat - y_prime)**2).sum().float().sqrt() < threshold : #if (y_hat - y_prime).abs().float().mean() < threshold : if y_hat == y_prime: flag = True break data = to_var(data.clone( )) # (batch_size, seq_length, latent_size) # Set requires_grad attribute of tensor. Important for Attack data.requires_grad = True logits, classif_loss = self.model.predict( data, y_prime, weights=self.train_data_iter.weights) # Calculate gradients of model in backward pass self.model.zero_grad() classif_loss.backward() data = data - w_i * data.grad.data it += 1 # data = perturbed_data w_i = lambda_ * w_i if it > max_iter_per_epsilon: break data = data.transpose(0, 1) origin_data = origin_data.transpose(0, 1) texts = self.generate(x, lengths, langs, origin_data, z_prime=data, log=False) for k, v in texts.items(): text_z_prime[k].append(v) references.extend(texts[KEYS["input"]]) hypothesis.extend(texts[KEYS["gen"]]) hypothesis2.extend(texts[KEYS["deb"]]) text_z_prime["origin_labels"].append(y2.cpu().numpy()) text_z_prime["pred_label"].append(y_hat.cpu().numpy()) text_z_prime["change"].append([flag] * len(y2)) text_z_prime["w_i"].append([w_i] * len(y2)) n_batches += 1 if n_batches > limit_batches: break self.end_eval(text_z_prime, references, hypothesis, hypothesis2)
def eval_step(self, get_loss, test=False, prefix=""): # eval mode self.deb.eval() self.model.eval() self.pre_trainer.encoder.eval() self.pre_trainer.decoder.eval() total_stats = [] text_z_prime = { KEYS["input"]: [], KEYS["gen"]: [], KEYS["deb"]: [], "origin_labels": [], "pred_label": [] } references = [] hypothesis = [] hypothesis2 = [] with torch.no_grad(): for batch in tqdm(self.val_data_iter, desc='val'): n_words, xe_loss, n_valid = 0, 0, 0 (x, lengths, langs), y1, y2, weight_out = batch flag = True """ # only on negative example #negative_examples = ~(y2.squeeze() < self.params.threshold) negative_examples = y2.squeeze() > self.params.threshold batch, flag = select_with_mask(batch, mask = negative_examples) (x, lengths, langs), y1, y2, weight_out = batch #""" if flag: y = y2 if self.params.version == 3 else y1 x, y, lengths, langs = to_cuda(x, y, lengths, langs) #langs = langs if self.params.n_langs > 1 else None #langs = None batch = (x, lengths, langs), y1, y2, weight_out _, _, z, _, stats, y_hat = self.classif_step( get_loss, y, batch) z = z.transpose(0, 1) # (bs-ϵ, seq_len, dim) bs = z.size(0) z_prime = self.deb('fwd', x=z, lengths=lengths, causal=False) z_prime = z_prime.transpose(0, 1) # (bs-ϵ, seq_len, dim) non_mask_deb = torch.BoolTensor([True] * bs) loss_rec, word_scores, y_ = self.enc_dec( x, lengths, langs, z, non_mask_deb, bs) # update stats n_words += y_.size(0) xe_loss += loss_rec.item() * len(y_) n_valid += (word_scores.max(1)[1] == y_).sum().item() # compute perplexity and prediction accuracy n_words = n_words + eps stats['rec_ppl'] = np.exp(xe_loss / n_words) stats['rec_acc'] = 100. * n_valid / n_words texts = self.generate(x, lengths, langs, z, z_prime=z_prime, log=False) for k, v in texts.items(): text_z_prime[k].append(v) references.extend(texts[KEYS["input"]]) hypothesis.extend(texts[KEYS["gen"]]) hypothesis2.extend(texts[KEYS["deb"]]) text_z_prime["origin_labels"].append(y.cpu().numpy()) text_z_prime["pred_label"].append(y_hat.cpu().numpy()) total_stats.append(stats) self.end_eval(text_z_prime, references, hypothesis, hypothesis2) if test: pre_train_scores = {} return total_stats, pre_train_scores return total_stats
def train_step(self, get_loss): # train mode self.deb.train() self.model.train() self.pre_trainer.encoder.train() self.pre_trainer.decoder.train() total_stats = [] for i, batch in enumerate(self.train_data_iter): stats_ = {} n_words, xe_loss, n_valid = 0, 0, 0 (x, lengths, langs), y1, y2, weight_out = batch stats_['n_words'] = lengths.sum().item() flag = True if self.params.train_only_on_negative_examples: #negative_examples = ~(y2.squeeze() < self.params.threshold) negative_examples = y2.squeeze() > self.params.threshold batch, flag = select_with_mask(batch, mask=negative_examples) (x, lengths, langs), y1, y2, weight_out = batch if flag: y = y2 if self.params.version == 3 else y1 x, y, lengths, langs = to_cuda(x, y, lengths, langs) #langs = langs if self.params.n_langs > 1 else None #langs = None batch = (x, lengths, langs), y1, y2, weight_out classif_loss, logits, z, z_list, stats, y_hat = self.classif_step( get_loss, y, batch) #self.optimize(classif_loss, retain_graph = True) stats_ = {**stats, **stats_} version = 1 if version == 0: mask_deb = y_hat.squeeze( ) >= self.lambda_ if self.params.positive_label == 0 else y_hat.squeeze( ) < self.lambda_ non_mask_deb = ~mask_deb flag = mask_deb.any() rec_step = non_mask_deb.any() else: mask_deb = None non_mask_deb = None flag = True rec_step = True ############### z = z.transpose(0, 1) # (bs-ϵ, seq_len, dim) bs = z.size(0) loss_deb = 0 # torch.tensor(float("nan")) loss_rec = 0 # torch.tensor(float("nan")) if flag: # if f(z) > lambda : loss_deb, _, _ = self.debias_step(y, lengths, z, z_list, mask_deb, bs) if rec_step: # else : loss_rec, word_scores, y_ = self.enc_dec( x, lengths, langs, z, non_mask_deb, bs) # update stats n_words += y_.size(0) xe_loss += loss_rec.item() * len(y_) n_valid += (word_scores.max(1)[1] == y_).sum().item() # compute perplexity and prediction accuracy n_words = n_words + eps stats_['rec_ppl'] = np.exp(xe_loss / n_words) stats_['rec_acc'] = 100. * n_valid / n_words # optimize loss = classif_loss + loss_deb + loss_rec #self.pre_trainer.optimize(loss) stats_["loss_"] = loss.item() #if True : if self.n_total_iter % self.log_interval == 0: self.generate(x, lengths, langs, z) # number of processed sentences / words self.n_sentences += self.params.batch_size self.stats['processed_s'] += self.params.batch_size self.stats['processed_w'] += stats_['n_words'] self.stats['progress'] = min( int(((i + 1) / self.params.train_num_step) * 100), 100) total_stats.append(stats_) self.put_in_stats(stats_) self.iter() self.print_stats() if self.epoch_size < self.n_sentences: break return total_stats