示例#1
0
class EarlyFusionInverseModel(InverseModel):
    def __init__(self, params, logger):
        super().__init__(params, logger)

    def build_network(self, build_encoder=True):
        self._hp.input_nc = 6
        self.encoder = Encoder(self._hp)

        if self._hp.pred_states:
            outdim = self._hp.n_actions + self._hp.state_dim
        else:
            outdim = self._hp.n_actions
        self.action_pred = Predictor(self._hp, self._hp.nz_enc, outdim, 3)

    def forward(self, inputs):
        self.get_timesteps(inputs)
        enc = self.encoder.forward(
            torch.cat([inputs.img_t0, inputs.img_t1], dim=1))[0]
        output = AttrDict()
        out = self.action_pred(enc)
        if self._hp.pred_states:
            output.actions, output.states = torch.split(
                torch.squeeze(out), [2, 2], 1)
        else:
            output.actions = torch.squeeze(out)
        return output
示例#2
0
class InverseModel(BaseModel):
    def __init__(self, params, logger):
        super().__init__(logger)
        self._hp = self._default_hparams()
        self.override_defaults(params)  # override defaults with config file
        self.postprocess_params()
        assert self._hp.n_actions != -1  # make sure action dimensionality was overridden
        self.build_network()

        # load only the encoder params during training
        if self._hp.enc_params_checkpoint is not None:
            assert self._hp.build_encoder  # provided checkpoint but did not build encoder
            self._load_weights([
                (self.encoder, 'encoder', self._hp.enc_params_checkpoint),
            ])
        self.detach_enc = not self._hp.finetune_enc

    def _default_hparams(self):
        # put new parameters in here:
        default_dict = {
            'ngf': 4,  # number of feature maps in shallowest level
            'nz_enc': 128,  # number of dimensions in encoder-latent space
            'nz_mid': 128,  # number of hidden units in fully connected layer
            'n_processing_layers': 3,  # Number of layers in MLPs
            'temp_dist':
            1,  # sample temporal distances between 1 and temp_dist, regress only first action
            'enc_params_checkpoint':
            None,  # specify pretrained encoder weights to load for training
            'take_first_tstep':
            False,  # take only first and second time step, no shuffling.
            'use_states': False,
            'aggregate_actions':
            False,  # when taking two images that are more than one step apart sum the actions along that
            'pred_states': False,
            'finetune_enc': False,
            'checkpt_path': None,
            'build_encoder':
            True,  # if False, does not build an encoder, assumes that inputs are encoded from model
            'add_lstm_state_enc':
            False,  # if True, expects lstm state as additional encoded input
            'log_topdown_maze': False,
            'train_full_seq': False,
            'train_im0_enc':
            True,  # If True, the first frame latent is passed in as `enc_traj_seq`
        }

        # loss weights
        default_dict.update({
            'action_rec_weight': 1.0,
            'state_rec_weight': 1.0,
        })

        # misc params
        default_dict.update({
            'use_skips': False,
            'dense_rec_type': None,
            'device': None,
            'randomize_length': False,
        })

        # add new params to parent params
        parent_params = super()._default_hparams()
        for k in default_dict.keys():
            parent_params.add_hparam(k, default_dict[k])
        return parent_params

    def build_network(self, build_encoder=True):
        if self._hp.build_encoder:
            self.encoder = Encoder(self._hp)
        input_sz = self._hp.nz_enc * 3 if self._hp.add_lstm_state_enc else self._hp.nz_enc * 2
        self.action_pred = Predictor(self._hp, input_sz, self._hp.n_actions)

    def sample_offsets(self, end_ind):
        """
         # sample temporal distances between 1 and temp_dist, regress only first action
        :return:  None, call by reference
        """
        bs = end_ind.shape[0]
        if self._hp.take_first_tstep:
            t0 = torch.zeros(bs, device=self._hp.device).long()
            t1 = torch.ones_like(t0)
        else:
            t0 = np.zeros(bs)
            for b in range(bs):
                assert end_ind[b].cpu().numpy() >= self._hp.temp_dist
                t0[b] = np.random.randint(
                    0, end_ind[b].cpu().numpy() - self._hp.temp_dist + 1, 1)
            delta_t = np.random.randint(1, self._hp.temp_dist + 1, bs)
            t1 = t0 + delta_t
            t0 = torch.tensor(t0, device=self._hp.device, dtype=torch.long)
            t1 = torch.tensor(t1, device=self._hp.device, dtype=torch.long)
        return t0, t1

    def index_input(self, input, t, aggregate=False, t1=None):
        if aggregate:
            assert t1 is not None  # need end time step for aggregation
            selected = torch.zeros_like(input[:, 0])
            for b in range(input.shape[0]):
                selected[b] = torch.sum(input[b, t[b]:t1[b]], dim=0)
        else:
            selected = batchwise_index(input, t)
        return selected

    def full_seq_forward(self, inputs):
        if 'model_enc_seq' in inputs:
            enc_seq_1 = inputs.model_enc_seq[:, 1:]
            if self._hp.train_im0_enc and 'enc_traj_seq' in inputs:
                enc_seq_0 = inputs.enc_traj_seq.reshape(
                    inputs.enc_traj_seq.shape[:2] +
                    (self._hp.nz_enc, ))[:, :-1]
                enc_seq_0 = enc_seq_0[:, :enc_seq_1.shape[1]]
            else:
                enc_seq_0 = inputs.model_enc_seq[:, :-1]
        else:
            enc_seq = batch_apply(self.encoder, inputs.traj_seq)
            enc_seq_0, enc_seq_1 = enc_seq[:, :-1], enc_seq[:, 1:]

        if self.detach_enc:
            enc_seq_0 = enc_seq_0.detach()
            enc_seq_1 = enc_seq_1.detach()

        # TODO quite sure the concatenation is automatic
        actions_pred = batch_apply(self.action_pred,
                                   torch.cat([enc_seq_0, enc_seq_1], dim=2))

        output = AttrDict()
        output.actions = actions_pred  #remove_spatial(actions_pred)
        if 'actions' in inputs:
            output.action_targets = inputs.actions
            output.pad_mask = inputs.pad_mask
        return output

    def forward(self, inputs, full_seq=None):
        """
        forward pass at training time
        :arg full_seq: if True, outputs actions for the full sequence, expects input encodings
        """
        if full_seq is None:
            full_seq = self._hp.train_full_seq

        if full_seq:
            return self.full_seq_forward(inputs)

        t0, t1 = self.sample_offsets(inputs.norep_end_ind if 'norep_end_ind' in
                                     inputs else inputs.end_ind)
        im0 = self.index_input(inputs.traj_seq, t0)
        im1 = self.index_input(inputs.traj_seq, t1)
        if 'model_enc_seq' in inputs:
            if self._hp.train_im0_enc and 'enc_traj_seq' in inputs:
                enc_im0 = self.index_input(
                    inputs.enc_traj_seq,
                    t0).reshape(inputs.enc_traj_seq.shape[:1] +
                                (self._hp.nz_enc, ))
            else:
                enc_im0 = self.index_input(inputs.model_enc_seq, t0)
            enc_im1 = self.index_input(inputs.model_enc_seq, t1)
        else:
            assert self._hp.build_encoder  # need encoder if no encoded latents are given
            enc_im0 = self.encoder.forward(im0)[0]
            enc_im1 = self.encoder.forward(im1)[0]

        if self.detach_enc:
            enc_im0 = enc_im0.detach()
            enc_im1 = enc_im1.detach()

        selected_actions = self.index_input(
            inputs.actions, t0, aggregate=self._hp.aggregate_actions, t1=t1)
        selected_states = self.index_input(inputs.traj_seq_states, t0)

        if self._hp.pred_states:
            actions_pred, states_pred = torch.split(
                self.action_pred(enc_im0, enc_im1), 2, 1)
        else:
            actions_pred = self.action_pred(enc_im0, enc_im1)

        output = AttrDict()
        output.actions = remove_spatial(actions_pred)
        output.action_targets = selected_actions
        output.state_targets = selected_states
        output.img_t0, output.img_t1 = im0, im1

        return output

    def loss(self, inputs, outputs, add_total=True):
        losses = AttrDict()

        # subgoal reconstruction loss
        n_action_output = outputs.actions.shape[1]
        loss_weights = broadcast_final(
            outputs.pad_mask[:, :n_action_output],
            inputs.actions) if 'pad_mask' in outputs else 1
        losses.action_reconst = L2Loss(self._hp.action_rec_weight)(
            outputs.actions,
            outputs.action_targets[:, :n_action_output],
            weights=loss_weights)
        if self._hp.pred_states:
            losses.state_reconst = L2Loss(self._hp.state_rec_weight)(
                outputs.states, outputs.state_targets)

        return losses

    def log_outputs(self, outputs, inputs, losses, step, log_images, phase):
        super()._log_losses(losses, step, log_images, phase)

        if 'actions' not in outputs:
            # TODO figure out why this happens
            return

        if log_images and len(inputs.traj_seq.shape) == 5:
            self._logger.log_pred_actions(outputs, inputs, 'pred_actions',
                                          step, phase)
        if self._hp.pred_states:
            self._logger.log_pred_states(outputs, inputs, 'pred_states', step,
                                         phase)

        if log_images:
            dataset = self._hp.dataset_class
            if len(outputs.actions.shape) == 3:
                actions = outputs.actions
            else:
                # Training, need to get the action sequence
                actions = self(inputs, full_seq=True).actions

            cum_action_traj = torch.cat(
                (inputs.traj_seq_states[:, :1], actions), dim=1).cumsum(1)
            self._logger.log_dataset_specific_trajectory(
                outputs,
                inputs,
                "action_traj_topdown",
                step,
                phase,
                dataset,
                predictions=cum_action_traj,
                end_inds=inputs.end_ind)

            cum_action_traj = torch.cat(
                (inputs.traj_seq_states[:, :1], inputs.actions),
                dim=1).cumsum(1)
            self._logger.log_dataset_specific_trajectory(
                outputs,
                inputs,
                "action_traj_gt_topdown",
                step,
                phase,
                dataset,
                predictions=cum_action_traj,
                end_inds=inputs.end_ind)

    def run_single(self, enc_latent_img0, model_latent_img1):
        """Runs inverse model on first input encoded by encoded and second input produced by model."""
        assert self._hp.train_im0_enc  # inv model needs to be trained from
        return remove_spatial(
            self.action_pred(enc_latent_img0, model_latent_img1))

    @contextmanager
    def val_mode(self, *args, **kwargs):
        yield