コード例 #1
0
def main():
    opts = AppSettings()
    opts = update_settings(opts)

    device = resolve_device('cpu')

    xfm = Compose([
        BoxPoints2D(device=device,
                    key_out='points_2d_debug',
                    key_out_3d='points_3d_debug'),
        DrawKeypoints(DrawKeypoints.Settings()),
        DenseMapsMobilePose(DenseMapsMobilePose.Settings(), device),
        InstancePadding(InstancePadding.Settings()),
    ])

    dataset, _ = get_loaders(opts.dataset, device, None, xfm)

    for data in dataset:
        #print('2D')
        #print(data['points_2d_debug'])
        #print(data[Schema.KEYPOINT_2D])

        #print('3D')
        #print(data['points_3d_debug'])
        #print(data[Schema.KEYPOINT_3D])

        save_image(data[Schema.IMAGE] / 255.0, F'/tmp/img.png')
        for i, img in enumerate(data[Schema.HEATMAP]):
            save_image(img, F'/tmp/heatmap-{i}.png', normalize=True)
        save_image(data['rendered_keypoints'] / 255.0, F'/tmp/rkpts.png')

        break
コード例 #2
0
def main():
    opts = DatasetSettings()
    opts = update_settings(opts)
    augment = GeometricAugment(GeometricAugment.Settings(p_flip_lr=1.0),
                               in_place=False)
    key_out = '__keypoint_image__'  # try not to result in key collision

    visualize = Compose([
        DenseMapsMobilePose(DenseMapsMobilePose.Settings(),
                            device=th.device('cpu')),
        DrawKeypointMap(
            DrawKeypointMap.Settings(key_in=Schema.KEYPOINT_HEATMAP,
                                     key_out=key_out,
                                     as_displacement=False)),
        DrawBoundingBoxFromKeypoints(
            DrawBoundingBoxFromKeypoints.Settings(key_in=key_out,
                                                  key_out=key_out))
    ])

    # visualize = DrawKeypoints(DrawKeypoints.Settings(key_out=key_out))
    #visualize = DrawKeypointMap(DrawKeypointMap.Settings(key_in=Schema.IMAGE, key_out=key_out,
    #    as_displacement = False))
    train_loader, _ = get_loaders(opts,
                                  device=th.device('cpu'),
                                  batch_size=None,
                                  transform=None)

    fig, ax = plt.subplots(2, 1)
    for data in train_loader:
        resize = Resize(size=data[Schema.IMAGE].shape[-2:])

        aug_data = augment(data)
        # aug_data = data
        print(data[Schema.KEYPOINT_2D])
        print(aug_data[Schema.KEYPOINT_2D])
        v0 = resize(visualize(data)[key_out])
        v1 = resize(visualize(aug_data)[key_out])
        # v0 = 255 * resize(v0) + data[Schema.IMAGE]
        # v1 = 255 * resize(v1) + aug_data[Schema.IMAGE]
        v0 = th.where(v0 <= 0, data[Schema.IMAGE].float(), 255.0 * v0)
        v1 = th.where(v1 <= 0, aug_data[Schema.IMAGE].float(), 255.0 * v1)

        v0 = _to_image(v0) / 255.0
        v1 = _to_image(v1) / 255.0
        #image = data[Schema.IMAGE][0].permute((1, 2, 0))  # 13HW
        #aug_image = data['aug_img'][0].permute((1, 2, 0))  # 13HW
        ax[0].imshow(v0)
        ax[0].set_title('orig')
        ax[1].imshow(v1)
        ax[1].set_title('aug')
        k = plt.waitforbuttonpress()
コード例 #3
0
def eval_main(opts: AppSettings):

    model = load_model()
    evaluator = Evaluator(opts, model)

    if isinstance(model, GroundTruthDecoder):
        # for Keypoints (+ ground-truth kpt-style decoder)
        transform = Compose([
            DenseMapsMobilePose(DenseMapsMobilePose.Settings(),
                                th.device('cpu')),
            Normalize(Normalize.Settings()),
            InstancePadding(InstancePadding.Settings()),
            # TODO(ycho): Does the order between padding<->label matter here?
            FormatLabel(FormatLabel.Settings(), opts.vis_thresh),
        ])
        collate_fn = None
    elif isinstance(model, BboxWrapper):
        # for Bounding Box
        transform = Compose([
            # NOTE(ycho): `FormatLabel` must be applied prior
            # to `CropObject` since it modifies the requisite tensors.
            FormatLabel(FormatLabel.Settings(), opts.vis_thresh),
            # CropObject(CropObject.Settings()),
            # Normalize(Normalize.Settings(keys=(Schema.CROPPED_IMAGE,)))
        ])

        # NOTE(ycho): passthrough collation;
        # actual collation will be handled independently.
        def collate_fn(data):
            return data

    _, test_loader = get_loaders(opts.dataset,
                                 th.device('cpu'),
                                 opts.batch_size,
                                 transform=transform,
                                 collate_fn=collate_fn)

    # Run evaluation ...
    try:
        for i, data in enumerate(tqdm.tqdm(test_loader)):
            evaluator.evaluate(data)
            if (opts.max_num >= 0) and (i >= opts.max_num):
                break
    except KeyboardInterrupt:
        pass
    finally:
        evaluator.finalize()
        evaluator.write_report()
コード例 #4
0
def main():
    model = GroundTruthDecoder()
    device = resolve_device('cpu')

    transform = Compose([
        DenseMapsMobilePose(DenseMapsMobilePose.Settings(),
                            th.device('cpu:0')),
        BoxPoints2D(device, key_out='p2d-debug'),
        InstancePadding(InstancePadding.Settings()),
    ])

    _, test_loader = get_loaders(DatasetSettings(),
                                 device,
                                 1,
                                 transform=transform)

    for data in test_loader:
        outputs = model(data)
        break
コード例 #5
0
def main():
    opts = Settings()
    opts = update_settings(opts)
    device = resolve_device(opts.device)
    transform = Compose([
        DenseMapsMobilePose(DenseMapsMobilePose.Settings(), device),
        Normalize(Normalize.Settings()),
        InstancePadding(InstancePadding.Settings())
    ])
    data_loader, _ = get_loaders(opts.dataset, device, opts.batch_size,
                                 transform)

    for i, data in enumerate(data_loader):
        for k, v in data.items():
            if isinstance(v, th.Tensor):
                print(k, v.shape)
            else:
                print(k, v)
        if i >= opts.num_samples:
            break
コード例 #6
0
def main():
    opts = DatasetSettings()
    opts = update_settings(opts)
    key_out = '__aug_img__'  # Try to prevent key collision
    transform = Compose([
        InstancePadding(InstancePadding.Settings()),
        PhotometricAugment(PhotometricAugment.Settings(key_out=key_out))
    ])
    train_loader, test_loader = get_loaders(opts,
                                            device=th.device('cpu'),
                                            batch_size=4,
                                            transform=transform)

    fig, ax = plt.subplots(2, 1)
    for data in train_loader:
        image = _stack_images(data[Schema.IMAGE])
        aug_image = _stack_images(data[key_out])
        image = _to_image(image)
        aug_image = _to_image(aug_image)
        ax[0].imshow(image)
        ax[1].imshow(aug_image)
        k = plt.waitforbuttonpress()
コード例 #7
0
def main():
    opts = DatasetSettings()
    opts = update_settings(opts)

    transform = Compose([
        # NOTE(ycho): `FormatLabel` must be applied prior
        # to `CropObject` since it modifies the requisite tensors.
        # FormatLabel(FormatLabel.Settings(), opts.vis_thresh),
        CropObject(CropObject.Settings()),
        # Normalize(Normalize.Settings(keys=(Schema.CROPPED_IMAGE,)))
    ])
    _, test_loader = get_loaders(opts,
                                 device=th.device('cpu'),
                                 batch_size=4,
                                 transform=transform,
                                 collate_fn=collate_cropped_img)

    fig, ax = plt.subplots(1, 1)
    for data in test_loader:
        # -> 4,1,3,640,480
        # -- (batch_size, ? ,3, 640, 480)
        print(data[Schema.IMAGE].shape)
        print(data[Schema.CROPPED_IMAGE].shape)
        print(data[Schema.INDEX])
        print(data[Schema.INSTANCE_NUM])
        print('points_2d')
        print(data[Schema.KEYPOINT_2D])
        print('points_3d')
        print(data[Schema.KEYPOINT_3D])

        # image = _stack_images(data[Schema.IMAGE])
        image = _stack_images(data[Schema.CROPPED_IMAGE])
        image = _to_image(image)
        ax.imshow(image)
        ax.set_xticks([])
        ax.set_yticks([])
        k = plt.waitforbuttonpress()
コード例 #8
0
def main():
    logging.basicConfig(level=logging.WARN)
    opts = AppSettings()
    opts = update_settings(opts)
    path = RunPath(opts.path)

    device = resolve_device(opts.device)
    model = BoundingBoxRegressionModel(opts.model).to(device)
    optimizer = th.optim.Adam(model.parameters(), lr=1e-5)
    writer = SummaryWriter(path.log)

    transform = Compose([
        CropObject(CropObject.Settings()),
        Normalize(Normalize.Settings(keys=(Schema.CROPPED_IMAGE, )))
    ])
    train_loader, test_loader = get_loaders(opts.dataset,
                                            device=th.device('cpu'),
                                            batch_size=opts.batch_size,
                                            transform=transform,
                                            collate_fn=collate_cropped_img)

    # NOTE(ycho): Synchronous event hub.
    hub = Hub()

    # Save meta-parameters.
    def _save_params():
        opts.save(path.dir / 'opts.yaml')

    hub.subscribe(Topic.TRAIN_BEGIN, _save_params)

    # Periodically log training statistics.
    # FIXME(ycho): hardcoded logging period.
    # NOTE(ycho): Currently only plots `loss`.
    collect = Collect(hub, Topic.METRICS, [])
    train_logger = TrainLogger(hub, writer, opts.log_period)

    # Periodically save model, per epoch.
    # TODO(ycho): Consider folding this callback inside Trainer().
    hub.subscribe(
        Topic.EPOCH, lambda epoch: Saver(model, optimizer).save(
            path.ckpt / F'epoch-{epoch}.zip'))

    # Periodically save model, per N training steps.
    # TODO(ycho): Consider folding this callback inside Trainer()
    # and adding {save_period} args to Trainer instead.
    hub.subscribe(
        Topic.STEP,
        Periodic(
            opts.save_period, lambda step: Saver(model, optimizer).save(
                path.ckpt / F'step-{step}.zip')))

    # Periodically evaluate model, per N training steps.
    # NOTE(ycho): Load and process test data ...
    # TODO(ycho): Consider folding this callback inside Trainer()
    # and adding {test_loader, eval_fn} args to Trainer instead.
    def _eval_fn(model, data):
        # TODO(Jiyong): hardcode for cropped image size
        crop_img = data[Schema.CROPPED_IMAGE].view(-1, 3, 224, 224)
        return model(crop_img.to(device))

    evaluator = Evaluator(Evaluator.Settings(period=opts.eval_period), hub,
                          model, test_loader, _eval_fn)

    # TODO(Jiyong):
    # All metrics evaluation should reset stats at eval_begin(),
    # aggregate stats at eval_step(),
    # and output stats at eval_end(). These signals are all implemented.
    # What are the appropriate metrics to implement for bounding box regression?
    def _on_eval_step(inputs, outputs):
        pass

    hub.subscribe(Topic.EVAL_STEP, _on_eval_step)

    collect = Collect(hub, Topic.METRICS, [])

    def _log_all(metrics: Dict[Topic, Any]):
        pass

    hub.subscribe(Topic.METRICS, _log_all)

    orientation_loss_func = nn.L1Loss().to(device)
    scale_loss_func = nn.L1Loss().to(device)

    def _loss_fn(model: th.nn.Module, data):
        # Now that we're here, convert all inputs to the device.
        image = data[Schema.CROPPED_IMAGE].to(device)
        c, h, w = image.shape[-3:]
        image = image.view(-1, c, h, w)
        truth_quat = data[Schema.QUATERNION].to(device)
        truth_quat = truth_quat.view(-1, 4)
        truth_dim = data[Schema.SCALE].to(device)
        truth_dim = truth_dim.view(-1, 3)
        truth_trans = data[Schema.TRANSLATION].to(device)
        truth_trans = truth_trans.view(-1, 3)

        dim, quat = model(image)

        outputs = {}
        outputs[Schema.SCALE] = dim
        outputs[Schema.QUATERNION] = quat

        # Also make input/output pair from training
        # iterations available to the event bus.
        hub.publish(Topic.TRAIN_OUT, inputs=data, outputs=outputs)

        loss = {}

        scale_loss = scale_loss_func(dim, truth_dim)
        orient_loss = orientation_loss_func(quat, truth_quat)
        total_loss = opts.alpha * scale_loss + orient_loss

        loss["total"] = total_loss
        loss["scale"] = scale_loss
        loss["orientation"] = orient_loss

        return loss

    ## Load from checkpoint
    if opts.load_ckpt:
        logging.info(F'Loading checkpoint {opts.load_ckpt} ...')
        Saver(model, optimizer).load(opts.load_ckpt)

    ## Trainer
    trainer = Trainer(opts.train, model, optimizer, _loss_fn, hub,
                      train_loader)

    # Train, optionally profile
    if opts.profile:
        try:
            with profiler.profile(record_shapes=True, use_cuda=True) as prof:
                trainer.train()
        finally:
            print(prof.key_averages().table(sort_by='cpu_time_total',
                                            row_limit=16))
            prof.export_chrome_trace("/tmp/trace.json")
    else:
        trainer.train()
コード例 #9
0
def main():
    logging.basicConfig(level=logging.WARN)
    opts = AppSettings()
    opts = update_settings(opts)
    path = RunPath(opts.path)

    device = resolve_device(opts.device)
    model = KeypointNetwork2D(opts.model).to(device)
    # FIXME(ycho): Hardcoded lr == 1e-3
    optimizer = th.optim.Adam(model.parameters(), lr=1e-3)
    writer = th.utils.tensorboard.SummaryWriter(path.log)

    # NOTE(ycho): Force data loading on the CPU.
    data_device = th.device('cpu')

    # TODO(ycho): Consider scripted compositions?
    # If a series of transforms can be fused and compiled,
    # it would probably make it a lot faster to train...
    transform = Compose([
        DenseMapsMobilePose(opts.maps, data_device),
        PhotometricAugment(opts.photo_aug, False),
        Normalize(Normalize.Settings()),
        InstancePadding(opts.padding)
    ])

    train_loader, test_loader = get_loaders(opts.dataset,
                                            device=data_device,
                                            batch_size=opts.batch_size,
                                            transform=transform)

    # NOTE(ycho): Synchronous event hub.
    hub = Hub()

    def _on_train_begin():

        # Save meta-parameters.
        opts.save(path.dir / 'opts.yaml')
        # NOTE(ycho): Currently `load` only works with a modified version of the
        # main SimpleParsing repository.
        # opts.load(path.dir / 'opts.yaml')

        # Generate tensorboard graph.
        data = next(iter(test_loader))
        dummy = data[Schema.IMAGE].to(device).detach()
        # NOTE(ycho): No need to set model to `eval`,
        # eval mode is set internally within add_graph().
        writer.add_graph(ModelAsTuple(model), dummy)

    hub.subscribe(Topic.TRAIN_BEGIN, _on_train_begin)

    # Periodically log training statistics.
    # FIXME(ycho): hardcoded logging period.
    # NOTE(ycho): Currently only plots `loss`.
    collect = Collect(hub, Topic.METRICS, [])
    train_logger = TrainLogger(hub, writer, opts.log_period)

    # Periodically save model, per epoch.
    # TODO(ycho): Consider folding this callback inside Trainer().
    hub.subscribe(
        Topic.EPOCH, lambda epoch: Saver(model, optimizer).save(
            path.ckpt / F'epoch-{epoch}.zip'))

    # Periodically save model, per N training steps.
    # TODO(ycho): Consider folding this callback inside Trainer()
    # and adding {save_period} args to Trainer instead.
    hub.subscribe(
        Topic.STEP,
        Periodic(
            opts.save_period, lambda step: Saver(model, optimizer).save(
                path.ckpt / F'step-{step}.zip')))

    # Periodically evaluate model, per N training steps.
    # NOTE(ycho): Load and process test data ...
    # TODO(ycho): Consider folding this callback inside Trainer()
    # and adding {test_loader, eval_fn} args to Trainer instead.
    def _eval_fn(model, data):
        # TODO(ycho): Actually implement evaluation function.
        # return model(data[Schema.IMAGE].to(device))
        return None

    evaluator = Evaluator(Evaluator.Settings(period=opts.eval_period), hub,
                          model, test_loader, _eval_fn)

    # TODO(ycho):
    # All metrics evaluation should reset stats at eval_begin(),
    # aggregate stats at eval_step(),
    # and output stats at eval_end(). These signals are all implemented.
    # What are the appropriate metrics to implement for keypoint regression?
    # - keypoint matching F1 score(?)
    # - loss_fn() but for the evaluation datasets
    def _on_eval_step(inputs, outputs):
        pass

    hub.subscribe(Topic.EVAL_STEP, _on_eval_step)

    collect = Collect(hub, Topic.METRICS, [])

    def _log_all(metrics: Dict[Topic, Any]):
        pass

    hub.subscribe(Topic.METRICS, _log_all)

    # TODO(ycho): weight the losses with some constant ??
    losses = {
        Schema.HEATMAP:
        ObjectHeatmapLoss(key=Schema.HEATMAP),
        # Schema.DISPLACEMENT_MAP: KeypointDisplacementLoss(),
        Schema.KEYPOINT_HEATMAP:
        ObjectHeatmapLoss(key=Schema.KEYPOINT_HEATMAP),
        Schema.SCALE:
        KeypointScaleLoss()
    }

    def _loss_fn(model: th.nn.Module, data):
        # Now that we're here, convert all inputs to the device.
        data = {
            k: (v.to(device) if isinstance(v, th.Tensor) else v)
            for (k, v) in data.items()
        }
        image = data[Schema.IMAGE]
        outputs = model(image)
        # Also make input/output pair from training
        # iterations available to the event bus.
        hub.publish(Topic.TRAIN_OUT, inputs=data, outputs=outputs)
        kpt_heatmap_loss = losses[Schema.KEYPOINT_HEATMAP](outputs, data)
        heatmap_loss = losses[Schema.HEATMAP](outputs, data)
        scale_loss = losses[Schema.SCALE](outputs, data)
        # Independently log stuff
        hub.publish(
            Topic.TRAIN_LOSSES, {
                'keypoint': kpt_heatmap_loss,
                'center': heatmap_loss,
                'scale': scale_loss
            })
        return (kpt_heatmap_loss + heatmap_loss + scale_loss)

    ## Load from checkpoint
    if opts.load_ckpt:
        logging.info(F'Loading checkpoint {opts.load_ckpt} ...')
        Saver(model, optimizer).load(opts.load_ckpt)

    ## Trainer
    trainer = Trainer(opts.train, model, optimizer, _loss_fn, hub,
                      train_loader)

    # Train, optionally profile
    if opts.profile:
        try:
            with profiler.profile(record_shapes=True, use_cuda=True) as prof:
                trainer.train()
        finally:
            print(prof.key_averages().table(sort_by='cpu_time_total',
                                            row_limit=16))
            prof.export_chrome_trace("/tmp/trace.json")
    else:
        trainer.train()
コード例 #10
0
def main():
    # logging.basicConfig(level=logging.DEBUG)

    # Initial parsing looking for `RunPath` ...
    opts = AppSettings()
    opts = update_settings(opts)
    if not opts.path.key:
        raise ValueError('opts.path.key required for evaluation (For now)')
    path = RunPath(opts.path)

    # Re-parse full args with `base_opts` as default instead
    # TODO(ycho): Verify if this works.
    base_opts = update_settings(
        opts, argv=['--config_file',
                    str(path.dir / 'opts.yaml')])
    opts = update_settings(base_opts)

    # Instantiation ...
    device = resolve_device(opts.device)
    model = KeypointNetwork2D(opts.model).to(device)

    # Load checkpoint.
    ckpt_file = get_latest_file(path.ckpt)
    print('ckpt = {}'.format(ckpt_file))
    Saver(model, None).load(ckpt_file)

    # NOTE(ycho): Forcing data loading on the CPU.
    # TODO(ycho): Consider scripted compositions?
    transform = Compose([
        DenseMapsMobilePose(opts.maps, th.device('cpu:0')),
        Normalize(Normalize.Settings()),
        InstancePadding(opts.padding)
    ])
    _, test_loader = get_loaders(opts.dataset,
                                 device=th.device('cpu:0'),
                                 batch_size=opts.batch_size,
                                 transform=transform)

    model.eval()
    for data in test_loader:
        # Now that we're here, convert all inputs to the device.
        data = {
            k: (v.to(device) if isinstance(v, th.Tensor) else v)
            for (k, v) in data.items()
        }
        image = data[Schema.IMAGE]
        image_scale = th.as_tensor(image.shape[-2:])  # (h,w) order
        print('# instances = {}'.format(data[Schema.INSTANCE_NUM]))
        with th.no_grad():
            outputs = model(image)

            heatmap = outputs[Schema.HEATMAP]
            kpt_heatmap = outputs[Schema.KEYPOINT_HEATMAP]

            # FIXME(ycho): hardcoded obj==1 assumption
            scores, indices = decode_kpt_heatmap(kpt_heatmap,
                                                 max_num_instance=4)

            # hmm...
            upsample_ratio = th.as_tensor(image_scale /
                                          th.as_tensor(heatmap.shape[-2:]),
                                          device=indices.device)
            upsample_ratio = upsample_ratio[None, None, None, :]

        scaled_indices = indices * upsample_ratio

        # Visualize inferred keypoints ...
        if False:
            # FIXME(ycho): Pedantically incorrect!!
            heatmap_vis = DrawKeypointMap(
                DrawKeypointMap.Settings(as_displacement=False))(heatmap)
            kpt_heatmap_vis = DrawKeypointMap(
                DrawKeypointMap.Settings(as_displacement=False))(kpt_heatmap)

            fig, ax = plt.subplots(3, 1)
            hv_cpu = heatmap_vis[0].detach().cpu().numpy().transpose(1, 2, 0)
            khv_cpu = kpt_heatmap_vis[0].detach().cpu().numpy().transpose(
                1, 2, 0)
            img_cpu = th.clip(0.5 + (image[0] * 0.25), 0.0,
                              1.0).detach().cpu().numpy().transpose(1, 2, 0)
            ax[0].imshow(hv_cpu)
            ax[1].imshow(khv_cpu / khv_cpu.max())
            ax[2].imshow(img_cpu)
            plt.show()

        # scores = (32,9,4)
        # (i,j)  = (32,2,9,4)
        for i_batch in range(scores.shape[0]):
            # GROUND_TRUTH
            kpt_in = data[Schema.KEYPOINT_2D][i_batch, ..., :2]
            kpt_in = kpt_in * image_scale.to(kpt_in.device)
            # X-Y order (J-I order)
            # print(kpt_in)

            # print(scaled_indices[i_batch])  # Y-X order (I-J order)
            print('scale.shape')  # 32,4,3
            print(data[Schema.SCALE].shape)
            sol = compute_pose_epnp(
                data[Schema.PROJECTION][i_batch],
                # not estimating scale info for now ...,
                data[Schema.SCALE][i_batch],
                th.flip(scaled_indices[i_batch], dims=(-1, )) /
                image_scale.to(scaled_indices.device))
            if sol is None:
                continue
            R, T = sol
            print(R, data[Schema.ORIENTATION][i_batch])
            print(T, data[Schema.TRANSLATION][i_batch])
            break

        np.save(F'/tmp/heatmap.npy', heatmap.cpu().numpy())
        np.save(F'/tmp/kpt_heatmap.npy', kpt_heatmap.cpu().numpy())
        break
コード例 #11
0
def main():
    # data
    transform = Compose([
        CropObject(CropObject.Settings()),
        Normalize(Normalize.Settings(keys=(Schema.CROPPED_IMAGE, )))
    ])
    _, test_loader = get_loaders(DatasetSettings(),
                                 th.device('cpu'),
                                 1,
                                 transform=transform,
                                 collate_fn=collate_cropped_img)
    # model
    device = th.device('cuda')
    model = load_model()
    model = model.to(device)
    model.eval()

    # translation solver?
    solve_translation = SolveTranslation()

    box_points = BoxPoints2D(th.device('cpu'), Schema.KEYPOINT_2D)
    draw_bbox = DrawBoundingBoxFromKeypoints(
        DrawBoundingBoxFromKeypoints.Settings())

    # eval
    for data in test_loader:
        # Skip occasional batches without any images.
        if Schema.CROPPED_IMAGE not in data:
            continue

        with th.no_grad():
            # run inference
            crop_img = data[Schema.CROPPED_IMAGE].view(-1, 3, 224, 224)
            dim, quat = model(crop_img.to(device))
            dim2, quat2 = data[Schema.SCALE], data[Schema.QUATERNION]
            logging.debug('D {} {}'.format(dim, dim2))
            logging.debug('Q {} {}'.format(quat, quat2))
            # trans = data[Schema.TRANSLATION]

            if False:
                dim = dim2
                quat = quat2
                R = quaternion_to_matrix(quat)

            R = quaternion_to_matrix(quat)

            input_image = data[Schema.IMAGE].detach().cpu()
            proj_matrix = (data[Schema.PROJECTION].detach().cpu().reshape(
                -1, 4, 4))

            # Solve translations.
            translations = []
            for i in range(len(proj_matrix)):
                box_i, box_j, box_h, box_w = data[Schema.BOX_2D][i]
                box_2d = th.as_tensor(
                    [box_i, box_j, box_i + box_h, box_j + box_w])
                box_2d = 2.0 * (box_2d - 0.5)
                args = {
                    # inputs from dataset
                    Schema.PROJECTION: proj_matrix[i],
                    Schema.BOX_2D: box_2d,
                    # inputs from network
                    Schema.ORIENTATION: R[i],
                    Schema.QUATERNION: quat[i],
                    Schema.SCALE: dim[i]
                }
                # Solve translation
                translation, _ = solve_translation(args)
                translations.append(translation)
            translations = th.as_tensor(translations, dtype=th.float32)

            if True:
                print('num instances = {}'.format(len(translations)))
                pred_data = {
                    Schema.IMAGE: data[Schema.IMAGE][0],
                    Schema.ORIENTATION: R.cpu(),
                    Schema.TRANSLATION: translations,
                    Schema.SCALE: dim.cpu(),
                    Schema.PROJECTION: proj_matrix[0],
                    Schema.INSTANCE_NUM: len(proj_matrix),
                }
                pred_data = box_points(pred_data)
                pred_data = draw_bbox(pred_data)
                image_with_box = pred_data['img_w_bbox']
            else:
                dimensions = dim.detach().cpu()
                quaternion = quat.detach().cpu()
                translations = translations.detach().cpu()

                #print(input_image.shape)
                #print(data[Schema.BOX_2D].shape)
                #print(proj_matrix.shape)
                #print(translations.shape)
                #print(dimensions.shape)
                #print(quaternion.shape)

                # draw box
                image_with_box = plot_regressed_3d_bbox(
                    input_image,
                    # keypoints_2d,
                    # data[Schema.BOX_2D],
                    data[Schema.KEYPOINT_2D],
                    proj_matrix,
                    dimensions,
                    quaternion,
                    translations)

            plt.clf()
            plt.imshow(image_with_box.permute(1, 2, 0))
            plt.pause(0.1)