コード例 #1
0
ファイル: main.py プロジェクト: zxlzr/CycleGT
def g2t_teach_t2g_one_step(raw_batch,
                           model_g2t,
                           model_t2g,
                           optimizer,
                           config,
                           vocab,
                           t2g_weight=None):
    # train a t2g model with the synthetic input from g2t model
    model_g2t.eval()
    model_t2g.train()
    syn_batch = []
    if len(raw_batch) > 0:
        batch_g2t = batch2tensor_g2t(raw_batch, config['g2t']['device'], vocab)
        with torch.no_grad():
            g2t_pred = model_g2t(batch_g2t, beam_size=1).cpu()
        for _i, sample in enumerate(g2t_pred):
            _s = sample.tolist()
            if 2 in _s:  # <EOS> in list
                _s = _s[:_s.index(2)]
            _syn = tensor2data_g2t(raw_batch[_i], _s)
            syn_batch.append(_syn)
    batch_t2g = batch2tensor_t2g(syn_batch,
                                 config['t2g']['device'],
                                 vocab,
                                 add_inp=True)
    loss = train_t2g_one_step(batch_t2g,
                              model_t2g,
                              optimizer,
                              config['t2g'],
                              t2g_weight=t2g_weight)
    return loss
コード例 #2
0
ファイル: main.py プロジェクト: zxlzr/CycleGT
def t2g_teach_g2t_one_step(raw_batch, model_t2g, model_g2t, optimizer, config,
                           vocab):
    # train a g2t model with the synthetic input from t2g model
    model_t2g.eval()
    model_g2t.train()
    batch_t2g = batch2tensor_t2g(raw_batch, config['t2g']['device'], vocab)
    with torch.no_grad():
        t2g_pred = model_t2g(batch_t2g).argmax(-1).cpu()
    syn_batch = []
    for _i, sample in enumerate(t2g_pred):
        rel = []
        for e1 in range(len(raw_batch[_i]['ent_text'])):
            for e2 in range(len(raw_batch[_i]['ent_text'])):
                try:
                    if sample[e1, e2] != 3 and sample[
                            e1, e2] != 0:  # 3 is no relation and 0 is <PAD>
                        rel.append([e1, int(sample[e1, e2]), e2])
                except:
                    logging.warn('{0:}'.format(
                        [[vocab['entity'](x) for x in y]
                         for y in raw_batch[_i]['ent_text']]))
                    logging.warn('{0:}'.format(sample.size()))
        _syn = tensor2data_t2g(raw_batch[_i], rel, vocab)
        syn_batch.append(_syn)
    if len(syn_batch) == 0:
        return None
    batch_g2t = batch2tensor_g2t(syn_batch, config['g2t']['device'], vocab)
    loss = train_g2t_one_step(batch_g2t, model_g2t, optimizer, config['g2t'])
    return loss
コード例 #3
0
ファイル: main.py プロジェクト: zxlzr/CycleGT
def supervise(batch_g2t, batch_t2g, model_g2t, model_t2g, optimizerG2T,
              optimizerT2G, config, t2g_weight, vocab):
    model_g2t.blind, model_t2g.blind = False, False
    batch = batch2tensor_t2g(batch_t2g, config['t2g']['device'], vocab)
    _loss1 = train_t2g_one_step(batch,
                                model_t2g,
                                optimizerT2G,
                                config['t2g'],
                                t2g_weight=t2g_weight)
    batch = batch2tensor_g2t(batch_g2t, config['g2t']['device'], vocab)
    _loss2 = train_g2t_one_step(batch, model_g2t, optimizerG2T, config['g2t'])
    return _loss1, _loss2
コード例 #4
0
ファイル: main.py プロジェクト: zxlzr/CycleGT
def warmup_step1(batch_g2t, batch_t2g, model_g2t, model_t2g, optimizerG2T,
                 optimizerT2G, config, t2g_weight, vocab):
    model_g2t.blind, model_t2g.blind = True, True
    batch = batch2tensor_t2g(batch_t2g, config['t2g']['device'], vocab)
    loss1 = train_t2g_one_step(batch,
                               model_t2g,
                               optimizerT2G,
                               config['t2g'],
                               t2g_weight=t2g_weight)
    batch = batch2tensor_g2t(batch_g2t, config['g2t']['device'], vocab)
    loss2 = train_g2t_one_step(batch, model_g2t, optimizerG2T, config['g2t'])
    return loss1, loss2
コード例 #5
0
ファイル: main.py プロジェクト: alexandrethm/CycleGT
def eval_g2t(pool, _type, vocab, model, config, display=True):
    logging.info("Eval on {0:}".format(_type))
    model.eval()
    hyp, ref, _same = [], [], []
    unq_hyp = {}
    unq_ref = defaultdict(list)
    batch_size = 8 * config["batch_size"]
    with tqdm.tqdm(
            list(pool.draw_with_type(batch_size, False, _type)),
            disable=False if display else True,
    ) as tqb:
        for i, _batch in enumerate(tqb):
            with torch.no_grad():
                batch = batch2tensor_g2t(_batch, config["device"], vocab)
                seq = model(batch, beam_size=config["beam_size"])
            r = write_txt(batch, batch["tgt"], vocab["text"])
            h = write_txt(batch, seq, vocab["text"])
            _same.extend(
                [str(x["raw_relation"]) + str(x["ent_text"]) for x in _batch])
            hyp.extend(h)
            ref.extend(r)
        hyp = [x[0] for x in hyp]
        ref = [x[0] for x in ref]
        idxs, _same = list(zip(*sorted(enumerate(_same), key=lambda x: x[1])))

        ptr = 0
        for i in range(len(hyp)):
            if i > 0 and _same[i] != _same[i - 1]:
                ptr += 1
            unq_hyp[ptr] = hyp[idxs[i]]
            unq_ref[ptr].append(ref[idxs[i]])

        max_len = max([len(ref) for ref in unq_ref.values()])
        unq_hyp = sorted(unq_hyp.items(), key=lambda x: x[0])
        unq_ref = sorted(unq_ref.items(), key=lambda x: x[0])
        hyp = [x[1] for x in unq_hyp]
        ref = [[x.lower() for x in y[1]] for y in unq_ref]

    wf_h = open("hyp.txt", "w")
    for i, h in enumerate(hyp):
        wf_h.write(str(h) + "\n")
    wf_h.close()
    hyp = dict(zip(range(len(hyp)), [[x.lower()] for x in hyp]))
    ref = dict(zip(range(len(ref)), ref))
    ret = bleu.compute_score(ref, hyp)
    logging.info("BLEU INP {0:}".format(len(hyp)))
    logging.info("BLEU 1-4 {0:}".format(ret[0]))
    logging.info("METEOR {0:}".format(meteor.compute_score(ref, hyp)[0]))
    logging.info("ROUGE_L {0:}".format(rouge.compute_score(ref, hyp)[0]))
    logging.info("Cider {0:}".format(cider.compute_score(ref, hyp)[0]))

    return ret[0][-1]
コード例 #6
0
ファイル: main.py プロジェクト: alexandrethm/CycleGT
def supervise(
    batch_g2t,
    batch_t2g,
    model_g2t,
    model_t2g,
    optimizerG2T,
    optimizerT2G,
    config,
    t2g_weight,
    vocab,
):
    model_g2t.blind, model_t2g.blind = False, False
    batch = batch2tensor_t2g(batch_t2g, config["t2g"]["device"], vocab)
    _loss1 = train_t2g_one_step(batch,
                                model_t2g,
                                optimizerT2G,
                                config["t2g"],
                                t2g_weight=t2g_weight)
    batch = batch2tensor_g2t(batch_g2t, config["g2t"]["device"], vocab)
    _loss2, kld = train_g2t_one_step(batch, model_g2t, optimizerG2T,
                                     config["g2t"])
    return _loss1, _loss2, kld
コード例 #7
0
ファイル: main.py プロジェクト: alexandrethm/CycleGT
def warmup_step1(
    batch_g2t,
    batch_t2g,
    model_g2t,
    model_t2g,
    optimizerG2T,
    optimizerT2G,
    config,
    t2g_weight,
    vocab,
):
    model_g2t.blind, model_t2g.blind = True, True
    batch = batch2tensor_t2g(batch_t2g, config["t2g"]["device"], vocab)
    loss1 = train_t2g_one_step(batch,
                               model_t2g,
                               optimizerT2G,
                               config["t2g"],
                               t2g_weight=t2g_weight)
    batch = batch2tensor_g2t(batch_g2t, config["g2t"]["device"], vocab)
    loss2, kld = train_g2t_one_step(batch, model_g2t, optimizerG2T,
                                    config["g2t"])
    return loss1, loss2, kld