Пример #1
0
def main(args):
    setseeds.set_all_seeds(args.manual_seed)
    # Initialize hosting
    dat_str = "_".join(args.train_datasets)
    now = datetime.now()
    split_str = "_".join(args.train_splits)
    exp_id = (
        f"checkpoints/{dat_str}_{split_str}_mini{args.mini_factor}/{now.year}_{now.month:02d}_{now.day:02d}_{now.hour:02d}_{now.minute:02d}/"
        #f"{args.com}_frac{args.fraction:.1e}"
        #f"lr{args.lr}_mom{args.momentum}_bs{args.batch_size}_"
        #f"_lmbeta{args.mano_lambda_shape:.1e}"
        #f"_lmpr{args.mano_lambda_pose_reg:.1e}"
        #f"_lmrj3d{args.mano_lambda_recov_joints3d:.1e}"
        #f"_lovr3d{args.obj_lambda_recov_verts3d:.1e}"
        #f"seed{args.manual_seed}"
    )
    # if args.no_augm:
    #     exp_id = f"{exp_id}_no_augm"
    # if args.block_rot:
    #     exp_id = f"{exp_id}_block_rot"
    # if args.freeze_batchnorm:
    #     exp_id = f"{exp_id}_fbn"

    # Initialize local checkpoint folder
    print(f"Saving experiment logs, models, and training curves and images at {exp_id}")
    save_args(args, exp_id, "opt")
    img_folder = os.path.join(exp_id, "images")
    os.makedirs(img_folder, exist_ok=True)
    result_folder = os.path.join(exp_id, "results")
    result_path = os.path.join(result_folder, "results.pkl")
    os.makedirs(result_folder, exist_ok=True)
    pyapt_path = os.path.join(result_folder, f"{args.pyapt_id}__{now.strftime('%H_%M_%S')}")
    with open(pyapt_path, "a") as t_f:
        t_f.write(" ")

    loaders = []
    if not args.evaluate:
        for train_split, dat_name in zip(args.train_splits, args.train_datasets):
            train_dataset, _ = get_dataset.get_dataset(
                dat_name,
                split=train_split,
                meta={"version": args.version, "split_mode": args.split_mode},
                use_cache=args.use_cache,
                mini_factor=args.mini_factor,
                no_augm=args.no_augm,
                block_rot=args.block_rot,
                max_rot=args.max_rot,
                center_idx=args.center_idx,
                scale_jittering=args.scale_jittering,
                center_jittering=args.center_jittering,
                fraction=args.fraction,
                mode="strong",
                sample_nb=None,
            )
            train_loader = torch.utils.data.DataLoader(
                train_dataset,
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=int(args.workers / len(args.train_datasets)),
                drop_last=True,
                collate_fn=collate.meshreg_collate,
            )
            loaders.append(train_loader)
        loader = concatloader.ConcatDataloader(loaders)
    if args.evaluate or args.eval_freq != -1:
        val_dataset, _ = get_dataset.get_dataset(
            args.val_dataset,
            split=args.val_split,
            meta={"version": args.version, "split_mode": args.split_mode},
            use_cache=args.use_cache,
            mini_factor=args.mini_factor,
            no_augm=True,
            block_rot=args.block_rot,
            max_rot=args.max_rot,
            center_idx=args.center_idx,
            scale_jittering=args.scale_jittering,
            center_jittering=args.center_jittering,
            sample_nb=None,
        )
        val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=int(args.workers),
            drop_last=False,
            collate_fn=collate.meshreg_collate,
        )

    model = MeshRegNet(
        mano_center_idx=args.center_idx,
        mano_lambda_joints2d=args.mano_lambda_joints2d,
        mano_lambda_joints3d=args.mano_lambda_joints3d,
        mano_lambda_recov_joints3d=args.mano_lambda_recov_joints3d,
        mano_lambda_recov_verts3d=args.mano_lambda_recov_verts3d,
        mano_lambda_verts2d=args.mano_lambda_verts2d,
        mano_lambda_verts3d=args.mano_lambda_verts3d,
        mano_lambda_shape=args.mano_lambda_shape,
        mano_use_shape=args.mano_lambda_shape > 0,
        mano_lambda_pose_reg=args.mano_lambda_pose_reg,
        obj_lambda_recov_verts3d=args.obj_lambda_recov_verts3d,
        obj_lambda_verts2d=args.obj_lambda_verts2d,
        obj_lambda_verts3d=args.obj_lambda_verts3d,
        obj_trans_factor=args.obj_trans_factor,
        obj_scale_factor=args.obj_scale_factor,
        mano_fhb_hand="fhbhands" in args.train_datasets,
    )
    model.cuda()
    # Initalize model
    if args.resume is not None:
        opts = reloadmodel.load_opts(args.resume)
        model, epoch = reloadmodel.reload_model(args.resume, opts)
        model.cuda()
        if args.evaluate:
            args.epochs = epoch + 1
    else:
        epoch = 0

    if args.freeze_batchnorm:
        freeze.freeze_batchnorm_stats(model)  # Freeze batchnorm

    model_params = filter(lambda p: p.requires_grad, model.parameters())
    if args.optimizer == "adam":
        optimizer = torch.optim.Adam(model_params, lr=args.lr, weight_decay=args.weight_decay)
    elif args.optimizer == "sgd":
        optimizer = torch.optim.SGD(
            model_params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
        )
    if args.resume is not None:
        reloadmodel.reload_optimizer(args.resume, optimizer)
    if args.lr_decay_gamma:
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_decay_step, gamma=args.lr_decay_gamma)
    fig = plt.figure(figsize=(10, 10))
    save_results = {}
    save_results["opt"] = dict(vars(args))
    save_results["train_losses"] = []
    save_results["val_losses"] = []

    monitor = MetricMonitor()
    print("Will monitor metrics.")

    for epoch_idx in tqdm(range(epoch, args.epochs), desc="epoch"):
        if not args.freeze_batchnorm:
            model.train()
        else:
            model.eval()

        if not args.evaluate:
            save_dict = epochpass.epoch_pass(
                loader,
                model,
                train=True,
                optimizer=optimizer,
                scheduler=scheduler,
                epoch=epoch_idx,
                img_folder=img_folder,
                fig=fig,
                display_freq=args.display_freq,
                epoch_display_freq=args.epoch_display_freq,
                lr_decay_gamma=args.lr_decay_gamma,
                monitor=monitor,
            )
            save_results["train_losses"].append(save_dict)
            with open(result_path, "wb") as p_f:
                pickle.dump(save_results, p_f)
            modelio.save_checkpoint(
                {
                    "epoch": epoch_idx + 1,
                    "network": "correspnet",
                    "state_dict": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "scheduler": scheduler,
                },
                is_best=True,
                checkpoint=exp_id,
                snapshot=args.snapshot,
            )

        if args.evaluate or (args.eval_freq != -1 and epoch_idx % args.eval_freq == 0):
            val_save_dict = epochpass.epoch_pass(
                val_loader,
                model,
                train=False,
                optimizer=None,
                scheduler=None,
                epoch=epoch_idx,
                img_folder=img_folder,
                fig=fig,
                display_freq=args.display_freq,
                epoch_display_freq=args.epoch_display_freq,
                lr_decay_gamma=args.lr_decay_gamma,
                monitor=monitor,
            )

            save_results["val_losses"].append(val_save_dict)
            monitor.plot(os.path.join(exp_id, "training.html"), plotly=True)
            monitor.plot_histogram(os.path.join(exp_id, "evaluation.html"), plotly=True)
            if args.matplotlib:
                monitor.plot(result_folder, matplotlib=True)
                monitor.plot_histogram(result_folder, matplotlib=True)
            monitor.save_metrics(os.path.join(exp_id, "metrics.pkl"))
            if args.evaluate:
                print(val_save_dict)
                break
Пример #2
0
def main(args):
    torch.cuda.manual_seed_all(args.manual_seed)
    torch.manual_seed(args.manual_seed)
    np.random.seed(args.manual_seed)
    random.seed(args.manual_seed)
    # Initialize hosting
    dat_str = args.val_dataset
    now = datetime.now()
    exp_id = (
        f"checkpoints/{dat_str}_mini{args.mini_factor}/"
        f"{now.year}_{now.month:02d}_{now.day:02d}/"
        f"{args.com}_frac{args.fraction}_mode{args.mode}_bs{args.batch_size}_"
        f"objs{args.obj_scale_factor}_objt{args.obj_trans_factor}"
    )

    # Initialize local checkpoint folder
    save_args(args, exp_id, "opt")
    result_folder = os.path.join(exp_id, "results")
    os.makedirs(result_folder, exist_ok=True)
    pyapt_path = os.path.join(result_folder, f"{args.pyapt_id}__{now.strftime('%H_%M_%S')}")
    with open(pyapt_path, "a") as t_f:
        t_f.write(" ")

    val_dataset, input_size = get_dataset.get_dataset(
        args.val_dataset,
        split=args.val_split,
        meta={"version": args.version, "split_mode": "paper"},
        use_cache=args.use_cache,
        mini_factor=args.mini_factor,
        mode=args.mode,
        fraction=args.fraction,
        no_augm=True,
        center_idx=args.center_idx,
        scale_jittering=0,
        center_jittering=0,
        sample_nb=None,
        has_dist2strong=True,
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=int(args.workers),
        drop_last=False,
        collate_fn=collate.meshreg_collate,
    )

    opts = reloadmodel.load_opts(args.resume)
    model, epoch = reloadmodel.reload_model(args.resume, opts)
    if args.render_results:
        render_folder = os.path.join(exp_id, f"renders", f"epoch{epoch:04d}")
        os.makedirs(render_folder, exist_ok=True)
        print(f"Rendering to {render_folder}")
    else:
        render_folder = None
    img_folder = os.path.join(exp_id, "images", f"epoch{epoch:04d}")
    os.makedirs(img_folder, exist_ok=True)
    freeze.freeze_batchnorm_stats(model)  # Freeze batchnorm

    fig = plt.figure(figsize=(12, 4))
    save_results = {}
    save_results["opt"] = dict(vars(args))
    save_results["val_losses"] = []
    os.makedirs(args.json_folder, exist_ok=True)
    json_path = os.path.join(args.json_folder, f"{args.val_split}.json")
    evalpass.epoch_pass(
        val_loader,
        model,
        optimizer=None,
        scheduler=None,
        epoch=epoch,
        img_folder=img_folder,
        fig=fig,
        display_freq=args.display_freq,
        dump_results_path=json_path,
        render_folder=render_folder,
        render_freq=args.render_freq,
        true_root=args.true_root,
    )
    print(f"Saved results for split {args.val_split} to {json_path}")
Пример #3
0
def main(args):
    setseeds.set_all_seeds(args.manual_seed)
    # Initialize hosting
    exp_id = f"checkpoints/{args.dataset}/" f"{args.com}"

    # Initialize local checkpoint folder
    print(f"Saving info about experiment at {exp_id}")
    save_args(args, exp_id, "opt")
    render_folder = os.path.join(exp_id, "images")
    os.makedirs(render_folder, exist_ok=True)
    # Load models
    models = []
    for resume in args.resumes:
        print('Resuming', resume)
        opts = reloadmodel.load_opts(resume)
        model, epoch = reloadmodel.reload_model(resume, opts)
        models.append(model)
        freeze.freeze_batchnorm_stats(model)  # Freeze batchnorm

    dataset, input_res = get_dataset.get_dataset(args.dataset, split=args.split, meta={"version": args.version, "split_mode": args.split_mode},
                                                 mode=args.mode, use_cache=args.use_cache,
                                                 no_augm=True, center_idx=opts["center_idx"], sample_nb=None)

    loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=int(args.workers), drop_last=False,
                                         collate_fn=collate.meshreg_collate)


    all_samples = []

    # Put models on GPU and evaluation mode
    for model in models:
        model.cuda()
        model.eval()

    i = 0
    for batch in tqdm(loader):  # Loop over batches
        all_results = []
        # Compute model outputs
        with torch.no_grad():
            for model in models:
                _, results, _ = model(batch)
                all_results.append(results)

        # Densely render error map for the meshes
        # for results in all_results:
        #     render_results, cmap_obj = fastrender.comp_render(
        #         batch, all_results, rotate=True, modes=("all", "obj", "hand"), max_val=args.max_val
        #     )
        
        # if i > 100:
        #     break
        # i += 1

        for img_idx, img in enumerate(batch[BaseQueries.IMAGE]):    # Each batch has 4 images
            network_out = all_results[0]
            sample_idx = batch['idx'][img_idx]
            handpose, handtrans, handshape = dataset.pose_dataset.get_hand_info(sample_idx)

            sample_dict = dict()
            # sample_dict['image'] = img
            sample_dict['obj_faces'] = batch[BaseQueries.OBJFACES][img_idx, :, :]
            sample_dict['obj_verts_gt'] = batch[BaseQueries.OBJVERTS3D][img_idx, :, :]
            sample_dict['hand_faces'], _ = manoutils.get_closed_faces()
            sample_dict['hand_verts_gt'] = batch[BaseQueries.HANDVERTS3D][img_idx, :, :]

            sample_dict['obj_verts_pred'] = network_out['recov_objverts3d'][img_idx, :, :]
            sample_dict['hand_verts_pred'] = network_out['recov_handverts3d'][img_idx, :, :]
            # sample_dict['hand_adapt_trans'] = network_out['mano_adapt_trans'][img_idx, :]
            sample_dict['hand_pose_pred'] = network_out['pose'][img_idx, :]
            sample_dict['hand_beta_pred'] = network_out['shape'][img_idx, :]
            sample_dict['side'] = batch[BaseQueries.SIDE][img_idx]

            sample_dict['hand_pose_gt'] = handpose
            sample_dict['hand_beta_gt'] = handshape
            sample_dict['hand_trans_gt'] = handtrans
            sample_dict['hand_extr_gt'] = cam_extr = torch.Tensor([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])

            for k in sample_dict.keys():
                sample_dict[k] = to_cpu_npy(sample_dict[k])

            all_samples.append(sample_dict)

            continue

            # obj_verts_gt = samplevis.get_check_none(sample, BaseQueries.OBJVERTS3D, cpu=False)
            # hand_verts_gt = samplevis.get_check_none(sample, BaseQueries.HANDVERTS3D, cpu=False)
            # hand_verts = samplevis.get_check_none(results, "recov_handverts3d", cpu=False)
            # obj_verts = samplevis.get_check_none(results, "recov_objverts3d", cpu=False)
            # obj_faces = samplevis.get_check_none(sample, BaseQueries.OBJFACES, cpu=False).long()
            # hand_faces, _ = manoutils.get_closed_faces()


            # plt.imshow(img)
            # plt.show()

            for k in batch.keys():
                elem = batch[k]
                # if isinstance(elem, list):
                #     s = len(s)
                if isinstance(elem, torch.Tensor):
                    s = elem.shape
                else:
                    s = elem
                print('{}: Shape {}'.format(k, s))

            for k in all_results[0].keys():
                elem = all_results[0][k]
                if isinstance(elem, list):
                    s = len(s)
                elif isinstance(elem, torch.Tensor):
                    s = elem.shape
                else:
                    s = elem
                print('Network out {}: Shape {}'.format(k, s))

            fig = plt.figure()
            ax = fig.add_subplot(projection='3d')
            ax.plot_trisurf(sample_dict['obj_verts_gt'][:,0], sample_dict['obj_verts_gt'][:,1], sample_dict['obj_verts_gt'][:,2],
                            triangles=sample_dict['obj_faces'])

            ax.plot_trisurf(sample_dict['hand_verts_gt'][:,0], sample_dict['hand_verts_gt'][:,1], sample_dict['hand_verts_gt'][:,2],
                            triangles=sample_dict['hand_faces'])

            ax.plot_trisurf(sample_dict['obj_verts_pred'][:,0], sample_dict['obj_verts_pred'][:,1], sample_dict['obj_verts_pred'][:,2],
                            triangles=sample_dict['obj_faces'])

            ax.plot_trisurf(sample_dict['hand_verts_pred'][:,0], sample_dict['hand_verts_pred'][:,1], sample_dict['hand_verts_pred'][:,2],
                            triangles=sample_dict['hand_faces'])

            plt.show()

    print('Saving final dict', len(all_samples))
    with open('all_samples.pkl', 'wb') as handle:
        pickle.dump(all_samples, handle)
    print('Done saving')
Пример #4
0
def main(args):
    torch.cuda.manual_seed_all(args.manual_seed)
    torch.manual_seed(args.manual_seed)
    np.random.seed(args.manual_seed)
    random.seed(args.manual_seed)
    # Initialize hosting
    dat_str = args.val_dataset
    now = datetime.now()
    exp_id = (
        f"checkpoints/{dat_str}_mini{args.mini_factor}/"
        f"{now.year}_{now.month:02d}_{now.day:02d}/"
        f"{args.com}_frac{args.fraction}_mode{args.mode}_bs{args.batch_size}_"
        f"objs{args.obj_scale_factor}_objt{args.obj_trans_factor}")

    # Initialize local checkpoint folder
    save_args(args, exp_id, "opt")
    result_folder = os.path.join(exp_id, "results")
    os.makedirs(result_folder, exist_ok=True)
    pyapt_path = os.path.join(result_folder,
                              f"{args.pyapt_id}__{now.strftime('%H_%M_%S')}")
    with open(pyapt_path, "a") as t_f:
        t_f.write(" ")

    val_dataset, input_size = get_dataset.get_dataset(
        args.val_dataset,
        split=args.val_split,
        meta={
            "version": args.version,
            "split_mode": args.split_mode
        },
        use_cache=args.use_cache,
        mini_factor=args.mini_factor,
        mode=args.mode,
        fraction=args.fraction,
        no_augm=True,
        center_idx=args.center_idx,
        scale_jittering=0,
        center_jittering=0,
        sample_nb=None,
        has_dist2strong=True,
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=int(args.workers),
        drop_last=False,
        collate_fn=collate.meshreg_collate,
    )

    opts = reloadmodel.load_opts(args.resume)
    model, epoch = reloadmodel.reload_model(args.resume, opts)
    freeze.freeze_batchnorm_stats(model)  # Freeze batchnorm

    all_samples = []

    model.eval()
    model.cuda()
    for batch_idx, batch in enumerate(tqdm(val_loader)):
        all_results = []
        with torch.no_grad():
            loss, results, losses = model(batch)
            all_results.append(results)

        for img_idx, img in enumerate(
                batch[BaseQueries.IMAGE]):  # Each batch has 4 images
            network_out = all_results[0]
            sample_idx = batch['idx'][img_idx]
            handpose, handtrans, handshape = val_dataset.pose_dataset.get_hand_info(
                sample_idx)

            sample_dict = dict()
            # sample_dict['image'] = img
            sample_dict['obj_faces'] = batch[BaseQueries.OBJFACES][
                img_idx, :, :]
            sample_dict['obj_verts_gt'] = batch[BaseQueries.OBJVERTS3D][
                img_idx, :, :]
            sample_dict['hand_faces'], _ = manoutils.get_closed_faces()
            sample_dict['hand_verts_gt'] = batch[BaseQueries.HANDVERTS3D][
                img_idx, :, :]

            sample_dict['obj_verts_pred'] = network_out['recov_objverts3d'][
                img_idx, :, :]
            sample_dict['hand_verts_pred'] = network_out['recov_handverts3d'][
                img_idx, :, :]
            # sample_dict['hand_adapt_trans'] = network_out['mano_adapt_trans'][img_idx, :]
            sample_dict['hand_pose_pred'] = network_out['pose'][img_idx, :]
            sample_dict['hand_beta_pred'] = network_out['shape'][img_idx, :]
            sample_dict['side'] = batch[BaseQueries.SIDE][img_idx]

            sample_dict['hand_pose_gt'] = handpose
            sample_dict['hand_beta_gt'] = handshape
            sample_dict['hand_trans_gt'] = handtrans
            sample_dict['hand_extr_gt'] = torch.Tensor([[1, 0, 0, 0],
                                                        [0, -1, 0, 0],
                                                        [0, 0, -1, 0],
                                                        [0, 0, 0, 1]])

            for k in sample_dict.keys():
                sample_dict[k] = to_cpu_npy(sample_dict[k])

            all_samples.append(sample_dict)

    print('Saving final dict', len(all_samples))
    with open('all_samples.pkl', 'wb') as handle:
        pickle.dump(all_samples, handle)
    print('Done saving')
def main(args):
    setseeds.set_all_seeds(args.manual_seed)
    train_dat_str = "_".join(args.train_datasets)
    dat_str = f"{train_dat_str}_warp{args.consist_dataset}"
    now = datetime.now()
    exp_id = (
        f"checkpoints/{dat_str}_mini{args.mini_factor}/{now.year}_{now.month:02d}_{now.day:02d}/"
        f"{args.com}_frac{args.fraction:.1e}"
        f"_ld{args.lambda_data:.1e}_lc{args.lambda_consist:.1e}"
        f"crit{args.consist_criterion}sca{args.consist_scale}"
        f"opt{args.optimizer}_lr{args.lr}_crit{args.criterion2d}"
        f"_mom{args.momentum}_bs{args.batch_size}"
        f"_lmj3d{args.mano_lambda_joints3d:.1e}"
        f"_lmbeta{args.mano_lambda_shape:.1e}"
        f"_lmpr{args.mano_lambda_pose_reg:.1e}"
        f"_lmrj3d{args.mano_lambda_recov_joints3d:.1e}"
        f"_lmrw3d{args.mano_lambda_recov_verts3d:.1e}"
        f"_lov2d{args.obj_lambda_verts2d:.1e}_lov3d{args.obj_lambda_verts3d:.1e}"
        f"_lovr3d{args.obj_lambda_recov_verts3d:.1e}"
        f"cj{args.center_jittering}seed{args.manual_seed}"
        f"sample_nb{args.sample_nb}_spac{args.spacing}"
        f"csteps{args.progressive_consist_steps}")
    if args.no_augm:
        exp_id = f"{exp_id}_no_augm"
    if args.no_consist_augm:
        exp_id = f"{exp_id}_noconsaugm"
    if args.block_rot:
        exp_id = f"{exp_id}_block_rot"
    if args.consist_gt_refs:
        exp_id = f"{exp_id}_gt_refs"

    # Initialize local checkpoint folder
    save_args(args, exp_id, "opt")
    monitor = Monitor(exp_id, hosting_folder=exp_id)
    img_folder = os.path.join(exp_id, "images")
    os.makedirs(img_folder, exist_ok=True)
    result_folder = os.path.join(exp_id, "results")
    result_path = os.path.join(result_folder, "results.pkl")
    os.makedirs(result_folder, exist_ok=True)
    pyapt_path = os.path.join(result_folder,
                              f"{args.pyapt_id}__{now.strftime('%H_%M_%S')}")
    with open(pyapt_path, "a") as t_f:
        t_f.write(" ")

    train_loader_nb = 0
    if len(args.train_datasets):
        train_loader_nb = train_loader_nb + len(args.train_datasets)
    if args.consist_dataset is not None:
        train_loader_nb = train_loader_nb + 1

    loaders = []
    if len(args.train_datasets) is not None:
        for train_split, dat_name in zip(args.train_splits,
                                         args.train_datasets):
            train_dataset, input_res = get_dataset.get_dataset(
                dat_name,
                split=train_split,
                meta={
                    "version": args.version,
                    "split_mode": "objects"
                },
                block_rot=args.block_rot,
                max_rot=args.max_rot,
                center_idx=args.center_idx,
                center_jittering=args.center_jittering,
                fraction=args.fraction,
                mini_factor=args.mini_factor,
                mode="strong",
                no_augm=args.no_augm,
                scale_jittering=args.scale_jittering,
                sample_nb=1,
                use_cache=args.use_cache,
            )
            train_loader = torch.utils.data.DataLoader(
                train_dataset,
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=int(args.workers / train_loader_nb),
                drop_last=True,
                collate_fn=collate_fn,
            )
            loaders.append(train_loader)

    if args.consist_dataset is not None:
        consist_dataset, consist_input_res = get_dataset.get_dataset(
            args.consist_dataset,
            split=args.consist_split,
            meta={
                "version": args.version,
                "split_mode": "objects"
            },
            use_cache=args.use_cache,
            mini_factor=args.mini_factor,
            block_rot=args.block_rot,
            center_idx=args.center_idx,
            center_jittering=args.center_jittering,
            fraction=args.fraction,
            max_rot=args.max_rot,
            mode="full",
            no_augm=args.no_consist_augm,
            sample_nb=args.sample_nb,
            scale_jittering=args.scale_jittering,
            spacing=args.
            spacing,  # Otherwise black padding gets included on warps
        )
        print(
            f"Got consist dataset {args.consist_dataset} of size {len(consist_dataset)}"
        )
        if input_res != consist_input_res:
            raise ValueError(
                f"train and consist dataset should have same input sizes"
                f"but got {input_res} and {consist_input_res}")
        consist_loader = torch.utils.data.DataLoader(
            consist_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=int(args.workers / train_loader_nb),
            drop_last=True,
            collate_fn=collate_fn,
        )
        loaders.append(consist_loader)
    loader = concatloader.ConcatLoader(loaders)
    if args.eval_freq != -1:
        val_dataset, input_res = get_dataset.get_dataset(
            args.val_dataset,
            split=args.val_split,
            meta={
                "version": args.version,
                "split_mode": "objects"
            },
            use_cache=args.use_cache,
            mini_factor=args.mini_factor,
            no_augm=True,
            block_rot=args.block_rot,
            max_rot=args.max_rot,
            center_idx=args.center_idx,
            scale_jittering=args.scale_jittering,
            center_jittering=args.center_jittering,
            sample_nb=None,
        )
        val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=int(args.workers),
            drop_last=False,
            collate_fn=collate.meshreg_collate,
        )

    model = MeshRegNet(
        mano_lambda_joints3d=args.mano_lambda_joints3d,
        mano_lambda_joints2d=args.mano_lambda_joints2d,
        mano_lambda_pose_reg=args.mano_lambda_pose_reg,
        mano_lambda_recov_joints3d=args.mano_lambda_recov_joints3d,
        mano_lambda_recov_verts3d=args.mano_lambda_recov_verts3d,
        mano_lambda_verts2d=args.mano_lambda_verts2d,
        mano_lambda_verts3d=args.mano_lambda_verts3d,
        mano_lambda_shape=args.mano_lambda_shape,
        mano_use_shape=args.mano_lambda_shape > 0,
        obj_lambda_verts2d=args.obj_lambda_verts2d,
        obj_lambda_verts3d=args.obj_lambda_verts3d,
        obj_lambda_recov_verts3d=args.obj_lambda_recov_verts3d,
        obj_trans_factor=args.obj_trans_factor,
        obj_scale_factor=args.obj_scale_factor,
        mano_fhb_hand="fhbhands" in args.train_datasets,
    )
    model.cuda()
    # Initalize model
    if args.resume is not None:
        opts = reloadmodel.load_opts(args.resume)
        model, epoch = reloadmodel.reload_model(args.resume, opts)
    else:
        epoch = 0
    if args.freeze_batchnorm:
        freeze.freeze_batchnorm_stats(model)  # Freeze batchnorm

    model_params = filter(lambda p: p.requires_grad, model.parameters())
    pmodel = warpreg.WarpRegNet(
        input_res,
        model,
        lambda_consist=args.lambda_consist,
        lambda_data=args.lambda_data,
        criterion=args.consist_criterion,
        consist_scale=args.consist_scale,
        gt_refs=args.consist_gt_refs,
        progressive_steps=args.progressive_consist_steps,
        use_backward=args.consist_use_backward,
    )
    pmodel.cuda()
    if args.optimizer == "adam":
        optimizer = torch.optim.Adam(model_params,
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)
    elif args.optimizer == "sgd":
        optimizer = torch.optim.SGD(model_params,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    if args.resume is not None:
        reloadmodel.reload_optimizer(args.resume, optimizer)
    if args.lr_decay_gamma:
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    args.lr_decay_step,
                                                    gamma=args.lr_decay_gamma)
    fig = plt.figure(figsize=(10, 10))
    save_results = {}
    save_results["opt"] = dict(vars(args))
    save_results["train_losses"] = []
    save_results["val_losses"] = []
    for epoch_idx in tqdm(range(epoch, args.epochs), desc="epoch"):
        if not args.freeze_batchnorm:
            model.train()
        else:
            model.eval()
        save_dict, avg_meters, _ = epochpassconsist.epoch_pass(
            loader,
            model,
            train=True,
            optimizer=optimizer,
            scheduler=scheduler,
            epoch=epoch_idx,
            img_folder=img_folder,
            fig=fig,
            display_freq=args.display_freq,
            epoch_display_freq=args.epoch_display_freq,
            lr_decay_gamma=args.lr_decay_gamma,
            premodel=pmodel,
            loader_nb=train_loader_nb,
        )
        monitor.log_train(
            epoch_idx + 1,
            {key: val.avg
             for key, val in avg_meters.average_meters.items()})
        monitor.metrics.save_metrics(epoch_idx + 1, save_dict)
        monitor.metrics.plot_metrics()
        save_results["train_losses"].append(save_dict)
        with open(result_path, "wb") as p_f:
            pickle.dump(save_results, p_f)
        modelio.save_checkpoint(
            {
                "epoch": epoch_idx + 1,
                "network": "correspnet",
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler,
            },
            is_best=True,
            checkpoint=exp_id,
            snapshot=args.snapshot,
        )
        if args.eval_freq != -1 and epoch_idx % args.eval_freq == 0:
            val_save_dict, val_avg_meters, _ = epochpass.epoch_pass(
                val_loader,
                model,
                train=False,
                optimizer=None,
                scheduler=None,
                epoch=epoch_idx,
                img_folder=img_folder,
                fig=fig,
                display_freq=args.display_freq,
                epoch_display_freq=args.epoch_display_freq,
                lr_decay_gamma=args.lr_decay_gamma,
            )

            save_results["val_losses"].append(val_save_dict)
            monitor.log_val(epoch_idx + 1, {
                key: val.avg
                for key, val in val_avg_meters.average_meters.items()
            })
            monitor.metrics.save_metrics(epoch_idx + 1, val_save_dict)
            monitor.metrics.plot_metrics()
Пример #6
0
def main(args):
    setseeds.set_all_seeds(args.manual_seed)
    # Initialize hosting
    exp_id = f"checkpoints/{args.dataset}/" f"{args.com}"

    # Initialize local checkpoint folder
    print(f"Saving info about experiment at {exp_id}")
    save_args(args, exp_id, "opt")
    render_folder = os.path.join(exp_id, "images")
    os.makedirs(render_folder, exist_ok=True)
    # Load models
    models = []
    for resume in args.resumes:
        opts = reloadmodel.load_opts(resume)
        model, epoch = reloadmodel.reload_model(resume, opts)
        models.append(model)
        freeze.freeze_batchnorm_stats(model)  # Freeze batchnorm

    dataset, input_res = get_dataset.get_dataset(
        args.dataset,
        split=args.split,
        meta={},
        mode=args.mode,
        use_cache=args.use_cache,
        no_augm=True,
        center_idx=opts["center_idx"],
        sample_nb=None,
    )
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=int(args.workers),
        drop_last=False,
        collate_fn=collate.meshreg_collate,
    )

    model = MeshRegNet(
        mano_center_idx=opts["center_idx"],
        mano_lambda_joints2d=opts["mano_lambda_joints2d"],
        mano_lambda_joints3d=opts["mano_lambda_joints3d"],
        mano_lambda_recov_joints3d=opts["mano_lambda_recov_joints3d"],
        mano_lambda_recov_verts3d=opts["mano_lambda_recov_verts3d"],
        mano_lambda_verts2d=opts["mano_lambda_verts2d"],
        mano_lambda_verts3d=opts["mano_lambda_verts3d"],
        mano_lambda_shape=opts["mano_lambda_shape"],
        mano_use_shape=opts["mano_lambda_shape"] > 0,
        mano_lambda_pose_reg=opts["mano_lambda_pose_reg"],
        obj_lambda_recov_verts3d=opts["obj_lambda_recov_verts3d"],
        obj_lambda_verts2d=opts["obj_lambda_verts2d"],
        obj_lambda_verts3d=opts["obj_lambda_verts3d"],
        obj_trans_factor=opts["obj_trans_factor"],
        obj_scale_factor=opts["obj_scale_factor"],
        mano_fhb_hand="fhbhands" in args.dataset,
    )

    fig = plt.figure(figsize=(10, 10))
    save_results = {}
    save_results["opt"] = dict(vars(args))
    # Put models on GPU and evaluation mode
    for model in models:
        model.cuda()
        model.eval()
    render_step = 0
    for batch in tqdm(loader):
        all_results = []
        # Compute model outputs
        with torch.no_grad():
            for model in models:
                _, results, _ = model(batch)
                all_results.append(results)

        # Densely render error map for the meshes
        for results in all_results:
            render_results, cmap_obj = fastrender.comp_render(
                batch,
                all_results,
                rotate=True,
                modes=("all", "obj", "hand"),
                max_val=args.max_val)

        for img_idx, img in enumerate(batch[BaseQueries.IMAGE]):
            # Get rendered results for current image
            render_ress = [res[img_idx] for res in render_results["all"]]
            renderot_ress = [
                res[img_idx] for res in render_results["all_rotated"]
            ]
            # Initialize figure
            fig.clf()
            row_nb = len(models) + 1
            col_nb = 3
            axes = fig.subplots(row_nb, col_nb)
            # Display cmap
            cmap = cm.get_cmap("jet")
            norm = matplotlib.colors.Normalize(vmin=0, vmax=1)
            cax = fig.add_axes([0.27, 0.05, 0.5, 0.02])

            cb = matplotlib.colorbar.ColorbarBase(cax,
                                                  cmap=cmap,
                                                  norm=norm,
                                                  ticks=[0, 1],
                                                  orientation="horizontal")
            cb.ax.set_xticklabels(["0", str(args.max_val * 100)])
            cb.set_label("3D mesh error (cm)")

            # Get masks for hand and object in current image
            obj_masks = [
                res.cpu()[img_idx][:, :].sum(2).numpy()
                for res in render_results["obj"]
            ]
            hand_masks = [
                res.cpu()[img_idx][:, :].sum(2).numpy()
                for res in render_results["hand"]
            ]
            # Compute bounding boxes of masks
            crops = [
                vizdemo.get_crop(render_res) for render_res in render_ress
            ]
            rot_crops = [
                vizdemo.get_crop(renderot_res)
                for renderot_res in renderot_ress
            ]
            # Get crop that encompasses the spatial extent of all results
            crop = vizdemo.get_common_crop(crops)
            rot_crop = vizdemo.get_common_crop(rot_crops)
            for model_idx, (render_res, renderot_res) in enumerate(
                    zip(render_ress, renderot_ress)):
                # Draw input image with predicted contours in column 1
                ax = vizdemo.get_axis(axes, row_nb, col_nb, model_idx, 0)
                # Initialize white background and copy input image
                viz_img = 255 * img.new_ones(max(img.shape), max(img.shape), 3)
                viz_img[:img.shape[0], :img.shape[1], :3] = img

                # Clamp so that displayed image values are in [0, 255]
                render_res = render_res.clamp(0, 1)
                renderot_res = renderot_res.clamp(0, 1)

                obj_mask = obj_masks[model_idx] > 0
                hand_mask = hand_masks[model_idx] > 0
                # Draw hand and object contours
                contoured_img = vizdemo.draw_contours(viz_img.numpy(),
                                                      hand_mask,
                                                      color=(0, 210, 255))
                contoured_img = vizdemo.draw_contours(contoured_img,
                                                      obj_mask,
                                                      color=(255, 50, 50))
                ax.imshow(contoured_img[crop[0]:crop[2], crop[1]:crop[3]])
                # Image with rendering overlay
                ax = vizdemo.get_axis(axes, row_nb, col_nb, model_idx, 1)
                ax.set_title(args.model_names[model_idx])
                ax.imshow(viz_img[crop[0]:crop[2], crop[1]:crop[3]])
                ax.imshow(render_res[crop[0]:crop[2], crop[1]:crop[3]], )

                # Render rotated
                ax = vizdemo.get_axis(axes, row_nb, col_nb, model_idx, 2)
                ax.imshow(renderot_res[rot_crop[0]:rot_crop[2],
                                       rot_crop[1]:rot_crop[3]])
            fig.tight_layout()
            save_path = os.path.join(render_folder,
                                     f"render{render_step:06d}.png")
            fig.savefig(save_path)
            print(f"Saved demo visualization at {save_path}")
            render_step += 1