示例#1
0
    def test(self, task_name=''):
        if self.pred_model is not None:
            self.pred_model.eval()

        pos, num_sample = 0, 0
        # sample actions
        acts = self.enumerate_actions()
        for batch_idx, (data, boxes, gt_acts, labels, _, _, _, _) in enumerate(self.val_loader):
            batch_size = data.shape[0]
            num_objs = gt_acts.shape[1]
            pos_feat = xyxy_to_posf(boxes, data.shape)
            rois = xyxy_to_rois(boxes, batch_size, data.shape[1], self.num_gpus)
            gt_rois = boxes.cpu().numpy().copy()

            pred_acts = np.zeros((batch_size, num_objs, 3))
            conf_acts = -np.inf * np.ones((batch_size,))

            # if self.random_policy:
            #     for i in range(batch_size):
            #         obj_id = np.random.randint(num_objs)
            #         pred_acts[i, obj_id] = acts[np.random.randint(len(acts))]

            for idx, (act, obj_id) in enumerate(itertools.product(acts, range(num_objs))):
                tprint(f'current batch: {idx} / {acts.shape[0] * num_objs}' + ' ' * 10)
                act_array = torch.zeros((batch_size, num_objs, 3), dtype=torch.float32)
                act_array[:, obj_id, :] = torch.from_numpy(act)
                traj_array = self.generate_trajs(data, rois, pos_feat, act_array, boxes)
                conf = self.get_act_conf(traj_array, gt_rois, obj_id)
                pred_acts[conf > conf_acts] = act_array[conf > conf_acts]
                conf_acts[conf > conf_acts] = conf[conf > conf_acts]

                # for i in range(C.SOLVER.BATCH_SIZE):
                #     plot_image_idx = C.SOLVER.BATCH_SIZE * batch_idx + i
                #     video_idx, img_idx = self.val_loader.dataset.video_info[plot_image_idx]
                #     video_name = self.val_loader.dataset.video_list[video_idx]
                #     search_suffix = self.val_loader.dataset.search_suffix
                #     image_list = sorted(glob(f'{video_name}/{search_suffix}'))
                #     im_name = image_list[img_idx]
                #     video_id, image_id = im_name.split('.')[0].split('/')[-2:]
                #     output_name = f'{video_id}_{image_id}_{idx}'
                #     im_data = get_im_data(im_name, gt_rois[[i], 0:1], C.DATA_ROOT, False)
                #     plt.axis('off')
                #     plt.imshow(im_data[..., ::-1])
                #     _plot_bbox_traj(traj_array[i], size=160, alpha=1.0)
                #     x = gt_rois[i, 0, obj_id, 0] + self.ball_radius
                #     y = gt_rois[i, 0, obj_id, 1] + self.ball_radius
                #     dy, dx = act[0] * act[2], act[1] * act[2]
                #     plt.arrow(x, y, dx, dy, color=(0.99, 0.99, 0.99), linewidth=5)
                #     os.makedirs(f'{self.output_dir}/plan', exist_ok=True)
                #     kwargs = {'format': 'svg', 'bbox_inches': 'tight', 'pad_inches': 0}
                #     plt.savefig(f'{self.output_dir}/plan/pred_{output_name}.svg', **kwargs)
                #     plt.close()

            sim_rst, debug_gt_traj_array = self.simulate_action(gt_rois, pred_acts)
            pos += sim_rst.sum()
            num_sample += sim_rst.shape[0]
            pprint(f'{task_name} {batch_idx}/{len(self.val_loader)}: {pos / num_sample:.4f}' + ' ' * 10)
        pprint(f'{task_name}: {pos / num_sample:.4f}' + ' ' * 10)
示例#2
0
    def generate_trajs(self, data, boxes):
        with torch.no_grad():
            num_objs = boxes.shape[2]
            g_idx = np.array([[i, j, 1] for i in range(num_objs) for j in range(num_objs) if j != i])
            g_idx = torch.from_numpy(g_idx[None].repeat(data.shape[0], 0))
            rois = xyxy_to_rois(boxes, batch=data.shape[0], time_step=data.shape[1], num_devices=self.num_gpus)
            pos_feat = xyxy_to_posf(rois, data.shape)
            outputs = self.model(data, rois, pos_feat, num_rollouts=self.pred_rollout, g_idx=g_idx)
            outputs = {
                'boxes': outputs['boxes'].cpu().numpy(),
                'masks': outputs['masks'].cpu().numpy(),
            }
            outputs['boxes'][..., 0::2] *= self.input_width
            outputs['boxes'][..., 1::2] *= self.input_height
            outputs['boxes'] = xywh2xyxy(
                outputs['boxes'].reshape(-1, 4)
            ).reshape((data.shape[0], -1, num_objs, 4))

        return outputs['boxes'], outputs['masks']
示例#3
0
    def generate_trajs(self, data, rois, pos_feat, acts, boxes):
        all_pred_rois = np.zeros((data.shape[0], 0, acts.shape[1], 4))
        with torch.no_grad():
            data, rois, pos_feat, acts = \
                data.to(self.device), rois.to(self.device), pos_feat.to(self.device), acts.to(self.device)
            if self.oracle_actor:
                sim_len_backup = self.sim_rollout_length
                self.sim_rollout_length = self.act_rollout + 1
                _, pred_rois = self.simulate_action(rois[..., 1:].cpu().numpy(), acts.cpu().numpy(), return_rst=False)
                self.sim_rollout_length = sim_len_backup
                all_pred_rois = np.concatenate([all_pred_rois, pred_rois], axis=1)
            else:
                data = data[:, [0]]
                rois = xyxy_to_rois(boxes, data.shape[0], data.shape[1], self.num_gpus)
                coor_features = pos_feat[:, [0]]
                outputs = self.act_model(data, rois, None, act_features=acts, num_rollouts=self.act_rollout)
                pred_rois = xcyc_to_xyxy(torch.clamp(outputs['bbox'], 0, 1).cpu().numpy()[..., 2:], self.input_height, self.input_width, self.ball_radius)
                pred_rois = np.concatenate([rois[..., 1:].cpu().numpy(), pred_rois], axis=1)
                all_pred_rois = np.concatenate([all_pred_rois, pred_rois], axis=1)

            pred_rois = all_pred_rois[:, -self.input_size:].copy()

            data = sim_rendering(pred_rois, self.input_height, self.input_width, self.ball_radius)
            for c in range(3):
                data[..., c] -= C.INPUT.IMAGE_MEAN[c]
                data[..., c] /= C.INPUT.IMAGE_STD[c]
            data = data.permute(0, 1, 4, 2, 3)

            # data is (batch x time_step x 3 x h x w)
            boxes = torch.from_numpy(pred_rois.astype(np.float32))
            pos_feat = xyxy_to_posf(boxes, data.shape)
            rois = xyxy_to_rois(boxes, data.shape[0], data.shape[1], self.num_gpus)

            if self.roi_masking:
                # expand to (batch x time_step x num_objs x 3 x h x w)
                data = data[:, :, None].repeat(1, 1, self.num_objs, 1, 1, 1)
                for b, t, o in itertools.product(range(data.shape[0]), range(data.shape[1]), range(self.num_objs)):
                    box = boxes[b, t, o].numpy()
                    x1, y1 = np.floor([box[0], box[1]]).astype(np.int)
                    x2, y2 = np.ceil([box[2], box[3]]).astype(np.int)
                    data[b, t, o, :, :, :x1] = 0
                    data[b, t, o, :, :y1, :] = 0
                    data[b, t, o, :, :, x2:] = 0
                    data[b, t, o, :, y2:, :] = 0

            if self.roi_cropping:
                data_c = np.zeros((data.shape[0], data.shape[1], self.num_objs,) + data.shape[2:])
                for b, t, o in itertools.product(range(data.shape[0]), range(data.shape[1]), range(self.num_objs)):
                    box = boxes[b, t, o].numpy()
                    x_c = 0.5 * (box[0] + box[2])
                    y_c = 0.5 * (box[1] + box[3])
                    r = self.roi_crop_r
                    d = 2 * r
                    data_c_ = np.zeros((d, d))
                    image = data[b, t].cpu().numpy().transpose((1, 2, 0))
                    image_pad = np.pad(image, ((d, d), (d, d), (0, 0)))
                    if x_c > -r or y_c > -r or x_c < self.input_width + r or y_c < self.input_height + r:
                        x_c += d
                        y_c += d
                        data_c_ = image_pad[int(y_c - r):int(y_c + r), int(x_c - r):int(x_c + r), :]
                    data_c_ = cv2.resize(data_c_, (self.input_width, self.input_height))
                    data_c[b, t, o] = data_c_.transpose((2, 0, 1))
                data = torch.from_numpy(data_c.astype(np.float32))

            if self.roi_masking or self.roi_cropping:
                data = data.permute((0, 2, 1, 3, 4, 5))
                data = data.reshape((data.shape[0] * data.shape[1],) + data.shape[2:])

            outputs = self.pred_model(data, rois, pos_feat, num_rollouts=self.pred_rollout + self.cons_size)
            bbox_rollouts = outputs['bbox'].cpu().numpy()[..., 2:]
            pred_rois = xcyc_to_xyxy(bbox_rollouts, self.input_height, self.input_width, self.ball_radius)
            pred_rois = pred_rois[:, -(1 + self.pred_rollout):]
            all_pred_rois = np.concatenate([all_pred_rois, pred_rois], axis=1)

        return all_pred_rois
示例#4
0
    def train_epoch(self):
        for batch_idx, (data, data_pred, data_t, env_name, rois, gt_boxes,
                        gt_masks, valid, module_valid, g_idx, seq_l, objinfo,
                        gtindicator) in enumerate(self.train_loader):
            self._adjust_learning_rate()

            if C.RIN.ROI_MASKING or C.RIN.ROI_CROPPING:
                # data should be (b x t x o x c x h x w)
                data = data.permute((0, 2, 1, 3, 4, 5))  # (b, o, t, c, h, w)
                data = data.reshape((data.shape[0] * data.shape[1], ) +
                                    data.shape[2:])  # (b*o, t, c, h, w)

            data, data_t = data.to(self.device), data_t.to(self.device)
            pos_feat = xyxy_to_posf(rois, data.shape)
            rois = xyxy_to_rois(rois,
                                batch=data.shape[0],
                                time_step=data.shape[1],
                                num_devices=self.num_gpus)
            self.optim.zero_grad()

            outputs = self.model(data,
                                 rois,
                                 pos_feat,
                                 valid,
                                 num_rollouts=self.ptrain_size,
                                 g_idx=g_idx,
                                 x_t=data_t,
                                 phase='train')
            labels = {
                'boxes': gt_boxes.to(self.device),
                'masks': gt_masks.to(self.device),
                'valid': valid.to(self.device),
                'module_valid': module_valid.to(self.device),
                'seq_l': seq_l.to(self.device),
                'gt_indicators': gtindicator.to(self.device),
            }
            loss = self.loss(outputs, labels, 'train')
            loss.backward()
            self.optim.step()
            # this is an approximation for printing; the dataset size may not divide the batch size
            self.iterations += self.batch_size

            print_msg = ""
            print_msg += f"{self.epochs:03}/{self.iterations // 1000:04}k"
            print_msg += f" | "
            mean_loss = np.mean(
                np.array(self.box_p_step_losses[:self.ptrain_size]) /
                self.loss_cnt) * 1e3
            print_msg += f"{mean_loss:.3f} | "
            print_msg += f" | ".join([
                "{:.3f}".format(self.losses[name] * 1e3 / self.loss_cnt)
                for name in self.loss_name
            ])
            print_msg += " || {:.4f}".format(self.loss_ind)
            speed = self.loss_cnt / (timer() - self.time)
            eta = (self.max_iters - self.iterations) / speed / 3600
            print_msg += f" | speed: {speed:.1f} | eta: {eta:.2f} h"
            print_msg += (
                " " * (os.get_terminal_size().columns - len(print_msg) - 10))
            tprint(print_msg)

            if self.iterations % self.val_interval == 0:
                self.snapshot()
                self.val()
                self._init_loss()
                self.model.train()

            if self.iterations >= self.max_iters:
                print('\r', end='')
                print(f'{self.best_mean:.3f}')
                break
示例#5
0
    def val(self):
        self.model.eval()
        self._init_loss()

        if C.RIN.VAE:
            losses = dict.fromkeys(self.loss_name, 0.0)
            box_p_step_losses = [0.0 for _ in range(self.ptest_size)]
            masks_step_losses = [0.0 for _ in range(self.ptest_size)]

        for batch_idx, (data, data_pred, data_t, env_name, rois, gt_boxes,
                        gt_masks, valid, module_valid, g_idx, seq_l, objinfo,
                        gtindicator) in enumerate(self.val_loader):
            tprint(f'eval: {batch_idx}/{len(self.val_loader)}')
            with torch.no_grad():

                if C.RIN.ROI_MASKING or C.RIN.ROI_CROPPING:
                    data = data.permute((0, 2, 1, 3, 4, 5))
                    data = data.reshape((data.shape[0] * data.shape[1], ) +
                                        data.shape[2:])

                data = data.to(self.device)
                pos_feat = xyxy_to_posf(rois, data.shape)
                rois = xyxy_to_rois(rois,
                                    batch=data.shape[0],
                                    time_step=data.shape[1],
                                    num_devices=self.num_gpus)
                labels = {
                    'boxes': gt_boxes.to(self.device),
                    'masks': gt_masks.to(self.device),
                    'valid': valid.to(self.device),
                    'module_valid': module_valid.to(self.device),
                    'seq_l': seq_l.to(self.device),
                    'gt_indicators': gtindicator.to(self.device),
                }

                outputs = self.model(data,
                                     rois,
                                     pos_feat,
                                     valid,
                                     num_rollouts=self.ptest_size,
                                     g_idx=g_idx,
                                     phase='test')
                self.loss(outputs, labels, 'test')
                # VAE multiple runs
                if C.RIN.VAE:
                    vae_best_mean = np.mean(
                        np.array(self.box_p_step_losses[:self.ptest_size]) /
                        self.loss_cnt) * 1e3
                    losses_t = self.losses.copy()
                    box_p_step_losses_t = self.box_p_step_losses.copy()
                    masks_step_losses_t = self.masks_step_losses.copy()
                    for i in range(9):
                        outputs = self.model(data,
                                             rois,
                                             pos_feat,
                                             valid,
                                             num_rollouts=self.ptest_size,
                                             g_idx=g_idx,
                                             phase='test')
                        self.loss(outputs, labels, 'test')
                        mean_loss = np.mean(
                            np.array(self.box_p_step_losses[:self.ptest_size])
                            / self.loss_cnt) * 1e3
                        if mean_loss < vae_best_mean:
                            losses_t = self.losses.copy()
                            box_p_step_losses_t = self.box_p_step_losses.copy()
                            masks_step_losses_t = self.masks_step_losses.copy()
                            vae_best_mean = mean_loss
                        self._init_loss()

                    for k, v in losses.items():
                        losses[k] += losses_t[k]
                    for i in range(len(box_p_step_losses)):
                        box_p_step_losses[i] += box_p_step_losses_t[i]
                        masks_step_losses[i] += masks_step_losses_t[i]

        if C.RIN.VAE:
            self.losses = losses.copy()
            self.box_p_step_losses = box_p_step_losses.copy()
            self.loss_cnt = len(self.val_loader)

        print('\r', end='')
        print_msg = ""
        print_msg += f"{self.epochs:03}/{self.iterations // 1000:04}k"
        print_msg += f" | "
        mean_loss = np.mean(
            np.array(self.box_p_step_losses[:self.ptest_size]) /
            self.loss_cnt) * 1e3
        print_msg += f"{mean_loss:.3f} | "

        if mean_loss < self.best_mean:
            self.snapshot('ckpt_best.path.tar')
            self.best_mean = mean_loss

        print_msg += f" | ".join([
            "{:.3f}".format(self.losses[name] * 1e3 / self.loss_cnt)
            for name in self.loss_name
        ])
        print_msg += " || {:.4f}".format(self.loss_ind)
        print_msg += (" " *
                      (os.get_terminal_size().columns - len(print_msg) - 10))
        self.logger.info(print_msg)
    def test(self):
        self.model.eval()

        if C.RIN.VAE:
            losses = dict.fromkeys(self.loss_name, 0.0)
            box_p_step_losses = [0.0 for _ in range(self.ptest_size)]
            masks_step_losses = [0.0 for _ in range(self.ptest_size)]

        for batch_idx, (data, data_pred, data_t, env_name, rois, gt_boxes,
                        gt_masks, valid, module_valid, g_idx,
                        _) in enumerate(self.val_loader):
            with torch.no_grad():

                # decide module_valid here for evaluation
                mid = 0  # ball-only
                module_valid = module_valid[:, mid, :, :]

                if C.RIN.ROI_MASKING or C.RIN.ROI_CROPPING:
                    # data should be (b x t x o x c x h x w)
                    data = data.permute(
                        (0, 2, 1, 3, 4, 5))  # (b, o, t, c, h, w)
                    data = data.reshape((data.shape[0] * data.shape[1], ) +
                                        data.shape[2:])  # (b*o, t, c, h, w)

                data = data.to(self.device)
                pos_feat = xyxy_to_posf(rois, data.shape)
                rois = xyxy_to_rois(rois,
                                    batch=data.shape[0],
                                    time_step=data.shape[1],
                                    num_devices=self.num_gpus)
                labels = {
                    'boxes': gt_boxes.to(self.device),
                    'masks': gt_masks.to(self.device),
                    'valid': valid.to(self.device),
                    'module_valid': module_valid.to(self.device),
                }
                outputs = self.model(data,
                                     rois,
                                     pos_feat,
                                     valid,
                                     num_rollouts=self.ptest_size,
                                     g_idx=g_idx,
                                     phase='test')

                # # *********************************************************************************
                # # VISUALIZATION - generate input image and GT outputs + model outputs
                # input_data = data.cpu().detach().numpy()    # (128, 1, 3, 128, 128)
                # gt_data = data_pred.cpu().detach().numpy()    # (128, 10, 3, 128, 128)
                # data_t = data_t.cpu().detach().numpy()    # (128, 1, 128, 128)
                # validity = valid.cpu().detach().numpy()    # (128, 6)
                # outputs_boxes = outputs['boxes'].cpu().detach().numpy()    # (128, 10, 6, 4)
                # outputs_masks = outputs['masks'].cpu().detach().numpy()    # (128, 10, 6, 21, 21)
                # np.save('save/'+str(batch_idx)+'input_data.npy', input_data)
                # np.save('save/'+str(batch_idx)+'gt_data.npy', gt_data)
                # np.save('save/'+str(batch_idx)+'data_t.npy', data_t)
                # np.save('save/'+str(batch_idx)+'validity.npy', validity)
                # np.save('save/'+str(batch_idx)+'outputs_boxes.npy', outputs_boxes)
                # np.save('save/'+str(batch_idx)+'outputs_masks.npy', outputs_masks)
                # # self.visualize_results(input_data, gt_data, outputs, data_t, validity, env_name)
                # # *********************************************************************************

                self.loss(outputs, labels, 'test')
                # VAE multiple runs
                if C.RIN.VAE:
                    vae_best_mean = np.mean(
                        np.array(self.box_p_step_losses[:self.ptest_size]) /
                        self.loss_cnt) * 1e3
                    losses_t = self.losses.copy()
                    box_p_step_losses_t = self.box_p_step_losses.copy()
                    masks_step_losses_t = self.masks_step_losses.copy()
                    for i in range(99):
                        outputs = self.model(data,
                                             rois,
                                             None,
                                             num_rollouts=self.ptest_size,
                                             g_idx=g_idx,
                                             phase='test')
                        self.loss(outputs, labels, 'test')
                        mean_loss = np.mean(
                            np.array(self.box_p_step_losses[:self.ptest_size])
                            / self.loss_cnt) * 1e3
                        if mean_loss < vae_best_mean:
                            losses_t = self.losses.copy()
                            box_p_step_losses_t = self.box_p_step_losses.copy()
                            masks_step_losses_t = self.masks_step_losses.copy()
                            vae_best_mean = mean_loss
                        self._init_loss()

                    for k, v in losses.items():
                        losses[k] += losses_t[k]
                    for i in range(len(box_p_step_losses)):
                        box_p_step_losses[i] += box_p_step_losses_t[i]
                        masks_step_losses[i] += masks_step_losses_t[i]

                tprint(f'eval: {batch_idx}/{len(self.val_loader)}:' + ' ' * 20)

            if self.plot_image > 0:
                outputs = {
                    'boxes':
                    outputs['boxes'].cpu().numpy(),
                    'masks':
                    outputs['masks'].cpu().numpy()
                    if C.RIN.MASK_LOSS_WEIGHT else None,
                }
                outputs['boxes'][..., 0::2] *= self.input_width
                outputs['boxes'][..., 1::2] *= self.input_height
                outputs['boxes'] = xywh2xyxy(outputs['boxes'].reshape(
                    -1, 4)).reshape((data.shape[0], -1, C.RIN.NUM_OBJS, 4))

                labels = {
                    'boxes': labels['boxes'].cpu().numpy(),
                    'masks': labels['masks'].cpu().numpy(),
                }
                labels['boxes'][..., 0::2] *= self.input_width
                labels['boxes'][..., 1::2] *= self.input_height
                labels['boxes'] = xywh2xyxy(labels['boxes'].reshape(
                    -1, 4)).reshape((data.shape[0], -1, C.RIN.NUM_OBJS, 4))

                for i in range(rois.shape[0]):
                    batch_size = C.SOLVER.BATCH_SIZE if not C.RIN.VAE else 1
                    plot_image_idx = batch_size * batch_idx + i
                    if plot_image_idx < self.plot_image:
                        tprint(f'plotting: {plot_image_idx}' + ' ' * 20)
                        video_idx, img_idx = self.val_loader.dataset.video_info[
                            plot_image_idx]
                        video_name = self.val_loader.dataset.video_list[
                            video_idx]

                        v = valid[i].numpy().astype(np.bool)
                        pred_boxes_i = outputs['boxes'][i][:, v]
                        gt_boxes_i = labels['boxes'][i][:, v]

                        if 'PHYRE' in C.DATA_ROOT:
                            # [::-1] is to make it consistency with others where opencv is used
                            im_data = phyre.observations_to_float_rgb(
                                np.load(video_name).astype(
                                    np.uint8))[..., ::-1]
                            a, b, c = video_name.split('/')[5:8]
                            output_name = f'{a}_{b}_{c.replace(".npy", "")}'

                            bg_image = np.load(video_name).astype(np.uint8)
                            for fg_id in [1, 2, 3, 5]:
                                bg_image[bg_image == fg_id] = 0
                            bg_image = phyre.observations_to_float_rgb(
                                bg_image)

                            # if f'{a}_{b}' not in [
                            #     '00014_123', '00014_528', '00015_257', '00015_337', '00019_273', '00019_296'
                            # ]:
                            #     continue

                            # if f'{a}_{b}' not in [
                            #     '00000_069', '00001_000', '00002_185', '00003_064', '00004_823',
                            #     '00005_111', '00006_033', '00007_090', '00008_177', '00009_930',
                            #     '00010_508', '00011_841', '00012_071', '00013_074', '00014_214',
                            #     '00015_016', '00016_844', '00017_129', '00018_192', '00019_244',
                            #     '00020_010', '00021_115', '00022_537', '00023_470', '00024_048'
                            # ]:
                            #     continue
                        else:
                            bg_image = None
                            image_list = sorted(
                                glob(
                                    f'{video_name}/*{self.val_loader.dataset.image_ext}'
                                ))
                            im_name = image_list[img_idx]
                            output_name = '_'.join(
                                im_name.split('.')[0].split('/')[-2:])
                            # deal with image data here
                            # plot rollout function only take care of the usage of plt
                            # if output_name not in ['009_015', '009_031', '009_063', '039_038', '049_011', '059_033']:
                            #     continue
                            # if output_name not in ['00002_00037', '00008_00047', '00011_00048', '00013_00036',
                            #                        '00014_00033', '00020_00054', '00021_00013', '00024_00011']:
                            #     continue
                            if output_name not in [
                                    '0016_000', '0045_000', '0120_000',
                                    '0163_000'
                            ]:
                                continue
                            gt_boxes_i = labels['boxes'][i][:, v]
                            im_data = get_im_data(im_name, gt_boxes_i[None,
                                                                      0:1],
                                                  C.DATA_ROOT,
                                                  self.high_resolution_plot)

                        if self.high_resolution_plot:
                            scale_w = im_data.shape[1] / self.input_width
                            scale_h = im_data.shape[0] / self.input_height
                            pred_boxes_i[..., [0, 2]] *= scale_w
                            pred_boxes_i[..., [1, 3]] *= scale_h
                            gt_boxes_i[..., [0, 2]] *= scale_w
                            gt_boxes_i[..., [1, 3]] *= scale_h

                        pred_masks_i = None
                        if C.RIN.MASK_LOSS_WEIGHT:
                            pred_masks_i = outputs['masks'][i][:, v]

                        plot_rollouts(im_data,
                                      pred_boxes_i,
                                      gt_boxes_i,
                                      pred_masks_i,
                                      labels['masks'][i][:, v],
                                      output_dir=self.output_dir,
                                      output_name=output_name,
                                      bg_image=bg_image)

        if C.RIN.VAE:
            self.losses = losses.copy()
            self.box_p_step_losses = box_p_step_losses.copy()
            self.loss_cnt = len(self.val_loader)

        print('\r', end='')
        print_msg = ""
        mean_loss = np.mean(
            np.array(self.box_p_step_losses[:self.ptest_size]) /
            self.loss_cnt) * 1e3
        print_msg += f"{mean_loss:.3f} | "
        print_msg += f" | ".join([
            "{:.3f}".format(self.losses[name] * 1e3 / self.loss_cnt)
            for name in self.loss_name
        ])
        pprint(print_msg)