Exemple #1
0
def main():
    from top.data.transforms.sequence import (CropSequence)
    from top.data.transforms.common import InstancePadding
    from top.run.app_util import update_settings
    # opts = SampleObjectron.Settings()
    opts = ObjectronSequence.Settings()
    opts = update_settings(opts)
    xfm = transforms.Compose([
        CropSequence(CropSequence.Settings()),
        InstancePadding(InstancePadding.Settings(instance_dim=1))
    ])
    # dataset = SampleObjectron(opts, transform=xfm)
    dataset = ObjectronSequence(opts, transform=xfm)
    loader = th.utils.data.DataLoader(dataset,
                                      batch_size=8,
                                      num_workers=0,
                                      collate_fn=_skip_none)

    for data in loader:
        print(data[Schema.INSTANCE_NUM])
        print({
            k: (v.shape if isinstance(v, th.Tensor) else v)
            for k, v in data.items()
        })
        break
Exemple #2
0
def main():
    opts = AppSettings()
    opts = update_settings(opts)

    device = resolve_device('cpu')

    xfm = Compose([
        BoxPoints2D(device=device,
                    key_out='points_2d_debug',
                    key_out_3d='points_3d_debug'),
        DrawKeypoints(DrawKeypoints.Settings()),
        DenseMapsMobilePose(DenseMapsMobilePose.Settings(), device),
        InstancePadding(InstancePadding.Settings()),
    ])

    dataset, _ = get_loaders(opts.dataset, device, None, xfm)

    for data in dataset:
        #print('2D')
        #print(data['points_2d_debug'])
        #print(data[Schema.KEYPOINT_2D])

        #print('3D')
        #print(data['points_3d_debug'])
        #print(data[Schema.KEYPOINT_3D])

        save_image(data[Schema.IMAGE] / 255.0, F'/tmp/img.png')
        for i, img in enumerate(data[Schema.HEATMAP]):
            save_image(img, F'/tmp/heatmap-{i}.png', normalize=True)
        save_image(data['rendered_keypoints'] / 255.0, F'/tmp/rkpts.png')

        break
def main():
    opts = ColoredCubeDataset.Settings(batch_size=32, unstack=False)
    opts = update_settings(opts)

    device = resolve_device()
    dataset = ColoredCubeDataset(
        opts,
        device,
    )
    for data in dataset:
        print({k: v.shape for k, v in data.items()})
        save_image(data[Schema.IMAGE] / 255.0, F'/tmp/img.png')
        break
Exemple #4
0
def main():
    opts = SampleObjectron.Settings()
    opts = update_settings(opts)
    dataset = SampleObjectron(opts)
    loader = th.utils.data.DataLoader(dataset,
                                      batch_size=1,
                                      num_workers=0,
                                      collate_fn=None)

    for data in loader:
        print(data.keys())
        # print(data[Schema.KEYPOINT_2D])
        break
def main():
    opts = DatasetSettings()
    opts = update_settings(opts)
    augment = GeometricAugment(GeometricAugment.Settings(p_flip_lr=1.0),
                               in_place=False)
    key_out = '__keypoint_image__'  # try not to result in key collision

    visualize = Compose([
        DenseMapsMobilePose(DenseMapsMobilePose.Settings(),
                            device=th.device('cpu')),
        DrawKeypointMap(
            DrawKeypointMap.Settings(key_in=Schema.KEYPOINT_HEATMAP,
                                     key_out=key_out,
                                     as_displacement=False)),
        DrawBoundingBoxFromKeypoints(
            DrawBoundingBoxFromKeypoints.Settings(key_in=key_out,
                                                  key_out=key_out))
    ])

    # visualize = DrawKeypoints(DrawKeypoints.Settings(key_out=key_out))
    #visualize = DrawKeypointMap(DrawKeypointMap.Settings(key_in=Schema.IMAGE, key_out=key_out,
    #    as_displacement = False))
    train_loader, _ = get_loaders(opts,
                                  device=th.device('cpu'),
                                  batch_size=None,
                                  transform=None)

    fig, ax = plt.subplots(2, 1)
    for data in train_loader:
        resize = Resize(size=data[Schema.IMAGE].shape[-2:])

        aug_data = augment(data)
        # aug_data = data
        print(data[Schema.KEYPOINT_2D])
        print(aug_data[Schema.KEYPOINT_2D])
        v0 = resize(visualize(data)[key_out])
        v1 = resize(visualize(aug_data)[key_out])
        # v0 = 255 * resize(v0) + data[Schema.IMAGE]
        # v1 = 255 * resize(v1) + aug_data[Schema.IMAGE]
        v0 = th.where(v0 <= 0, data[Schema.IMAGE].float(), 255.0 * v0)
        v1 = th.where(v1 <= 0, aug_data[Schema.IMAGE].float(), 255.0 * v1)

        v0 = _to_image(v0) / 255.0
        v1 = _to_image(v1) / 255.0
        #image = data[Schema.IMAGE][0].permute((1, 2, 0))  # 13HW
        #aug_image = data['aug_img'][0].permute((1, 2, 0))  # 13HW
        ax[0].imshow(v0)
        ax[0].set_title('orig')
        ax[1].imshow(v1)
        ax[1].set_title('aug')
        k = plt.waitforbuttonpress()
def main():
    logging.captureWarnings(True)
    opts = AppSettings()
    opts = update_settings(opts)
    if opts.profile:
        try:
            with profiler.profile(record_shapes=True, use_cuda=True) as prof:
                eval_main(opts)
        finally:
            print('tracing...')
            print(prof.key_averages().table(
                sort_by='cpu_time_total',
                row_limit=16))
            prof.export_chrome_trace("/tmp/trace.json")
    else:
        eval_main(opts)
def main():
    opts = Settings()
    opts = update_settings(opts)
    device = resolve_device(opts.device)
    transform = Compose([
        DenseMapsMobilePose(DenseMapsMobilePose.Settings(), device),
        Normalize(Normalize.Settings()),
        InstancePadding(InstancePadding.Settings())
    ])
    data_loader, _ = get_loaders(opts.dataset, device, opts.batch_size,
                                 transform)

    for i, data in enumerate(data_loader):
        for k, v in data.items():
            if isinstance(v, th.Tensor):
                print(k, v.shape)
            else:
                print(k, v)
        if i >= opts.num_samples:
            break
def main():
    opts = DatasetSettings()
    opts = update_settings(opts)
    key_out = '__aug_img__'  # Try to prevent key collision
    transform = Compose([
        InstancePadding(InstancePadding.Settings()),
        PhotometricAugment(PhotometricAugment.Settings(key_out=key_out))
    ])
    train_loader, test_loader = get_loaders(opts,
                                            device=th.device('cpu'),
                                            batch_size=4,
                                            transform=transform)

    fig, ax = plt.subplots(2, 1)
    for data in train_loader:
        image = _stack_images(data[Schema.IMAGE])
        aug_image = _stack_images(data[key_out])
        image = _to_image(image)
        aug_image = _to_image(aug_image)
        ax[0].imshow(image)
        ax[1].imshow(aug_image)
        k = plt.waitforbuttonpress()
def main():
    opts = DatasetSettings()
    opts = update_settings(opts)

    transform = Compose([
        # NOTE(ycho): `FormatLabel` must be applied prior
        # to `CropObject` since it modifies the requisite tensors.
        # FormatLabel(FormatLabel.Settings(), opts.vis_thresh),
        CropObject(CropObject.Settings()),
        # Normalize(Normalize.Settings(keys=(Schema.CROPPED_IMAGE,)))
    ])
    _, test_loader = get_loaders(opts,
                                 device=th.device('cpu'),
                                 batch_size=4,
                                 transform=transform,
                                 collate_fn=collate_cropped_img)

    fig, ax = plt.subplots(1, 1)
    for data in test_loader:
        # -> 4,1,3,640,480
        # -- (batch_size, ? ,3, 640, 480)
        print(data[Schema.IMAGE].shape)
        print(data[Schema.CROPPED_IMAGE].shape)
        print(data[Schema.INDEX])
        print(data[Schema.INSTANCE_NUM])
        print('points_2d')
        print(data[Schema.KEYPOINT_2D])
        print('points_3d')
        print(data[Schema.KEYPOINT_3D])

        # image = _stack_images(data[Schema.IMAGE])
        image = _stack_images(data[Schema.CROPPED_IMAGE])
        image = _to_image(image)
        ax.imshow(image)
        ax.set_xticks([])
        ax.set_yticks([])
        k = plt.waitforbuttonpress()
def main():
    logging.basicConfig(level=logging.WARN)
    opts = AppSettings()
    opts = update_settings(opts)
    path = RunPath(opts.path)

    device = resolve_device(opts.device)
    model = BoundingBoxRegressionModel(opts.model).to(device)
    optimizer = th.optim.Adam(model.parameters(), lr=1e-5)
    writer = SummaryWriter(path.log)

    transform = Compose([
        CropObject(CropObject.Settings()),
        Normalize(Normalize.Settings(keys=(Schema.CROPPED_IMAGE, )))
    ])
    train_loader, test_loader = get_loaders(opts.dataset,
                                            device=th.device('cpu'),
                                            batch_size=opts.batch_size,
                                            transform=transform,
                                            collate_fn=collate_cropped_img)

    # NOTE(ycho): Synchronous event hub.
    hub = Hub()

    # Save meta-parameters.
    def _save_params():
        opts.save(path.dir / 'opts.yaml')

    hub.subscribe(Topic.TRAIN_BEGIN, _save_params)

    # Periodically log training statistics.
    # FIXME(ycho): hardcoded logging period.
    # NOTE(ycho): Currently only plots `loss`.
    collect = Collect(hub, Topic.METRICS, [])
    train_logger = TrainLogger(hub, writer, opts.log_period)

    # Periodically save model, per epoch.
    # TODO(ycho): Consider folding this callback inside Trainer().
    hub.subscribe(
        Topic.EPOCH, lambda epoch: Saver(model, optimizer).save(
            path.ckpt / F'epoch-{epoch}.zip'))

    # Periodically save model, per N training steps.
    # TODO(ycho): Consider folding this callback inside Trainer()
    # and adding {save_period} args to Trainer instead.
    hub.subscribe(
        Topic.STEP,
        Periodic(
            opts.save_period, lambda step: Saver(model, optimizer).save(
                path.ckpt / F'step-{step}.zip')))

    # Periodically evaluate model, per N training steps.
    # NOTE(ycho): Load and process test data ...
    # TODO(ycho): Consider folding this callback inside Trainer()
    # and adding {test_loader, eval_fn} args to Trainer instead.
    def _eval_fn(model, data):
        # TODO(Jiyong): hardcode for cropped image size
        crop_img = data[Schema.CROPPED_IMAGE].view(-1, 3, 224, 224)
        return model(crop_img.to(device))

    evaluator = Evaluator(Evaluator.Settings(period=opts.eval_period), hub,
                          model, test_loader, _eval_fn)

    # TODO(Jiyong):
    # All metrics evaluation should reset stats at eval_begin(),
    # aggregate stats at eval_step(),
    # and output stats at eval_end(). These signals are all implemented.
    # What are the appropriate metrics to implement for bounding box regression?
    def _on_eval_step(inputs, outputs):
        pass

    hub.subscribe(Topic.EVAL_STEP, _on_eval_step)

    collect = Collect(hub, Topic.METRICS, [])

    def _log_all(metrics: Dict[Topic, Any]):
        pass

    hub.subscribe(Topic.METRICS, _log_all)

    orientation_loss_func = nn.L1Loss().to(device)
    scale_loss_func = nn.L1Loss().to(device)

    def _loss_fn(model: th.nn.Module, data):
        # Now that we're here, convert all inputs to the device.
        image = data[Schema.CROPPED_IMAGE].to(device)
        c, h, w = image.shape[-3:]
        image = image.view(-1, c, h, w)
        truth_quat = data[Schema.QUATERNION].to(device)
        truth_quat = truth_quat.view(-1, 4)
        truth_dim = data[Schema.SCALE].to(device)
        truth_dim = truth_dim.view(-1, 3)
        truth_trans = data[Schema.TRANSLATION].to(device)
        truth_trans = truth_trans.view(-1, 3)

        dim, quat = model(image)

        outputs = {}
        outputs[Schema.SCALE] = dim
        outputs[Schema.QUATERNION] = quat

        # Also make input/output pair from training
        # iterations available to the event bus.
        hub.publish(Topic.TRAIN_OUT, inputs=data, outputs=outputs)

        loss = {}

        scale_loss = scale_loss_func(dim, truth_dim)
        orient_loss = orientation_loss_func(quat, truth_quat)
        total_loss = opts.alpha * scale_loss + orient_loss

        loss["total"] = total_loss
        loss["scale"] = scale_loss
        loss["orientation"] = orient_loss

        return loss

    ## Load from checkpoint
    if opts.load_ckpt:
        logging.info(F'Loading checkpoint {opts.load_ckpt} ...')
        Saver(model, optimizer).load(opts.load_ckpt)

    ## Trainer
    trainer = Trainer(opts.train, model, optimizer, _loss_fn, hub,
                      train_loader)

    # Train, optionally profile
    if opts.profile:
        try:
            with profiler.profile(record_shapes=True, use_cuda=True) as prof:
                trainer.train()
        finally:
            print(prof.key_averages().table(sort_by='cpu_time_total',
                                            row_limit=16))
            prof.export_chrome_trace("/tmp/trace.json")
    else:
        trainer.train()
Exemple #11
0
def main():
    # Settings ...
    opts = AppSettings()
    opts = update_settings(opts)

    # Path configuration ...
    path = RunPath(opts.run)

    # Device resolution ...
    device = resolve_device(opts.device)

    # Data
    train_loader, test_loader = load_data(opts)

    # Model, loss
    model = Model().to(device)
    xs_loss = nn.CrossEntropyLoss()

    def loss_fn(model: nn.Module, data):
        inputs, target = data
        inputs = inputs.to(device)
        target = target.to(device)
        output = model(inputs)
        return xs_loss(output, target)

    # Optimizer
    optimizer = th.optim.Adam(model.parameters(), lr=1e-3)

    # Callbacks, logging, ...
    writer = th.utils.tensorboard.SummaryWriter(path.log)

    def _eval_fn(model, data):
        inputs, _ = data
        output = model(inputs.to(device))
        return output

    hub = Hub()

    # TODO(ycho): The default behavior of evaluator (num_samples==1)
    # might be confusing and unintuitive - prefer more reasonable default?
    evaluator = Evaluator(
        Evaluator.Settings(period=opts.eval_period, num_samples=128), hub,
        model, test_loader, _eval_fn)

    accuracy = Accuracy(hub, 'accuracy')
    metrics = Collect(hub, Topic.METRICS, (Topic.STEP, 'accuracy'))

    def _on_metrics(data):
        # TODO(ycho): Fix clunky syntax with `Collect`.
        step_arg, _ = data[Topic.STEP]
        step = step_arg[0]
        acc_arg, _ = data['accuracy']
        accuracy = acc_arg[0]

        # Print to stdout ...
        print(F'@{step} accuracy={accuracy} ')

        # Tensorboard logging ...
        writer.add_scalar('accuracy', accuracy, step)

    hub.subscribe(Topic.METRICS, _on_metrics)

    # Trainer
    trainer = Trainer(opts.train, model, optimizer, loss_fn, hub, train_loader)

    trainer.train()
Exemple #12
0
def main():
    logging.basicConfig(level=logging.WARN)
    opts = AppSettings()
    opts = update_settings(opts)
    path = RunPath(opts.path)

    device = resolve_device(opts.device)
    model = KeypointNetwork2D(opts.model).to(device)
    # FIXME(ycho): Hardcoded lr == 1e-3
    optimizer = th.optim.Adam(model.parameters(), lr=1e-3)
    writer = th.utils.tensorboard.SummaryWriter(path.log)

    # NOTE(ycho): Force data loading on the CPU.
    data_device = th.device('cpu')

    # TODO(ycho): Consider scripted compositions?
    # If a series of transforms can be fused and compiled,
    # it would probably make it a lot faster to train...
    transform = Compose([
        DenseMapsMobilePose(opts.maps, data_device),
        PhotometricAugment(opts.photo_aug, False),
        Normalize(Normalize.Settings()),
        InstancePadding(opts.padding)
    ])

    train_loader, test_loader = get_loaders(opts.dataset,
                                            device=data_device,
                                            batch_size=opts.batch_size,
                                            transform=transform)

    # NOTE(ycho): Synchronous event hub.
    hub = Hub()

    def _on_train_begin():

        # Save meta-parameters.
        opts.save(path.dir / 'opts.yaml')
        # NOTE(ycho): Currently `load` only works with a modified version of the
        # main SimpleParsing repository.
        # opts.load(path.dir / 'opts.yaml')

        # Generate tensorboard graph.
        data = next(iter(test_loader))
        dummy = data[Schema.IMAGE].to(device).detach()
        # NOTE(ycho): No need to set model to `eval`,
        # eval mode is set internally within add_graph().
        writer.add_graph(ModelAsTuple(model), dummy)

    hub.subscribe(Topic.TRAIN_BEGIN, _on_train_begin)

    # Periodically log training statistics.
    # FIXME(ycho): hardcoded logging period.
    # NOTE(ycho): Currently only plots `loss`.
    collect = Collect(hub, Topic.METRICS, [])
    train_logger = TrainLogger(hub, writer, opts.log_period)

    # Periodically save model, per epoch.
    # TODO(ycho): Consider folding this callback inside Trainer().
    hub.subscribe(
        Topic.EPOCH, lambda epoch: Saver(model, optimizer).save(
            path.ckpt / F'epoch-{epoch}.zip'))

    # Periodically save model, per N training steps.
    # TODO(ycho): Consider folding this callback inside Trainer()
    # and adding {save_period} args to Trainer instead.
    hub.subscribe(
        Topic.STEP,
        Periodic(
            opts.save_period, lambda step: Saver(model, optimizer).save(
                path.ckpt / F'step-{step}.zip')))

    # Periodically evaluate model, per N training steps.
    # NOTE(ycho): Load and process test data ...
    # TODO(ycho): Consider folding this callback inside Trainer()
    # and adding {test_loader, eval_fn} args to Trainer instead.
    def _eval_fn(model, data):
        # TODO(ycho): Actually implement evaluation function.
        # return model(data[Schema.IMAGE].to(device))
        return None

    evaluator = Evaluator(Evaluator.Settings(period=opts.eval_period), hub,
                          model, test_loader, _eval_fn)

    # TODO(ycho):
    # All metrics evaluation should reset stats at eval_begin(),
    # aggregate stats at eval_step(),
    # and output stats at eval_end(). These signals are all implemented.
    # What are the appropriate metrics to implement for keypoint regression?
    # - keypoint matching F1 score(?)
    # - loss_fn() but for the evaluation datasets
    def _on_eval_step(inputs, outputs):
        pass

    hub.subscribe(Topic.EVAL_STEP, _on_eval_step)

    collect = Collect(hub, Topic.METRICS, [])

    def _log_all(metrics: Dict[Topic, Any]):
        pass

    hub.subscribe(Topic.METRICS, _log_all)

    # TODO(ycho): weight the losses with some constant ??
    losses = {
        Schema.HEATMAP:
        ObjectHeatmapLoss(key=Schema.HEATMAP),
        # Schema.DISPLACEMENT_MAP: KeypointDisplacementLoss(),
        Schema.KEYPOINT_HEATMAP:
        ObjectHeatmapLoss(key=Schema.KEYPOINT_HEATMAP),
        Schema.SCALE:
        KeypointScaleLoss()
    }

    def _loss_fn(model: th.nn.Module, data):
        # Now that we're here, convert all inputs to the device.
        data = {
            k: (v.to(device) if isinstance(v, th.Tensor) else v)
            for (k, v) in data.items()
        }
        image = data[Schema.IMAGE]
        outputs = model(image)
        # Also make input/output pair from training
        # iterations available to the event bus.
        hub.publish(Topic.TRAIN_OUT, inputs=data, outputs=outputs)
        kpt_heatmap_loss = losses[Schema.KEYPOINT_HEATMAP](outputs, data)
        heatmap_loss = losses[Schema.HEATMAP](outputs, data)
        scale_loss = losses[Schema.SCALE](outputs, data)
        # Independently log stuff
        hub.publish(
            Topic.TRAIN_LOSSES, {
                'keypoint': kpt_heatmap_loss,
                'center': heatmap_loss,
                'scale': scale_loss
            })
        return (kpt_heatmap_loss + heatmap_loss + scale_loss)

    ## Load from checkpoint
    if opts.load_ckpt:
        logging.info(F'Loading checkpoint {opts.load_ckpt} ...')
        Saver(model, optimizer).load(opts.load_ckpt)

    ## Trainer
    trainer = Trainer(opts.train, model, optimizer, _loss_fn, hub,
                      train_loader)

    # Train, optionally profile
    if opts.profile:
        try:
            with profiler.profile(record_shapes=True, use_cuda=True) as prof:
                trainer.train()
        finally:
            print(prof.key_averages().table(sort_by='cpu_time_total',
                                            row_limit=16))
            prof.export_chrome_trace("/tmp/trace.json")
    else:
        trainer.train()
def main():
    logging.basicConfig(level=logging.INFO)
    opts = Settings()
    opts = update_settings(opts)
    pool_states = [{} for _ in range(opts.num_workers)]
    # for train in [False, True]:
    for train in [False, True]:
        name = 'objectron-train' if train else 'objectron-test'
        logging.info(F'Processing {name}')

        max_bytes = opts.max_train_bytes if train else opts.max_test_bytes

        # TODO(ycho): Consider fancier (e.g. class-equalizing) shard samplers.
        shards = ObjectronDetection(ObjectronDetection.Settings(local=False),
                                    train).shards

        out_dir = (Path(opts.cache_dir).expanduser() / name)
        out_dir.mkdir(parents=True, exist_ok=True)

        if opts.use_pool:
            # NOTE(ycho): The initial approach based on mp.Pool().
            # Turned out that it is not possible to guarantee graceful exit in
            # this way.
            _download = functools.partial(download_shard, out_dir=out_dir)
            with mp.Pool(opts.num_workers, init_worker) as p:
                with tqdm(total=max_bytes) as pbar:
                    total_bytes = 0
                    for shard_bytes in p.imap_unordered(_download, shards):
                        pbar.update(shard_bytes)
                        # Accumulate and check for termination.
                        total_bytes += shard_bytes
                        if total_bytes >= max_bytes:
                            logging.info(F'Done: {total_bytes} > {max_bytes}')
                            # NOTE(ycho): Due to bug in mp.Pool(), imap_unordered() with close()/join()
                            # does NOT work, thus we implicitly call terminate() via context manager
                            # which may result in incomplete shards. This condition
                            # must be checked.
                            break
        else:
            init_bytes = sum(f.stat().st_size for f in out_dir.rglob('*')
                             if f.is_file())
            logging.info(F'Starting from {init_bytes}/{max_bytes} ...')
            ctx = mp.get_context('fork')
            stop = ctx.Value('b', (init_bytes >= max_bytes))
            queue = ctx.Queue()
            workers = [
                ctx.Process(target=download_shards,
                            args=(shards[i::opts.num_workers], out_dir, stop,
                                  queue)) for i in range(opts.num_workers)
            ]
            # Start!
            for p in workers:
                p.start()

            # Progress logging ...
            try:
                with tqdm(initial=init_bytes, total=max_bytes) as pbar:
                    # Periodically check progress...
                    total_bytes = init_bytes
                    while True:
                        shard_bytes = queue.get()
                        pbar.update(shard_bytes)
                        total_bytes += shard_bytes
                        if total_bytes >= max_bytes:
                            break
            except KeyboardInterrupt:
                logging.info('Cancelling download, trying to clean up ...')
                pass
            finally:
                # Stop.
                with stop.get_lock():
                    stop.value = True

                # Join.
                logging.info(
                    'Download completed, joining the rest of the processes...')
                for p in workers:
                    p.join()
Exemple #14
0
def main():
    # logging.basicConfig(level=logging.DEBUG)

    # Initial parsing looking for `RunPath` ...
    opts = AppSettings()
    opts = update_settings(opts)
    if not opts.path.key:
        raise ValueError('opts.path.key required for evaluation (For now)')
    path = RunPath(opts.path)

    # Re-parse full args with `base_opts` as default instead
    # TODO(ycho): Verify if this works.
    base_opts = update_settings(
        opts, argv=['--config_file',
                    str(path.dir / 'opts.yaml')])
    opts = update_settings(base_opts)

    # Instantiation ...
    device = resolve_device(opts.device)
    model = KeypointNetwork2D(opts.model).to(device)

    # Load checkpoint.
    ckpt_file = get_latest_file(path.ckpt)
    print('ckpt = {}'.format(ckpt_file))
    Saver(model, None).load(ckpt_file)

    # NOTE(ycho): Forcing data loading on the CPU.
    # TODO(ycho): Consider scripted compositions?
    transform = Compose([
        DenseMapsMobilePose(opts.maps, th.device('cpu:0')),
        Normalize(Normalize.Settings()),
        InstancePadding(opts.padding)
    ])
    _, test_loader = get_loaders(opts.dataset,
                                 device=th.device('cpu:0'),
                                 batch_size=opts.batch_size,
                                 transform=transform)

    model.eval()
    for data in test_loader:
        # Now that we're here, convert all inputs to the device.
        data = {
            k: (v.to(device) if isinstance(v, th.Tensor) else v)
            for (k, v) in data.items()
        }
        image = data[Schema.IMAGE]
        image_scale = th.as_tensor(image.shape[-2:])  # (h,w) order
        print('# instances = {}'.format(data[Schema.INSTANCE_NUM]))
        with th.no_grad():
            outputs = model(image)

            heatmap = outputs[Schema.HEATMAP]
            kpt_heatmap = outputs[Schema.KEYPOINT_HEATMAP]

            # FIXME(ycho): hardcoded obj==1 assumption
            scores, indices = decode_kpt_heatmap(kpt_heatmap,
                                                 max_num_instance=4)

            # hmm...
            upsample_ratio = th.as_tensor(image_scale /
                                          th.as_tensor(heatmap.shape[-2:]),
                                          device=indices.device)
            upsample_ratio = upsample_ratio[None, None, None, :]

        scaled_indices = indices * upsample_ratio

        # Visualize inferred keypoints ...
        if False:
            # FIXME(ycho): Pedantically incorrect!!
            heatmap_vis = DrawKeypointMap(
                DrawKeypointMap.Settings(as_displacement=False))(heatmap)
            kpt_heatmap_vis = DrawKeypointMap(
                DrawKeypointMap.Settings(as_displacement=False))(kpt_heatmap)

            fig, ax = plt.subplots(3, 1)
            hv_cpu = heatmap_vis[0].detach().cpu().numpy().transpose(1, 2, 0)
            khv_cpu = kpt_heatmap_vis[0].detach().cpu().numpy().transpose(
                1, 2, 0)
            img_cpu = th.clip(0.5 + (image[0] * 0.25), 0.0,
                              1.0).detach().cpu().numpy().transpose(1, 2, 0)
            ax[0].imshow(hv_cpu)
            ax[1].imshow(khv_cpu / khv_cpu.max())
            ax[2].imshow(img_cpu)
            plt.show()

        # scores = (32,9,4)
        # (i,j)  = (32,2,9,4)
        for i_batch in range(scores.shape[0]):
            # GROUND_TRUTH
            kpt_in = data[Schema.KEYPOINT_2D][i_batch, ..., :2]
            kpt_in = kpt_in * image_scale.to(kpt_in.device)
            # X-Y order (J-I order)
            # print(kpt_in)

            # print(scaled_indices[i_batch])  # Y-X order (I-J order)
            print('scale.shape')  # 32,4,3
            print(data[Schema.SCALE].shape)
            sol = compute_pose_epnp(
                data[Schema.PROJECTION][i_batch],
                # not estimating scale info for now ...,
                data[Schema.SCALE][i_batch],
                th.flip(scaled_indices[i_batch], dims=(-1, )) /
                image_scale.to(scaled_indices.device))
            if sol is None:
                continue
            R, T = sol
            print(R, data[Schema.ORIENTATION][i_batch])
            print(T, data[Schema.TRANSLATION][i_batch])
            break

        np.save(F'/tmp/heatmap.npy', heatmap.cpu().numpy())
        np.save(F'/tmp/kpt_heatmap.npy', kpt_heatmap.cpu().numpy())
        break