Example #1
0
def main(args):
    if (not args.fine_tune or not args.fine_tune in ['rhd', 'stb']):
        raise Exception('expect --fine_tune in [rhd|stb], got {}'.format(
            args.fine_tune))
    args.datasets = [
        args.fine_tune,
    ]
    misc.print_args(args)
    if not os.path.isdir(args.checkpoint):
        os.makedirs(args.checkpoint)
    print("\nCREATE NETWORK")
    model = models.NetBiHand(
        net_modules=['seed', 'lift', 'sik'],
        njoints=args.njoints,
        inp_res=256,
        out_hm_res=64,
        out_dep_res=64,
        upstream_hg_stacks=args.hg_stacks,
        upstream_hg_blocks=args.hg_blocks,
    )
    model = model.to(device)

    # define loss function (criterion) and optimizer

    print("\nCREATE TESTSET")
    val_dataset = HandDataset(
        data_split='test',
        train=False,
        subset_name=args.datasets,
        data_root=args.data_root,
    )

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)
    print("Total test dataset size: {}".format(len(val_dataset)))
    print("\nLOAD CHECKPOINT")
    model.load_checkpoints(
        ckp_seednet=os.path.join(args.checkpoint, 'ckp_seednet_all.pth.tar'),
        ckp_liftnet=os.path.join(
            args.checkpoint, args.fine_tune,
            'ckp_liftnet_{}.pth.tar'.format(args.fine_tune)),
        ckp_siknet=os.path.join(args.checkpoint, args.fine_tune,
                                'ckp_siknet_{}.pth.tar'.format(
                                    args.fine_tune)))

    validate(val_loader, model, vis=args.vis)
    return 0
Example #2
0
def main(args):
    best_acc = 0
    if not os.path.isdir(args.checkpoint):
        os.makedirs(args.checkpoint)
    misc.print_args(args)
    print("\nCREATE NETWORK")
    model = models.NetBiHand(
        net_modules=args.net_modules,  # only train hm
        njoints=21,
        inp_res=256,
        out_hm_res=64,
        out_dep_res=64,
        upstream_hg_stacks=2,
        upstream_hg_blocks=1,
    )
    model = model.to(device)

    # define loss function (criterion) and optimizer
    criterion_ups = losses.UpstreamLoss(
        lambda_hm=100.0,
        lambda_mask=1.0
    )

    criterion = {
        'ups': criterion_ups,
    }
    optimizer = torch.optim.Adam(
        [
            {
                'params': model.upstream.seednet.parameters(),
                'initial_lr': args.learning_rate
            },
        ],
        lr=args.learning_rate,
    )

    print("\nCREATE DATASET")
    train_dataset = HandDataset(
        data_split='train',
        train=True,
        scale_jittering=0.2,
        center_jettering=0.2,
        max_rot=0.5 * np.pi,
        subset_name=args.datasets,
        data_root=args.data_root,
    )
    val_dataset = HandDataset(
        data_split='test',
        train=False,
        subset_name=args.datasets,
        data_root=args.data_root,
    )
    print("Total train dataset size: {}".format(len(train_dataset)))
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.train_batch,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=args.test_batch,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True
    )
    print("Total test dataset size: {}".format(len(val_dataset)))
    print("\nLOAD CHECKPOINT")
    if args.resume or args.evaluate:
        model.load_checkpoints(
            ckp_seednet=os.path.join(
                args.checkpoint,
                'ckp_seednet_all.pth.tar'
            )
        )
    else:
        for m in model.modules():
            if isinstance(m, torch.nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)

    if args.evaluate:
        validate(val_loader, model, criterion, args=args)
        return 0

    model = torch.nn.DataParallel(model)
    print("\nUSING {} GPUs".format(torch.cuda.device_count()))

    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, args.lr_decay_step, gamma=args.gamma,
        last_epoch=args.start_epoch
    )

    for epoch in range(args.start_epoch, args.epochs + 1):
        print('\nEpoch: %d' % (epoch + 1))
        for i in range(len(optimizer.param_groups)):
            print('group %d lr:' % i, optimizer.param_groups[i]['lr'])
        #############  trian for on epoch  ###############
        train(
            train_loader,
            model,
            criterion,
            optimizer,
            args=args,
        )
        ##################################################
        acc_hm = best_acc
        if epoch >= 50 and epoch % 5 == 0:
            acc_hm = validate(val_loader, model, criterion, args=args)
        misc.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.module.upstream.seednet.state_dict(),
            },
            checkpoint=args.checkpoint,
            filename='{}.pth.tar'.format(args.saved_prefix),
            snapshot=args.snapshot,
            is_best=acc_hm > best_acc
        )
        if acc_hm > best_acc:
            best_acc = acc_hm
        scheduler.step()
    cprint('All Done', 'yellow', attrs=['bold'])
    return 0  # end of main
Example #3
0
def main(args):
    if (
            not args.fine_tune
            or not args.fine_tune in ['rhd', 'stb']
    ):
        raise Exception('expect --fine_tune in [rhd|stb], got {}'
                        .format(args.fine_tune))
    if (
            not args.data_split
            or not args.data_split in ['train', 'test']
    ):
        raise Exception('expect --data_split in [train|test], got {}'
                        .format(args.data_split))

    args.datasets = [args.fine_tune, ]
    is_train = (args.data_split == 'train')
    if not is_train:
        args.epochs = 1
    if not os.path.isdir(args.sik_genpath):
        os.makedirs(args.sik_genpath)
    misc.print_args(args)

    print("\nCREATE NETWORK")
    model = models.NetBiHand(
        net_modules=args.net_modules,
        njoints=21,
        inp_res=256,
        out_hm_res=64,
        out_dep_res=64,
        upstream_hg_stacks=2,
        upstream_hg_blocks=1,
    )
    model = model.to(device)

    criterion = {}

    print("\nCREATE DATASET")
    print(colored(args.datasets, 'yellow', attrs=['bold']),
          colored(args.data_split, 'blue', attrs=['bold']),
          colored('is_train:{}'.format(is_train), 'red', attrs=['bold']))

    gen_set = HandDataset(
        data_split=args.data_split,
        train=is_train,
        subset_name=args.datasets,
        scale_jittering=0.2,
        center_jettering=0.2,
        max_rot=0.5 * np.pi,
        data_root=args.data_root,
    )

    gen_loader = torch.utils.data.DataLoader(
        gen_set,
        batch_size=args.train_batch,
        shuffle=is_train,
        num_workers=args.workers,
        pin_memory=True
    )

    print("\nLOAD CHECKPOINT")
    model.load_checkpoints(
        ckp_seednet=os.path.join(args.checkpoint, 'ckp_seednet_all.pth.tar'),
        ckp_liftnet=os.path.join(args.checkpoint, args.fine_tune,
                                 'ckp_liftnet_{}.pth.tar'.format(args.fine_tune))
    )
    model = torch.nn.DataParallel(model)
    print("\nUSING {} GPUs".format(torch.cuda.device_count()))
    all_save_at = []
    for i in range(args.epochs):
        saving = validate(gen_loader, model, criterion, args=args)
        save_at = os.path.join(args.sik_genpath, "sik_{}_{}_{}.pkl"
                               .format(args.data_split, args.fine_tune, i))
        with open(save_at, 'wb') as fid:
            pickle.dump(saving, fid)
        fid.close()
        cprint("saving {} epoch data at {}".format(i, save_at), 'yellow')
        all_save_at.append(save_at)

    # merge all temp files
    allJointGt_, allJointImpl_ = [], []
    for save_at in all_save_at:
        with open(save_at, 'rb') as fid:
            raw = dict(pickle.load(fid))
            fid.close()
        allJointGt_.append(raw['jointGt_'])
        allJointImpl_.append(raw['jointImpl_'])

    allJointGt_ = np.concatenate(allJointGt_, axis=0)
    allJointImpl_ = np.concatenate(allJointImpl_, axis=0)
    sikdata = {
        'jointGt_': allJointGt_,
        'jointImpl_': allJointImpl_
    }
    sikdata_at = os.path.join(
        args.sik_genpath,
        'sik_{}_{}{}.pkl'.format(
            args.data_split,
            args.fine_tune,
            '_{}epochs'.format(args.epochs) if is_train else ''
        )
    )
    with open(sikdata_at, 'wb') as fid:
        pickle.dump(sikdata, fid)
        fid.close()
    cprint('Saved {} samples at {}'
           .format(allJointGt_.shape[0], sikdata_at), 'yellow', attrs=['bold'])

    # delete intermediate outputs
    for save_at in all_save_at:
        os.remove(save_at)
    cprint('All Done', 'yellow', attrs=['bold'])
    return 0  # end of main
Example #4
0
def main(args):
    if (not args.fine_tune or not args.fine_tune in ['rhd', 'stb']):
        raise Exception('expect --fine_tune in [rhd|stb], got {}'.format(
            args.fine_tune))
    if not os.path.isdir(args.checkpoint):
        os.makedirs(args.checkpoint)

    args.datasets = [
        args.fine_tune,
    ]
    misc.print_args(args)
    auc_best = 0
    print("\nCREATE NETWORK")

    model = models.NetBiHand(
        net_modules=['seed', 'lift', 'sik'],
        njoints=21,
        inp_res=256,
        out_hm_res=64,
        out_dep_res=64,
        upstream_hg_stacks=2,
        upstream_hg_blocks=1,
    )
    model = model.to(device)

    criterion = losses.SIKLoss(lambda_quat=0.0,
                               lambda_joint=1.0,
                               lambda_shape=1.0)

    optimizer = torch.optim.Adam(
        [
            {
                'params': model.siknet.parameters(),
                'initial_lr': args.learning_rate
            },
        ],
        lr=args.learning_rate,
    )

    train_dataset = HandDataset(
        data_split='train',
        train=True,
        scale_jittering=0.2,
        center_jettering=0.2,
        max_rot=0.5 * np.pi,
        subset_name=args.datasets,
        data_root=args.data_root,
    )

    val_dataset = HandDataset(
        data_split='test',
        train=False,
        subset_name=args.datasets,
        data_root=args.data_root,
    )

    print("Total train dataset size: {}".format(len(train_dataset)))
    print("Total val dataset size: {}".format(len(val_dataset)))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.train_batch,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.test_batch,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    model.load_checkpoints(
        ckp_seednet=os.path.join(args.checkpoint, 'ckp_seednet_all.pth.tar'),
        ckp_liftnet=os.path.join(
            args.checkpoint, args.fine_tune,
            'ckp_liftnet_{}.pth.tar'.format(args.fine_tune)),
        ckp_siknet=os.path.join(args.checkpoint, 'ckp_siknet_synth.pth.tar'))
    for params in model.upstream.parameters():
        params.requires_grad = False

    if args.evaluate or args.resume:
        model.load_checkpoints(ckp_siknet=os.path.join(
            args.checkpoint, args.fine_tune, 'ckp_siknet_{}.pth.tar'.format(
                args.fine_tune)))

    if args.evaluate:
        validate(val_loader, model, criterion, args=args)
        cprint('Eval All Done', 'yellow', attrs=['bold'])
        return 0

    model = torch.nn.DataParallel(model)
    print("\nUSING {} GPUs".format(torch.cuda.device_count()))
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                args.lr_decay_step,
                                                gamma=args.gamma,
                                                last_epoch=args.start_epoch)

    for epoch in range(args.start_epoch, args.epochs + 1):
        print('\nEpoch: %d' % (epoch))
        for i in range(len(optimizer.param_groups)):
            print('group %d lr:' % i, optimizer.param_groups[i]['lr'])
        #############  trian for on epoch  ###############
        train(
            train_loader,
            model,
            criterion,
            optimizer,
            args=args,
        )
        ##################################################
        auc_all = validate(val_loader, model, criterion, args)
        misc.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.module.state_dict(),
            },
            checkpoint=args.checkpoint,
            filename='{}_{}.pth.tar'.format(args.saved_prefix, args.fine_tune),
            snapshot=args.snapshot,
            is_best=auc_all > auc_best)
        if auc_all > auc_best:
            auc_best = auc_all

        scheduler.step()
    cprint('All Done', 'yellow', attrs=['bold'])
    return 0  # end of main