Beispiel #1
0
def t2g_teach_g2t_one_step(raw_batch, model_t2g, model_g2t, optimizer, config, vocab):
    # train a g2t model with the synthetic input from t2g model
    model_t2g.eval()
    model_g2t.train()
    batch_t2g = batch2tensor_t2g(raw_batch, config['t2g']['device'], vocab)
    with torch.no_grad():
        t2g_pred = model_t2g(batch_t2g).argmax(-1).cpu()
    syn_batch = []
    for _i, sample in enumerate(t2g_pred):
        rel = []
        for e1 in range(len(raw_batch[_i]['ent_text'])):
            for e2 in range(len(raw_batch[_i]['ent_text'])):
                try:
                    if sample[e1, e2] != 3 and sample[e1, e2] != 0: # 3 is no relation and 0 is <PAD>
                        rel.append([e1, int(sample[e1, e2]), e2])
                except:
                    logging.warn('{0:}'.format([[vocab['entity'](x) for x in y] for y in raw_batch[_i]['ent_text']]))
                    logging.warn('{0:}'.format(sample.size()))
        _syn = tensor2data_t2g(raw_batch[_i], rel, vocab)
        syn_batch.append(_syn) 
    if len(syn_batch) == 0:
        return None
    batch_g2t = batch2tensor_g2t(syn_batch, config['g2t']['device'], vocab)
    loss = train_g2t_one_step(batch_g2t, model_g2t, optimizer, config['g2t'])
    return loss
Beispiel #2
0
def supervise(batch_g2t, batch_t2g, model_g2t, model_t2g, optimizerG2T, optimizerT2G, config, t2g_weight, vocab):
    model_g2t.blind, model_t2g.blind = False, False
    batch = batch2tensor_t2g(batch_t2g, config['t2g']['device'], vocab)
    _loss1 = train_t2g_one_step(batch, model_t2g, optimizerT2G, config['t2g'], t2g_weight=t2g_weight)
    batch = batch2tensor_g2t(batch_g2t, config['g2t']['device'], vocab)
    _loss2 = train_g2t_one_step(batch, model_g2t, optimizerG2T, config['g2t'])
    return _loss1, _loss2
Beispiel #3
0
def warmup_step1(batch_g2t, batch_t2g, model_g2t, model_t2g, optimizerG2T, optimizerT2G, config, t2g_weight, vocab):
    model_g2t.blind, model_t2g.blind = True, True
    batch = batch2tensor_t2g(batch_t2g, config['t2g']['device'], vocab)
    loss1 = train_t2g_one_step(batch, model_t2g, optimizerT2G, config['t2g'], t2g_weight=t2g_weight)
    batch = batch2tensor_g2t(batch_g2t, config['g2t']['device'], vocab)
    loss2 = train_g2t_one_step(batch, model_g2t, optimizerG2T, config['g2t'])
    return loss1, loss2
Beispiel #4
0
def eval_g2t(pool, _type, vocab, model, config, display=True):
    logging.info('Eval on {0:}'.format(_type))
    model.eval()
    hyp, ref, _same = [], [], []
    unq_hyp = {}
    unq_ref = defaultdict(list)
    batch_size = 8 * config['batch_size']
    with tqdm.tqdm(list(pool.draw_with_type(batch_size, False, _type)), disable=(not display)) as tqb:
        for i, _batch in enumerate(tqb):
            with torch.no_grad():
                batch = batch2tensor_g2t(_batch, config['device'], vocab)
                seq = model(batch, beam_size=config['beam_size'])
            r = write_txt(batch, batch['tgt'], vocab['text'])
            h = write_txt(batch, seq, vocab['text'])
            _same.extend([str(x['raw_relation'])+str(x['ent_text']) for x in _batch])
            hyp.extend(h)
            ref.extend(r)
        hyp = [x[0] for x in hyp]
        ref = [x[0] for x in ref]
        idxs, _same = list(
            zip(*sorted(enumerate(_same), key=lambda x: x[1]))
        )

        ptr = 0
        for i in range(len(hyp)):
            if i > 0 and _same[i] !=_same[i - 1]:
                ptr += 1
            unq_hyp[ptr] = hyp[idxs[i]]
            unq_ref[ptr].append(ref[idxs[i]])
            
        max_len = max([len(ref) for ref in unq_ref.values()])
        unq_hyp = sorted(unq_hyp.items(), key=lambda x:x[0])
        unq_ref = sorted(unq_ref.items(), key=lambda x:x[0])
        hyp = [x[1] for x in unq_hyp]
        ref = [[x.lower() for x in y[1]] for y in unq_ref]

    wf_h = open('hyp.txt', 'w')
    for i,h in enumerate(hyp):
        wf_h.write(str(h)+'\n')
    wf_h.close()
    hyp = dict(zip(range(len(hyp)), [[x.lower()] for x in hyp]))
    ref = dict(zip(range(len(ref)), ref))
    ret = bleu.compute_score(ref, hyp)
    logging.info('BLEU INP {0:}'.format(len(hyp)))
    logging.info('BLEU 1-4 {0:}'.format(ret[0]))
    logging.info('METEOR {0:}'.format(meteor.compute_score(ref, hyp)[0]))
    logging.info('ROUGE_L {0:}'.format(rouge.compute_score(ref, hyp)[0]))
    logging.info('Cider {0:}'.format(cider.compute_score(ref, hyp)[0]))
    return ret[0][-1]
Beispiel #5
0
def g2t_teach_t2g_one_step(raw_batch, model_g2t, model_t2g, optimizer, config, vocab, t2g_weight=None):
    # train a t2g model with the synthetic input from g2t model
    model_g2t.eval()
    model_t2g.train()
    syn_batch = []
    if len(raw_batch)>0:
        batch_g2t = batch2tensor_g2t(raw_batch, config['g2t']['device'], vocab)
        with torch.no_grad():
            g2t_pred = model_g2t(batch_g2t, beam_size=1).cpu()
        for _i, sample in enumerate(g2t_pred):
            _s = sample.tolist()
            if 2 in _s: # <EOS> in list
                _s = _s[:_s.index(2)]
            _syn = tensor2data_g2t(raw_batch[_i], _s)
            syn_batch.append(_syn) 
    batch_t2g = batch2tensor_t2g(syn_batch, config['t2g']['device'], vocab, add_inp=True)
    loss = train_t2g_one_step(batch_t2g, model_t2g, optimizer, config['t2g'], t2g_weight=t2g_weight)
    return loss