def t2g_teach_g2t_one_step(raw_batch, model_t2g, model_g2t, optimizer, config, vocab): # train a g2t model with the synthetic input from t2g model model_t2g.eval() model_g2t.train() batch_t2g = batch2tensor_t2g(raw_batch, config['t2g']['device'], vocab) with torch.no_grad(): t2g_pred = model_t2g(batch_t2g).argmax(-1).cpu() syn_batch = [] for _i, sample in enumerate(t2g_pred): rel = [] for e1 in range(len(raw_batch[_i]['ent_text'])): for e2 in range(len(raw_batch[_i]['ent_text'])): try: if sample[e1, e2] != 3 and sample[e1, e2] != 0: # 3 is no relation and 0 is <PAD> rel.append([e1, int(sample[e1, e2]), e2]) except: logging.warn('{0:}'.format([[vocab['entity'](x) for x in y] for y in raw_batch[_i]['ent_text']])) logging.warn('{0:}'.format(sample.size())) _syn = tensor2data_t2g(raw_batch[_i], rel, vocab) syn_batch.append(_syn) if len(syn_batch) == 0: return None batch_g2t = batch2tensor_g2t(syn_batch, config['g2t']['device'], vocab) loss = train_g2t_one_step(batch_g2t, model_g2t, optimizer, config['g2t']) return loss
def supervise(batch_g2t, batch_t2g, model_g2t, model_t2g, optimizerG2T, optimizerT2G, config, t2g_weight, vocab): model_g2t.blind, model_t2g.blind = False, False batch = batch2tensor_t2g(batch_t2g, config['t2g']['device'], vocab) _loss1 = train_t2g_one_step(batch, model_t2g, optimizerT2G, config['t2g'], t2g_weight=t2g_weight) batch = batch2tensor_g2t(batch_g2t, config['g2t']['device'], vocab) _loss2 = train_g2t_one_step(batch, model_g2t, optimizerG2T, config['g2t']) return _loss1, _loss2
def warmup_step1(batch_g2t, batch_t2g, model_g2t, model_t2g, optimizerG2T, optimizerT2G, config, t2g_weight, vocab): model_g2t.blind, model_t2g.blind = True, True batch = batch2tensor_t2g(batch_t2g, config['t2g']['device'], vocab) loss1 = train_t2g_one_step(batch, model_t2g, optimizerT2G, config['t2g'], t2g_weight=t2g_weight) batch = batch2tensor_g2t(batch_g2t, config['g2t']['device'], vocab) loss2 = train_g2t_one_step(batch, model_g2t, optimizerG2T, config['g2t']) return loss1, loss2
def eval_g2t(pool, _type, vocab, model, config, display=True): logging.info('Eval on {0:}'.format(_type)) model.eval() hyp, ref, _same = [], [], [] unq_hyp = {} unq_ref = defaultdict(list) batch_size = 8 * config['batch_size'] with tqdm.tqdm(list(pool.draw_with_type(batch_size, False, _type)), disable=(not display)) as tqb: for i, _batch in enumerate(tqb): with torch.no_grad(): batch = batch2tensor_g2t(_batch, config['device'], vocab) seq = model(batch, beam_size=config['beam_size']) r = write_txt(batch, batch['tgt'], vocab['text']) h = write_txt(batch, seq, vocab['text']) _same.extend([str(x['raw_relation'])+str(x['ent_text']) for x in _batch]) hyp.extend(h) ref.extend(r) hyp = [x[0] for x in hyp] ref = [x[0] for x in ref] idxs, _same = list( zip(*sorted(enumerate(_same), key=lambda x: x[1])) ) ptr = 0 for i in range(len(hyp)): if i > 0 and _same[i] !=_same[i - 1]: ptr += 1 unq_hyp[ptr] = hyp[idxs[i]] unq_ref[ptr].append(ref[idxs[i]]) max_len = max([len(ref) for ref in unq_ref.values()]) unq_hyp = sorted(unq_hyp.items(), key=lambda x:x[0]) unq_ref = sorted(unq_ref.items(), key=lambda x:x[0]) hyp = [x[1] for x in unq_hyp] ref = [[x.lower() for x in y[1]] for y in unq_ref] wf_h = open('hyp.txt', 'w') for i,h in enumerate(hyp): wf_h.write(str(h)+'\n') wf_h.close() hyp = dict(zip(range(len(hyp)), [[x.lower()] for x in hyp])) ref = dict(zip(range(len(ref)), ref)) ret = bleu.compute_score(ref, hyp) logging.info('BLEU INP {0:}'.format(len(hyp))) logging.info('BLEU 1-4 {0:}'.format(ret[0])) logging.info('METEOR {0:}'.format(meteor.compute_score(ref, hyp)[0])) logging.info('ROUGE_L {0:}'.format(rouge.compute_score(ref, hyp)[0])) logging.info('Cider {0:}'.format(cider.compute_score(ref, hyp)[0])) return ret[0][-1]
def g2t_teach_t2g_one_step(raw_batch, model_g2t, model_t2g, optimizer, config, vocab, t2g_weight=None): # train a t2g model with the synthetic input from g2t model model_g2t.eval() model_t2g.train() syn_batch = [] if len(raw_batch)>0: batch_g2t = batch2tensor_g2t(raw_batch, config['g2t']['device'], vocab) with torch.no_grad(): g2t_pred = model_g2t(batch_g2t, beam_size=1).cpu() for _i, sample in enumerate(g2t_pred): _s = sample.tolist() if 2 in _s: # <EOS> in list _s = _s[:_s.index(2)] _syn = tensor2data_g2t(raw_batch[_i], _s) syn_batch.append(_syn) batch_t2g = batch2tensor_t2g(syn_batch, config['t2g']['device'], vocab, add_inp=True) loss = train_t2g_one_step(batch_t2g, model_t2g, optimizer, config['t2g'], t2g_weight=t2g_weight) return loss