Esempio n. 1
0
def g2t_teach_t2g_one_step(raw_batch,
                           model_g2t,
                           model_t2g,
                           optimizer,
                           config,
                           vocab,
                           t2g_weight=None):
    # train a t2g model with the synthetic input from g2t model
    model_g2t.eval()
    model_t2g.train()
    syn_batch = []
    if len(raw_batch) > 0:
        batch_g2t = batch2tensor_g2t(raw_batch, config['g2t']['device'], vocab)
        with torch.no_grad():
            g2t_pred = model_g2t(batch_g2t, beam_size=1).cpu()
        for _i, sample in enumerate(g2t_pred):
            _s = sample.tolist()
            if 2 in _s:  # <EOS> in list
                _s = _s[:_s.index(2)]
            _syn = tensor2data_g2t(raw_batch[_i], _s)
            syn_batch.append(_syn)
    batch_t2g = batch2tensor_t2g(syn_batch,
                                 config['t2g']['device'],
                                 vocab,
                                 add_inp=True)
    loss = train_t2g_one_step(batch_t2g,
                              model_t2g,
                              optimizer,
                              config['t2g'],
                              t2g_weight=t2g_weight)
    return loss
Esempio n. 2
0
def t2g_teach_g2t_one_step(raw_batch, model_t2g, model_g2t, optimizer, config,
                           vocab):
    # train a g2t model with the synthetic input from t2g model
    model_t2g.eval()
    model_g2t.train()
    batch_t2g = batch2tensor_t2g(raw_batch, config['t2g']['device'], vocab)
    with torch.no_grad():
        t2g_pred = model_t2g(batch_t2g).argmax(-1).cpu()
    syn_batch = []
    for _i, sample in enumerate(t2g_pred):
        rel = []
        for e1 in range(len(raw_batch[_i]['ent_text'])):
            for e2 in range(len(raw_batch[_i]['ent_text'])):
                try:
                    if sample[e1, e2] != 3 and sample[
                            e1, e2] != 0:  # 3 is no relation and 0 is <PAD>
                        rel.append([e1, int(sample[e1, e2]), e2])
                except:
                    logging.warn('{0:}'.format(
                        [[vocab['entity'](x) for x in y]
                         for y in raw_batch[_i]['ent_text']]))
                    logging.warn('{0:}'.format(sample.size()))
        _syn = tensor2data_t2g(raw_batch[_i], rel, vocab)
        syn_batch.append(_syn)
    if len(syn_batch) == 0:
        return None
    batch_g2t = batch2tensor_g2t(syn_batch, config['g2t']['device'], vocab)
    loss = train_g2t_one_step(batch_g2t, model_g2t, optimizer, config['g2t'])
    return loss
Esempio n. 3
0
def supervise(batch_g2t, batch_t2g, model_g2t, model_t2g, optimizerG2T,
              optimizerT2G, config, t2g_weight, vocab):
    model_g2t.blind, model_t2g.blind = False, False
    batch = batch2tensor_t2g(batch_t2g, config['t2g']['device'], vocab)
    _loss1 = train_t2g_one_step(batch,
                                model_t2g,
                                optimizerT2G,
                                config['t2g'],
                                t2g_weight=t2g_weight)
    batch = batch2tensor_g2t(batch_g2t, config['g2t']['device'], vocab)
    _loss2 = train_g2t_one_step(batch, model_g2t, optimizerG2T, config['g2t'])
    return _loss1, _loss2
Esempio n. 4
0
def warmup_step1(batch_g2t, batch_t2g, model_g2t, model_t2g, optimizerG2T,
                 optimizerT2G, config, t2g_weight, vocab):
    model_g2t.blind, model_t2g.blind = True, True
    batch = batch2tensor_t2g(batch_t2g, config['t2g']['device'], vocab)
    loss1 = train_t2g_one_step(batch,
                               model_t2g,
                               optimizerT2G,
                               config['t2g'],
                               t2g_weight=t2g_weight)
    batch = batch2tensor_g2t(batch_g2t, config['g2t']['device'], vocab)
    loss2 = train_g2t_one_step(batch, model_g2t, optimizerG2T, config['g2t'])
    return loss1, loss2
Esempio n. 5
0
def supervise(
    batch_g2t,
    batch_t2g,
    model_g2t,
    model_t2g,
    optimizerG2T,
    optimizerT2G,
    config,
    t2g_weight,
    vocab,
):
    model_g2t.blind, model_t2g.blind = False, False
    batch = batch2tensor_t2g(batch_t2g, config["t2g"]["device"], vocab)
    _loss1 = train_t2g_one_step(batch,
                                model_t2g,
                                optimizerT2G,
                                config["t2g"],
                                t2g_weight=t2g_weight)
    batch = batch2tensor_g2t(batch_g2t, config["g2t"]["device"], vocab)
    _loss2, kld = train_g2t_one_step(batch, model_g2t, optimizerG2T,
                                     config["g2t"])
    return _loss1, _loss2, kld
Esempio n. 6
0
def warmup_step1(
    batch_g2t,
    batch_t2g,
    model_g2t,
    model_t2g,
    optimizerG2T,
    optimizerT2G,
    config,
    t2g_weight,
    vocab,
):
    model_g2t.blind, model_t2g.blind = True, True
    batch = batch2tensor_t2g(batch_t2g, config["t2g"]["device"], vocab)
    loss1 = train_t2g_one_step(batch,
                               model_t2g,
                               optimizerT2G,
                               config["t2g"],
                               t2g_weight=t2g_weight)
    batch = batch2tensor_g2t(batch_g2t, config["g2t"]["device"], vocab)
    loss2, kld = train_g2t_one_step(batch, model_g2t, optimizerG2T,
                                    config["g2t"])
    return loss1, loss2, kld
Esempio n. 7
0
def eval_t2g(pool, _type, vocab, model, config, display=True):
    # evaluate t2g model
    logging.info('Eval on {0:}'.format(_type))
    hyp, ref, pos_label = [], [], []
    model.eval()
    wf = open('t2g_show.txt', 'w')
    with tqdm.tqdm(list(pool.draw_with_type(config['batch_size'], False,
                                            _type)),
                   disable=False if display else True) as tqb:
        for i, _batch in enumerate(tqb):
            with torch.no_grad():
                batch = batch2tensor_t2g(_batch, config['device'], vocab)
                pred = model(batch)
            _pred = pred.view(-1,
                              pred.shape[-1]).argmax(-1).cpu().long().tolist()
            _gold = batch['tgt'].view(-1).cpu().long().tolist()
            tpred = pred.argmax(-1).cpu().numpy()
            tgold = batch['tgt'].cpu().numpy()

            cnts = []
            for j in range(len(_batch)):
                _cnt = 0
                ents = [[y for y in vocab['entity'](x) if y[0] != '<']
                        for x in _batch[j]['ent_text']]
                wf.write('=====================\n')
                rels = []
                for e1 in range(len(ents)):
                    for e2 in range(len(ents)):
                        if tpred[j, e1, e2] != 3 and tpred[j, e1, e2] != 0:
                            rels.append((e1, int(tpred[j, e1, e2]), e2))
                wf.write(
                    str([(ents[e1], vocab['relation'](r), ents[e2])
                         for e1, r, e2 in rels]) + '\n')
                rels = []
                for e1 in range(len(ents)):
                    for e2 in range(len(ents)):
                        if tgold[j, e1, e2] != 3 and tgold[j, e1, e2] != 0:
                            rels.append((e1, int(tgold[j, e1, e2]), e2))
                        if tgold[j, e1, e2] > 0:
                            _cnt += 1
                wf.write(
                    str([(ents[e1], vocab['relation'](r), ents[e2])
                         for e1, r, e2 in rels]) + '\n')
                cnts.append(_cnt)

            pred, gold = [], []
            for j in range(len(_gold)):
                if _gold[j] > 0:  # not the <PAD>
                    pred.append(_pred[j])
                    gold.append(_gold[j])
            pos_label.extend([x for x in gold if x != 3])  # 3 is no relation
            hyp.extend(pred)
            ref.extend(gold)
    wf.close()
    pos_label = list(set(pos_label))

    f1_micro = f1_score(ref,
                        hyp,
                        average='micro',
                        labels=pos_label,
                        zero_division=0)
    f1_macro = f1_score(ref,
                        hyp,
                        average='macro',
                        labels=pos_label,
                        zero_division=0)

    logging.info('F1 micro {0:} F1 macro {1:}'.format(f1_micro, f1_macro))
    return f1_micro