예제 #1
0
    def teacherforcing_batch(self, batch: DictList, batch_lengths,
                             sketch_lengths,
                             recurrence) -> (DictList, DictList):
        """
        :param batch: DictList object [bsz, seqlen]
        :param batch_lengths: [bsz]
        :param sketch_lengths: [bsz]
        :param recurrence: an int
        :return:
            stats: A DictList of bsz, mem_size
            extra_info: A DictList of extra info
        """
        bsz, seqlen = batch.actions.shape[0], batch.actions.shape[1]
        sketchs = batch.tasks
        final_outputs = DictList({})
        extra_info = DictList({})
        mems = None
        if self.is_recurrent:
            mems = self.init_memory(sketchs, sketch_lengths)

        for t in range(seqlen):
            final_output = DictList({})
            model_output = self.forward(batch.states[:, t], sketchs,
                                        sketch_lengths, mems)
            logprobs = model_output.dist.log_prob(batch.actions[:, t].float())
            if 'log_end' in model_output:
                # p_end + (1 - pend) action_prob
                log_no_end_term = model_output.log_no_end + logprobs
                logprobs = torch.logsumexp(torch.stack(
                    [model_output.log_end, log_no_end_term], dim=-1),
                                           dim=-1)
                final_output.log_end = model_output.log_end
            final_output.logprobs = logprobs
            if 'p' in model_output:
                extra_info.append({'p': model_output.p})
            final_outputs.append(final_output)

            # Update memory
            next_mems = None
            if self.is_recurrent:
                next_mems = model_output.mems
                if (t + 1) % recurrence == 0:
                    next_mems = next_mems.detach()
            mems = next_mems

        # Stack on time dim
        final_outputs.apply(lambda _tensors: torch.stack(_tensors, dim=1))
        extra_info.apply(lambda _tensors: torch.stack(_tensors, dim=1))
        sequence_mask = torch.arange(
            batch_lengths.max().item(),
            device=batch_lengths.device)[None, :] < batch_lengths[:, None]
        final_outputs.loss = -final_outputs.logprobs
        if 'log_end' in final_outputs:
            batch_ids = torch.arange(bsz, device=batch.states.device)
            final_outputs.loss[batch_ids, batch_lengths -
                               1] = final_outputs.log_end[batch_ids,
                                                          batch_lengths - 1]
        final_outputs.apply(lambda _t: _t.masked_fill(~sequence_mask, 0.))
        return final_outputs, extra_info
예제 #2
0
def run_batch(batch: DictList,
              batch_lengths,
              sketch_lengths,
              bot: ModelBot,
              mode='train') \
        -> (DictList, torch.Tensor):
    """
    :param batch: DictList object [bsz, seqlen]
    :param bot:  A model Bot
    :param mode: 'train' or 'eval'
    :return:
        stats: A DictList of bsz, mem_size
    """
    bsz, seqlen = batch.actions.shape[0], batch.actions.shape[1]
    sketchs = batch.tasks
    final_outputs = DictList({})
    mems = None
    if bot.is_recurrent:
        mems = bot.init_memory(sketchs, sketch_lengths)

    for t in range(seqlen):
        final_output = DictList({})
        model_output = bot.forward(batch.states[:, t], sketchs, sketch_lengths,
                                   mems)
        logprobs = model_output.dist.log_prob(batch.actions[:, t].float())
        if 'log_end' in model_output:
            # p_end + (1 - pend) action_prob
            log_no_end_term = model_output.log_no_end + logprobs
            logprobs = torch.logsumexp(torch.stack(
                [model_output.log_end, log_no_end_term], dim=-1),
                                       dim=-1)
            final_output.log_end = model_output.log_end
        final_output.logprobs = logprobs
        final_outputs.append(final_output)

        # Update memory
        next_mems = None
        if bot.is_recurrent:
            next_mems = model_output.mems
            if (t + 1) % FLAGS.il_recurrence == 0 and mode == 'train':
                next_mems = next_mems.detach()
        mems = next_mems

    # Stack on time dim
    final_outputs.apply(lambda _tensors: torch.stack(_tensors, dim=1))
    sequence_mask = torch.arange(
        batch_lengths.max().item(),
        device=batch_lengths.device)[None, :] < batch_lengths[:, None]
    final_outputs.loss = -final_outputs.logprobs
    if 'log_end' in final_outputs:
        batch_ids = torch.arange(bsz, device=batch.states.device)
        final_outputs.loss[batch_ids, batch_lengths -
                           1] = final_outputs.log_end[batch_ids,
                                                      batch_lengths - 1]
    final_outputs.apply(lambda _t: _t.masked_fill(~sequence_mask, 0.))
    return final_outputs