def g2t_teach_t2g_one_step(raw_batch, model_g2t, model_t2g, optimizer, config, vocab, t2g_weight=None): # train a t2g model with the synthetic input from g2t model model_g2t.eval() model_t2g.train() syn_batch = [] if len(raw_batch) > 0: batch_g2t = batch2tensor_g2t(raw_batch, config['g2t']['device'], vocab) with torch.no_grad(): g2t_pred = model_g2t(batch_g2t, beam_size=1).cpu() for _i, sample in enumerate(g2t_pred): _s = sample.tolist() if 2 in _s: # <EOS> in list _s = _s[:_s.index(2)] _syn = tensor2data_g2t(raw_batch[_i], _s) syn_batch.append(_syn) batch_t2g = batch2tensor_t2g(syn_batch, config['t2g']['device'], vocab, add_inp=True) loss = train_t2g_one_step(batch_t2g, model_t2g, optimizer, config['t2g'], t2g_weight=t2g_weight) return loss
def t2g_teach_g2t_one_step(raw_batch, model_t2g, model_g2t, optimizer, config, vocab): # train a g2t model with the synthetic input from t2g model model_t2g.eval() model_g2t.train() batch_t2g = batch2tensor_t2g(raw_batch, config['t2g']['device'], vocab) with torch.no_grad(): t2g_pred = model_t2g(batch_t2g).argmax(-1).cpu() syn_batch = [] for _i, sample in enumerate(t2g_pred): rel = [] for e1 in range(len(raw_batch[_i]['ent_text'])): for e2 in range(len(raw_batch[_i]['ent_text'])): try: if sample[e1, e2] != 3 and sample[ e1, e2] != 0: # 3 is no relation and 0 is <PAD> rel.append([e1, int(sample[e1, e2]), e2]) except: logging.warn('{0:}'.format( [[vocab['entity'](x) for x in y] for y in raw_batch[_i]['ent_text']])) logging.warn('{0:}'.format(sample.size())) _syn = tensor2data_t2g(raw_batch[_i], rel, vocab) syn_batch.append(_syn) if len(syn_batch) == 0: return None batch_g2t = batch2tensor_g2t(syn_batch, config['g2t']['device'], vocab) loss = train_g2t_one_step(batch_g2t, model_g2t, optimizer, config['g2t']) return loss
def supervise(batch_g2t, batch_t2g, model_g2t, model_t2g, optimizerG2T, optimizerT2G, config, t2g_weight, vocab): model_g2t.blind, model_t2g.blind = False, False batch = batch2tensor_t2g(batch_t2g, config['t2g']['device'], vocab) _loss1 = train_t2g_one_step(batch, model_t2g, optimizerT2G, config['t2g'], t2g_weight=t2g_weight) batch = batch2tensor_g2t(batch_g2t, config['g2t']['device'], vocab) _loss2 = train_g2t_one_step(batch, model_g2t, optimizerG2T, config['g2t']) return _loss1, _loss2
def warmup_step1(batch_g2t, batch_t2g, model_g2t, model_t2g, optimizerG2T, optimizerT2G, config, t2g_weight, vocab): model_g2t.blind, model_t2g.blind = True, True batch = batch2tensor_t2g(batch_t2g, config['t2g']['device'], vocab) loss1 = train_t2g_one_step(batch, model_t2g, optimizerT2G, config['t2g'], t2g_weight=t2g_weight) batch = batch2tensor_g2t(batch_g2t, config['g2t']['device'], vocab) loss2 = train_g2t_one_step(batch, model_g2t, optimizerG2T, config['g2t']) return loss1, loss2
def supervise( batch_g2t, batch_t2g, model_g2t, model_t2g, optimizerG2T, optimizerT2G, config, t2g_weight, vocab, ): model_g2t.blind, model_t2g.blind = False, False batch = batch2tensor_t2g(batch_t2g, config["t2g"]["device"], vocab) _loss1 = train_t2g_one_step(batch, model_t2g, optimizerT2G, config["t2g"], t2g_weight=t2g_weight) batch = batch2tensor_g2t(batch_g2t, config["g2t"]["device"], vocab) _loss2, kld = train_g2t_one_step(batch, model_g2t, optimizerG2T, config["g2t"]) return _loss1, _loss2, kld
def warmup_step1( batch_g2t, batch_t2g, model_g2t, model_t2g, optimizerG2T, optimizerT2G, config, t2g_weight, vocab, ): model_g2t.blind, model_t2g.blind = True, True batch = batch2tensor_t2g(batch_t2g, config["t2g"]["device"], vocab) loss1 = train_t2g_one_step(batch, model_t2g, optimizerT2G, config["t2g"], t2g_weight=t2g_weight) batch = batch2tensor_g2t(batch_g2t, config["g2t"]["device"], vocab) loss2, kld = train_g2t_one_step(batch, model_g2t, optimizerG2T, config["g2t"]) return loss1, loss2, kld
def eval_t2g(pool, _type, vocab, model, config, display=True): # evaluate t2g model logging.info('Eval on {0:}'.format(_type)) hyp, ref, pos_label = [], [], [] model.eval() wf = open('t2g_show.txt', 'w') with tqdm.tqdm(list(pool.draw_with_type(config['batch_size'], False, _type)), disable=False if display else True) as tqb: for i, _batch in enumerate(tqb): with torch.no_grad(): batch = batch2tensor_t2g(_batch, config['device'], vocab) pred = model(batch) _pred = pred.view(-1, pred.shape[-1]).argmax(-1).cpu().long().tolist() _gold = batch['tgt'].view(-1).cpu().long().tolist() tpred = pred.argmax(-1).cpu().numpy() tgold = batch['tgt'].cpu().numpy() cnts = [] for j in range(len(_batch)): _cnt = 0 ents = [[y for y in vocab['entity'](x) if y[0] != '<'] for x in _batch[j]['ent_text']] wf.write('=====================\n') rels = [] for e1 in range(len(ents)): for e2 in range(len(ents)): if tpred[j, e1, e2] != 3 and tpred[j, e1, e2] != 0: rels.append((e1, int(tpred[j, e1, e2]), e2)) wf.write( str([(ents[e1], vocab['relation'](r), ents[e2]) for e1, r, e2 in rels]) + '\n') rels = [] for e1 in range(len(ents)): for e2 in range(len(ents)): if tgold[j, e1, e2] != 3 and tgold[j, e1, e2] != 0: rels.append((e1, int(tgold[j, e1, e2]), e2)) if tgold[j, e1, e2] > 0: _cnt += 1 wf.write( str([(ents[e1], vocab['relation'](r), ents[e2]) for e1, r, e2 in rels]) + '\n') cnts.append(_cnt) pred, gold = [], [] for j in range(len(_gold)): if _gold[j] > 0: # not the <PAD> pred.append(_pred[j]) gold.append(_gold[j]) pos_label.extend([x for x in gold if x != 3]) # 3 is no relation hyp.extend(pred) ref.extend(gold) wf.close() pos_label = list(set(pos_label)) f1_micro = f1_score(ref, hyp, average='micro', labels=pos_label, zero_division=0) f1_macro = f1_score(ref, hyp, average='macro', labels=pos_label, zero_division=0) logging.info('F1 micro {0:} F1 macro {1:}'.format(f1_micro, f1_macro)) return f1_micro