コード例 #1
0
    def test(self, task_name=''):
        if self.pred_model is not None:
            self.pred_model.eval()

        pos, num_sample = 0, 0
        # sample actions
        acts = self.enumerate_actions()
        for batch_idx, (data, boxes, gt_acts, labels, _, _, _, _) in enumerate(self.val_loader):
            batch_size = data.shape[0]
            num_objs = gt_acts.shape[1]
            pos_feat = xyxy_to_posf(boxes, data.shape)
            rois = xyxy_to_rois(boxes, batch_size, data.shape[1], self.num_gpus)
            gt_rois = boxes.cpu().numpy().copy()

            pred_acts = np.zeros((batch_size, num_objs, 3))
            conf_acts = -np.inf * np.ones((batch_size,))

            # if self.random_policy:
            #     for i in range(batch_size):
            #         obj_id = np.random.randint(num_objs)
            #         pred_acts[i, obj_id] = acts[np.random.randint(len(acts))]

            for idx, (act, obj_id) in enumerate(itertools.product(acts, range(num_objs))):
                tprint(f'current batch: {idx} / {acts.shape[0] * num_objs}' + ' ' * 10)
                act_array = torch.zeros((batch_size, num_objs, 3), dtype=torch.float32)
                act_array[:, obj_id, :] = torch.from_numpy(act)
                traj_array = self.generate_trajs(data, rois, pos_feat, act_array, boxes)
                conf = self.get_act_conf(traj_array, gt_rois, obj_id)
                pred_acts[conf > conf_acts] = act_array[conf > conf_acts]
                conf_acts[conf > conf_acts] = conf[conf > conf_acts]

                # for i in range(C.SOLVER.BATCH_SIZE):
                #     plot_image_idx = C.SOLVER.BATCH_SIZE * batch_idx + i
                #     video_idx, img_idx = self.val_loader.dataset.video_info[plot_image_idx]
                #     video_name = self.val_loader.dataset.video_list[video_idx]
                #     search_suffix = self.val_loader.dataset.search_suffix
                #     image_list = sorted(glob(f'{video_name}/{search_suffix}'))
                #     im_name = image_list[img_idx]
                #     video_id, image_id = im_name.split('.')[0].split('/')[-2:]
                #     output_name = f'{video_id}_{image_id}_{idx}'
                #     im_data = get_im_data(im_name, gt_rois[[i], 0:1], C.DATA_ROOT, False)
                #     plt.axis('off')
                #     plt.imshow(im_data[..., ::-1])
                #     _plot_bbox_traj(traj_array[i], size=160, alpha=1.0)
                #     x = gt_rois[i, 0, obj_id, 0] + self.ball_radius
                #     y = gt_rois[i, 0, obj_id, 1] + self.ball_radius
                #     dy, dx = act[0] * act[2], act[1] * act[2]
                #     plt.arrow(x, y, dx, dy, color=(0.99, 0.99, 0.99), linewidth=5)
                #     os.makedirs(f'{self.output_dir}/plan', exist_ok=True)
                #     kwargs = {'format': 'svg', 'bbox_inches': 'tight', 'pad_inches': 0}
                #     plt.savefig(f'{self.output_dir}/plan/pred_{output_name}.svg', **kwargs)
                #     plt.close()

            sim_rst, debug_gt_traj_array = self.simulate_action(gt_rois, pred_acts)
            pos += sim_rst.sum()
            num_sample += sim_rst.shape[0]
            pprint(f'{task_name} {batch_idx}/{len(self.val_loader)}: {pos / num_sample:.4f}' + ' ' * 10)
        pprint(f'{task_name}: {pos / num_sample:.4f}' + ' ' * 10)
コード例 #2
0
    def __init__(self, data_root, split, image_ext='.jpg'):
        super().__init__(data_root, split, image_ext)

        self.video_list = sorted(glob(f'{self.data_root}/{self.split}/*/'))
        self.anno_list = [v[:-1] + '_boxes.pkl' for v in self.video_list]

        self.video_info = np.zeros((0, 2), dtype=np.int32)
        for idx, video_name in enumerate(self.video_list):
            tprint(f'loading progress: {idx}/{len(self.video_list)}')
            num_im = len(glob(f'{video_name}/*{image_ext}'))
            if self.split == 'test':
                num_sw = min(1, num_im - self.seq_size + 1)
            else:
                num_sw = num_im - self.seq_size + 1  # number of sliding windows
            if num_sw <= 0:
                continue
            video_info_t = np.zeros((num_sw, 2), dtype=np.int32)
            video_info_t[:, 0] = idx  # video index
            video_info_t[:, 1] = np.arange(num_sw)  # sliding window index
            self.video_info = np.vstack((self.video_info, video_info_t))
コード例 #3
0
def _eval_and_score_actions(cache, model, simulator, num_actions, batch_size,
                            observations):
    actions = cache.action_array[:num_actions]
    indices = np.random.RandomState(1).permutation(
        len(observations))[:AUCCESS_EVAL_TASKS]
    evaluator = phyre.Evaluator(
        [simulator.task_ids[index] for index in indices])
    for i, task_index in enumerate(indices):
        tprint(f'{i}/{len(indices)} {task_index}')
        scores = eval_actions(model, actions, batch_size,
                              observations[task_index]).tolist()

        _, sorted_actions = zip(
            *sorted(zip(scores, actions), key=lambda x: (-x[0], tuple(x[1]))))
        for action in sorted_actions:
            if (evaluator.get_attempts_for_task(i) >= phyre.MAX_TEST_ATTEMPTS):
                break
            status = simulator.simulate_action(task_index,
                                               action,
                                               need_images=False).status
            evaluator.maybe_log_attempt(i, status)
    return evaluator.get_aucess()
コード例 #4
0
    def __init__(self, data_root, split, image_ext='.jpg'):
        super().__init__(data_root, split, image_ext)

        if C.INPUT.PRELOAD_TO_MEMORY:
            print('loading data from pickle...')
            with open(f'{self.data_root}/{self.split}.pkl', 'rb') as f:
                data = pickle.load(f)
            self.total_img = np.transpose(data['X'], (0, 1, 4, 2, 3))
            self.total_box = np.zeros((data['y'].shape[:3] + (5, )))
            assert len(self.anno_list) > 0
            for anno_idx, anno_name in enumerate(self.anno_list):
                tprint(f'loading progress: {anno_idx}/{len(self.anno_list)}')
                with open(anno_name, 'rb') as f:
                    boxes = pickle.load(f)
                self.total_box[anno_idx] = boxes
        else:
            self.video_list = sorted(glob(f'{self.data_root}/{self.split}/*/'))
            self.anno_list = [v[:-1] + '.pkl' for v in self.video_list]

        self.video_info = np.zeros((0, 2), dtype=np.int32)
        for idx, video_name in enumerate(
                self.video_list if not C.INPUT.PRELOAD_TO_MEMORY else self.
                total_box):
            tprint(f'loading progress: {idx}/{len(self.video_list)}')
            if C.INPUT.PRELOAD_TO_MEMORY:
                num_sw = self.total_box[idx].shape[0] - self.seq_size + 1
            else:
                num_im = len(glob(f'{video_name}/*{image_ext}'))
                num_sw = num_im - self.seq_size + 1  # number of sliding windows

            if num_sw <= 0:
                continue
            video_info_t = np.zeros((num_sw, 2), dtype=np.int32)
            video_info_t[:, 0] = idx  # video index
            video_info_t[:, 1] = np.arange(num_sw)  # sliding window index
            self.video_info = np.vstack((self.video_info, video_info_t))
コード例 #5
0
    def __init__(self, data_root, split, protocal='within', image_ext='.jpg'):
        super().__init__(data_root, split, image_ext)

        env_list = open(f'{data_root}/{protocal}_{split}_fold_0.txt',
                        'r').read().split('\n')
        self.video_list = sum([
            sorted(glob(f'{data_root}/images/{env.replace(":", "/")}/*.npy'))
            for env in env_list
        ], [])
        self.anno_list = [(v[:-4] + '_boxes.hkl').replace('images', 'labels')
                          for v in self.video_list]

        # just for plot images
        if plot:
            self.video_list = [
                k for k in self.video_list
                if int(k.split('/')[-1].split('.')[0]) < 40
            ]
            self.anno_list = [
                k for k in self.anno_list
                if int(k.split('/')[-1].split('_')[0]) < 40
            ]
            assert len(self.video_list) == len(self.anno_list)
            self.video_list = self.video_list[::80]
            self.anno_list = self.anno_list[::80]

        # video_info_name = f'for_plot.npy'
        video_info_name = f'{data_root}/{protocal}_{split}_{self.input_size}_{self.pred_size}_fold_0_info.npy'
        if os.path.exists(video_info_name):
            print(f'loading info from: {video_info_name}')
            self.video_info = np.load(video_info_name)
        else:
            self.video_info = np.zeros((0, 2), dtype=np.int32)
            for idx, video_name in enumerate(self.video_list):
                tprint(f'loading progress: {idx}/{len(self.video_list)}')
                num_im = hickle.load(
                    video_name.replace('images', 'labels').replace(
                        '.npy', '_boxes.hkl')).shape[0]
                num_sw = num_im - self.seq_size + 1  # number of sliding windows
                if plot:
                    num_sw = 1
                if self.input_size == 1:
                    num_sw = min(1, num_im - self.seq_size + 1)

                if num_sw <= 0:
                    continue
                video_info_t = np.zeros((num_sw, 2), dtype=np.int32)
                video_info_t[:, 0] = idx  # video index
                video_info_t[:, 1] = np.arange(num_sw)  # sliding window index
                self.video_info = np.vstack((self.video_info, video_info_t))
            np.save(video_info_name, self.video_info)

        for module in C.SINGULAR_MODULES:
            self.module_dict[module] = np.load(
                f'{data_root}/{module}/thresh_{C.MASK_THRESH}/{protocal}_{split}_{self.input_size}_{self.pred_size}_fold_0_info.npy'
            )
        for module in C.DUAL_MODULES:
            self.module_dict[module] = np.load(
                f'{data_root}/{module}/{C.DUAL_DATATYPE}/thresh_{C.MASK_THRESH}/{protocal}_{split}_{self.input_size}_{self.pred_size}_fold_0_info.npy'
            )

        # load GT indicators and object info tags
        self.objinfo = np.load(
            f'{data_root}/thresh/{C.MASK_THRESH}/{protocal}_{split}_{self.input_size}_{self.pred_size}_fold_0_objinfo.npy'
        )
        self.gtindicatorinfo = np.load(
            f'{data_root}/thresh/{C.MASK_THRESH}/{protocal}_{split}_{self.input_size}_{self.pred_size}_fold_0_gtindicatorinfo.npy'
        )
コード例 #6
0
    def train_epoch(self):
        for batch_idx, (data, data_pred, data_t, env_name, rois, gt_boxes,
                        gt_masks, valid, module_valid, g_idx, seq_l, objinfo,
                        gtindicator) in enumerate(self.train_loader):
            self._adjust_learning_rate()

            if C.RIN.ROI_MASKING or C.RIN.ROI_CROPPING:
                # data should be (b x t x o x c x h x w)
                data = data.permute((0, 2, 1, 3, 4, 5))  # (b, o, t, c, h, w)
                data = data.reshape((data.shape[0] * data.shape[1], ) +
                                    data.shape[2:])  # (b*o, t, c, h, w)

            data, data_t = data.to(self.device), data_t.to(self.device)
            pos_feat = xyxy_to_posf(rois, data.shape)
            rois = xyxy_to_rois(rois,
                                batch=data.shape[0],
                                time_step=data.shape[1],
                                num_devices=self.num_gpus)
            self.optim.zero_grad()

            outputs = self.model(data,
                                 rois,
                                 pos_feat,
                                 valid,
                                 num_rollouts=self.ptrain_size,
                                 g_idx=g_idx,
                                 x_t=data_t,
                                 phase='train')
            labels = {
                'boxes': gt_boxes.to(self.device),
                'masks': gt_masks.to(self.device),
                'valid': valid.to(self.device),
                'module_valid': module_valid.to(self.device),
                'seq_l': seq_l.to(self.device),
                'gt_indicators': gtindicator.to(self.device),
            }
            loss = self.loss(outputs, labels, 'train')
            loss.backward()
            self.optim.step()
            # this is an approximation for printing; the dataset size may not divide the batch size
            self.iterations += self.batch_size

            print_msg = ""
            print_msg += f"{self.epochs:03}/{self.iterations // 1000:04}k"
            print_msg += f" | "
            mean_loss = np.mean(
                np.array(self.box_p_step_losses[:self.ptrain_size]) /
                self.loss_cnt) * 1e3
            print_msg += f"{mean_loss:.3f} | "
            print_msg += f" | ".join([
                "{:.3f}".format(self.losses[name] * 1e3 / self.loss_cnt)
                for name in self.loss_name
            ])
            print_msg += " || {:.4f}".format(self.loss_ind)
            speed = self.loss_cnt / (timer() - self.time)
            eta = (self.max_iters - self.iterations) / speed / 3600
            print_msg += f" | speed: {speed:.1f} | eta: {eta:.2f} h"
            print_msg += (
                " " * (os.get_terminal_size().columns - len(print_msg) - 10))
            tprint(print_msg)

            if self.iterations % self.val_interval == 0:
                self.snapshot()
                self.val()
                self._init_loss()
                self.model.train()

            if self.iterations >= self.max_iters:
                print('\r', end='')
                print(f'{self.best_mean:.3f}')
                break
コード例 #7
0
    def val(self):
        self.model.eval()
        self._init_loss()

        if C.RIN.VAE:
            losses = dict.fromkeys(self.loss_name, 0.0)
            box_p_step_losses = [0.0 for _ in range(self.ptest_size)]
            masks_step_losses = [0.0 for _ in range(self.ptest_size)]

        for batch_idx, (data, data_pred, data_t, env_name, rois, gt_boxes,
                        gt_masks, valid, module_valid, g_idx, seq_l, objinfo,
                        gtindicator) in enumerate(self.val_loader):
            tprint(f'eval: {batch_idx}/{len(self.val_loader)}')
            with torch.no_grad():

                if C.RIN.ROI_MASKING or C.RIN.ROI_CROPPING:
                    data = data.permute((0, 2, 1, 3, 4, 5))
                    data = data.reshape((data.shape[0] * data.shape[1], ) +
                                        data.shape[2:])

                data = data.to(self.device)
                pos_feat = xyxy_to_posf(rois, data.shape)
                rois = xyxy_to_rois(rois,
                                    batch=data.shape[0],
                                    time_step=data.shape[1],
                                    num_devices=self.num_gpus)
                labels = {
                    'boxes': gt_boxes.to(self.device),
                    'masks': gt_masks.to(self.device),
                    'valid': valid.to(self.device),
                    'module_valid': module_valid.to(self.device),
                    'seq_l': seq_l.to(self.device),
                    'gt_indicators': gtindicator.to(self.device),
                }

                outputs = self.model(data,
                                     rois,
                                     pos_feat,
                                     valid,
                                     num_rollouts=self.ptest_size,
                                     g_idx=g_idx,
                                     phase='test')
                self.loss(outputs, labels, 'test')
                # VAE multiple runs
                if C.RIN.VAE:
                    vae_best_mean = np.mean(
                        np.array(self.box_p_step_losses[:self.ptest_size]) /
                        self.loss_cnt) * 1e3
                    losses_t = self.losses.copy()
                    box_p_step_losses_t = self.box_p_step_losses.copy()
                    masks_step_losses_t = self.masks_step_losses.copy()
                    for i in range(9):
                        outputs = self.model(data,
                                             rois,
                                             pos_feat,
                                             valid,
                                             num_rollouts=self.ptest_size,
                                             g_idx=g_idx,
                                             phase='test')
                        self.loss(outputs, labels, 'test')
                        mean_loss = np.mean(
                            np.array(self.box_p_step_losses[:self.ptest_size])
                            / self.loss_cnt) * 1e3
                        if mean_loss < vae_best_mean:
                            losses_t = self.losses.copy()
                            box_p_step_losses_t = self.box_p_step_losses.copy()
                            masks_step_losses_t = self.masks_step_losses.copy()
                            vae_best_mean = mean_loss
                        self._init_loss()

                    for k, v in losses.items():
                        losses[k] += losses_t[k]
                    for i in range(len(box_p_step_losses)):
                        box_p_step_losses[i] += box_p_step_losses_t[i]
                        masks_step_losses[i] += masks_step_losses_t[i]

        if C.RIN.VAE:
            self.losses = losses.copy()
            self.box_p_step_losses = box_p_step_losses.copy()
            self.loss_cnt = len(self.val_loader)

        print('\r', end='')
        print_msg = ""
        print_msg += f"{self.epochs:03}/{self.iterations // 1000:04}k"
        print_msg += f" | "
        mean_loss = np.mean(
            np.array(self.box_p_step_losses[:self.ptest_size]) /
            self.loss_cnt) * 1e3
        print_msg += f"{mean_loss:.3f} | "

        if mean_loss < self.best_mean:
            self.snapshot('ckpt_best.path.tar')
            self.best_mean = mean_loss

        print_msg += f" | ".join([
            "{:.3f}".format(self.losses[name] * 1e3 / self.loss_cnt)
            for name in self.loss_name
        ])
        print_msg += " || {:.4f}".format(self.loss_ind)
        print_msg += (" " *
                      (os.get_terminal_size().columns - len(print_msg) - 10))
        self.logger.info(print_msg)
コード例 #8
0
    def test(self):
        self.model.eval()

        if C.RIN.VAE:
            losses = dict.fromkeys(self.loss_name, 0.0)
            box_p_step_losses = [0.0 for _ in range(self.ptest_size)]
            masks_step_losses = [0.0 for _ in range(self.ptest_size)]

        for batch_idx, (data, data_pred, data_t, env_name, rois, gt_boxes,
                        gt_masks, valid, module_valid, g_idx,
                        _) in enumerate(self.val_loader):
            with torch.no_grad():

                # decide module_valid here for evaluation
                mid = 0  # ball-only
                module_valid = module_valid[:, mid, :, :]

                if C.RIN.ROI_MASKING or C.RIN.ROI_CROPPING:
                    # data should be (b x t x o x c x h x w)
                    data = data.permute(
                        (0, 2, 1, 3, 4, 5))  # (b, o, t, c, h, w)
                    data = data.reshape((data.shape[0] * data.shape[1], ) +
                                        data.shape[2:])  # (b*o, t, c, h, w)

                data = data.to(self.device)
                pos_feat = xyxy_to_posf(rois, data.shape)
                rois = xyxy_to_rois(rois,
                                    batch=data.shape[0],
                                    time_step=data.shape[1],
                                    num_devices=self.num_gpus)
                labels = {
                    'boxes': gt_boxes.to(self.device),
                    'masks': gt_masks.to(self.device),
                    'valid': valid.to(self.device),
                    'module_valid': module_valid.to(self.device),
                }
                outputs = self.model(data,
                                     rois,
                                     pos_feat,
                                     valid,
                                     num_rollouts=self.ptest_size,
                                     g_idx=g_idx,
                                     phase='test')

                # # *********************************************************************************
                # # VISUALIZATION - generate input image and GT outputs + model outputs
                # input_data = data.cpu().detach().numpy()    # (128, 1, 3, 128, 128)
                # gt_data = data_pred.cpu().detach().numpy()    # (128, 10, 3, 128, 128)
                # data_t = data_t.cpu().detach().numpy()    # (128, 1, 128, 128)
                # validity = valid.cpu().detach().numpy()    # (128, 6)
                # outputs_boxes = outputs['boxes'].cpu().detach().numpy()    # (128, 10, 6, 4)
                # outputs_masks = outputs['masks'].cpu().detach().numpy()    # (128, 10, 6, 21, 21)
                # np.save('save/'+str(batch_idx)+'input_data.npy', input_data)
                # np.save('save/'+str(batch_idx)+'gt_data.npy', gt_data)
                # np.save('save/'+str(batch_idx)+'data_t.npy', data_t)
                # np.save('save/'+str(batch_idx)+'validity.npy', validity)
                # np.save('save/'+str(batch_idx)+'outputs_boxes.npy', outputs_boxes)
                # np.save('save/'+str(batch_idx)+'outputs_masks.npy', outputs_masks)
                # # self.visualize_results(input_data, gt_data, outputs, data_t, validity, env_name)
                # # *********************************************************************************

                self.loss(outputs, labels, 'test')
                # VAE multiple runs
                if C.RIN.VAE:
                    vae_best_mean = np.mean(
                        np.array(self.box_p_step_losses[:self.ptest_size]) /
                        self.loss_cnt) * 1e3
                    losses_t = self.losses.copy()
                    box_p_step_losses_t = self.box_p_step_losses.copy()
                    masks_step_losses_t = self.masks_step_losses.copy()
                    for i in range(99):
                        outputs = self.model(data,
                                             rois,
                                             None,
                                             num_rollouts=self.ptest_size,
                                             g_idx=g_idx,
                                             phase='test')
                        self.loss(outputs, labels, 'test')
                        mean_loss = np.mean(
                            np.array(self.box_p_step_losses[:self.ptest_size])
                            / self.loss_cnt) * 1e3
                        if mean_loss < vae_best_mean:
                            losses_t = self.losses.copy()
                            box_p_step_losses_t = self.box_p_step_losses.copy()
                            masks_step_losses_t = self.masks_step_losses.copy()
                            vae_best_mean = mean_loss
                        self._init_loss()

                    for k, v in losses.items():
                        losses[k] += losses_t[k]
                    for i in range(len(box_p_step_losses)):
                        box_p_step_losses[i] += box_p_step_losses_t[i]
                        masks_step_losses[i] += masks_step_losses_t[i]

                tprint(f'eval: {batch_idx}/{len(self.val_loader)}:' + ' ' * 20)

            if self.plot_image > 0:
                outputs = {
                    'boxes':
                    outputs['boxes'].cpu().numpy(),
                    'masks':
                    outputs['masks'].cpu().numpy()
                    if C.RIN.MASK_LOSS_WEIGHT else None,
                }
                outputs['boxes'][..., 0::2] *= self.input_width
                outputs['boxes'][..., 1::2] *= self.input_height
                outputs['boxes'] = xywh2xyxy(outputs['boxes'].reshape(
                    -1, 4)).reshape((data.shape[0], -1, C.RIN.NUM_OBJS, 4))

                labels = {
                    'boxes': labels['boxes'].cpu().numpy(),
                    'masks': labels['masks'].cpu().numpy(),
                }
                labels['boxes'][..., 0::2] *= self.input_width
                labels['boxes'][..., 1::2] *= self.input_height
                labels['boxes'] = xywh2xyxy(labels['boxes'].reshape(
                    -1, 4)).reshape((data.shape[0], -1, C.RIN.NUM_OBJS, 4))

                for i in range(rois.shape[0]):
                    batch_size = C.SOLVER.BATCH_SIZE if not C.RIN.VAE else 1
                    plot_image_idx = batch_size * batch_idx + i
                    if plot_image_idx < self.plot_image:
                        tprint(f'plotting: {plot_image_idx}' + ' ' * 20)
                        video_idx, img_idx = self.val_loader.dataset.video_info[
                            plot_image_idx]
                        video_name = self.val_loader.dataset.video_list[
                            video_idx]

                        v = valid[i].numpy().astype(np.bool)
                        pred_boxes_i = outputs['boxes'][i][:, v]
                        gt_boxes_i = labels['boxes'][i][:, v]

                        if 'PHYRE' in C.DATA_ROOT:
                            # [::-1] is to make it consistency with others where opencv is used
                            im_data = phyre.observations_to_float_rgb(
                                np.load(video_name).astype(
                                    np.uint8))[..., ::-1]
                            a, b, c = video_name.split('/')[5:8]
                            output_name = f'{a}_{b}_{c.replace(".npy", "")}'

                            bg_image = np.load(video_name).astype(np.uint8)
                            for fg_id in [1, 2, 3, 5]:
                                bg_image[bg_image == fg_id] = 0
                            bg_image = phyre.observations_to_float_rgb(
                                bg_image)

                            # if f'{a}_{b}' not in [
                            #     '00014_123', '00014_528', '00015_257', '00015_337', '00019_273', '00019_296'
                            # ]:
                            #     continue

                            # if f'{a}_{b}' not in [
                            #     '00000_069', '00001_000', '00002_185', '00003_064', '00004_823',
                            #     '00005_111', '00006_033', '00007_090', '00008_177', '00009_930',
                            #     '00010_508', '00011_841', '00012_071', '00013_074', '00014_214',
                            #     '00015_016', '00016_844', '00017_129', '00018_192', '00019_244',
                            #     '00020_010', '00021_115', '00022_537', '00023_470', '00024_048'
                            # ]:
                            #     continue
                        else:
                            bg_image = None
                            image_list = sorted(
                                glob(
                                    f'{video_name}/*{self.val_loader.dataset.image_ext}'
                                ))
                            im_name = image_list[img_idx]
                            output_name = '_'.join(
                                im_name.split('.')[0].split('/')[-2:])
                            # deal with image data here
                            # plot rollout function only take care of the usage of plt
                            # if output_name not in ['009_015', '009_031', '009_063', '039_038', '049_011', '059_033']:
                            #     continue
                            # if output_name not in ['00002_00037', '00008_00047', '00011_00048', '00013_00036',
                            #                        '00014_00033', '00020_00054', '00021_00013', '00024_00011']:
                            #     continue
                            if output_name not in [
                                    '0016_000', '0045_000', '0120_000',
                                    '0163_000'
                            ]:
                                continue
                            gt_boxes_i = labels['boxes'][i][:, v]
                            im_data = get_im_data(im_name, gt_boxes_i[None,
                                                                      0:1],
                                                  C.DATA_ROOT,
                                                  self.high_resolution_plot)

                        if self.high_resolution_plot:
                            scale_w = im_data.shape[1] / self.input_width
                            scale_h = im_data.shape[0] / self.input_height
                            pred_boxes_i[..., [0, 2]] *= scale_w
                            pred_boxes_i[..., [1, 3]] *= scale_h
                            gt_boxes_i[..., [0, 2]] *= scale_w
                            gt_boxes_i[..., [1, 3]] *= scale_h

                        pred_masks_i = None
                        if C.RIN.MASK_LOSS_WEIGHT:
                            pred_masks_i = outputs['masks'][i][:, v]

                        plot_rollouts(im_data,
                                      pred_boxes_i,
                                      gt_boxes_i,
                                      pred_masks_i,
                                      labels['masks'][i][:, v],
                                      output_dir=self.output_dir,
                                      output_name=output_name,
                                      bg_image=bg_image)

        if C.RIN.VAE:
            self.losses = losses.copy()
            self.box_p_step_losses = box_p_step_losses.copy()
            self.loss_cnt = len(self.val_loader)

        print('\r', end='')
        print_msg = ""
        mean_loss = np.mean(
            np.array(self.box_p_step_losses[:self.ptest_size]) /
            self.loss_cnt) * 1e3
        print_msg += f"{mean_loss:.3f} | "
        print_msg += f" | ".join([
            "{:.3f}".format(self.losses[name] * 1e3 / self.loss_cnt)
            for name in self.loss_name
        ])
        pprint(print_msg)