def reload_model(resume_checkpoint,
                 opts,
                 optimizer=None,
                 mano_lambda_pose_reg=None,
                 mano_lambda_shape=None):
    opts = load_opts(resume_checkpoint)
    if mano_lambda_pose_reg is not None:
        opts["mano_lambda_pose_reg"] = mano_lambda_pose_reg
    if mano_lambda_shape is not None:
        opts["mano_lambda_shape"] = mano_lambda_shape
    if "train_dataset" in opts:
        opts["train_datasets"] = opts["train_dataset"]
    # Initialize model
    model = MeshRegNet(
        mano_lambda_verts2d=opts["mano_lambda_verts2d"],
        mano_lambda_verts3d=opts["mano_lambda_verts3d"],
        mano_lambda_joints3d=opts["mano_lambda_joints3d"],
        mano_lambda_recov_joints3d=opts["mano_lambda_recov_joints3d"],
        mano_lambda_recov_verts3d=opts["mano_lambda_recov_verts3d"],
        mano_lambda_joints2d=opts["mano_lambda_joints2d"],
        mano_lambda_shape=opts["mano_lambda_shape"],
        mano_use_shape=opts["mano_lambda_shape"] > 0,
        mano_lambda_pose_reg=opts["mano_lambda_pose_reg"],
        obj_lambda_recov_verts3d=opts["obj_lambda_recov_verts3d"],
        obj_lambda_verts2d=opts["obj_lambda_verts2d"],
        obj_lambda_verts3d=opts["obj_lambda_verts3d"],
        obj_trans_factor=opts["obj_trans_factor"],
        obj_scale_factor=opts["obj_scale_factor"],
        mano_fhb_hand="fhbhands" in opts["train_datasets"],
        uncertainty_pnp=True,
        domain_norm=opts["domain_norm"] if "domain_norm" in opts else False,
    )
    # model = torch.nn.DataParallel(model)
    if resume_checkpoint:
        start_epoch, _ = modelio.load_checkpoint(model,
                                                 optimizer=optimizer,
                                                 resume_path=resume_checkpoint,
                                                 strict=False,
                                                 as_parallel=False)
    else:
        start_epoch = 0
    return model, start_epoch
Example #2
0
def main(args):
    setseeds.set_all_seeds(args.manual_seed)
    # Initialize hosting
    dat_str = "_".join(args.train_datasets)
    now = datetime.now()
    split_str = "_".join(args.train_splits)
    exp_id = (
        f"checkpoints/{dat_str}_{split_str}_mini{args.mini_factor}/{now.year}_{now.month:02d}_{now.day:02d}_{now.hour:02d}_{now.minute:02d}/"
        #f"{args.com}_frac{args.fraction:.1e}"
        #f"lr{args.lr}_mom{args.momentum}_bs{args.batch_size}_"
        #f"_lmbeta{args.mano_lambda_shape:.1e}"
        #f"_lmpr{args.mano_lambda_pose_reg:.1e}"
        #f"_lmrj3d{args.mano_lambda_recov_joints3d:.1e}"
        #f"_lovr3d{args.obj_lambda_recov_verts3d:.1e}"
        #f"seed{args.manual_seed}"
    )
    # if args.no_augm:
    #     exp_id = f"{exp_id}_no_augm"
    # if args.block_rot:
    #     exp_id = f"{exp_id}_block_rot"
    # if args.freeze_batchnorm:
    #     exp_id = f"{exp_id}_fbn"

    # Initialize local checkpoint folder
    print(f"Saving experiment logs, models, and training curves and images at {exp_id}")
    save_args(args, exp_id, "opt")
    img_folder = os.path.join(exp_id, "images")
    os.makedirs(img_folder, exist_ok=True)
    result_folder = os.path.join(exp_id, "results")
    result_path = os.path.join(result_folder, "results.pkl")
    os.makedirs(result_folder, exist_ok=True)
    pyapt_path = os.path.join(result_folder, f"{args.pyapt_id}__{now.strftime('%H_%M_%S')}")
    with open(pyapt_path, "a") as t_f:
        t_f.write(" ")

    loaders = []
    if not args.evaluate:
        for train_split, dat_name in zip(args.train_splits, args.train_datasets):
            train_dataset, _ = get_dataset.get_dataset(
                dat_name,
                split=train_split,
                meta={"version": args.version, "split_mode": args.split_mode},
                use_cache=args.use_cache,
                mini_factor=args.mini_factor,
                no_augm=args.no_augm,
                block_rot=args.block_rot,
                max_rot=args.max_rot,
                center_idx=args.center_idx,
                scale_jittering=args.scale_jittering,
                center_jittering=args.center_jittering,
                fraction=args.fraction,
                mode="strong",
                sample_nb=None,
            )
            train_loader = torch.utils.data.DataLoader(
                train_dataset,
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=int(args.workers / len(args.train_datasets)),
                drop_last=True,
                collate_fn=collate.meshreg_collate,
            )
            loaders.append(train_loader)
        loader = concatloader.ConcatDataloader(loaders)
    if args.evaluate or args.eval_freq != -1:
        val_dataset, _ = get_dataset.get_dataset(
            args.val_dataset,
            split=args.val_split,
            meta={"version": args.version, "split_mode": args.split_mode},
            use_cache=args.use_cache,
            mini_factor=args.mini_factor,
            no_augm=True,
            block_rot=args.block_rot,
            max_rot=args.max_rot,
            center_idx=args.center_idx,
            scale_jittering=args.scale_jittering,
            center_jittering=args.center_jittering,
            sample_nb=None,
        )
        val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=int(args.workers),
            drop_last=False,
            collate_fn=collate.meshreg_collate,
        )

    model = MeshRegNet(
        mano_center_idx=args.center_idx,
        mano_lambda_joints2d=args.mano_lambda_joints2d,
        mano_lambda_joints3d=args.mano_lambda_joints3d,
        mano_lambda_recov_joints3d=args.mano_lambda_recov_joints3d,
        mano_lambda_recov_verts3d=args.mano_lambda_recov_verts3d,
        mano_lambda_verts2d=args.mano_lambda_verts2d,
        mano_lambda_verts3d=args.mano_lambda_verts3d,
        mano_lambda_shape=args.mano_lambda_shape,
        mano_use_shape=args.mano_lambda_shape > 0,
        mano_lambda_pose_reg=args.mano_lambda_pose_reg,
        obj_lambda_recov_verts3d=args.obj_lambda_recov_verts3d,
        obj_lambda_verts2d=args.obj_lambda_verts2d,
        obj_lambda_verts3d=args.obj_lambda_verts3d,
        obj_trans_factor=args.obj_trans_factor,
        obj_scale_factor=args.obj_scale_factor,
        mano_fhb_hand="fhbhands" in args.train_datasets,
    )
    model.cuda()
    # Initalize model
    if args.resume is not None:
        opts = reloadmodel.load_opts(args.resume)
        model, epoch = reloadmodel.reload_model(args.resume, opts)
        model.cuda()
        if args.evaluate:
            args.epochs = epoch + 1
    else:
        epoch = 0

    if args.freeze_batchnorm:
        freeze.freeze_batchnorm_stats(model)  # Freeze batchnorm

    model_params = filter(lambda p: p.requires_grad, model.parameters())
    if args.optimizer == "adam":
        optimizer = torch.optim.Adam(model_params, lr=args.lr, weight_decay=args.weight_decay)
    elif args.optimizer == "sgd":
        optimizer = torch.optim.SGD(
            model_params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
        )
    if args.resume is not None:
        reloadmodel.reload_optimizer(args.resume, optimizer)
    if args.lr_decay_gamma:
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_decay_step, gamma=args.lr_decay_gamma)
    fig = plt.figure(figsize=(10, 10))
    save_results = {}
    save_results["opt"] = dict(vars(args))
    save_results["train_losses"] = []
    save_results["val_losses"] = []

    monitor = MetricMonitor()
    print("Will monitor metrics.")

    for epoch_idx in tqdm(range(epoch, args.epochs), desc="epoch"):
        if not args.freeze_batchnorm:
            model.train()
        else:
            model.eval()

        if not args.evaluate:
            save_dict = epochpass.epoch_pass(
                loader,
                model,
                train=True,
                optimizer=optimizer,
                scheduler=scheduler,
                epoch=epoch_idx,
                img_folder=img_folder,
                fig=fig,
                display_freq=args.display_freq,
                epoch_display_freq=args.epoch_display_freq,
                lr_decay_gamma=args.lr_decay_gamma,
                monitor=monitor,
            )
            save_results["train_losses"].append(save_dict)
            with open(result_path, "wb") as p_f:
                pickle.dump(save_results, p_f)
            modelio.save_checkpoint(
                {
                    "epoch": epoch_idx + 1,
                    "network": "correspnet",
                    "state_dict": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "scheduler": scheduler,
                },
                is_best=True,
                checkpoint=exp_id,
                snapshot=args.snapshot,
            )

        if args.evaluate or (args.eval_freq != -1 and epoch_idx % args.eval_freq == 0):
            val_save_dict = epochpass.epoch_pass(
                val_loader,
                model,
                train=False,
                optimizer=None,
                scheduler=None,
                epoch=epoch_idx,
                img_folder=img_folder,
                fig=fig,
                display_freq=args.display_freq,
                epoch_display_freq=args.epoch_display_freq,
                lr_decay_gamma=args.lr_decay_gamma,
                monitor=monitor,
            )

            save_results["val_losses"].append(val_save_dict)
            monitor.plot(os.path.join(exp_id, "training.html"), plotly=True)
            monitor.plot_histogram(os.path.join(exp_id, "evaluation.html"), plotly=True)
            if args.matplotlib:
                monitor.plot(result_folder, matplotlib=True)
                monitor.plot_histogram(result_folder, matplotlib=True)
            monitor.save_metrics(os.path.join(exp_id, "metrics.pkl"))
            if args.evaluate:
                print(val_save_dict)
                break
def main(args):
    setseeds.set_all_seeds(args.manual_seed)
    train_dat_str = "_".join(args.train_datasets)
    dat_str = f"{train_dat_str}_warp{args.consist_dataset}"
    now = datetime.now()
    exp_id = (
        f"checkpoints/{dat_str}_mini{args.mini_factor}/{now.year}_{now.month:02d}_{now.day:02d}/"
        f"{args.com}_frac{args.fraction:.1e}"
        f"_ld{args.lambda_data:.1e}_lc{args.lambda_consist:.1e}"
        f"crit{args.consist_criterion}sca{args.consist_scale}"
        f"opt{args.optimizer}_lr{args.lr}_crit{args.criterion2d}"
        f"_mom{args.momentum}_bs{args.batch_size}"
        f"_lmj3d{args.mano_lambda_joints3d:.1e}"
        f"_lmbeta{args.mano_lambda_shape:.1e}"
        f"_lmpr{args.mano_lambda_pose_reg:.1e}"
        f"_lmrj3d{args.mano_lambda_recov_joints3d:.1e}"
        f"_lmrw3d{args.mano_lambda_recov_verts3d:.1e}"
        f"_lov2d{args.obj_lambda_verts2d:.1e}_lov3d{args.obj_lambda_verts3d:.1e}"
        f"_lovr3d{args.obj_lambda_recov_verts3d:.1e}"
        f"cj{args.center_jittering}seed{args.manual_seed}"
        f"sample_nb{args.sample_nb}_spac{args.spacing}"
        f"csteps{args.progressive_consist_steps}")
    if args.no_augm:
        exp_id = f"{exp_id}_no_augm"
    if args.no_consist_augm:
        exp_id = f"{exp_id}_noconsaugm"
    if args.block_rot:
        exp_id = f"{exp_id}_block_rot"
    if args.consist_gt_refs:
        exp_id = f"{exp_id}_gt_refs"

    # Initialize local checkpoint folder
    save_args(args, exp_id, "opt")
    monitor = Monitor(exp_id, hosting_folder=exp_id)
    img_folder = os.path.join(exp_id, "images")
    os.makedirs(img_folder, exist_ok=True)
    result_folder = os.path.join(exp_id, "results")
    result_path = os.path.join(result_folder, "results.pkl")
    os.makedirs(result_folder, exist_ok=True)
    pyapt_path = os.path.join(result_folder,
                              f"{args.pyapt_id}__{now.strftime('%H_%M_%S')}")
    with open(pyapt_path, "a") as t_f:
        t_f.write(" ")

    train_loader_nb = 0
    if len(args.train_datasets):
        train_loader_nb = train_loader_nb + len(args.train_datasets)
    if args.consist_dataset is not None:
        train_loader_nb = train_loader_nb + 1

    loaders = []
    if len(args.train_datasets) is not None:
        for train_split, dat_name in zip(args.train_splits,
                                         args.train_datasets):
            train_dataset, input_res = get_dataset.get_dataset(
                dat_name,
                split=train_split,
                meta={
                    "version": args.version,
                    "split_mode": "objects"
                },
                block_rot=args.block_rot,
                max_rot=args.max_rot,
                center_idx=args.center_idx,
                center_jittering=args.center_jittering,
                fraction=args.fraction,
                mini_factor=args.mini_factor,
                mode="strong",
                no_augm=args.no_augm,
                scale_jittering=args.scale_jittering,
                sample_nb=1,
                use_cache=args.use_cache,
            )
            train_loader = torch.utils.data.DataLoader(
                train_dataset,
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=int(args.workers / train_loader_nb),
                drop_last=True,
                collate_fn=collate_fn,
            )
            loaders.append(train_loader)

    if args.consist_dataset is not None:
        consist_dataset, consist_input_res = get_dataset.get_dataset(
            args.consist_dataset,
            split=args.consist_split,
            meta={
                "version": args.version,
                "split_mode": "objects"
            },
            use_cache=args.use_cache,
            mini_factor=args.mini_factor,
            block_rot=args.block_rot,
            center_idx=args.center_idx,
            center_jittering=args.center_jittering,
            fraction=args.fraction,
            max_rot=args.max_rot,
            mode="full",
            no_augm=args.no_consist_augm,
            sample_nb=args.sample_nb,
            scale_jittering=args.scale_jittering,
            spacing=args.
            spacing,  # Otherwise black padding gets included on warps
        )
        print(
            f"Got consist dataset {args.consist_dataset} of size {len(consist_dataset)}"
        )
        if input_res != consist_input_res:
            raise ValueError(
                f"train and consist dataset should have same input sizes"
                f"but got {input_res} and {consist_input_res}")
        consist_loader = torch.utils.data.DataLoader(
            consist_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=int(args.workers / train_loader_nb),
            drop_last=True,
            collate_fn=collate_fn,
        )
        loaders.append(consist_loader)
    loader = concatloader.ConcatLoader(loaders)
    if args.eval_freq != -1:
        val_dataset, input_res = get_dataset.get_dataset(
            args.val_dataset,
            split=args.val_split,
            meta={
                "version": args.version,
                "split_mode": "objects"
            },
            use_cache=args.use_cache,
            mini_factor=args.mini_factor,
            no_augm=True,
            block_rot=args.block_rot,
            max_rot=args.max_rot,
            center_idx=args.center_idx,
            scale_jittering=args.scale_jittering,
            center_jittering=args.center_jittering,
            sample_nb=None,
        )
        val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=int(args.workers),
            drop_last=False,
            collate_fn=collate.meshreg_collate,
        )

    model = MeshRegNet(
        mano_lambda_joints3d=args.mano_lambda_joints3d,
        mano_lambda_joints2d=args.mano_lambda_joints2d,
        mano_lambda_pose_reg=args.mano_lambda_pose_reg,
        mano_lambda_recov_joints3d=args.mano_lambda_recov_joints3d,
        mano_lambda_recov_verts3d=args.mano_lambda_recov_verts3d,
        mano_lambda_verts2d=args.mano_lambda_verts2d,
        mano_lambda_verts3d=args.mano_lambda_verts3d,
        mano_lambda_shape=args.mano_lambda_shape,
        mano_use_shape=args.mano_lambda_shape > 0,
        obj_lambda_verts2d=args.obj_lambda_verts2d,
        obj_lambda_verts3d=args.obj_lambda_verts3d,
        obj_lambda_recov_verts3d=args.obj_lambda_recov_verts3d,
        obj_trans_factor=args.obj_trans_factor,
        obj_scale_factor=args.obj_scale_factor,
        mano_fhb_hand="fhbhands" in args.train_datasets,
    )
    model.cuda()
    # Initalize model
    if args.resume is not None:
        opts = reloadmodel.load_opts(args.resume)
        model, epoch = reloadmodel.reload_model(args.resume, opts)
    else:
        epoch = 0
    if args.freeze_batchnorm:
        freeze.freeze_batchnorm_stats(model)  # Freeze batchnorm

    model_params = filter(lambda p: p.requires_grad, model.parameters())
    pmodel = warpreg.WarpRegNet(
        input_res,
        model,
        lambda_consist=args.lambda_consist,
        lambda_data=args.lambda_data,
        criterion=args.consist_criterion,
        consist_scale=args.consist_scale,
        gt_refs=args.consist_gt_refs,
        progressive_steps=args.progressive_consist_steps,
        use_backward=args.consist_use_backward,
    )
    pmodel.cuda()
    if args.optimizer == "adam":
        optimizer = torch.optim.Adam(model_params,
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)
    elif args.optimizer == "sgd":
        optimizer = torch.optim.SGD(model_params,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    if args.resume is not None:
        reloadmodel.reload_optimizer(args.resume, optimizer)
    if args.lr_decay_gamma:
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    args.lr_decay_step,
                                                    gamma=args.lr_decay_gamma)
    fig = plt.figure(figsize=(10, 10))
    save_results = {}
    save_results["opt"] = dict(vars(args))
    save_results["train_losses"] = []
    save_results["val_losses"] = []
    for epoch_idx in tqdm(range(epoch, args.epochs), desc="epoch"):
        if not args.freeze_batchnorm:
            model.train()
        else:
            model.eval()
        save_dict, avg_meters, _ = epochpassconsist.epoch_pass(
            loader,
            model,
            train=True,
            optimizer=optimizer,
            scheduler=scheduler,
            epoch=epoch_idx,
            img_folder=img_folder,
            fig=fig,
            display_freq=args.display_freq,
            epoch_display_freq=args.epoch_display_freq,
            lr_decay_gamma=args.lr_decay_gamma,
            premodel=pmodel,
            loader_nb=train_loader_nb,
        )
        monitor.log_train(
            epoch_idx + 1,
            {key: val.avg
             for key, val in avg_meters.average_meters.items()})
        monitor.metrics.save_metrics(epoch_idx + 1, save_dict)
        monitor.metrics.plot_metrics()
        save_results["train_losses"].append(save_dict)
        with open(result_path, "wb") as p_f:
            pickle.dump(save_results, p_f)
        modelio.save_checkpoint(
            {
                "epoch": epoch_idx + 1,
                "network": "correspnet",
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler,
            },
            is_best=True,
            checkpoint=exp_id,
            snapshot=args.snapshot,
        )
        if args.eval_freq != -1 and epoch_idx % args.eval_freq == 0:
            val_save_dict, val_avg_meters, _ = epochpass.epoch_pass(
                val_loader,
                model,
                train=False,
                optimizer=None,
                scheduler=None,
                epoch=epoch_idx,
                img_folder=img_folder,
                fig=fig,
                display_freq=args.display_freq,
                epoch_display_freq=args.epoch_display_freq,
                lr_decay_gamma=args.lr_decay_gamma,
            )

            save_results["val_losses"].append(val_save_dict)
            monitor.log_val(epoch_idx + 1, {
                key: val.avg
                for key, val in val_avg_meters.average_meters.items()
            })
            monitor.metrics.save_metrics(epoch_idx + 1, val_save_dict)
            monitor.metrics.plot_metrics()
def main(args):
    setseeds.set_all_seeds(args.manual_seed)
    # Initialize hosting
    exp_id = f"checkpoints/{args.dataset}/" f"{args.com}"

    # Initialize local checkpoint folder
    print(f"Saving info about experiment at {exp_id}")
    save_args(args, exp_id, "opt")
    render_folder = os.path.join(exp_id, "images")
    os.makedirs(render_folder, exist_ok=True)
    # Load models
    models = []
    for resume in args.resumes:
        opts = reloadmodel.load_opts(resume)
        model, epoch = reloadmodel.reload_model(resume, opts)
        models.append(model)
        freeze.freeze_batchnorm_stats(model)  # Freeze batchnorm

    dataset, input_res = get_dataset.get_dataset(
        args.dataset,
        split=args.split,
        meta={},
        mode=args.mode,
        use_cache=args.use_cache,
        no_augm=True,
        center_idx=opts["center_idx"],
        sample_nb=None,
    )
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=int(args.workers),
        drop_last=False,
        collate_fn=collate.meshreg_collate,
    )

    model = MeshRegNet(
        mano_center_idx=opts["center_idx"],
        mano_lambda_joints2d=opts["mano_lambda_joints2d"],
        mano_lambda_joints3d=opts["mano_lambda_joints3d"],
        mano_lambda_recov_joints3d=opts["mano_lambda_recov_joints3d"],
        mano_lambda_recov_verts3d=opts["mano_lambda_recov_verts3d"],
        mano_lambda_verts2d=opts["mano_lambda_verts2d"],
        mano_lambda_verts3d=opts["mano_lambda_verts3d"],
        mano_lambda_shape=opts["mano_lambda_shape"],
        mano_use_shape=opts["mano_lambda_shape"] > 0,
        mano_lambda_pose_reg=opts["mano_lambda_pose_reg"],
        obj_lambda_recov_verts3d=opts["obj_lambda_recov_verts3d"],
        obj_lambda_verts2d=opts["obj_lambda_verts2d"],
        obj_lambda_verts3d=opts["obj_lambda_verts3d"],
        obj_trans_factor=opts["obj_trans_factor"],
        obj_scale_factor=opts["obj_scale_factor"],
        mano_fhb_hand="fhbhands" in args.dataset,
    )

    fig = plt.figure(figsize=(10, 10))
    save_results = {}
    save_results["opt"] = dict(vars(args))
    # Put models on GPU and evaluation mode
    for model in models:
        model.cuda()
        model.eval()
    render_step = 0
    for batch in tqdm(loader):
        all_results = []
        # Compute model outputs
        with torch.no_grad():
            for model in models:
                _, results, _ = model(batch)
                all_results.append(results)

        # Densely render error map for the meshes
        for results in all_results:
            render_results, cmap_obj = fastrender.comp_render(
                batch,
                all_results,
                rotate=True,
                modes=("all", "obj", "hand"),
                max_val=args.max_val)

        for img_idx, img in enumerate(batch[BaseQueries.IMAGE]):
            # Get rendered results for current image
            render_ress = [res[img_idx] for res in render_results["all"]]
            renderot_ress = [
                res[img_idx] for res in render_results["all_rotated"]
            ]
            # Initialize figure
            fig.clf()
            row_nb = len(models) + 1
            col_nb = 3
            axes = fig.subplots(row_nb, col_nb)
            # Display cmap
            cmap = cm.get_cmap("jet")
            norm = matplotlib.colors.Normalize(vmin=0, vmax=1)
            cax = fig.add_axes([0.27, 0.05, 0.5, 0.02])

            cb = matplotlib.colorbar.ColorbarBase(cax,
                                                  cmap=cmap,
                                                  norm=norm,
                                                  ticks=[0, 1],
                                                  orientation="horizontal")
            cb.ax.set_xticklabels(["0", str(args.max_val * 100)])
            cb.set_label("3D mesh error (cm)")

            # Get masks for hand and object in current image
            obj_masks = [
                res.cpu()[img_idx][:, :].sum(2).numpy()
                for res in render_results["obj"]
            ]
            hand_masks = [
                res.cpu()[img_idx][:, :].sum(2).numpy()
                for res in render_results["hand"]
            ]
            # Compute bounding boxes of masks
            crops = [
                vizdemo.get_crop(render_res) for render_res in render_ress
            ]
            rot_crops = [
                vizdemo.get_crop(renderot_res)
                for renderot_res in renderot_ress
            ]
            # Get crop that encompasses the spatial extent of all results
            crop = vizdemo.get_common_crop(crops)
            rot_crop = vizdemo.get_common_crop(rot_crops)
            for model_idx, (render_res, renderot_res) in enumerate(
                    zip(render_ress, renderot_ress)):
                # Draw input image with predicted contours in column 1
                ax = vizdemo.get_axis(axes, row_nb, col_nb, model_idx, 0)
                # Initialize white background and copy input image
                viz_img = 255 * img.new_ones(max(img.shape), max(img.shape), 3)
                viz_img[:img.shape[0], :img.shape[1], :3] = img

                # Clamp so that displayed image values are in [0, 255]
                render_res = render_res.clamp(0, 1)
                renderot_res = renderot_res.clamp(0, 1)

                obj_mask = obj_masks[model_idx] > 0
                hand_mask = hand_masks[model_idx] > 0
                # Draw hand and object contours
                contoured_img = vizdemo.draw_contours(viz_img.numpy(),
                                                      hand_mask,
                                                      color=(0, 210, 255))
                contoured_img = vizdemo.draw_contours(contoured_img,
                                                      obj_mask,
                                                      color=(255, 50, 50))
                ax.imshow(contoured_img[crop[0]:crop[2], crop[1]:crop[3]])
                # Image with rendering overlay
                ax = vizdemo.get_axis(axes, row_nb, col_nb, model_idx, 1)
                ax.set_title(args.model_names[model_idx])
                ax.imshow(viz_img[crop[0]:crop[2], crop[1]:crop[3]])
                ax.imshow(render_res[crop[0]:crop[2], crop[1]:crop[3]], )

                # Render rotated
                ax = vizdemo.get_axis(axes, row_nb, col_nb, model_idx, 2)
                ax.imshow(renderot_res[rot_crop[0]:rot_crop[2],
                                       rot_crop[1]:rot_crop[3]])
            fig.tight_layout()
            save_path = os.path.join(render_folder,
                                     f"render{render_step:06d}.png")
            fig.savefig(save_path)
            print(f"Saved demo visualization at {save_path}")
            render_step += 1