Ejemplo n.º 1
0
def next(feeder, batch_size):
    pids, qids, tids = feeder.next(batch_size)
    batch_size = len(pids)
    x = nu.tensor(pids)
    t = nu.tensor(tids)
    lengths = (x != NULL_ID).sum(-1)
    return x.transpose(0, 1), t.transpose(0, 1), lengths, pids, qids
Ejemplo n.º 2
0
 def __init__(self, beam_size, min_length):
     self.beam_size = beam_size
     self.min_length = min_length
     self.sid = nu.tensor(range(beam_size))
     self.cid = nu.tensor([
         data.NULL_ID if i != 0 else data.SOS_ID for i in range(beam_size)
     ])
     self.seq = [list() for _ in range(beam_size)]
     self.scores = nu.tensor([0] * beam_size).float()
     self.length = 0
Ejemplo n.º 3
0
def run_epoch(opt, model, feeder, optimizer, batches):
    model.train()
    nbatch = 0
    vocab_size = feeder.dataset.vocab_size
    criterion = models.make_loss_compute(vocab_size)
    while nbatch < batches:
        x, t, lengths, _, qids = data.next(feeder, opt.batch_size)
        batch_size = lengths.shape[0]
        nbatch += 1
        outputs, _, _, _ = model(x, t, lengths)
        loss = criterion(
            outputs.view(-1, vocab_size),
            t[1:].contiguous().view(-1)) / nu.tensor(batch_size).float()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print('------ITERATION {}, {}/{}, loss: {:>.4F}'.format(
            feeder.iteration, feeder.cursor, feeder.size, loss.tolist()))
        if nbatch % 10 == 0:
            logit = outputs.transpose(0, 1)
            gids = logit.argmax(-1).tolist()
            for k in range(len(gids)):
                question = feeder.ids_to_sent(gids[k])
                print('truth:   {}'.format(feeder.ids_to_sent(qids[k])))
                print('predict: {}'.format(question))
                print('----------')
    return loss
Ejemplo n.º 4
0
def evaluate_policy_docs():
    opt = make_options()
    dataset = data.Dataset()
    feeder = data.Feeder(dataset)
    model, _ = models.load_or_create_models(opt, False)
    translator = Translator(model, opt.beam_size, opt.min_length,
                            opt.max_length)
    docs = data.load_policy_documents()
    for doc in docs:
        data.parse_paragraphs(doc)
    lines = []
    for doc in docs:
        paras = [p for p in doc.paragraphs if 50 <= len(p) <= 400]
        if not paras:
            continue
        lines.append('=================================')
        lines.append(doc.title)
        if len(paras) > 16:
            paras = random.sample(paras, 16)
        paras = sorted(paras, key=lambda x: -len(x))
        pids = [feeder.sent_to_ids(p) for p in paras]
        pids = data.align2d(pids)
        src = nu.tensor(pids)
        lengths = (src != data.NULL_ID).sum(-1)
        tgt = translator.translate(src.transpose(0, 1), lengths,
                                   opt.best_k_questions)
        questions = [[feeder.ids_to_sent(t) for t in qs] for qs in tgt]
        for p, qs in zip(paras, questions):
            lines.append('--------------------------------')
            lines.append(p)
            for k, q in enumerate(qs):
                lines.append('predict {}: {}'.format(k, q))
    utils.write_all_lines(opt.output_file, lines)
Ejemplo n.º 5
0
def run_gan_epoch(opt, generator, discriminator, feeder, optimizer, batches,
                  step):
    generator.train()
    discriminator.train()
    nbatch = 0
    vocab_size = feeder.dataset.vocab_size
    g_criterion = models.make_loss_compute(feeder.dataset.vocab_size)
    d_criterion = torch.nn.NLLLoss()
    while nbatch < batches:
        x, t, lengths, _, qids = data.next(feeder, opt.batch_size)
        batch_size = lengths.shape[0]
        nbatch += 1
        y, tc_hidden, _, _ = generator(x, t, lengths)
        z, fr_hidden, _, _ = generator(x, None, lengths)
        if step == 'generator':
            g_loss = g_criterion(
                y.view(-1, vocab_size),
                t[1:].contiguous().view(-1)) / nu.tensor(batch_size).float()
            d_logit = discriminator(fr_hidden)
            flag = nu.tensor([1] * batch_size)
            d_loss = d_criterion(d_logit, flag)
            loss = g_loss + d_loss
            print('------{} {}, {}/{}, loss: {:>.4F}+{:>.4F}={:>.4F}'.format(
                step, feeder.iteration, feeder.cursor, feeder.size,
                g_loss.tolist(), d_loss.tolist(), loss.tolist()))
        else:
            tc_logit = discriminator(tc_hidden)
            fr_logit = discriminator(fr_hidden)
            logit = torch.cat([tc_logit, fr_logit], dim=0)
            flag = nu.tensor([1] * batch_size + [0] * batch_size)
            loss = d_criterion(logit, flag)
            print('------{} {}, {}/{}, loss: {:>.4F}'.format(
                step, feeder.iteration, feeder.cursor, feeder.size,
                loss.tolist()))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if nbatch % 10 == 0:
            logit = z.transpose(0, 1)
            gids = logit.argmax(-1).tolist()
            for k in range(len(gids)):
                question = feeder.ids_to_sent(gids[k])
                print('truth:   {}'.format(feeder.ids_to_sent(qids[k])))
                print('predict: {}'.format(question))
                print('----------')
    return loss
Ejemplo n.º 6
0
def generate_question(document):
    document = unquote(document)
    pids = [feeder.sent_to_ids(document)]
    src = nu.tensor(pids)
    lengths = (src != data.NULL_ID).sum(-1)
    tgt = translator.translate(src.transpose(0, 1), lengths, 3)
    questions = [feeder.ids_to_sent(t) for t in tgt[0]]
    questions = unique(questions)
    obj = {'document': document, 'questions': questions}
    return json.dumps(obj, ensure_ascii=False, indent=4)
Ejemplo n.º 7
0
 def _run_free_pass(self, generator, memory_bank, state, memory_lengths):
     input_feed = state.input_feed.squeeze(0)
     decoder_outputs = []
     attns = {'std': []}
     batch_size = memory_lengths.shape[0]
     batch_sos = nu.tensor([[data.SOS_ID]*batch_size])
     emb_t = self.embeddings(batch_sos).squeeze(0)#[batch, dim]
     hidden = state.hidden
     for _ in range(20):
         decoder_input = torch.cat([emb_t, input_feed], 1)
         rnn_output, hidden = self.rnn(decoder_input, hidden)
         decoder_output, p_attn = self.attn(rnn_output, memory_bank.transpose(0, 1), memory_lengths=memory_lengths)
         decoder_output = self.dropout(decoder_output)
         input_feed = decoder_output
         decoder_outputs.append(decoder_output)
         attns['std'].append(p_attn)
         _, ids = generator(decoder_output).max(-1)
         if ids.eq(data.NULL_ID).sum().tolist() == batch_size:
             break
         emb_t = self.embeddings(ids.unsqueeze(0)).squeeze(0)
     return hidden, decoder_outputs, attns