예제 #1
0
  def predict(self, enc_hidden, context, context_lengths, batch, beam_size, max_code_length, generator, replace_unk, vis_params):

      decState = DecoderState(
        enc_hidden,
        Variable(torch.zeros(1, 1, self.opt.rnn_size).cuda(), requires_grad=False)
      )

      # Repeat everything beam_size times.
      def rvar(a, beam_size):
        return Variable(a.repeat(beam_size, 1, 1), volatile=True)
      context = rvar(context.data, beam_size)
      context_lengths = context_lengths.repeat(beam_size)
      decState.repeat_beam_size_times(beam_size)

      beam = Beam(beam_size,
                      cuda=True,
                      vocab=self.vocabs['code'])

      for i in range(max_code_length):
        if beam.done():
          break

        # Construct batch x beam_size nxt words.
        # Get all the pending current beam words and arrange for forward.
        # Uses the start symbol in the beginning
        inp = beam.getCurrentState() # Should return a batch of the frontier
        # Turn any copied words to UNKs
        if self.opt.copy_attn:
            inp['code'] = inp['code'].masked_fill_(inp['code'].gt(len(self.vocabs["code"]) - 1), self.vocabs["code"].stoi['<unk>'])
        # Run one step., decState gets automatically updated
        decOut, attn, copy_attn = self.forward(inp, context, context_lengths, decState)

        # decOut: beam x rnn_size
        decOut = decOut.squeeze(1)

        out = generator(decOut, copy_attn.squeeze(1) if copy_attn is not None else None, batch['src_map'], inp).data
        out = out.unsqueeze(1)
        if self.opt.copy_attn:
          out = generator.collapseCopyScores(out, batch)
          out = out.log()

        # beam x tgt_vocab
        beam.advance(out[:, 0], attn.data[:, 0])
        decState.beam_update(beam.getCurrentOrigin(), beam_size)

      score, times, k = beam.getFinal() # times is the length of the prediction
      hyp, att = beam.getHyp(times, k)
      goldNl = self.vocabs['seq2seq'].addStartOrEnd(batch['raw_seq2seq'][0])
      goldCode = self.vocabs['code'].addStartOrEnd(batch['raw_code'][0])
      predSent = self.buildTargetTokens(
        hyp,
        self.vocabs,
        goldNl,
        att,
        batch['seq2seq_vocab'][0],
        replace_unk
      )
      return Prediction(goldNl, goldCode, predSent, att)
예제 #2
0
  def predict(self, enc_hidden, context, context_lengths, batch, beam_size, max_code_length, generator, replace_unk, vis_params):
    # This decoder does not have input feeding. Parent state replces that
    decState = DecoderState(
      enc_hidden, #encoder hidden
      Variable(torch.zeros(1, 1, self.opt.rnn_size).cuda(), requires_grad=False) # parent state
    )
    # Repeat everything beam_size times.
    def rvar(a, beam_size):
      return Variable(a.repeat(beam_size, 1, 1), volatile=True)
    context = rvar(context.data, beam_size)
    context_lengths = context_lengths.repeat(beam_size)
    decState.repeat_beam_size_times(beam_size) # TODO: get back to this

    # Use only one beam
    beam = TreeBeam(beam_size, True, self.vocabs, self.opt.rnn_size)

    for count in range(0, max_code_length): # We will break when we have the required number of terminals
      # to be consistent with seq2seq

      if beam.done(): # TODO: fix b.done
        break

      # Construct batch x beam_size nxt words.
      # Get all the pending current beam words and arrange for forward.
      # Uses the start symbol in the beginning
      inp = beam.getCurrentState() # Should return a batch of the frontier

      # Run one step., decState gets automatically updated
      output, attn, copy_attn = self.forward(inp, context, context_lengths, decState)
      scores = generator(bottle(output), bottle(copy_attn), batch['src_map'], inp) #generator needs the non-terminals

      out = generator.collapseCopyScores(unbottle(scores.data.clone(), beam_size), batch) # needs seq2seq from batch
      out = out.log()

      # beam x tgt_vocab
      beam.advance(out[:, 0],  attn.data[:, 0], output)
      decState.beam_update(beam.getCurrentOrigin(), beam_size)

    score, times, k = beam.getFinal() # times is the length of the prediction
    #hyp, att = beam.getHyp(times, k)
    goldNl = self.vocabs['seq2seq'].addStartOrEnd(batch['raw_seq2seq'][0]) # because batch = 1
    goldCode = self.vocabs['code'].addStartOrEnd(batch['raw_code'][0])
    # goldProd = self.vocabs['next_rules'].addStartOrEnd(batch['raw_next_rules'][0])
    predictions = []
    for score, times, k in beam.finished:
      hyp, att = beam.getHyp(times, k)
      predSent = self.buildTargetTokens(
        hyp,
        self.vocabs,
        goldNl,
        att,
        batch['seq2seq_vocab'][0],
        replace_unk
      )
      predSent = ProdDecoder.rulesToCode(predSent)
      predictions.append(Prediction(goldNl, goldCode, predSent, att, score))
    return predictions
예제 #3
0
    def forward(self, batch):

        # initial parent states for Prod Decoder
        batch_size = batch['seq2seq'].size(0)

        if self.opt.decoder_type == "concode":
            batch['parent_states'] = {}
            for j in range(0, batch_size):
                batch['parent_states'][j] = {}
                if self.opt.decoder_type in ["prod", "concode"]:
                    batch['parent_states'][j][0] = Variable(
                        torch.zeros(1, 1, self.opt.decoder_rnn_size).cuda(),
                        requires_grad=False)

        context, context_lengths, enc_hidden = self.encoder(batch)

        decInitState = DecoderState(
            enc_hidden,
            Variable(torch.zeros(batch_size, 1,
                                 self.opt.decoder_rnn_size).cuda(),
                     requires_grad=False))

        output, attn, copy_attn = self.decoder(batch, context, context_lengths,
                                               decInitState)

        if self.opt.decoder_type == "concode":
            del batch['parent_states']

        # Other generators will not use the extra parameters
        # Let the generator put the src_map in cuda if it uses it
        # TODO: Make sec_map variable again in generator
        src_map = torch.zeros(0, 0)
        if self.opt.decoder_type == "concode":
            src_map = torch.cat((batch['concode_src_map_vars'],
                                 batch['concode_src_map_methods']), 1)

        scores = self.generator(
            bottle(output), bottle(copy_attn), src_map if self.opt.encoder_type
            in ["concode"] else batch['src_map'], batch)
        loss, total, correct = self.generator.computeLoss(scores, batch)

        return loss, Statistics(loss.data.item(), total.item(), correct.item(),
                                self.encoder.n_src_words)
예제 #4
0
    def predict(self, enc_hidden, context, context_lengths, batch, beam_size,
                max_code_length, generator, replace_unk, vis_params):

        # This decoder does not have input feeding. Parent state replces that
        decState = DecoderState(
            enc_hidden,  #encoder hidden
            Variable(torch.zeros(1, 1, self.opt.decoder_rnn_size).cuda(),
                     requires_grad=False)  # parent state
        )

        # Repeat everything beam_size times.
        def rvar(a, beam_size):
            return Variable(a.repeat(beam_size, 1, 1), volatile=True)

        context = tuple(
            rvar(context[i].data, beam_size) for i in range(0, len(context)))
        context_lengths = tuple(context_lengths[i].repeat(beam_size, 1)
                                for i in range(0, len(context_lengths)))

        decState.repeat_beam_size_times(beam_size)

        # Use only one beam
        beam = TreeBeam(beam_size, True, self.vocabs,
                        self.opt.decoder_rnn_size)

        for count in range(
                0, max_code_length
        ):  # We will break when we have the required number of terminals
            # to be consistent with seq2seq

            if beam.done():
                break

            # Construct batch x beam_size nxt words.
            # Get all the pending current beam words and arrange for forward.
            # Uses the start symbol in the beginning
            inp = beam.getCurrentState(
            )  # Should return a batch of the frontier

            # Run one step., decState gets automatically updated
            output, attn, copy_attn = self.forward(inp, context,
                                                   context_lengths, decState)
            src_map = torch.zeros(0, 0)
            if self.opt.var_names:
                src_map = torch.cat((src_map, batch['concode_src_map_vars']),
                                    1)
            if self.opt.method_names:
                src_map = torch.cat(
                    (src_map, batch['concode_src_map_methods']), 1)

            scores = generator(bottle(output), bottle(copy_attn), src_map,
                               inp)  #generator needs the non-terminals

            out = generator.collapseCopyScores(
                unbottle(scores.data.clone(), beam_size),
                batch)  # needs seq2seq from batch
            out = out.log()

            # beam x tgt_vocab

            beam.advance(out[:, 0], attn.data[:, 0], output)
            decState.beam_update(beam.getCurrentOrigin(), beam_size)

        pred_score_total = 0
        pred_words_total = 0

        score, times, k = beam.getFinal(
        )  # times is the length of the prediction
        hyp, att = beam.getHyp(times, k)
        goldNl = []
        if self.opt.var_names:
            goldNl += batch['concode_var'][0]  # because batch = 1
        if self.opt.method_names:
            goldNl += batch['concode_method'][0]  # because batch = 1

        goldCode = self.vocabs['code'].addStartOrEnd(batch['raw_code'][0])
        predSent, copied_tokens, replaced_tokens = self.buildTargetTokens(
            hyp, self.vocabs, goldNl, att, batch['concode_vocab'][0],
            replace_unk)
        predSent = ConcodeDecoder.rulesToCode(predSent)
        pred_score_total += score
        pred_words_total += len(predSent)

        return Prediction(goldNl, goldCode, predSent, att)