コード例 #1
0
ファイル: gcp_model.py プロジェクト: codeaudit/video-gcp
    def log_outputs(self, model_output, inputs, losses, step, log_images, phase):
        super().log_outputs(model_output, inputs, losses, step, log_images, phase)
        if self._hp.attach_inv_mdl:
            self.inv_mdl.log_outputs(model_output, inputs, losses, step, log_images, phase)
        
        if log_images:
            if 'regressed_state' in model_output:
                self._logger.log_maze_topdown(model_output, inputs, "regressed_state_topdown", step, phase,
                                              predictions=model_output.regressed_state, end_inds=inputs.end_ind)
                
            if 'regressed_state' in model_output and self._hp.attach_inv_mdl:
                if len(model_output.actions.shape) == 3:
                    actions = model_output.actions
                else:
                    # Training, need to get the action sequence
                    actions = self.inv_mdl(inputs, full_seq=True).actions
                    
                cum_action_traj = torch.cat((model_output.regressed_state[:, :1], actions), dim=1).cumsum(1)
                self._logger.log_maze_topdown(model_output, inputs, "action_traj_topdown", step, phase,
                                              predictions=cum_action_traj, end_inds=inputs.end_ind)


            if not self._hp.use_convs:
                if self._hp.log_maze_topdown:
                    self._logger.log_maze_topdown(model_output, inputs, "prediction_topdown", step, phase)
                if self._hp.log_states_2d:
                    self._logger.log_states_2d(model_output, inputs, "prediction_states_2d", step, phase)
                if self._hp.log_sawyer:
                    self._logger.log_sawyer(model_output, inputs, "sawyer_from_states", step, phase, self._hp.data_dir)
                if self._hp.log_cartgripper:
                    self._logger.log_cartgripper(model_output, inputs, "cartgripper_from_states", step, phase,
                                                 self._hp.data_dir)
                if self._hp.train_on_action_seqs:
                    action_seq = model_output.dense_rec.images
                    cum_action_seq = torch.cumsum(action_seq, dim=1)
                    self._logger.log_maze_topdown(model_output, inputs, "cum_action_prediction_topdown", step, phase, 
                                                  predictions=cum_action_seq, end_inds=inputs.end_ind)

        if self._hp.dump_encodings:
            os.makedirs(self._logger._log_dir + '/stored_data/', exist_ok=True)
            torch.save(subdict(inputs, ['enc_demo_seq', 'demo_seq', 'demo_seq_states', 'actions']),
                       self._logger._log_dir + '/stored_data/encodings_{}'.format(step))

        if self._hp.dump_encodings_inv_model:
            os.makedirs(self._logger._log_dir + '/stored_data_inv_model/', exist_ok=True)
            torch.save(subdict(inputs, ['model_enc_seq', 'demo_seq_states', 'actions']),
                       self._logger._log_dir + '/stored_data_inv_model/encodings_{}.th'.format(step))
コード例 #2
0
ファイル: hedge.py プロジェクト: codeaudit/video-gcp
 def _log_outputs(model_output, inputs, losses, step, log_images, phase, logger):
     if log_images:
         # Log layerwise loss
         layerwise_keys = ['dense_img_rec', 'kl'] & losses.keys()
         for name, loss in subdict(losses, layerwise_keys).items():
             if len(loss.error_mat.shape) > 2:   # reduce to two dimensions
                 loss.error_mat = loss.error_mat.mean([i for i in range(len(loss.error_mat.shape))][2:])
             layerwise_loss = SubgoalTreeLayer.split_by_layer_bf(loss.error_mat, dim=1)
             layerwise_loss = torch.tensor([l[l != 0].mean() for l in layerwise_loss])
             logger.log_graph(layerwise_loss, '{}_{}'.format(name, 'loss_layerwise'), step, phase)
コード例 #3
0
    def forward(self,
                inputs,
                length,
                initial_inputs=None,
                static_inputs=None,
                initial_seq_inputs={}):
        """
        
        :param inputs: These are sliced by time. Time is the second dimension
        :param length: Rollout length
        :param initial_inputs: These are not sliced and are overridden by cell output
        :param initial_seq_inputs: These can contain partial sequences. Cell output is used after these end.
        :param static_inputs: These are not sliced and can't be overridden by cell output
        :return:
        """
        # NOTE! Unrolling the cell directly will result in crash as the hidden state is not being reset
        # Use this function or CustomLSTMCell.unroll if needed
        initial_inputs, static_inputs = self.assert_begin(
            inputs, initial_inputs, static_inputs)

        step_inputs = initial_inputs.copy()
        step_inputs.update(static_inputs)
        lstm_outputs = []
        for t in range(length):
            step_inputs.update(map_dict(lambda x: x[:, t], inputs))  # Slicing
            step_inputs.update(
                map_dict(
                    lambda x: x[:, t],
                    filter_dict(lambda x: t < x[1].shape[1],
                                initial_seq_inputs)))
            output = self.cell(**step_inputs)

            self.assert_post(output, inputs, initial_inputs, static_inputs)
            # TODO Test what signature does with *args
            autoregressive_output = subdict(
                output,
                output.keys() & signature(self.cell.forward).parameters)
            step_inputs.update(autoregressive_output)
            lstm_outputs.append(output)

        lstm_outputs = rmap_list(lambda *x: stack(x, dim=1), lstm_outputs)

        self.cell.reset()
        return lstm_outputs
コード例 #4
0
ファイル: tree.py プロジェクト: orybkin/video-gcp
 def filter_layerwise_inputs(self, inputs):
     layerwise_input_keys = [
         'z'
     ]  # these inputs are assumed to be depth-first inputs per node in dim 1
     layerwise_inputs = subdict(inputs, layerwise_input_keys, strict=False)
     return layerwise_inputs
コード例 #5
0
ファイル: hedge.py プロジェクト: codeaudit/video-gcp
 def _filter_inputs_for_model(self, inputs, phase):
     keys = ['I_0', 'I_g', 'skips', 'start_ind', 'end_ind', 'enc_e_0', 'enc_e_g', 'z']
     if phase == 'train': keys += ['inf_enc_seq', 'inf_enc_key_seq']
     return subdict(inputs, keys, strict=False)