def full_seq_forward(self, inputs): if 'model_enc_seq' in inputs: enc_seq_1 = inputs.model_enc_seq[:, 1:] if self._hp.train_im0_enc and 'enc_traj_seq' in inputs: enc_seq_0 = inputs.enc_traj_seq.reshape( inputs.enc_traj_seq.shape[:2] + (self._hp.nz_enc, ))[:, :-1] enc_seq_0 = enc_seq_0[:, :enc_seq_1.shape[1]] else: enc_seq_0 = inputs.model_enc_seq[:, :-1] else: enc_seq = batch_apply(self.encoder, inputs.traj_seq) enc_seq_0, enc_seq_1 = enc_seq[:, :-1], enc_seq[:, 1:] if self.detach_enc: enc_seq_0 = enc_seq_0.detach() enc_seq_1 = enc_seq_1.detach() # TODO quite sure the concatenation is automatic actions_pred = batch_apply(self.action_pred, torch.cat([enc_seq_0, enc_seq_1], dim=2)) output = AttrDict() output.actions = actions_pred #remove_spatial(actions_pred) if 'actions' in inputs: output.action_targets = inputs.actions output.pad_mask = inputs.pad_mask return output
def forward(self, inputs, full_seq=None): """ forward pass at training time :arg full_seq: if True, outputs actions for the full sequence, expects input encodings """ if full_seq is None: full_seq = self._hp.train_full_seq if full_seq: return self.full_seq_forward(inputs) t0, t1 = self.sample_offsets(inputs.norep_end_ind if 'norep_end_ind' in inputs else inputs.end_ind) im0 = self.index_input(inputs.traj_seq, t0) im1 = self.index_input(inputs.traj_seq, t1) if 'model_enc_seq' in inputs: if self._hp.train_im0_enc and 'enc_traj_seq' in inputs: enc_im0 = self.index_input( inputs.enc_traj_seq, t0).reshape(inputs.enc_traj_seq.shape[:1] + (self._hp.nz_enc, )) else: enc_im0 = self.index_input(inputs.model_enc_seq, t0) enc_im1 = self.index_input(inputs.model_enc_seq, t1) else: assert self._hp.build_encoder # need encoder if no encoded latents are given enc_im0 = self.encoder.forward(im0)[0] enc_im1 = self.encoder.forward(im1)[0] if self.detach_enc: enc_im0 = enc_im0.detach() enc_im1 = enc_im1.detach() selected_actions = self.index_input( inputs.actions, t0, aggregate=self._hp.aggregate_actions, t1=t1) selected_states = self.index_input(inputs.traj_seq_states, t0) if self._hp.pred_states: actions_pred, states_pred = torch.split( self.action_pred(enc_im0, enc_im1), 2, 1) else: actions_pred = self.action_pred(enc_im0, enc_im1) output = AttrDict() output.actions = remove_spatial(actions_pred) output.action_targets = selected_actions output.state_targets = selected_states output.img_t0, output.img_t1 = im0, im1 return output