Ejemplo n.º 1
0
    def run_encoder(self, inputs, start_ind):
        if 'demo_seq' in inputs:
            if not 'enc_demo_seq' in inputs:
                inputs.enc_demo_seq, inputs.skips = batch_apply(inputs.demo_seq, self.encoder)
                if self._hp.use_convs and self._hp.use_skips:
                    inputs.skips = map_recursive(lambda s: s[:, 0], inputs.skips)  # only use start image activations

        if self._hp.separate_cnn_start_goal_encoder:
            enc_e_0, inputs.skips = self.start_goal_enc(inputs.I_0_image)
            inputs.enc_e_0 = remove_spatial(enc_e_0)
            inputs.enc_e_g = remove_spatial(self.start_goal_enc(inputs.I_g_image)[0])
        else:
            inputs.enc_e_0, inputs.skips = self.encoder(inputs.I_0)
            inputs.enc_e_g = self.encoder(inputs.I_g)[0]
        
        if 'demo_seq' in inputs:
            if self._hp.act_cond_inference:
                inputs.inf_enc_seq = self.inf_encoder(inputs.enc_demo_seq, inputs.actions)
            elif self._hp.states_inference:
                inputs.inf_enc_seq = batch_apply((inputs.enc_demo_seq, inputs.demo_seq_states[..., None, None]),
                                                 self.inf_encoder, separate_arguments=True)
            else:
                inputs.inf_enc_seq = self.inf_encoder(inputs.enc_demo_seq)
            inputs.inf_enc_key_seq = self.inf_key_encoder(inputs.enc_demo_seq)
            
        if self._hp.action_conditioned_pred:
            inputs.enc_action_seq = batch_apply(inputs.actions, self.action_encoder)
Ejemplo n.º 2
0
    def forward(self, inputs, phase='train'):
        """
        forward pass at training time
        """
        if not 'enc_traj_seq' in inputs:
            enc_traj_seq, _ = self.encoder(inputs.traj_seq[:, 0]) if self._hp.train_first_action_only \
                                    else batch_apply(self.encoder, inputs.traj_seq)
            if self._hp.train_first_action_only:
                enc_traj_seq = enc_traj_seq[:, None]
            enc_traj_seq = enc_traj_seq.detach(
            ) if self.detach_enc else enc_traj_seq

        enc_goal, _ = self.encoder(inputs.I_g)
        n_dim = len(enc_goal.shape)
        fused_enc = torch.cat((enc_traj_seq, enc_goal[:, None].repeat(
            1, enc_traj_seq.shape[1], *([1] * (n_dim - 1)))),
                              dim=2)
        #fused_enc = torch.cat((enc_traj_seq, enc_goal[:, None].repeat(1, enc_traj_seq.shape[1], 1, 1, 1)), dim=2)

        if self._hp.reactive:
            actions_pred = batch_apply(self.policy, fused_enc)
        else:
            policy_output = self.policy(fused_enc)
            actions_pred = policy_output

        # remove last time step to match ground truth if training on full sequence
        actions_pred = actions_pred[:, :
                                    -1] if not self._hp.train_first_action_only else actions_pred

        output = AttrDict()
        output.actions = remove_spatial(actions_pred) if len(
            actions_pred.shape) > 3 else actions_pred
        return output
Ejemplo n.º 3
0
    def forward(self, inputs):
        for k in inputs:
            if inputs[k] is None:
                continue
            if not isinstance(inputs[k], torch.Tensor):
                inputs[k] = torch.Tensor(inputs[k])
            if not inputs[k].device == self.device:
                inputs[k] = inputs[k].to(self.device)

        enc, _ = self.encoder(inputs['I_0'])
        enc_goal, _ = self.encoder(inputs['I_g'])
        fused_enc = torch.cat((enc, enc_goal), dim=1)
        if self._hp.reactive:
            action_pred = self.policy(fused_enc)
            hidden_var = None
        else:
            hidden_var = self.init_hidden_var if inputs.hidden_var is None else inputs.hidden_var
            policy_output = self.policy(fused_enc[:, None], hidden_var)
            action_pred, hidden_var = policy_output.output, policy_output.hidden_state[:,
                                                                                       0]
        if self._hp.use_convs:
            return remove_spatial(action_pred if len(action_pred.shape) ==
                                  4 else action_pred[:, 0]), hidden_var
        else:
            return action_pred, hidden_var
Ejemplo n.º 4
0
    def forward(self, e0, eg):
        """Returns the logits of a OneHotCategorical distribution."""
        output = AttrDict()
        output.seq_len_logits = remove_spatial(self.p(e0, eg))
        output.seq_len_pred = OneHotCategorical(logits=output.seq_len_logits)

        return output
Ejemplo n.º 5
0
    def forward(self, inputs):
        for k in inputs:
            if not isinstance(inputs[k], torch.Tensor):
                inputs[k] = torch.Tensor(inputs[k])
            if not inputs[k].device == self.device:
                inputs[k] = inputs[k].to(self.device)

        enc_im0 = self.encoder.forward(inputs['img_t0'])[0]
        enc_im1 = self.encoder.forward(inputs['img_t1'])[0]
        return remove_spatial(self.action_pred(enc_im0, enc_im1))
Ejemplo n.º 6
0
    def forward(self, inputs, full_seq=None):
        """
        forward pass at training time
        :arg full_seq: if True, outputs actions for the full sequence, expects input encodings
        """
        if full_seq is None:
            full_seq = self._hp.train_full_seq

        if full_seq:
            return self.full_seq_forward(inputs)

        t0, t1 = self.sample_offsets(inputs.norep_end_ind if 'norep_end_ind' in
                                     inputs else inputs.end_ind)
        im0 = self.index_input(inputs.traj_seq, t0)
        im1 = self.index_input(inputs.traj_seq, t1)
        if 'model_enc_seq' in inputs:
            if self._hp.train_im0_enc and 'enc_traj_seq' in inputs:
                enc_im0 = self.index_input(
                    inputs.enc_traj_seq,
                    t0).reshape(inputs.enc_traj_seq.shape[:1] +
                                (self._hp.nz_enc, ))
            else:
                enc_im0 = self.index_input(inputs.model_enc_seq, t0)
            enc_im1 = self.index_input(inputs.model_enc_seq, t1)
        else:
            assert self._hp.build_encoder  # need encoder if no encoded latents are given
            enc_im0 = self.encoder.forward(im0)[0]
            enc_im1 = self.encoder.forward(im1)[0]

        if self.detach_enc:
            enc_im0 = enc_im0.detach()
            enc_im1 = enc_im1.detach()

        selected_actions = self.index_input(
            inputs.actions, t0, aggregate=self._hp.aggregate_actions, t1=t1)
        selected_states = self.index_input(inputs.traj_seq_states, t0)

        if self._hp.pred_states:
            actions_pred, states_pred = torch.split(
                self.action_pred(enc_im0, enc_im1), 2, 1)
        else:
            actions_pred = self.action_pred(enc_im0, enc_im1)

        output = AttrDict()
        output.actions = remove_spatial(actions_pred)
        output.action_targets = selected_actions
        output.state_targets = selected_states
        output.img_t0, output.img_t1 = im0, im1

        return output
Ejemplo n.º 7
0
 def run_single(self, enc_latent_img0, model_latent_img1):
     """Runs inverse model on first input encoded by encoded and second input produced by model."""
     assert self._hp.train_im0_enc  # inv model needs to be trained from
     return remove_spatial(
         self.action_pred(enc_latent_img0, model_latent_img1))
Ejemplo n.º 8
0
 def forward(self, *inp):
     out = super().forward(*inp)
     return remove_spatial(out, yes=not self.spatial)