def teacherforcing_batch(self, batch: DictList, batch_lengths, sketch_lengths, recurrence) -> (DictList, DictList): """ :param batch: DictList object [bsz, seqlen] :param batch_lengths: [bsz] :param sketch_lengths: [bsz] :param recurrence: an int :return: stats: A DictList of bsz, mem_size extra_info: A DictList of extra info """ bsz, seqlen = batch.actions.shape[0], batch.actions.shape[1] sketchs = batch.tasks final_outputs = DictList({}) extra_info = DictList({}) mems = None if self.is_recurrent: mems = self.init_memory(sketchs, sketch_lengths) for t in range(seqlen): final_output = DictList({}) model_output = self.forward(batch.states[:, t], sketchs, sketch_lengths, mems) logprobs = model_output.dist.log_prob(batch.actions[:, t].float()) if 'log_end' in model_output: # p_end + (1 - pend) action_prob log_no_end_term = model_output.log_no_end + logprobs logprobs = torch.logsumexp(torch.stack( [model_output.log_end, log_no_end_term], dim=-1), dim=-1) final_output.log_end = model_output.log_end final_output.logprobs = logprobs if 'p' in model_output: extra_info.append({'p': model_output.p}) final_outputs.append(final_output) # Update memory next_mems = None if self.is_recurrent: next_mems = model_output.mems if (t + 1) % recurrence == 0: next_mems = next_mems.detach() mems = next_mems # Stack on time dim final_outputs.apply(lambda _tensors: torch.stack(_tensors, dim=1)) extra_info.apply(lambda _tensors: torch.stack(_tensors, dim=1)) sequence_mask = torch.arange( batch_lengths.max().item(), device=batch_lengths.device)[None, :] < batch_lengths[:, None] final_outputs.loss = -final_outputs.logprobs if 'log_end' in final_outputs: batch_ids = torch.arange(bsz, device=batch.states.device) final_outputs.loss[batch_ids, batch_lengths - 1] = final_outputs.log_end[batch_ids, batch_lengths - 1] final_outputs.apply(lambda _t: _t.masked_fill(~sequence_mask, 0.)) return final_outputs, extra_info
def run_batch(batch: DictList, batch_lengths, sketch_lengths, bot: ModelBot, mode='train') \ -> (DictList, torch.Tensor): """ :param batch: DictList object [bsz, seqlen] :param bot: A model Bot :param mode: 'train' or 'eval' :return: stats: A DictList of bsz, mem_size """ bsz, seqlen = batch.actions.shape[0], batch.actions.shape[1] sketchs = batch.tasks final_outputs = DictList({}) mems = None if bot.is_recurrent: mems = bot.init_memory(sketchs, sketch_lengths) for t in range(seqlen): final_output = DictList({}) model_output = bot.forward(batch.states[:, t], sketchs, sketch_lengths, mems) logprobs = model_output.dist.log_prob(batch.actions[:, t].float()) if 'log_end' in model_output: # p_end + (1 - pend) action_prob log_no_end_term = model_output.log_no_end + logprobs logprobs = torch.logsumexp(torch.stack( [model_output.log_end, log_no_end_term], dim=-1), dim=-1) final_output.log_end = model_output.log_end final_output.logprobs = logprobs final_outputs.append(final_output) # Update memory next_mems = None if bot.is_recurrent: next_mems = model_output.mems if (t + 1) % FLAGS.il_recurrence == 0 and mode == 'train': next_mems = next_mems.detach() mems = next_mems # Stack on time dim final_outputs.apply(lambda _tensors: torch.stack(_tensors, dim=1)) sequence_mask = torch.arange( batch_lengths.max().item(), device=batch_lengths.device)[None, :] < batch_lengths[:, None] final_outputs.loss = -final_outputs.logprobs if 'log_end' in final_outputs: batch_ids = torch.arange(bsz, device=batch.states.device) final_outputs.loss[batch_ids, batch_lengths - 1] = final_outputs.log_end[batch_ids, batch_lengths - 1] final_outputs.apply(lambda _t: _t.masked_fill(~sequence_mask, 0.)) return final_outputs