def run_encoder(self, inputs, start_ind): if 'demo_seq' in inputs: if not 'enc_demo_seq' in inputs: inputs.enc_demo_seq, inputs.skips = batch_apply(inputs.demo_seq, self.encoder) if self._hp.use_convs and self._hp.use_skips: inputs.skips = map_recursive(lambda s: s[:, 0], inputs.skips) # only use start image activations if self._hp.separate_cnn_start_goal_encoder: enc_e_0, inputs.skips = self.start_goal_enc(inputs.I_0_image) inputs.enc_e_0 = remove_spatial(enc_e_0) inputs.enc_e_g = remove_spatial(self.start_goal_enc(inputs.I_g_image)[0]) else: inputs.enc_e_0, inputs.skips = self.encoder(inputs.I_0) inputs.enc_e_g = self.encoder(inputs.I_g)[0] if 'demo_seq' in inputs: if self._hp.act_cond_inference: inputs.inf_enc_seq = self.inf_encoder(inputs.enc_demo_seq, inputs.actions) elif self._hp.states_inference: inputs.inf_enc_seq = batch_apply((inputs.enc_demo_seq, inputs.demo_seq_states[..., None, None]), self.inf_encoder, separate_arguments=True) else: inputs.inf_enc_seq = self.inf_encoder(inputs.enc_demo_seq) inputs.inf_enc_key_seq = self.inf_key_encoder(inputs.enc_demo_seq) if self._hp.action_conditioned_pred: inputs.enc_action_seq = batch_apply(inputs.actions, self.action_encoder)
def forward(self, inputs, phase='train'): """ forward pass at training time """ if not 'enc_traj_seq' in inputs: enc_traj_seq, _ = self.encoder(inputs.traj_seq[:, 0]) if self._hp.train_first_action_only \ else batch_apply(self.encoder, inputs.traj_seq) if self._hp.train_first_action_only: enc_traj_seq = enc_traj_seq[:, None] enc_traj_seq = enc_traj_seq.detach( ) if self.detach_enc else enc_traj_seq enc_goal, _ = self.encoder(inputs.I_g) n_dim = len(enc_goal.shape) fused_enc = torch.cat((enc_traj_seq, enc_goal[:, None].repeat( 1, enc_traj_seq.shape[1], *([1] * (n_dim - 1)))), dim=2) #fused_enc = torch.cat((enc_traj_seq, enc_goal[:, None].repeat(1, enc_traj_seq.shape[1], 1, 1, 1)), dim=2) if self._hp.reactive: actions_pred = batch_apply(self.policy, fused_enc) else: policy_output = self.policy(fused_enc) actions_pred = policy_output # remove last time step to match ground truth if training on full sequence actions_pred = actions_pred[:, : -1] if not self._hp.train_first_action_only else actions_pred output = AttrDict() output.actions = remove_spatial(actions_pred) if len( actions_pred.shape) > 3 else actions_pred return output
def forward(self, inputs): for k in inputs: if inputs[k] is None: continue if not isinstance(inputs[k], torch.Tensor): inputs[k] = torch.Tensor(inputs[k]) if not inputs[k].device == self.device: inputs[k] = inputs[k].to(self.device) enc, _ = self.encoder(inputs['I_0']) enc_goal, _ = self.encoder(inputs['I_g']) fused_enc = torch.cat((enc, enc_goal), dim=1) if self._hp.reactive: action_pred = self.policy(fused_enc) hidden_var = None else: hidden_var = self.init_hidden_var if inputs.hidden_var is None else inputs.hidden_var policy_output = self.policy(fused_enc[:, None], hidden_var) action_pred, hidden_var = policy_output.output, policy_output.hidden_state[:, 0] if self._hp.use_convs: return remove_spatial(action_pred if len(action_pred.shape) == 4 else action_pred[:, 0]), hidden_var else: return action_pred, hidden_var
def forward(self, e0, eg): """Returns the logits of a OneHotCategorical distribution.""" output = AttrDict() output.seq_len_logits = remove_spatial(self.p(e0, eg)) output.seq_len_pred = OneHotCategorical(logits=output.seq_len_logits) return output
def forward(self, inputs): for k in inputs: if not isinstance(inputs[k], torch.Tensor): inputs[k] = torch.Tensor(inputs[k]) if not inputs[k].device == self.device: inputs[k] = inputs[k].to(self.device) enc_im0 = self.encoder.forward(inputs['img_t0'])[0] enc_im1 = self.encoder.forward(inputs['img_t1'])[0] return remove_spatial(self.action_pred(enc_im0, enc_im1))
def forward(self, inputs, full_seq=None): """ forward pass at training time :arg full_seq: if True, outputs actions for the full sequence, expects input encodings """ if full_seq is None: full_seq = self._hp.train_full_seq if full_seq: return self.full_seq_forward(inputs) t0, t1 = self.sample_offsets(inputs.norep_end_ind if 'norep_end_ind' in inputs else inputs.end_ind) im0 = self.index_input(inputs.traj_seq, t0) im1 = self.index_input(inputs.traj_seq, t1) if 'model_enc_seq' in inputs: if self._hp.train_im0_enc and 'enc_traj_seq' in inputs: enc_im0 = self.index_input( inputs.enc_traj_seq, t0).reshape(inputs.enc_traj_seq.shape[:1] + (self._hp.nz_enc, )) else: enc_im0 = self.index_input(inputs.model_enc_seq, t0) enc_im1 = self.index_input(inputs.model_enc_seq, t1) else: assert self._hp.build_encoder # need encoder if no encoded latents are given enc_im0 = self.encoder.forward(im0)[0] enc_im1 = self.encoder.forward(im1)[0] if self.detach_enc: enc_im0 = enc_im0.detach() enc_im1 = enc_im1.detach() selected_actions = self.index_input( inputs.actions, t0, aggregate=self._hp.aggregate_actions, t1=t1) selected_states = self.index_input(inputs.traj_seq_states, t0) if self._hp.pred_states: actions_pred, states_pred = torch.split( self.action_pred(enc_im0, enc_im1), 2, 1) else: actions_pred = self.action_pred(enc_im0, enc_im1) output = AttrDict() output.actions = remove_spatial(actions_pred) output.action_targets = selected_actions output.state_targets = selected_states output.img_t0, output.img_t1 = im0, im1 return output
def run_single(self, enc_latent_img0, model_latent_img1): """Runs inverse model on first input encoded by encoded and second input produced by model.""" assert self._hp.train_im0_enc # inv model needs to be trained from return remove_spatial( self.action_pred(enc_latent_img0, model_latent_img1))
def forward(self, *inp): out = super().forward(*inp) return remove_spatial(out, yes=not self.spatial)