Beispiel #1
0
def results(cnf):
    # type: (Conf) -> None
    """
    Shows a visual representation of the obtained results
    using the test set images as input
    """

    # init Autoencoder
    autoencoder = Autoencoder(cnf.hmap_d)
    autoencoder.load_w(f'log/{cnf.exp_name}/best.pth')
    autoencoder.to(cnf.device)
    autoencoder.eval()
    autoencoder.requires_grad(False)

    # init test loader
    test_set = JTAHMapDS(mode='test', cnf=cnf)
    test_loader = DataLoader(dataset=test_set,
                             batch_size=1,
                             num_workers=0,
                             shuffle=True)

    for step, sample in enumerate(test_loader):
        hmap_true, y_true, frame_path = sample
        frame_path = frame_path[0]
        hmap_true = hmap_true.to(cnf.device)
        y_true = json.loads(y_true[0])

        # hmap_true --> [autoencoder] --> hmap_pred
        hmap_pred = autoencoder.forward(hmap_true).squeeze()

        y_pred = utils.get_multi_local_maxima_3d(hmaps3d=hmap_pred,
                                                 threshold=0.1,
                                                 device=cnf.device)

        metrics = joint_det_metrics(points_pred=y_pred,
                                    points_true=y_true,
                                    th=1)
        f1 = metrics['f1']

        # show output
        print(f'\n\t▶▶ Showing results of \'{frame_path}\'')
        print(f'\t▶▶ F1@1px score:', f1)
        print(f'\t▶▶ Press some key to advance in the depth dimension')

        img = cv2.imread(frame_path)
        img = cv2.resize(img, (0, 0), fx=0.5, fy=0.5)
        cv2.imshow('img', img)

        utils.visualize_3d_hmap(hmap=hmap_pred[0, ...])
Beispiel #2
0
def main():
    import utils
    cnf = Conf(exp_name='default')
    ds = MOTSynthDS(mode='train', cnf=cnf, debug=True)
    loader = DataLoader(dataset=ds, batch_size=1, num_workers=0, shuffle=False)

    for i, sample in enumerate(loader):
        x, y, _ = sample
        x = x.to(cnf.device)
        y = json.loads(y[0])

        utils.visualize_3d_hmap(x[0, 13])
        y_pred = utils.get_multi_local_maxima_3d(hmaps3d=x.squeeze(), threshold=0.1, device=cnf.device)
        metrics = joint_det_metrics(points_pred=y_pred, points_true=y, th=1)
        f1 = metrics['f1']
        print(f'f1 score = {f1}')
        print(f'({i}) Dataset example: x.shape={tuple(x.shape)}, y={y}')
Beispiel #3
0
    def test(self):
        """
        test model on the Test-Set
        """

        self.model.eval()
        self.model.requires_grad(False)

        val_f1s = {'f1_iou': [], 'f1_center': [], 'f1_width': [], 'f1_height': []}
        val_losses = {'all': [], 'center': [], 'width': [], 'height': []}

        t = time()
        for step, sample in enumerate(self.val_loader):
            hmap_true, y_true, file_name, aug_info = sample
            hmap_true = hmap_true.to(self.cnf.device)
            y_true = json.loads(y_true[0])

            hmap_pred = self.model.forward(hmap_true)

            x_true_center, x_true_width, x_true_height = hmap_true[0, 0], hmap_true[0, 1], hmap_true[0, 2]
            x_pred_center, x_pred_width, x_pred_height = hmap_pred[0, 0], hmap_pred[0, 1], hmap_pred[0, 2]

            # log center, width, height losses
            mask = torch.tensor(torch.where(x_true_height != 0, 1, 0), dtype=torch.float32)
            loss_center = self.cnf.masked_loss_c * nn.MSELoss()(x_pred_center, x_true_center)
            loss_width = self.cnf.masked_loss_w * MaskedMSELoss()(x_pred_width, x_true_width, mask=mask)
            loss_height = self.cnf.masked_loss_h * MaskedMSELoss()(x_pred_height, x_true_height, mask=mask)
            loss = loss_center + loss_width + loss_height
            val_losses['all'].append(loss.item())
            val_losses['center'].append(loss_center.item())
            val_losses['width'].append(loss_width.item())
            val_losses['height'].append(loss_height.item())

            y_center = [(coord[0], coord[1], coord[2]) for coord in y_true]
            y_width = [(coord[0], coord[1], coord[2], coord[3]) for coord in y_true]
            y_height = [(coord[0], coord[1], coord[2], coord[4]) for coord in y_true]

            y_center_pred = utils.local_maxima_3d(heatmap=x_pred_center, threshold=0.1, device=self.cnf.device)
            y_width_pred = []
            y_height_pred = []
            bboxes_info_pred = []
            for center_coord in y_center_pred:  # y_center_pred
                cam_dist, y2d, x2d = center_coord

                width = float(x_pred_width[cam_dist, y2d, x2d])
                height = float(x_pred_height[cam_dist, y2d, x2d])

                # denormalize width and height
                width = int(round(width * STD_DEV_WIDTH + MEAN_WIDTH))
                height = int(round(height * STD_DEV_HEIGHT + MEAN_HEIGHT))
                # width = int(round(width * MAX_WIDTH))
                # height = int(round(height * MAX_HEIGHT))

                y_width_pred.append((*center_coord, width))
                y_height_pred.append((*center_coord, height))

                x2d, y2d, cam_dist = utils.rescale_to_real(x2d=x2d, y2d=y2d, cam_dist=cam_dist, q=self.cnf.q)
                bboxes_info_pred.append((x2d - width / 2, y2d - height / 2, width, height, cam_dist))

            y_center_true = utils.local_maxima_3d(heatmap=x_true_center, threshold=0.1, device=self.cnf.device)
            bboxes_info_true = []
            for center_coord in y_center_true:

                cam_dist, y2d, x2d = center_coord

                width = float(x_true_width[cam_dist, y2d, x2d])
                height = float(x_true_height[cam_dist, y2d, x2d])

                # denormalize width and height
                width = int(round(width * STD_DEV_WIDTH + MEAN_WIDTH))
                height = int(round(height * STD_DEV_HEIGHT + MEAN_HEIGHT))
                # width = int(round(width * MAX_WIDTH))
                # height = int(round(height * MAX_HEIGHT))

                y_width_pred.append((*center_coord, width))
                y_height_pred.append((*center_coord, height))

                x2d, y2d, cam_dist = utils.rescale_to_real(x2d=x2d, y2d=y2d, cam_dist=cam_dist, q=self.cnf.q)
                bboxes_info_true.append((x2d - width / 2, y2d - height / 2, width, height, cam_dist))

            metrics_iou = compute_det_metrics_iou(bboxes_a=bboxes_info_pred, bboxes_b=bboxes_info_true)
            metrics_center = joint_det_metrics(points_pred=y_center_pred, points_true=y_center, th=1)
            metrics_width = joint_det_metrics(points_pred=y_width_pred, points_true=y_width, th=1)
            metrics_height = joint_det_metrics(points_pred=y_height_pred, points_true=y_height, th=1)
            f1_iou = metrics_iou['f1']
            f1_center = metrics_center['f1']
            f1_width = metrics_width['f1']
            f1_height = metrics_height['f1']
            val_f1s['f1_iou'].append(f1_iou)
            val_f1s['f1_center'].append(f1_center)
            val_f1s['f1_width'].append(f1_width)
            val_f1s['f1_height'].append(f1_height)

            if step < 3:
                img_original = np.array(utils.imread(self.cnf.mot_synth_path / file_name[0]).convert("RGB"))
                hmap_pred = hmap_pred.squeeze()
                out_path = self.cnf.exp_log_path / f'{step}_center_pred.mp4'
                utils.save_3d_hmap(hmap=hmap_pred[0, ...], path=out_path)
                out_path = self.cnf.exp_log_path / f'{step}_width_pred.mp4'
                utils.save_3d_hmap(hmap=hmap_pred[1, ...], path=out_path, shift_values=True)
                out_path = self.cnf.exp_log_path / f'{step}_height_pred.mp4'
                utils.save_3d_hmap(hmap=hmap_pred[2, ...], path=out_path, shift_values=True)
                out_path = self.cnf.exp_log_path / f'{step}_bboxes_pred.jpg'
                utils.save_bboxes(img_original, bboxes_info_pred, path=out_path, use_z=True, half_images=True)

                hmap_true = hmap_true.squeeze()
                out_path = self.cnf.exp_log_path / f'{step}_center_true.mp4'
                utils.save_3d_hmap(hmap=hmap_true[0, ...], path=out_path)
                out_path = self.cnf.exp_log_path / f'{step}_width_true.mp4'
                utils.save_3d_hmap(hmap=hmap_true[1, ...], path=out_path, shift_values=True)
                out_path = self.cnf.exp_log_path / f'{step}_height_true.mp4'
                utils.save_3d_hmap(hmap=hmap_true[2, ...], path=out_path, shift_values=True)
                out_path = self.cnf.exp_log_path / f'{step}_bboxes_true.jpg'
                utils.save_bboxes(img_original, bboxes_info_true, path=out_path, use_z=True, half_images=True)

            if step >= self.cnf.test_len - 1:
                break

        # log average f1 on test set
        mean_val_loss = np.mean(val_losses['all'])
        mean_val_f1_iou = np.mean(val_f1s['f1_iou'])
        mean_val_f1_center = np.mean(val_f1s['f1_center'])
        mean_val_f1_width = np.mean(val_f1s['f1_width'])
        mean_val_f1_height = np.mean(val_f1s['f1_height'])
        mean_val_loss_center = np.mean(val_losses['center'])
        mean_val_loss_width = np.mean(val_losses['width'])
        mean_val_loss_height = np.mean(val_losses['height'])
        print(f'[TEST] AVG-Loss: {mean_val_loss:.6f}, '
              f'AVG-F1_iou: {mean_val_f1_iou:.6f}, '
              f'AVG-F1_center: {mean_val_f1_center:.6f}, '
              f'AVG-F1_width: {mean_val_f1_width:.6f}, '
              f'AVG-F1_height: {mean_val_f1_height:.6f}'
              f' │ Test time: {time() - t:.2f} s')
        self.sw.add_scalar(tag='val_F1/iou', scalar_value=mean_val_f1_iou, global_step=self.current_epoch)
        self.sw.add_scalar(tag='val_F1/center', scalar_value=mean_val_f1_center, global_step=self.current_epoch)
        self.sw.add_scalar(tag='val_F1/width', scalar_value=mean_val_f1_width, global_step=self.current_epoch)
        self.sw.add_scalar(tag='val_F1/height', scalar_value=mean_val_f1_height, global_step=self.current_epoch)
        self.sw.add_scalar(tag='val_loss', scalar_value=mean_val_loss, global_step=self.current_epoch)
        self.sw.add_scalar(tag='val_loss/center', scalar_value=mean_val_loss_center, global_step=self.current_epoch)
        self.sw.add_scalar(tag='val_loss/width', scalar_value=mean_val_loss_width, global_step=self.current_epoch)
        self.sw.add_scalar(tag='val_loss/height', scalar_value=mean_val_loss_height, global_step=self.current_epoch)

        # save best model
        if self.best_val_f1 is None or mean_val_f1_iou < self.best_val_f1:
            self.best_val_f1 = mean_val_f1_iou
            torch.save(self.model.state_dict(), self.log_path / 'best.pth')
Beispiel #4
0
def compute(exp_name):
    # type: (str) -> None

    cnf = Conf(exp_name=exp_name)

    # init Code Predictor
    predictor = CodePredictor()  # type: BaseModel
    predictor.to(cnf.device)
    predictor.eval()
    predictor.requires_grad(False)
    predictor.load_w(cnf.exp_log_path / 'best.pth')

    # init Decoder
    autoencoder = Autoencoder()  # type: BaseModel
    autoencoder.to(cnf.device)
    autoencoder.eval()
    autoencoder.requires_grad(False)
    autoencoder.load_w(Path(__file__).parent / 'models/weights/vha.pth')

    # init Hole Filler
    hole_filler = Refiner(pretrained=True)
    hole_filler.to(cnf.device)
    hole_filler.eval()
    hole_filler.requires_grad(False)
    hole_filler.load_w(
        Path(__file__).parent / 'models/weights/pose_refiner.pth')

    # init data loader
    ts = JTATestingSet(cnf=cnf)
    loader = DataLoader(dataset=ts, batch_size=1, shuffle=False, num_workers=0)

    metrics_dict = {}
    for th in THS:
        for key in ['pr', 're', 'f1']:
            metrics_dict[f'{key}@{th}'] = []  # without refinement
            metrics_dict[f'{key}@{th}+'] = []  # with refinement

    for step, sample in enumerate(loader):

        x, coords3d_true, fx, fy, cx, cy, frame_path = sample
        x = x.to(cnf.device)
        coords3d_true = json.loads(coords3d_true[0])
        fx, fy, cx, cy = fx.item(), fy.item(), cx.item(), cy.item()

        # image --> [code_predictor] --> code
        code_pred = predictor.forward(x).unsqueeze(0)

        # code --> [decoder] --> hmap
        hmap_pred = autoencoder.decode(code_pred).squeeze()

        # hmap --> [local maxima search] --> pseudo-3D coordinates
        coords2d_pred = []
        confs = []
        for jtype, hmp in enumerate(hmap_pred.squeeze()):
            res = nms3d_cuda.NMSFilter3d(nn.ConstantPad3d(1, 0)(hmp), 3, 1)
            nz = torch.nonzero(res).cpu()
            for el in nz:
                confid = res[tuple(el)]
                if confid > 0.1:
                    coords2d_pred.append(
                        (jtype, el[0].item(), el[1].item(), el[2].item()))
                    confs.append(confid.cpu())

        # pseudo-3D coordinates --> [to_3d] --> real 3D coordinates
        coords3d_pred = []
        for i in range(len(coords2d_pred)):
            joint_type, cam_dist, y2d, x2d = coords2d_pred[i]
            x2d, y2d, cam_dist = utils.rescale_to_real(x2d,
                                                       y2d,
                                                       cam_dist,
                                                       q=cnf.q)
            x3d, y3d, z3d = utils.to3d(x2d,
                                       y2d,
                                       cam_dist,
                                       fx=fx,
                                       fy=fy,
                                       cx=cx,
                                       cy=cy)
            coords3d_pred.append((joint_type, x3d, y3d, z3d))

        # real 3D coordinates --> [association] --> list of poses
        poses = coords_to_poses(coords3d_pred, confs)

        # a solitary joint is a joint that has been excluded from the association
        # process since no valid connection could be found;
        # note that only solitary joints with a confidence value >0.6 are considered
        all_pose_joints = []
        for pose in poses:
            all_pose_joints += [(j.type, j.confidence, j.x3d, j.y3d, j.z3d)
                                for j in pose]
        coords3d_pred_ = [(c[0], confs[k], c[1], c[2], c[3])
                          for k, c in enumerate(coords3d_pred)]
        solitary = [(s[0], s[2], s[3], s[4])
                    for s in (set(coords3d_pred_) - set(all_pose_joints))
                    if s[1] > 0.6]

        # list of poses --> [hole filler] --> refined list of poses
        refined_poses = []
        for person_id, pose in enumerate(poses):
            confidences = [j.confidence for j in pose]
            pose = [(joint.type, joint.x3d, joint.y3d, joint.z3d)
                    for joint in pose]
            refined_pose = hole_filler.refine(pose=pose,
                                              hole_th=0.2,
                                              confidences=confidences,
                                              replace_th=1)
            refined_poses.append(refined_pose)

        # refined list of poses --> [something] --> refined_coords3d_pred
        refined_coords3d_pred = []
        for pose in refined_poses:
            refined_coords3d_pred += pose

        # compute metrics without refinement
        for th in THS:
            __m = joint_det_metrics(points_pred=coords3d_pred,
                                    points_true=coords3d_true,
                                    th=th)
            for key in ['pr', 're', 'f1']:
                metrics_dict[f'{key}@{th}'].append(__m[key])

        # compute metrics with refinement
        for th in THS:
            __m = joint_det_metrics(points_pred=refined_coords3d_pred +
                                    solitary,
                                    points_true=coords3d_true,
                                    th=th)
            for key in ['pr', 're', 'f1']:
                metrics_dict[f'{key}@{th}+'].append(__m[key])

        # print test progress
        print(f'\r>> processing test image {step} of {len(loader)}', end='')

    print('\r', end='')
    for th in THS:
        print(f'(PR, RE, F1)@{th}:'
              f'\tno_ref=('
              f'{np.mean(metrics_dict[f"pr@{th}"]) * 100:.2f}, '
              f'{np.mean(metrics_dict[f"re@{th}"]) * 100:.2f}, '
              f'{np.mean(metrics_dict[f"f1@{th}"]) * 100:.2f})'
              f'\twith_ref=('
              f'{np.mean(metrics_dict[f"pr@{th}+"]) * 100:.2f}, '
              f'{np.mean(metrics_dict[f"re@{th}+"]) * 100:.2f}, '
              f'{np.mean(metrics_dict[f"f1@{th}+"]) * 100:.2f}) ')
Beispiel #5
0
    def test(self):
        """
        test model on the Validation-Set
        """

        self.code_predictor.eval()
        self.code_predictor.requires_grad(False)

        t = time()
        test_prs = []
        test_res = []
        test_f1s = []
        for step, sample in enumerate(self.test_loader):
            x, coords3d_true, fx, fy, cx, cy, _ = sample

            fx, fy, cx, cy = fx.item(), fy.item(), cx.item(), cy.item()
            x = x.to(self.cnf.device)
            coords3d_true = json.loads(coords3d_true[0])

            # image --> [code_predictor] --> code
            code_pred = self.code_predictor.forward(x).unsqueeze(0)

            # code --> [decode] --> hmap(s)
            hmap_pred = self.autoencoder.decode(code_pred).squeeze()

            # hmap --> [local_maxima_3d] --> rescaled pseudo-3D coordinates
            coords2d_pred = utils.local_maxima_3d(hmaps3d=hmap_pred,
                                                  threshold=0.1,
                                                  device=self.cnf.device)

            # rescaled pseudo-3D coordinates --> [to_3d] --> real 3D coordinates
            coords3d_pred = []
            for i in range(len(coords2d_pred)):
                joint_type, cam_dist, y2d, x2d = coords2d_pred[i]
                x2d, y2d, cam_dist = utils.rescale_to_real(x2d=x2d,
                                                           y2d=y2d,
                                                           cam_dist=cam_dist,
                                                           q=self.cnf.q)
                x3d, y3d, z3d = utils.to3d(x2d=x2d,
                                           y2d=y2d,
                                           cam_dist=cam_dist,
                                           fx=fx,
                                           fy=fy,
                                           cx=cx,
                                           cy=cy)
                coords3d_pred.append((joint_type, x3d, y3d, z3d))

            # real 3D
            metrics = joint_det_metrics(points_pred=coords3d_pred,
                                        points_true=coords3d_true,
                                        th=self.cnf.det_th)
            pr, re, f1 = metrics['pr'], metrics['re'], metrics['f1']
            test_prs.append(pr)
            test_res.append(re)
            test_f1s.append(f1)

        # log average loss on test set
        mean_test_pr = float(np.mean(test_prs))
        mean_test_re = float(np.mean(test_res))
        mean_test_f1 = float(np.mean(test_f1s))

        # print test metrics
        print(
            f'\t● AVG (PR, RE, F1) on TEST-set: '
            f'({mean_test_pr * 100:.2f}, '
            f'{mean_test_re * 100:.2f}, '
            f'{mean_test_f1 * 100:.2f}) ',
            end='')
        print(f'│ T: {time() - t:.2f} s')

        self.sw.add_scalar(tag='test/precision',
                           scalar_value=mean_test_pr,
                           global_step=self.epoch)
        self.sw.add_scalar(tag='test/recall',
                           scalar_value=mean_test_re,
                           global_step=self.epoch)
        self.sw.add_scalar(tag='test/f1',
                           scalar_value=mean_test_f1,
                           global_step=self.epoch)

        # save best model
        if self.best_test_f1 is None or mean_test_f1 >= self.best_test_f1:
            self.best_test_f1 = mean_test_f1
            torch.save(self.code_predictor.state_dict(),
                       self.log_path / 'best.pth')
Beispiel #6
0
    def test(self):
        """
        test model on the Test-Set
        """

        self.model.eval()
        self.model.requires_grad(False)

        t = time()
        for step, sample in enumerate(self.val_loader):
            hmap_true, y_true, _ = sample
            hmap_true = hmap_true.to(self.cnf.device)
            y_true = json.loads(y_true[0])

            hmap_pred = self.model.forward(hmap_true)

            loss = nn.MSELoss()(hmap_pred, hmap_true)
            self.val_losses.append(loss.item())

            y_pred = utils.get_multi_local_maxima_3d(
                hmaps3d=hmap_pred.squeeze(),
                threshold=0.1,
                device=self.cnf.device)

            metrics = joint_det_metrics(points_pred=y_pred,
                                        points_true=y_true,
                                        th=1)
            f1 = metrics['f1']
            self.val_f1s.append(f1)

            if step < 3:
                hmap_pred = hmap_pred.squeeze()
                out_path = self.cnf.exp_log_path / f'{step}_pred.mp4'
                utils.save_3d_hmap(hmap=hmap_pred[0, ...], path=out_path)

                hmap_true = hmap_true.squeeze()
                out_path = self.cnf.exp_log_path / f'{step}_true.mp4'
                utils.save_3d_hmap(hmap=hmap_true[0, ...], path=out_path)

            if step >= self.cnf.test_len:
                break

        # log average loss on test set
        mean_val_loss = np.mean(self.val_losses)
        self.val_losses = []
        print(
            f'\t● AVG Loss on VAL-set: {mean_val_loss:.6f} │ T: {time() - t:.2f} s'
        )
        self.sw.add_scalar(tag='val_loss',
                           scalar_value=mean_val_loss,
                           global_step=self.epoch)

        # log average f1 on test set
        mean_val_f1 = np.mean(self.val_f1s)
        self.val_f1s = []
        print(
            f'\t● AVG F1@1px on VAL-set: {mean_val_f1:.6f} │ T: {time() - t:.2f} s'
        )
        self.sw.add_scalar(tag='val_F1',
                           scalar_value=mean_val_f1,
                           global_step=self.epoch)

        # save best model
        if self.best_val_f1 is None or mean_val_f1 < self.best_val_f1:
            self.best_val_f1 = mean_val_f1
            torch.save(self.model.state_dict(), self.log_path / 'best.pth')
def main():
    MAX_WIDTH = 1919
    MAX_HEIGHT = 1079
    from test_metrics import joint_det_metrics, compute_det_metrics_iou
    import json
    import numpy as np
    from torch.utils.data import DataLoader
    from conf import Conf
    from dataset.mot_synth_det_ds import MOTSynthDetDS
    from utils import utils
    import torch

    cnf = Conf(exp_name='vha_d_debug', preload_checkpoint=False)

    # load dataset
    mode = 'test'
    ds = MOTSynthDetDS(mode=mode, cnf=cnf)
    loader = DataLoader(dataset=ds, batch_size=1, num_workers=0, shuffle=False)

    # load model
    from models.vha_det_variable_versions import Autoencoder as AutoencoderVariableVersions
    model = AutoencoderVariableVersions(vha_version=1).to(cnf.device)
    model.eval()
    model.requires_grad(False)
    if cnf.model_weights is not None:
        model.load_state_dict(torch.load(cnf.exp_log_path / 'best.pth',
                                         map_location=torch.device('cpu')),
                              strict=False)

    # ======== MAIN LOOP ========
    for i, sample in enumerate(loader):
        x, y, file_name, aug_info = None, None, None, None

        if mode == 'test':
            x, y, file_name, aug_info = sample
            y_true = json.loads(y[0])
        if mode == 'train':
            x, file_name, aug_info = sample
        x = x.to(cnf.device)
        x_center, x_width, x_height = x[0, 0], x[0, 1], x[0, 2]

        y_pred = model.forward(x)
        x_pred_center, x_pred_width, x_pred_height = y_pred[0, 0], y_pred[
            0, 1], y_pred[0, 2]

        if mode == 'test':
            y = json.loads(y[0])
            y_center = [(coord[0], coord[1], coord[2]) for coord in y]
            y_width = [(coord[0], coord[1], coord[2], coord[3]) for coord in y]
            y_height = [(coord[0], coord[1], coord[2], coord[4])
                        for coord in y]

        # utils.visualize_3d_hmap(x[0, 2])
        y_center_pred = utils.local_maxima_3d(heatmap=x_pred_center,
                                              threshold=0.1,
                                              device=cnf.device)
        y_width_pred = []
        y_height_pred = []
        bboxes_info_pred = []
        # w_min = min([float(x_width[cam_dist, y2d, x2d]) for cam_dist, y2d, x2d in y_center_pred])
        # w_max = max([float(x_width[cam_dist, y2d, x2d]) for cam_dist, y2d, x2d in y_center_pred])
        # h_min = min([float(x_height[cam_dist, y2d, x2d]) for cam_dist, y2d, x2d in y_center_pred])
        # h_max = max([float(x_height[cam_dist, y2d, x2d]) for cam_dist, y2d, x2d in y_center_pred])
        for cam_dist, y2d, x2d in y_center:
            width = float(x_pred_width[cam_dist, y2d, x2d])
            height = float(x_pred_height[cam_dist, y2d, x2d])

            # denormalize width and height
            width = int(round(width * MAX_WIDTH))
            height = int(round(height * MAX_HEIGHT))

            y_width_pred.append((cam_dist, y2d, x2d, width))
            y_height_pred.append((cam_dist, y2d, x2d, height))

            x2d, y2d, cam_dist = utils.rescale_to_real(x2d=x2d,
                                                       y2d=y2d,
                                                       cam_dist=cam_dist,
                                                       q=cnf.q)
            bboxes_info_pred.append(
                (x2d - width / 2, y2d - height / 2, width, height, cam_dist))

        img_original = np.array(
            utils.imread(cnf.mot_synth_path / file_name[0]).convert("RGB"))
        if mode == 'test':
            bboxes_info_true = []
            for cam_dist, y2d, x2d, width, height in y:
                x2d, y2d, cam_dist = utils.rescale_to_real(x2d=x2d,
                                                           y2d=y2d,
                                                           cam_dist=cam_dist,
                                                           q=cnf.q)
                bboxes_info_true.append((x2d - width / 2, y2d - height / 2,
                                         width, height, cam_dist))

            metrics_iou = compute_det_metrics_iou(bboxes_info_pred,
                                                  bboxes_info_true)
            metrics_center = joint_det_metrics(points_pred=y_center_pred,
                                               points_true=y_center,
                                               th=1)
            metrics_width = joint_det_metrics(points_pred=y_width_pred,
                                              points_true=y_width,
                                              th=1)
            metrics_height = joint_det_metrics(points_pred=y_height_pred,
                                               points_true=y_height,
                                               th=1)
            f1_iou = metrics_iou['f1']
            f1_center = metrics_center['f1']
            f1_width = metrics_width['f1']
            f1_height = metrics_height['f1']
            print(
                f'f1_iou={f1_iou}, f1_center={f1_center}, f1_width={f1_width}, f1_height={f1_height}'
            )

            # for cam_dist, y2d, x2d, width, height in y_true:
            #    x2d, y2d, cam_dist = utils.rescale_to_real(x2d=x2d, y2d=y2d, cam_dist=cam_dist, q=cnf.q)
            #    bboxes_info_true.append((x2d - width / 2, y2d - height / 2, width, height, cam_dist))

            # utils.visualize_bboxes(img_original, bboxes_info_true, use_z=True, half_images=False, aug_info=aug_info,
            #                       normalize_z=False)

        # print(f'({i}) Dataset example: x.shape={tuple(x.shape)}, y={y}')

        utils.visualize_bboxes(img_original,
                               bboxes_info_pred,
                               use_z=True,
                               half_images=True,
                               aug_info=aug_info,
                               normalize_z=False)
def main():
    from test_metrics import joint_det_metrics, compute_det_metrics_iou
    cnf = Conf(exp_name='debug')

    # load dataset
    mode = 'val'
    ds = MOTSynthDetDS(mode=mode, cnf=cnf)
    loader = DataLoader(dataset=ds,
                        batch_size=1,
                        num_workers=1,
                        shuffle=False,
                        worker_init_fn=MOTSynthDetDS.wif_test)

    # load model
    # from models.vha_det_c3d_pretrained import Autoencoder as AutoencoderC3dPretrained
    # model = AutoencoderC3dPretrained(hmap_d=cnf.hmap_d, legacy_pretrained=cnf.saved_epoch == 0).to(cnf.device)
    # model.eval()
    # model.requires_grad(False)
    # if cnf.model_weights is not None:
    #     model.load_state_dict(cnf.model_weights, strict=False)

    # ======== MAIN LOOP ========
    for i, sample in enumerate(loader):
        x, y, file_name, aug_info = None, None, None, None

        if mode == 'val' or mode == 'test':
            x, y, file_name, aug_info = sample
            y_true = json.loads(y[0])
        if mode == 'train':
            x, file_name, aug_info = sample
        x = x.to(cnf.device)
        x_center, x_width, x_height = x[0, 0], x[0, 1], x[0, 2]

        # y_pred = model.forward(x)
        # x_pred_center, x_pred_width, x_pred_height = y_pred[0, 0], y_pred[0, 1], y_pred[0, 2]

        if mode == 'test':
            y = json.loads(y[0])
            y_center = [(coord[0], coord[1], coord[2]) for coord in y]
            y_width = [(coord[0], coord[1], coord[2], coord[3]) for coord in y]
            y_height = [(coord[0], coord[1], coord[2], coord[4])
                        for coord in y]

        # utils.visualize_3d_hmap(x[0, 2])
        y_center_pred = utils.local_maxima_3d(heatmap=x_center,
                                              threshold=0.1,
                                              device=cnf.device)
        y_width_pred = []
        y_height_pred = []
        bboxes_info_pred = []
        # w_min = min([float(x_width[cam_dist, y2d, x2d]) for cam_dist, y2d, x2d in y_center_pred])
        # w_max = max([float(x_width[cam_dist, y2d, x2d]) for cam_dist, y2d, x2d in y_center_pred])
        # h_min = min([float(x_height[cam_dist, y2d, x2d]) for cam_dist, y2d, x2d in y_center_pred])
        # h_max = max([float(x_height[cam_dist, y2d, x2d]) for cam_dist, y2d, x2d in y_center_pred])
        for cam_dist, y2d, x2d in y_center_pred:
            width = float(x_width[cam_dist, y2d, x2d])
            height = float(x_height[cam_dist, y2d, x2d])

            # denormalize width and height
            # width = int(round(width * MAX_WIDTH))
            # height = int(round(height * MAX_HEIGHT))
            width = int(round(width * STD_DEV_WIDTH + MEAN_WIDTH))
            height = int(round(height * STD_DEV_HEIGHT + MEAN_HEIGHT))

            y_width_pred.append((cam_dist, y2d, x2d, width))
            y_height_pred.append((cam_dist, y2d, x2d, height))

            x2d, y2d, cam_dist = utils.rescale_to_real(x2d=x2d,
                                                       y2d=y2d,
                                                       cam_dist=cam_dist,
                                                       q=cnf.q)
            bboxes_info_pred.append(
                (x2d - width / 2, y2d - height / 2, width, height, cam_dist))

        img_original = np.array(
            utils.imread(cnf.mot_synth_path / file_name[0]).convert("RGB"))
        if mode == 'test':
            bboxes_info_true = []
            for cam_dist, y2d, x2d, width, height in y:
                x2d, y2d, cam_dist = utils.rescale_to_real(x2d=x2d,
                                                           y2d=y2d,
                                                           cam_dist=cam_dist,
                                                           q=cnf.q)
                bboxes_info_true.append((x2d - width / 2, y2d - height / 2,
                                         width, height, cam_dist))

            metrics_iou = compute_det_metrics_iou(bboxes_info_pred,
                                                  bboxes_info_true)
            metrics_center = joint_det_metrics(points_pred=y_center_pred,
                                               points_true=y_center,
                                               th=1)
            metrics_width = joint_det_metrics(points_pred=y_width_pred,
                                              points_true=y_width,
                                              th=1)
            metrics_height = joint_det_metrics(points_pred=y_height_pred,
                                               points_true=y_height,
                                               th=1)
            f1_iou = metrics_iou['f1']
            f1_center = metrics_center['f1']
            f1_width = metrics_width['f1']
            f1_height = metrics_height['f1']
            print(
                f'f1_iou={f1_iou}, f1_center={f1_center}, f1_width={f1_width}, f1_height={f1_height}'
            )

            # for cam_dist, y2d, x2d, width, height in y_true:
            #    x2d, y2d, cam_dist = utils.rescale_to_real(x2d=x2d, y2d=y2d, cam_dist=cam_dist, q=cnf.q)
            #    bboxes_info_true.append((x2d - width / 2, y2d - height / 2, width, height, cam_dist))

            # utils.visualize_bboxes(img_original, bboxes_info_true, use_z=True, half_images=False, aug_info=aug_info,
            #                       normalize_z=False)

        # print(f'({i}) Dataset example: x.shape={tuple(x.shape)}, y={y}')

        utils.visualize_bboxes(img_original,
                               bboxes_info_pred,
                               use_z=True,
                               half_images=True,
                               aug_info=aug_info,
                               normalize_z=False)