Esempio n. 1
0
    def _create_batch(self):
        num_context = Dialogue.num_context
        # All turns up to now
        self.convert_to_int()
        encoder_turns = self.batcher._get_turn_batch_at([self.dialogue],
                                                        Dialogue.ENC, None)

        encoder_inputs = self.batcher.get_encoder_inputs(encoder_turns)
        encoder_context = self.batcher.get_encoder_context(
            encoder_turns, num_context)
        encoder_args = {'inputs': encoder_inputs, 'context': encoder_context}
        decoder_args = {
            'inputs': self.get_decoder_inputs(),
            'targets': np.copy(encoder_turns[0]),
            'scenarios': np.array([self.dialogue.scenario]),
            'selections': np.array([self.dialogue.selection]),
        }

        context_data = {
            'agents': [self.agent],
            'kbs': [self.kb],
        }

        return Batch(encoder_args,
                     decoder_args,
                     context_data,
                     self.vocab,
                     sort_by_length=False,
                     num_context=num_context,
                     cuda=self.cuda)
 def iter_batches(self):
     """Compute the logprob of each generated utterance.
     """
     self.convert_to_int()
     batches = self.batcher.create_batch([self.dialogue])
     yield len(batches)
     for batch in batches:
         # TODO: this should be in batcher
         batch = Batch(batch['encoder_args'],
                       batch['decoder_args'],
                       batch['context_data'],
                       self.env.vocab,
                       num_context=Dialogue.num_context, cuda=self.env.cuda)
         yield batch