def log_outputs(self, model_output, inputs, losses, step, log_images, phase): super().log_outputs(model_output, inputs, losses, step, log_images, phase) if self._hp.attach_inv_mdl: self.inv_mdl.log_outputs(model_output, inputs, losses, step, log_images, phase) if log_images: if 'regressed_state' in model_output: self._logger.log_maze_topdown(model_output, inputs, "regressed_state_topdown", step, phase, predictions=model_output.regressed_state, end_inds=inputs.end_ind) if 'regressed_state' in model_output and self._hp.attach_inv_mdl: if len(model_output.actions.shape) == 3: actions = model_output.actions else: # Training, need to get the action sequence actions = self.inv_mdl(inputs, full_seq=True).actions cum_action_traj = torch.cat((model_output.regressed_state[:, :1], actions), dim=1).cumsum(1) self._logger.log_maze_topdown(model_output, inputs, "action_traj_topdown", step, phase, predictions=cum_action_traj, end_inds=inputs.end_ind) if not self._hp.use_convs: if self._hp.log_maze_topdown: self._logger.log_maze_topdown(model_output, inputs, "prediction_topdown", step, phase) if self._hp.log_states_2d: self._logger.log_states_2d(model_output, inputs, "prediction_states_2d", step, phase) if self._hp.log_sawyer: self._logger.log_sawyer(model_output, inputs, "sawyer_from_states", step, phase, self._hp.data_dir) if self._hp.log_cartgripper: self._logger.log_cartgripper(model_output, inputs, "cartgripper_from_states", step, phase, self._hp.data_dir) if self._hp.train_on_action_seqs: action_seq = model_output.dense_rec.images cum_action_seq = torch.cumsum(action_seq, dim=1) self._logger.log_maze_topdown(model_output, inputs, "cum_action_prediction_topdown", step, phase, predictions=cum_action_seq, end_inds=inputs.end_ind) if self._hp.dump_encodings: os.makedirs(self._logger._log_dir + '/stored_data/', exist_ok=True) torch.save(subdict(inputs, ['enc_demo_seq', 'demo_seq', 'demo_seq_states', 'actions']), self._logger._log_dir + '/stored_data/encodings_{}'.format(step)) if self._hp.dump_encodings_inv_model: os.makedirs(self._logger._log_dir + '/stored_data_inv_model/', exist_ok=True) torch.save(subdict(inputs, ['model_enc_seq', 'demo_seq_states', 'actions']), self._logger._log_dir + '/stored_data_inv_model/encodings_{}.th'.format(step))
def _log_outputs(model_output, inputs, losses, step, log_images, phase, logger): if log_images: # Log layerwise loss layerwise_keys = ['dense_img_rec', 'kl'] & losses.keys() for name, loss in subdict(losses, layerwise_keys).items(): if len(loss.error_mat.shape) > 2: # reduce to two dimensions loss.error_mat = loss.error_mat.mean([i for i in range(len(loss.error_mat.shape))][2:]) layerwise_loss = SubgoalTreeLayer.split_by_layer_bf(loss.error_mat, dim=1) layerwise_loss = torch.tensor([l[l != 0].mean() for l in layerwise_loss]) logger.log_graph(layerwise_loss, '{}_{}'.format(name, 'loss_layerwise'), step, phase)
def forward(self, inputs, length, initial_inputs=None, static_inputs=None, initial_seq_inputs={}): """ :param inputs: These are sliced by time. Time is the second dimension :param length: Rollout length :param initial_inputs: These are not sliced and are overridden by cell output :param initial_seq_inputs: These can contain partial sequences. Cell output is used after these end. :param static_inputs: These are not sliced and can't be overridden by cell output :return: """ # NOTE! Unrolling the cell directly will result in crash as the hidden state is not being reset # Use this function or CustomLSTMCell.unroll if needed initial_inputs, static_inputs = self.assert_begin( inputs, initial_inputs, static_inputs) step_inputs = initial_inputs.copy() step_inputs.update(static_inputs) lstm_outputs = [] for t in range(length): step_inputs.update(map_dict(lambda x: x[:, t], inputs)) # Slicing step_inputs.update( map_dict( lambda x: x[:, t], filter_dict(lambda x: t < x[1].shape[1], initial_seq_inputs))) output = self.cell(**step_inputs) self.assert_post(output, inputs, initial_inputs, static_inputs) # TODO Test what signature does with *args autoregressive_output = subdict( output, output.keys() & signature(self.cell.forward).parameters) step_inputs.update(autoregressive_output) lstm_outputs.append(output) lstm_outputs = rmap_list(lambda *x: stack(x, dim=1), lstm_outputs) self.cell.reset() return lstm_outputs
def filter_layerwise_inputs(self, inputs): layerwise_input_keys = [ 'z' ] # these inputs are assumed to be depth-first inputs per node in dim 1 layerwise_inputs = subdict(inputs, layerwise_input_keys, strict=False) return layerwise_inputs
def _filter_inputs_for_model(self, inputs, phase): keys = ['I_0', 'I_g', 'skips', 'start_ind', 'end_ind', 'enc_e_0', 'enc_e_g', 'z'] if phase == 'train': keys += ['inf_enc_seq', 'inf_enc_key_seq'] return subdict(inputs, keys, strict=False)