def _create_batch(self): num_context = Dialogue.num_context # All turns up to now self.convert_to_int() encoder_turns = self.batcher._get_turn_batch_at([self.dialogue], Dialogue.ENC, None) encoder_inputs = self.batcher.get_encoder_inputs(encoder_turns) encoder_context = self.batcher.get_encoder_context( encoder_turns, num_context) encoder_args = {'inputs': encoder_inputs, 'context': encoder_context} decoder_args = { 'inputs': self.get_decoder_inputs(), 'targets': np.copy(encoder_turns[0]), 'scenarios': np.array([self.dialogue.scenario]), 'selections': np.array([self.dialogue.selection]), } context_data = { 'agents': [self.agent], 'kbs': [self.kb], } return Batch(encoder_args, decoder_args, context_data, self.vocab, sort_by_length=False, num_context=num_context, cuda=self.cuda)
def iter_batches(self): """Compute the logprob of each generated utterance. """ self.convert_to_int() batches = self.batcher.create_batch([self.dialogue]) yield len(batches) for batch in batches: # TODO: this should be in batcher batch = Batch(batch['encoder_args'], batch['decoder_args'], batch['context_data'], self.env.vocab, num_context=Dialogue.num_context, cuda=self.env.cuda) yield batch