Пример #1
0
    def forward(self, inputs, full_seq=None):
        outputs = AttrDict()
        
        n_repeats = 32
        if inputs.I_0.shape[0] == inputs.demo_seq.shape[0]:
            # Repeat the I_0 so that the visualization is correct
            # This should only be done once per batch!!
            inputs.I_0 = inputs.I_0.repeat_interleave(n_repeats, 0)
        
        I_target = self.sample_target(inputs.demo_seq, inputs.end_ind, n_repeats)
        inputs.I_target = I_target = I_target.reshape((-1,) + I_target.shape[2:])

        # Note, the decoder supports skips and pixel copying!
        e_0, _ = self.encoder(inputs.I_0)
        e_g, _ = self.encoder(I_target)
        
        outputs.update(self.net(e_g, e_0))
        outputs.e_target_prime = e_target_prime = outputs.pop('mu')
        outputs.I_target_prime = self.decoder(e_target_prime)
        
        return outputs
Пример #2
0
    'n_lstm_layers': 3,
    'nz_mid': 128,
    'nz_enc': 128,
    'nz_vae': 256,
    'regress_length': True,
    'attach_state_regressor': True,
    'attach_inv_mdl': True,
    'inv_mdl_params': AttrDict(
        n_actions=2,
        use_convs=False,
        build_encoder=False,
    ),
    'untied_layers': True,
    'decoder_distribution': 'discrete_logistic_mixture',
})
h_config.pop("add_weighted_pixel_copy")

cem_params = AttrDict(
    prune_final=True,
    horizon=200,
    action_dim=256,
    verbose=True,
    n_iters=1,
    batch_size=5,
    elite_frac=1.0,
    n_level_hierarchy=8,
    sampler=SimpleTreeCEMSampler,
    cost_fcn=ImageWrappedLearnedCostFcn,
    cost_config=AttrDict(
        checkpt_path=os.environ['GCP_EXP_DIR'] + '/prediction/25room/gcp_tree/weights'
    ),
Пример #3
0
class Evaluator:
    """Performs evaluation of metrics etc."""
    N_PLOTTED_ELEMENTS = 5
    LOWER_IS_BETTER_METRICS = ['mse']
    HIGHER_IS_BETTER_METRICS = ['psnr', 'ssim']

    def __init__(self,
                 model,
                 logdir,
                 hp,
                 log_to_file,
                 tb_logger,
                 top_comp_metric='mse'):
        self._logdir = logdir + '/metrics'
        self._logger = FileEvalLogger(
            self._logdir) if log_to_file else TBEvalLogger(logdir, tb_logger)
        self._hp = hp
        self._pruning_scheme = hp.metric_pruning_scheme
        self._dense_rec_module = model.dense_rec
        self.use_images = model._hp.use_convs
        self._top_of_100 = hp.top_of_100_eval
        self._top_of = 100
        self._top_comp_metric = top_comp_metric
        if not os.path.exists(self._logdir): os.makedirs(self._logdir)
        self.evaluation_buffer = None
        self.full_evaluation = None
        self.dummy_env = None

    def reset(self):
        self.evaluation_buffer = None
        self.full_evaluation = None

    def _erase_eval_buffer(self):
        def get_init_array(val):
            n_eval_samples = self._top_of if self._top_of_100 else 1
            return val * np.ones((self._hp.batch_size, n_eval_samples))

        self.evaluation_buffer = AttrDict(
            ssim=get_init_array(0.),
            psnr=get_init_array(0.),
            mse=get_init_array(np.inf),
            gen_images=np.empty(self._hp.batch_size, dtype=np.object),
            rand_seqs=np.empty(self._hp.batch_size, dtype=np.object))
        for b in range(self._hp.batch_size):
            self.evaluation_buffer.rand_seqs[b] = []
        if not self.use_images:
            self.evaluation_buffer.pop('ssim')
            self.evaluation_buffer.pop('psnr')

    def eval_single(self, inputs, model_output, sample_n=0):
        input_images = inputs.demo_seq
        bsize = input_images.shape[0]
        store_states = "demo_seq_states" in inputs and (
            inputs.demo_seq_states.shape[-1] == 2
            or inputs.demo_seq_states.shape[-1] == 5)
        # TODO paralellize DTW
        for b in range(bsize):
            input_seq = input_images[b, :inputs.end_ind[b] + 1]
            input_len = input_seq.shape[0]
            gen_seq, matching_output = self._dense_rec_module.get_sample_with_len(
                b, input_len, model_output, inputs, self._pruning_scheme)
            input_seq, gen_seq = input_seq[1:-1], gen_seq[
                1:
                -1]  # crop first and last frame for eval (conditioning frames)
            state_seq = inputs.demo_seq_states[
                b, :input_len] if store_states else None

            full_gen_seq, gen_seq = self.compute_metrics(
                b, gen_seq, input_seq, model_output, sample_n)

            if self._is_better(
                    self.evaluation_buffer[self._top_comp_metric][b, sample_n],
                    self.evaluation_buffer[self._top_comp_metric][b]):
                # log visualization results for the best sample only, replace if better
                self.evaluation_buffer.gen_images[b] = AttrDict(
                    gt_seq=input_images.cpu().numpy()[b],
                    gen_images=gen_seq,
                    full_gen_seq=full_gen_seq,
                    matching_outputs=matching_output,
                    state_seq=state_seq)

            if sample_n < self.N_PLOTTED_ELEMENTS:
                pred_len = model_output.end_ind[b].data.cpu().numpy(
                ) + 1 if 'end_ind' in model_output else input_len
                pred_len_seq, _ = self._dense_rec_module.get_sample_with_len(
                    b, pred_len, model_output, inputs, self._pruning_scheme)
                self.evaluation_buffer.rand_seqs[b].append(
                    pred_len_seq.data.cpu().numpy())

    def compute_metrics(self, b, gen_seq, input_seq, model_output, sample_n):
        input_seq = input_seq.detach().cpu().numpy()
        gen_seq = gen_seq.detach().cpu().numpy()
        full_gen_seq = torch.stack([n.subgoal.images[b] for n in model_output.tree.depth_first_iter()]) \
            .detach().cpu().numpy() if 'tree' in model_output \
                                       and model_output.tree.subgoals is not None else gen_seq
        self.evaluation_buffer.mse[b, sample_n] = mse(gen_seq, input_seq)
        if 'psnr' in self.evaluation_buffer:
            self.evaluation_buffer.psnr[b, sample_n] = psnr(gen_seq, input_seq)
        if 'ssim' in self.evaluation_buffer:
            self.evaluation_buffer.ssim[b, sample_n] = ssim(gen_seq, input_seq)
        return full_gen_seq, gen_seq

    @timed("Eval time for batch: ")
    def eval(self, inputs, model_output, model):
        self._erase_eval_buffer()
        if self._top_of_100:
            for n in range(self._top_of):
                model_output = model(inputs)
                self.eval_single(inputs, model_output, sample_n=n)
        else:
            self.eval_single(inputs, model_output)
        self._flush_eval_buffer()

    def _flush_eval_buffer(self):
        if self.full_evaluation is None:
            self.full_evaluation = self.evaluation_buffer
        else:
            dict_concat(self.full_evaluation, self.evaluation_buffer)

    def dump_results(self, it):
        self.dump_metrics(it)
        if self.use_images:
            self.dump_seqs(it)
            if 'matching_outputs' in self.full_evaluation.gen_images[0] \
                    and self.full_evaluation.gen_images[0].matching_outputs is not None:
                self.dump_matching_vis(it)
            # if 'tree' in self.full_evaluation.gen_images[0].model_output:
            #     # TODO dump trees for top of 100
            #     self.dump_trees(it)
        self.reset()

    def dump_trees(self, it):
        no_pruning = lambda x, b: False  # show full tree, not pruning anything
        img_dict = self.full_evaluation.gen_images[0]
        plot_matched = img_dict.model_output.tree.match_eval_idx is not None
        assert Evaluator.N_PLOTTED_ELEMENTS <= len(
            img_dict.gen_images
        )  # can currently only max plot as many trees as in batch

        def make_padded_seq_img(tensor, target_width, prepend=0):
            assert len(
                tensor.shape) == 4  # assume [n_frames, channels, res, res]
            n_frames, channels, res, _ = tensor.shape
            seq_im = np.transpose(tensor, (1, 2, 0, 3)).reshape(
                channels, res, n_frames * res)
            concats = [
                np.zeros((channels, res, prepend * res), dtype=np.float32)
            ] if prepend > 0 else []
            concats.extend([
                seq_im,
                np.zeros((channels, res,
                          target_width - seq_im.shape[2] - prepend * res),
                         dtype=np.float32)
            ])
            seq_im = np.concatenate(concats, axis=-1)
            return seq_im

        with self._logger.log_to('trees', it, 'image'):
            tree_imgs = plot_pruned_tree(img_dict.model_output.tree,
                                         no_pruning,
                                         plot_matched).detach().cpu().numpy()
            for i in range(Evaluator.N_PLOTTED_ELEMENTS):
                im = tree_imgs[i]
                if plot_matched:
                    gt_seq_im = make_padded_seq_img(img_dict.gt_seq[i],
                                                    im.shape[-1])
                    pred_seq_im = make_padded_seq_img(
                        img_dict.gen_images[i], im.shape[-1],
                        prepend=1)  # prepend for cropped first frame
                    im = np.concatenate((gt_seq_im, im, pred_seq_im), axis=1)
                im = np.transpose(im, [1, 2, 0])
                self._logger.log(im)

    def dump_metrics(self, it):
        with self._logger.log_to('results', it, 'metric'):
            best_idxs = self._get_best_idxs(
                self.full_evaluation[self._top_comp_metric])
            print_st = []
            for metric in sorted(self.full_evaluation):
                vals = self.full_evaluation[metric]
                if metric in ['psnr', 'ssim', 'mse']:
                    if metric not in self.evaluation_buffer: continue
                    best_vals = batchwise_index(vals, best_idxs)
                    print_st.extend([
                        best_vals.mean(),
                        best_vals.std(),
                        vals.std(axis=1).mean()
                    ])
                    self._logger.log(metric,
                                     vals if self._top_of_100 else None,
                                     best_vals)
            print(*print_st, sep=',')

    def dump_seqs(self, it):
        """Dumps all predicted sequences and all ground truth sequences in separate .npy files"""
        DUMP_KEYS = ['gt_seq', 'gen_images', 'full_gen_seq']
        batch = len(self.full_evaluation.gen_images)
        _, c, h, w = self.full_evaluation.gen_images[0].gt_seq.shape
        stacked_seqs = AttrDict()

        for key in DUMP_KEYS:
            if key == 'full_gen_seq':
                time = max(
                    [i[key].shape[0] for i in self.full_evaluation.gen_images])
            else:
                time = self.full_evaluation.gen_images[0]['gt_seq'].shape[0] - 1
            stacked_seqs[key] = np.zeros(
                (batch, time, c, h, w),
                dtype=self.full_evaluation.gen_images[0][key].dtype)
        for b, seqs in enumerate(self.full_evaluation.gen_images):
            stacked_seqs['gt_seq'][b] = seqs['gt_seq'][
                1:]  # skip the first (conditioning frame)
            stacked_seqs['gen_images'][
                b, :seqs['gen_images'].shape[0]] = seqs['gen_images']
            stacked_seqs['full_gen_seq'][
                b, :seqs['full_gen_seq'].shape[0]] = seqs['full_gen_seq']
        for b, seqs in enumerate(
                self.full_evaluation.rand_seqs[:self.N_PLOTTED_ELEMENTS]):
            key = 'seq_samples_{}'.format(b)
            time = self.full_evaluation.gen_images[0]['gt_seq'].shape[0] - 1
            stacked_seqs[key] = np.zeros(
                (self.N_PLOTTED_ELEMENTS, time, c, h, w),
                dtype=self.full_evaluation.rand_seqs[0][0].dtype)
            for i, seq_i in enumerate(seqs):
                stacked_seqs[key][i, :seq_i.shape[0]] = seq_i[:time]
        for key in DUMP_KEYS:
            with self._logger.log_to(key, it, 'array'):
                self._logger.log(stacked_seqs[key])
        self.dump_gifs(stacked_seqs, it)

        if self._hp.n_rooms is not None and self.full_evaluation.gen_images[
                0].state_seq is not None:
            self.dump_traj_overview(it)

    def dump_matching_vis(self, it):
        """Dumps some visualization of the matching procedure."""
        with self._logger.log_to('matchings', it, 'image'):
            try:
                for i in range(
                        min(Evaluator.N_PLOTTED_ELEMENTS,
                            self.full_evaluation.gen_images.shape[0])):
                    im = self._dense_rec_module.eval_matcher.vis_matching(
                        self.full_evaluation.gen_images[i].matching_outputs)
                    self._logger.log(im)
            except AttributeError:
                print("Matcher does not provide matching visualization")
                pass

    def dump_gifs(self, seqs, it):
        """Dumps gif visualizations of pruned and full sequences."""
        with self._logger.log_to('pruned_seq', it, 'gif'):
            im = make_gif(
                [torch.Tensor(seqs.gt_seq), (torch.Tensor(seqs.gen_images))])
            self._logger.log(im)
        with self._logger.log_to('full_gen_seq', it, 'gif'):
            im = make_gif([torch.Tensor(seqs.full_gen_seq)])
            self._logger.log(im)
        for key in seqs:
            if 'seq_samples' in key:
                with self._logger.log_to(key, it, 'gif'):
                    im = make_gif([torch.Tensor(seqs[key])])
                    self._logger.log(im)

    def dump_traj_overview(self, it):
        """Dumps top-down overview of trajectories in Multiroom datasets."""
        from gcp.infra.envs.miniworld_env.multiroom3d.multiroom3d_env import Multiroom3dEnv
        if self.dummy_env is None:
            self.dummy_env = Multiroom3dEnv({'n_rooms': self._hp.n_rooms},
                                            no_env=True)
        with self._logger.log_to('trajectories', it, 'image'):
            for b in range(
                    min(Evaluator.N_PLOTTED_ELEMENTS,
                        self.full_evaluation.gen_images.shape[0])):
                im = self.dummy_env.render_top_down(
                    self.full_evaluation.gen_images[b].state_seq.data.cpu(
                    ).numpy())
                self._logger.log(im * 2 - 1)

    def _is_better(self, val, other):
        """Comparison function for different metrics.
           returns True if val is "better" than any of the values in the array other
        """
        if self._top_comp_metric in self.LOWER_IS_BETTER_METRICS:
            return np.all(val <= other)
        elif self._top_comp_metric in self.HIGHER_IS_BETTER_METRICS:
            return np.all(val >= other)
        else:
            raise ValueError(
                "Currently only support comparison on the following metrics: {}. Got {}."
                .format(
                    self.LOWER_IS_BETTER_METRICS +
                    self.HIGHER_IS_BETTER_METRICS, self._top_comp_metric))

    def _get_best_idxs(self, vals):
        assert len(
            vals.shape
        ) == 2  # assumes batch in first dimension, N samples in second dim
        if self._top_comp_metric in self.LOWER_IS_BETTER_METRICS:
            return np.argmin(vals, axis=1)
        else:
            return np.argmax(vals, axis=1)