Exemplo n.º 1
0
def main(args):
    setseeds.set_all_seeds(args.manual_seed)
    # Initialize hosting
    dat_str = "_".join(args.train_datasets)
    now = datetime.now()
    split_str = "_".join(args.train_splits)
    exp_id = (
        f"checkpoints/{dat_str}_{split_str}_mini{args.mini_factor}/{now.year}_{now.month:02d}_{now.day:02d}_{now.hour:02d}_{now.minute:02d}/"
        #f"{args.com}_frac{args.fraction:.1e}"
        #f"lr{args.lr}_mom{args.momentum}_bs{args.batch_size}_"
        #f"_lmbeta{args.mano_lambda_shape:.1e}"
        #f"_lmpr{args.mano_lambda_pose_reg:.1e}"
        #f"_lmrj3d{args.mano_lambda_recov_joints3d:.1e}"
        #f"_lovr3d{args.obj_lambda_recov_verts3d:.1e}"
        #f"seed{args.manual_seed}"
    )
    # if args.no_augm:
    #     exp_id = f"{exp_id}_no_augm"
    # if args.block_rot:
    #     exp_id = f"{exp_id}_block_rot"
    # if args.freeze_batchnorm:
    #     exp_id = f"{exp_id}_fbn"

    # Initialize local checkpoint folder
    print(f"Saving experiment logs, models, and training curves and images at {exp_id}")
    save_args(args, exp_id, "opt")
    img_folder = os.path.join(exp_id, "images")
    os.makedirs(img_folder, exist_ok=True)
    result_folder = os.path.join(exp_id, "results")
    result_path = os.path.join(result_folder, "results.pkl")
    os.makedirs(result_folder, exist_ok=True)
    pyapt_path = os.path.join(result_folder, f"{args.pyapt_id}__{now.strftime('%H_%M_%S')}")
    with open(pyapt_path, "a") as t_f:
        t_f.write(" ")

    loaders = []
    if not args.evaluate:
        for train_split, dat_name in zip(args.train_splits, args.train_datasets):
            train_dataset, _ = get_dataset.get_dataset(
                dat_name,
                split=train_split,
                meta={"version": args.version, "split_mode": args.split_mode},
                use_cache=args.use_cache,
                mini_factor=args.mini_factor,
                no_augm=args.no_augm,
                block_rot=args.block_rot,
                max_rot=args.max_rot,
                center_idx=args.center_idx,
                scale_jittering=args.scale_jittering,
                center_jittering=args.center_jittering,
                fraction=args.fraction,
                mode="strong",
                sample_nb=None,
            )
            train_loader = torch.utils.data.DataLoader(
                train_dataset,
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=int(args.workers / len(args.train_datasets)),
                drop_last=True,
                collate_fn=collate.meshreg_collate,
            )
            loaders.append(train_loader)
        loader = concatloader.ConcatDataloader(loaders)
    if args.evaluate or args.eval_freq != -1:
        val_dataset, _ = get_dataset.get_dataset(
            args.val_dataset,
            split=args.val_split,
            meta={"version": args.version, "split_mode": args.split_mode},
            use_cache=args.use_cache,
            mini_factor=args.mini_factor,
            no_augm=True,
            block_rot=args.block_rot,
            max_rot=args.max_rot,
            center_idx=args.center_idx,
            scale_jittering=args.scale_jittering,
            center_jittering=args.center_jittering,
            sample_nb=None,
        )
        val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=int(args.workers),
            drop_last=False,
            collate_fn=collate.meshreg_collate,
        )

    model = MeshRegNet(
        mano_center_idx=args.center_idx,
        mano_lambda_joints2d=args.mano_lambda_joints2d,
        mano_lambda_joints3d=args.mano_lambda_joints3d,
        mano_lambda_recov_joints3d=args.mano_lambda_recov_joints3d,
        mano_lambda_recov_verts3d=args.mano_lambda_recov_verts3d,
        mano_lambda_verts2d=args.mano_lambda_verts2d,
        mano_lambda_verts3d=args.mano_lambda_verts3d,
        mano_lambda_shape=args.mano_lambda_shape,
        mano_use_shape=args.mano_lambda_shape > 0,
        mano_lambda_pose_reg=args.mano_lambda_pose_reg,
        obj_lambda_recov_verts3d=args.obj_lambda_recov_verts3d,
        obj_lambda_verts2d=args.obj_lambda_verts2d,
        obj_lambda_verts3d=args.obj_lambda_verts3d,
        obj_trans_factor=args.obj_trans_factor,
        obj_scale_factor=args.obj_scale_factor,
        mano_fhb_hand="fhbhands" in args.train_datasets,
    )
    model.cuda()
    # Initalize model
    if args.resume is not None:
        opts = reloadmodel.load_opts(args.resume)
        model, epoch = reloadmodel.reload_model(args.resume, opts)
        model.cuda()
        if args.evaluate:
            args.epochs = epoch + 1
    else:
        epoch = 0

    if args.freeze_batchnorm:
        freeze.freeze_batchnorm_stats(model)  # Freeze batchnorm

    model_params = filter(lambda p: p.requires_grad, model.parameters())
    if args.optimizer == "adam":
        optimizer = torch.optim.Adam(model_params, lr=args.lr, weight_decay=args.weight_decay)
    elif args.optimizer == "sgd":
        optimizer = torch.optim.SGD(
            model_params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
        )
    if args.resume is not None:
        reloadmodel.reload_optimizer(args.resume, optimizer)
    if args.lr_decay_gamma:
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_decay_step, gamma=args.lr_decay_gamma)
    fig = plt.figure(figsize=(10, 10))
    save_results = {}
    save_results["opt"] = dict(vars(args))
    save_results["train_losses"] = []
    save_results["val_losses"] = []

    monitor = MetricMonitor()
    print("Will monitor metrics.")

    for epoch_idx in tqdm(range(epoch, args.epochs), desc="epoch"):
        if not args.freeze_batchnorm:
            model.train()
        else:
            model.eval()

        if not args.evaluate:
            save_dict = epochpass.epoch_pass(
                loader,
                model,
                train=True,
                optimizer=optimizer,
                scheduler=scheduler,
                epoch=epoch_idx,
                img_folder=img_folder,
                fig=fig,
                display_freq=args.display_freq,
                epoch_display_freq=args.epoch_display_freq,
                lr_decay_gamma=args.lr_decay_gamma,
                monitor=monitor,
            )
            save_results["train_losses"].append(save_dict)
            with open(result_path, "wb") as p_f:
                pickle.dump(save_results, p_f)
            modelio.save_checkpoint(
                {
                    "epoch": epoch_idx + 1,
                    "network": "correspnet",
                    "state_dict": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "scheduler": scheduler,
                },
                is_best=True,
                checkpoint=exp_id,
                snapshot=args.snapshot,
            )

        if args.evaluate or (args.eval_freq != -1 and epoch_idx % args.eval_freq == 0):
            val_save_dict = epochpass.epoch_pass(
                val_loader,
                model,
                train=False,
                optimizer=None,
                scheduler=None,
                epoch=epoch_idx,
                img_folder=img_folder,
                fig=fig,
                display_freq=args.display_freq,
                epoch_display_freq=args.epoch_display_freq,
                lr_decay_gamma=args.lr_decay_gamma,
                monitor=monitor,
            )

            save_results["val_losses"].append(val_save_dict)
            monitor.plot(os.path.join(exp_id, "training.html"), plotly=True)
            monitor.plot_histogram(os.path.join(exp_id, "evaluation.html"), plotly=True)
            if args.matplotlib:
                monitor.plot(result_folder, matplotlib=True)
                monitor.plot_histogram(result_folder, matplotlib=True)
            monitor.save_metrics(os.path.join(exp_id, "metrics.pkl"))
            if args.evaluate:
                print(val_save_dict)
                break
def main(args):
    setseeds.set_all_seeds(args.manual_seed)
    train_dat_str = "_".join(args.train_datasets)
    dat_str = f"{train_dat_str}_warp{args.consist_dataset}"
    now = datetime.now()
    exp_id = (
        f"checkpoints/{dat_str}_mini{args.mini_factor}/{now.year}_{now.month:02d}_{now.day:02d}/"
        f"{args.com}_frac{args.fraction:.1e}"
        f"_ld{args.lambda_data:.1e}_lc{args.lambda_consist:.1e}"
        f"crit{args.consist_criterion}sca{args.consist_scale}"
        f"opt{args.optimizer}_lr{args.lr}_crit{args.criterion2d}"
        f"_mom{args.momentum}_bs{args.batch_size}"
        f"_lmj3d{args.mano_lambda_joints3d:.1e}"
        f"_lmbeta{args.mano_lambda_shape:.1e}"
        f"_lmpr{args.mano_lambda_pose_reg:.1e}"
        f"_lmrj3d{args.mano_lambda_recov_joints3d:.1e}"
        f"_lmrw3d{args.mano_lambda_recov_verts3d:.1e}"
        f"_lov2d{args.obj_lambda_verts2d:.1e}_lov3d{args.obj_lambda_verts3d:.1e}"
        f"_lovr3d{args.obj_lambda_recov_verts3d:.1e}"
        f"cj{args.center_jittering}seed{args.manual_seed}"
        f"sample_nb{args.sample_nb}_spac{args.spacing}"
        f"csteps{args.progressive_consist_steps}")
    if args.no_augm:
        exp_id = f"{exp_id}_no_augm"
    if args.no_consist_augm:
        exp_id = f"{exp_id}_noconsaugm"
    if args.block_rot:
        exp_id = f"{exp_id}_block_rot"
    if args.consist_gt_refs:
        exp_id = f"{exp_id}_gt_refs"

    # Initialize local checkpoint folder
    save_args(args, exp_id, "opt")
    monitor = Monitor(exp_id, hosting_folder=exp_id)
    img_folder = os.path.join(exp_id, "images")
    os.makedirs(img_folder, exist_ok=True)
    result_folder = os.path.join(exp_id, "results")
    result_path = os.path.join(result_folder, "results.pkl")
    os.makedirs(result_folder, exist_ok=True)
    pyapt_path = os.path.join(result_folder,
                              f"{args.pyapt_id}__{now.strftime('%H_%M_%S')}")
    with open(pyapt_path, "a") as t_f:
        t_f.write(" ")

    train_loader_nb = 0
    if len(args.train_datasets):
        train_loader_nb = train_loader_nb + len(args.train_datasets)
    if args.consist_dataset is not None:
        train_loader_nb = train_loader_nb + 1

    loaders = []
    if len(args.train_datasets) is not None:
        for train_split, dat_name in zip(args.train_splits,
                                         args.train_datasets):
            train_dataset, input_res = get_dataset.get_dataset(
                dat_name,
                split=train_split,
                meta={
                    "version": args.version,
                    "split_mode": "objects"
                },
                block_rot=args.block_rot,
                max_rot=args.max_rot,
                center_idx=args.center_idx,
                center_jittering=args.center_jittering,
                fraction=args.fraction,
                mini_factor=args.mini_factor,
                mode="strong",
                no_augm=args.no_augm,
                scale_jittering=args.scale_jittering,
                sample_nb=1,
                use_cache=args.use_cache,
            )
            train_loader = torch.utils.data.DataLoader(
                train_dataset,
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=int(args.workers / train_loader_nb),
                drop_last=True,
                collate_fn=collate_fn,
            )
            loaders.append(train_loader)

    if args.consist_dataset is not None:
        consist_dataset, consist_input_res = get_dataset.get_dataset(
            args.consist_dataset,
            split=args.consist_split,
            meta={
                "version": args.version,
                "split_mode": "objects"
            },
            use_cache=args.use_cache,
            mini_factor=args.mini_factor,
            block_rot=args.block_rot,
            center_idx=args.center_idx,
            center_jittering=args.center_jittering,
            fraction=args.fraction,
            max_rot=args.max_rot,
            mode="full",
            no_augm=args.no_consist_augm,
            sample_nb=args.sample_nb,
            scale_jittering=args.scale_jittering,
            spacing=args.
            spacing,  # Otherwise black padding gets included on warps
        )
        print(
            f"Got consist dataset {args.consist_dataset} of size {len(consist_dataset)}"
        )
        if input_res != consist_input_res:
            raise ValueError(
                f"train and consist dataset should have same input sizes"
                f"but got {input_res} and {consist_input_res}")
        consist_loader = torch.utils.data.DataLoader(
            consist_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=int(args.workers / train_loader_nb),
            drop_last=True,
            collate_fn=collate_fn,
        )
        loaders.append(consist_loader)
    loader = concatloader.ConcatLoader(loaders)
    if args.eval_freq != -1:
        val_dataset, input_res = get_dataset.get_dataset(
            args.val_dataset,
            split=args.val_split,
            meta={
                "version": args.version,
                "split_mode": "objects"
            },
            use_cache=args.use_cache,
            mini_factor=args.mini_factor,
            no_augm=True,
            block_rot=args.block_rot,
            max_rot=args.max_rot,
            center_idx=args.center_idx,
            scale_jittering=args.scale_jittering,
            center_jittering=args.center_jittering,
            sample_nb=None,
        )
        val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=int(args.workers),
            drop_last=False,
            collate_fn=collate.meshreg_collate,
        )

    model = MeshRegNet(
        mano_lambda_joints3d=args.mano_lambda_joints3d,
        mano_lambda_joints2d=args.mano_lambda_joints2d,
        mano_lambda_pose_reg=args.mano_lambda_pose_reg,
        mano_lambda_recov_joints3d=args.mano_lambda_recov_joints3d,
        mano_lambda_recov_verts3d=args.mano_lambda_recov_verts3d,
        mano_lambda_verts2d=args.mano_lambda_verts2d,
        mano_lambda_verts3d=args.mano_lambda_verts3d,
        mano_lambda_shape=args.mano_lambda_shape,
        mano_use_shape=args.mano_lambda_shape > 0,
        obj_lambda_verts2d=args.obj_lambda_verts2d,
        obj_lambda_verts3d=args.obj_lambda_verts3d,
        obj_lambda_recov_verts3d=args.obj_lambda_recov_verts3d,
        obj_trans_factor=args.obj_trans_factor,
        obj_scale_factor=args.obj_scale_factor,
        mano_fhb_hand="fhbhands" in args.train_datasets,
    )
    model.cuda()
    # Initalize model
    if args.resume is not None:
        opts = reloadmodel.load_opts(args.resume)
        model, epoch = reloadmodel.reload_model(args.resume, opts)
    else:
        epoch = 0
    if args.freeze_batchnorm:
        freeze.freeze_batchnorm_stats(model)  # Freeze batchnorm

    model_params = filter(lambda p: p.requires_grad, model.parameters())
    pmodel = warpreg.WarpRegNet(
        input_res,
        model,
        lambda_consist=args.lambda_consist,
        lambda_data=args.lambda_data,
        criterion=args.consist_criterion,
        consist_scale=args.consist_scale,
        gt_refs=args.consist_gt_refs,
        progressive_steps=args.progressive_consist_steps,
        use_backward=args.consist_use_backward,
    )
    pmodel.cuda()
    if args.optimizer == "adam":
        optimizer = torch.optim.Adam(model_params,
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)
    elif args.optimizer == "sgd":
        optimizer = torch.optim.SGD(model_params,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    if args.resume is not None:
        reloadmodel.reload_optimizer(args.resume, optimizer)
    if args.lr_decay_gamma:
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    args.lr_decay_step,
                                                    gamma=args.lr_decay_gamma)
    fig = plt.figure(figsize=(10, 10))
    save_results = {}
    save_results["opt"] = dict(vars(args))
    save_results["train_losses"] = []
    save_results["val_losses"] = []
    for epoch_idx in tqdm(range(epoch, args.epochs), desc="epoch"):
        if not args.freeze_batchnorm:
            model.train()
        else:
            model.eval()
        save_dict, avg_meters, _ = epochpassconsist.epoch_pass(
            loader,
            model,
            train=True,
            optimizer=optimizer,
            scheduler=scheduler,
            epoch=epoch_idx,
            img_folder=img_folder,
            fig=fig,
            display_freq=args.display_freq,
            epoch_display_freq=args.epoch_display_freq,
            lr_decay_gamma=args.lr_decay_gamma,
            premodel=pmodel,
            loader_nb=train_loader_nb,
        )
        monitor.log_train(
            epoch_idx + 1,
            {key: val.avg
             for key, val in avg_meters.average_meters.items()})
        monitor.metrics.save_metrics(epoch_idx + 1, save_dict)
        monitor.metrics.plot_metrics()
        save_results["train_losses"].append(save_dict)
        with open(result_path, "wb") as p_f:
            pickle.dump(save_results, p_f)
        modelio.save_checkpoint(
            {
                "epoch": epoch_idx + 1,
                "network": "correspnet",
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler,
            },
            is_best=True,
            checkpoint=exp_id,
            snapshot=args.snapshot,
        )
        if args.eval_freq != -1 and epoch_idx % args.eval_freq == 0:
            val_save_dict, val_avg_meters, _ = epochpass.epoch_pass(
                val_loader,
                model,
                train=False,
                optimizer=None,
                scheduler=None,
                epoch=epoch_idx,
                img_folder=img_folder,
                fig=fig,
                display_freq=args.display_freq,
                epoch_display_freq=args.epoch_display_freq,
                lr_decay_gamma=args.lr_decay_gamma,
            )

            save_results["val_losses"].append(val_save_dict)
            monitor.log_val(epoch_idx + 1, {
                key: val.avg
                for key, val in val_avg_meters.average_meters.items()
            })
            monitor.metrics.save_metrics(epoch_idx + 1, val_save_dict)
            monitor.metrics.plot_metrics()