コード例 #1
0
ファイル: inverse_mdl.py プロジェクト: orybkin/video-gcp
    def full_seq_forward(self, inputs):
        if 'model_enc_seq' in inputs:
            enc_seq_1 = inputs.model_enc_seq[:, 1:]
            if self._hp.train_im0_enc and 'enc_traj_seq' in inputs:
                enc_seq_0 = inputs.enc_traj_seq.reshape(
                    inputs.enc_traj_seq.shape[:2] +
                    (self._hp.nz_enc, ))[:, :-1]
                enc_seq_0 = enc_seq_0[:, :enc_seq_1.shape[1]]
            else:
                enc_seq_0 = inputs.model_enc_seq[:, :-1]
        else:
            enc_seq = batch_apply(self.encoder, inputs.traj_seq)
            enc_seq_0, enc_seq_1 = enc_seq[:, :-1], enc_seq[:, 1:]

        if self.detach_enc:
            enc_seq_0 = enc_seq_0.detach()
            enc_seq_1 = enc_seq_1.detach()

        # TODO quite sure the concatenation is automatic
        actions_pred = batch_apply(self.action_pred,
                                   torch.cat([enc_seq_0, enc_seq_1], dim=2))

        output = AttrDict()
        output.actions = actions_pred  #remove_spatial(actions_pred)
        if 'actions' in inputs:
            output.action_targets = inputs.actions
            output.pad_mask = inputs.pad_mask
        return output
コード例 #2
0
ファイル: inverse_mdl.py プロジェクト: orybkin/video-gcp
    def forward(self, inputs, full_seq=None):
        """
        forward pass at training time
        :arg full_seq: if True, outputs actions for the full sequence, expects input encodings
        """
        if full_seq is None:
            full_seq = self._hp.train_full_seq

        if full_seq:
            return self.full_seq_forward(inputs)

        t0, t1 = self.sample_offsets(inputs.norep_end_ind if 'norep_end_ind' in
                                     inputs else inputs.end_ind)
        im0 = self.index_input(inputs.traj_seq, t0)
        im1 = self.index_input(inputs.traj_seq, t1)
        if 'model_enc_seq' in inputs:
            if self._hp.train_im0_enc and 'enc_traj_seq' in inputs:
                enc_im0 = self.index_input(
                    inputs.enc_traj_seq,
                    t0).reshape(inputs.enc_traj_seq.shape[:1] +
                                (self._hp.nz_enc, ))
            else:
                enc_im0 = self.index_input(inputs.model_enc_seq, t0)
            enc_im1 = self.index_input(inputs.model_enc_seq, t1)
        else:
            assert self._hp.build_encoder  # need encoder if no encoded latents are given
            enc_im0 = self.encoder.forward(im0)[0]
            enc_im1 = self.encoder.forward(im1)[0]

        if self.detach_enc:
            enc_im0 = enc_im0.detach()
            enc_im1 = enc_im1.detach()

        selected_actions = self.index_input(
            inputs.actions, t0, aggregate=self._hp.aggregate_actions, t1=t1)
        selected_states = self.index_input(inputs.traj_seq_states, t0)

        if self._hp.pred_states:
            actions_pred, states_pred = torch.split(
                self.action_pred(enc_im0, enc_im1), 2, 1)
        else:
            actions_pred = self.action_pred(enc_im0, enc_im1)

        output = AttrDict()
        output.actions = remove_spatial(actions_pred)
        output.action_targets = selected_actions
        output.state_targets = selected_states
        output.img_t0, output.img_t1 = im0, im1

        return output