示例#1
0
  def predict(self, enc_hidden, context, context_lengths, batch, beam_size, max_code_length, generator, replace_unk, vis_params):

      decState = DecoderState(
        enc_hidden,
        Variable(torch.zeros(1, 1, self.opt.rnn_size).cuda(), requires_grad=False)
      )

      # Repeat everything beam_size times.
      def rvar(a, beam_size):
        return Variable(a.repeat(beam_size, 1, 1), volatile=True)
      context = rvar(context.data, beam_size)
      context_lengths = context_lengths.repeat(beam_size)
      decState.repeat_beam_size_times(beam_size)

      beam = Beam(beam_size,
                      cuda=True,
                      vocab=self.vocabs['code'])

      for i in range(max_code_length):
        if beam.done():
          break

        # Construct batch x beam_size nxt words.
        # Get all the pending current beam words and arrange for forward.
        # Uses the start symbol in the beginning
        inp = beam.getCurrentState() # Should return a batch of the frontier
        # Turn any copied words to UNKs
        if self.opt.copy_attn:
            inp['code'] = inp['code'].masked_fill_(inp['code'].gt(len(self.vocabs["code"]) - 1), self.vocabs["code"].stoi['<unk>'])
        # Run one step., decState gets automatically updated
        decOut, attn, copy_attn = self.forward(inp, context, context_lengths, decState)

        # decOut: beam x rnn_size
        decOut = decOut.squeeze(1)

        out = generator(decOut, copy_attn.squeeze(1) if copy_attn is not None else None, batch['src_map'], inp).data
        out = out.unsqueeze(1)
        if self.opt.copy_attn:
          out = generator.collapseCopyScores(out, batch)
          out = out.log()

        # beam x tgt_vocab
        beam.advance(out[:, 0], attn.data[:, 0])
        decState.beam_update(beam.getCurrentOrigin(), beam_size)

      score, times, k = beam.getFinal() # times is the length of the prediction
      hyp, att = beam.getHyp(times, k)
      goldNl = self.vocabs['seq2seq'].addStartOrEnd(batch['raw_seq2seq'][0])
      goldCode = self.vocabs['code'].addStartOrEnd(batch['raw_code'][0])
      predSent = self.buildTargetTokens(
        hyp,
        self.vocabs,
        goldNl,
        att,
        batch['seq2seq_vocab'][0],
        replace_unk
      )
      return Prediction(goldNl, goldCode, predSent, att)
示例#2
0
  def predict(self, enc_hidden, context, context_lengths, batch, beam_size, max_code_length, generator, replace_unk, vis_params):
    # This decoder does not have input feeding. Parent state replces that
    decState = DecoderState(
      enc_hidden, #encoder hidden
      Variable(torch.zeros(1, 1, self.opt.rnn_size).cuda(), requires_grad=False) # parent state
    )
    # Repeat everything beam_size times.
    def rvar(a, beam_size):
      return Variable(a.repeat(beam_size, 1, 1), volatile=True)
    context = rvar(context.data, beam_size)
    context_lengths = context_lengths.repeat(beam_size)
    decState.repeat_beam_size_times(beam_size) # TODO: get back to this

    # Use only one beam
    beam = TreeBeam(beam_size, True, self.vocabs, self.opt.rnn_size)

    for count in range(0, max_code_length): # We will break when we have the required number of terminals
      # to be consistent with seq2seq

      if beam.done(): # TODO: fix b.done
        break

      # Construct batch x beam_size nxt words.
      # Get all the pending current beam words and arrange for forward.
      # Uses the start symbol in the beginning
      inp = beam.getCurrentState() # Should return a batch of the frontier

      # Run one step., decState gets automatically updated
      output, attn, copy_attn = self.forward(inp, context, context_lengths, decState)
      scores = generator(bottle(output), bottle(copy_attn), batch['src_map'], inp) #generator needs the non-terminals

      out = generator.collapseCopyScores(unbottle(scores.data.clone(), beam_size), batch) # needs seq2seq from batch
      out = out.log()

      # beam x tgt_vocab
      beam.advance(out[:, 0],  attn.data[:, 0], output)
      decState.beam_update(beam.getCurrentOrigin(), beam_size)

    score, times, k = beam.getFinal() # times is the length of the prediction
    #hyp, att = beam.getHyp(times, k)
    goldNl = self.vocabs['seq2seq'].addStartOrEnd(batch['raw_seq2seq'][0]) # because batch = 1
    goldCode = self.vocabs['code'].addStartOrEnd(batch['raw_code'][0])
    # goldProd = self.vocabs['next_rules'].addStartOrEnd(batch['raw_next_rules'][0])
    predictions = []
    for score, times, k in beam.finished:
      hyp, att = beam.getHyp(times, k)
      predSent = self.buildTargetTokens(
        hyp,
        self.vocabs,
        goldNl,
        att,
        batch['seq2seq_vocab'][0],
        replace_unk
      )
      predSent = ProdDecoder.rulesToCode(predSent)
      predictions.append(Prediction(goldNl, goldCode, predSent, att, score))
    return predictions
示例#3
0
def main():
    opt = parser.parse_args()
    torch.cuda.set_device(opt.gpu)
    checkpoint = torch.load(opt.model,
                            map_location=lambda storage, loc: storage)
    vocabs = checkpoint['vocab']
    vocabs['mask'] = vocabs['mask'].cuda()

    test = CDDataset(opt.src, None, test=True, trunc=opt.trunc)
    test.toNumbers(checkpoint['vocab'])
    total_test = test.compute_batches(opt.batch_size,
                                      checkpoint['vocab'],
                                      checkpoint['opt'].max_camel,
                                      0,
                                      1,
                                      checkpoint['opt'].decoder_type,
                                      randomize=False,
                                      no_filter=True)
    sys.stderr.write('Total test: {}'.format(total_test))
    sys.stderr.flush()

    model = S2SModel(checkpoint['opt'], vocabs)
    model.load_state_dict(checkpoint['model'])
    model.cuda()
    model.eval()

    predictions = []
    count = 0
    num_skipped = 0
    for idx, batch in enumerate(test.batches):  # For each batch
        try:
            hyps = model.predict(batch, opt, None)
            hyps = hyps[:opt.beam_size]
            predictions.extend(hyps)
            count += len(hyps)
            #print('predicted successfully')
        except Exception as ex:
            dummy_pred = Prediction(' '.join(batch['raw_src'][0]),
                                    ' '.join(batch['raw_code'][0]), 'Failed',
                                    'Failed', 0)
            #print('Skipping:', ' '.join(batch['raw_src'][0]), ' '.join(batch['raw_code'][0]))
            predictions.extend([dummy_pred] * opt.beam_size)
            count += opt.beam_size
            num_skipped += 1

    print('Count: ', count)
    print('Num skipped: ', num_skipped)
    with open(opt.output, 'w') as outfile:
        with open(opt.output + '.scores.txt', 'w') as scores_file:
            for idx, prediction in enumerate(predictions):
                prediction.output(outfile=outfile, scorefile=scores_file)
示例#4
0
    def predict(self, enc_hidden, context, context_lengths, batch, beam_size,
                max_code_length, generator, replace_unk, vis_params):

        # This decoder does not have input feeding. Parent state replces that
        decState = DecoderState(
            enc_hidden,  #encoder hidden
            Variable(torch.zeros(1, 1, self.opt.decoder_rnn_size).cuda(),
                     requires_grad=False)  # parent state
        )

        # Repeat everything beam_size times.
        def rvar(a, beam_size):
            return Variable(a.repeat(beam_size, 1, 1), volatile=True)

        context = tuple(
            rvar(context[i].data, beam_size) for i in range(0, len(context)))
        context_lengths = tuple(context_lengths[i].repeat(beam_size, 1)
                                for i in range(0, len(context_lengths)))

        decState.repeat_beam_size_times(beam_size)

        # Use only one beam
        beam = TreeBeam(beam_size, True, self.vocabs,
                        self.opt.decoder_rnn_size)

        for count in range(
                0, max_code_length
        ):  # We will break when we have the required number of terminals
            # to be consistent with seq2seq

            if beam.done():
                break

            # Construct batch x beam_size nxt words.
            # Get all the pending current beam words and arrange for forward.
            # Uses the start symbol in the beginning
            inp = beam.getCurrentState(
            )  # Should return a batch of the frontier

            # Run one step., decState gets automatically updated
            output, attn, copy_attn = self.forward(inp, context,
                                                   context_lengths, decState)
            src_map = torch.zeros(0, 0)
            if self.opt.var_names:
                src_map = torch.cat((src_map, batch['concode_src_map_vars']),
                                    1)
            if self.opt.method_names:
                src_map = torch.cat(
                    (src_map, batch['concode_src_map_methods']), 1)

            scores = generator(bottle(output), bottle(copy_attn), src_map,
                               inp)  #generator needs the non-terminals

            out = generator.collapseCopyScores(
                unbottle(scores.data.clone(), beam_size),
                batch)  # needs seq2seq from batch
            out = out.log()

            # beam x tgt_vocab

            beam.advance(out[:, 0], attn.data[:, 0], output)
            decState.beam_update(beam.getCurrentOrigin(), beam_size)

        pred_score_total = 0
        pred_words_total = 0

        score, times, k = beam.getFinal(
        )  # times is the length of the prediction
        hyp, att = beam.getHyp(times, k)
        goldNl = []
        if self.opt.var_names:
            goldNl += batch['concode_var'][0]  # because batch = 1
        if self.opt.method_names:
            goldNl += batch['concode_method'][0]  # because batch = 1

        goldCode = self.vocabs['code'].addStartOrEnd(batch['raw_code'][0])
        predSent, copied_tokens, replaced_tokens = self.buildTargetTokens(
            hyp, self.vocabs, goldNl, att, batch['concode_vocab'][0],
            replace_unk)
        predSent = ConcodeDecoder.rulesToCode(predSent)
        pred_score_total += score
        pred_words_total += len(predSent)

        return Prediction(goldNl, goldCode, predSent, att)