Exemple #1
0
def main(args):
    if (not args.fine_tune or not args.fine_tune in ['rhd', 'stb']):
        raise Exception('expect --fine_tune in [rhd|stb], got {}'.format(
            args.fine_tune))
    if not os.path.isdir(args.checkpoint):
        os.makedirs(args.checkpoint)

    args.datasets = [
        args.fine_tune,
    ]
    misc.print_args(args)
    auc_best = 0
    print("\nCREATE NETWORK")

    model = models.NetBiHand(
        net_modules=['seed', 'lift', 'sik'],
        njoints=21,
        inp_res=256,
        out_hm_res=64,
        out_dep_res=64,
        upstream_hg_stacks=2,
        upstream_hg_blocks=1,
    )
    model = model.to(device)

    criterion = losses.SIKLoss(lambda_quat=0.0,
                               lambda_joint=1.0,
                               lambda_shape=1.0)

    optimizer = torch.optim.Adam(
        [
            {
                'params': model.siknet.parameters(),
                'initial_lr': args.learning_rate
            },
        ],
        lr=args.learning_rate,
    )

    train_dataset = HandDataset(
        data_split='train',
        train=True,
        scale_jittering=0.2,
        center_jettering=0.2,
        max_rot=0.5 * np.pi,
        subset_name=args.datasets,
        data_root=args.data_root,
    )

    val_dataset = HandDataset(
        data_split='test',
        train=False,
        subset_name=args.datasets,
        data_root=args.data_root,
    )

    print("Total train dataset size: {}".format(len(train_dataset)))
    print("Total val dataset size: {}".format(len(val_dataset)))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.train_batch,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.test_batch,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    model.load_checkpoints(
        ckp_seednet=os.path.join(args.checkpoint, 'ckp_seednet_all.pth.tar'),
        ckp_liftnet=os.path.join(
            args.checkpoint, args.fine_tune,
            'ckp_liftnet_{}.pth.tar'.format(args.fine_tune)),
        ckp_siknet=os.path.join(args.checkpoint, 'ckp_siknet_synth.pth.tar'))
    for params in model.upstream.parameters():
        params.requires_grad = False

    if args.evaluate or args.resume:
        model.load_checkpoints(ckp_siknet=os.path.join(
            args.checkpoint, args.fine_tune, 'ckp_siknet_{}.pth.tar'.format(
                args.fine_tune)))

    if args.evaluate:
        validate(val_loader, model, criterion, args=args)
        cprint('Eval All Done', 'yellow', attrs=['bold'])
        return 0

    model = torch.nn.DataParallel(model)
    print("\nUSING {} GPUs".format(torch.cuda.device_count()))
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                args.lr_decay_step,
                                                gamma=args.gamma,
                                                last_epoch=args.start_epoch)

    for epoch in range(args.start_epoch, args.epochs + 1):
        print('\nEpoch: %d' % (epoch))
        for i in range(len(optimizer.param_groups)):
            print('group %d lr:' % i, optimizer.param_groups[i]['lr'])
        #############  trian for on epoch  ###############
        train(
            train_loader,
            model,
            criterion,
            optimizer,
            args=args,
        )
        ##################################################
        auc_all = validate(val_loader, model, criterion, args)
        misc.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.module.state_dict(),
            },
            checkpoint=args.checkpoint,
            filename='{}_{}.pth.tar'.format(args.saved_prefix, args.fine_tune),
            snapshot=args.snapshot,
            is_best=auc_all > auc_best)
        if auc_all > auc_best:
            auc_best = auc_all

        scheduler.step()
    cprint('All Done', 'yellow', attrs=['bold'])
    return 0  # end of main
def main(args):
    if (
            not args.fine_tune
            or not args.fine_tune in ['rhd', 'stb']
    ):
        raise Exception('expect --fine_tune in [rhd|stb], got {}'
                        .format(args.fine_tune))
    if not os.path.isdir(args.checkpoint):
        os.makedirs(args.checkpoint)

    args.datasets = [args.fine_tune, ]
    misc.print_args(args)
    auc_best = 0
    print("\nCREATE NETWORK")
    model = models.SIKNet()
    model = model.to(device)
    criterion = losses.SIKLoss(
        lambda_quat=0.0,
        lambda_joint=1.0,
        lambda_shape=1.0
    )

    optimizer = torch.optim.Adam(
        [
            {
                'params': model.invk_layers.parameters(),
                'initial_lr': args.learning_rate
            },
            {
                'params': model.shapereg_layers.parameters(),
                'initial_lr': args.learning_rate
            },

        ],
        lr=args.learning_rate,
    )

    train_dataset = datasets.SIKOFFLINE(
        data_root=args.data_root,
        data_split="train",
        data_source=args.datasets,
    )

    val_dataset = datasets.SIKOFFLINE(
        data_root=args.data_root,
        data_split="test",
        data_source=args.datasets,
    )

    print("Total train dataset size: {}".format(len(train_dataset)))
    print("Total val dataset size: {}".format(len(val_dataset)))

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.train_batch,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True
    )

    if args.evaluate or args.resume:
        misc.load_checkpoint(
            model,
            os.path.join(
                args.checkpoint, args.fine_tune,
                '{}_{}.pth.tar'.format(args.saved_prefix, args.fine_tune)
            )
        )

    if args.evaluate:
        validate(val_loader, model, criterion, args=args)
        cprint('Eval All Done', 'yellow', attrs=['bold'])
        return 0

    model = torch.nn.DataParallel(model)
    print("\nUSING {} GPUs".format(torch.cuda.device_count()))
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, args.lr_decay_step, gamma=args.gamma,
        last_epoch=args.start_epoch
    )

    for epoch in range(args.start_epoch, args.epochs + 1):
        print('\nEpoch: %d' % (epoch))
        for i in range(len(optimizer.param_groups)):
            print('group %d lr:' % i, optimizer.param_groups[i]['lr'])
        #############  trian for on epoch  ###############
        train(
            train_loader,
            model,
            criterion,
            optimizer,
            args=args,
        )
        ##################################################
        auc_all = validate(val_loader, model, criterion, args)
        misc.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.module.state_dict(),
            },
            checkpoint=args.checkpoint,
            filename='{}_{}.pth.tar'.format(args.saved_prefix, args.fine_tune),
            snapshot=args.snapshot,
            is_best=auc_all > auc_best
        )
        if auc_all > auc_best:
            auc_best = auc_all

        scheduler.step()
    cprint('All Done', 'yellow', attrs=['bold'])
    return 0  # end of main
Exemple #3
0
def main(args):
    best_acc = 0
    if not os.path.isdir(args.checkpoint):
        os.makedirs(args.checkpoint)
    misc.print_args(args)
    print("\nCREATE NETWORK")
    model = models.NetBiHand(
        net_modules=args.net_modules,  # only train hm
        njoints=21,
        inp_res=256,
        out_hm_res=64,
        out_dep_res=64,
        upstream_hg_stacks=2,
        upstream_hg_blocks=1,
    )
    model = model.to(device)

    # define loss function (criterion) and optimizer
    criterion_ups = losses.UpstreamLoss(lambda_hm=100.0, lambda_mask=1.0)

    criterion = {
        'ups': criterion_ups,
    }
    optimizer = torch.optim.Adam(
        [
            {
                'params': model.upstream.seednet.parameters(),
                'initial_lr': args.learning_rate
            },
        ],
        lr=args.learning_rate,
    )

    print("\nCREATE DATASET")
    train_dataset = HandDataset(
        data_split='train',
        train=True,
        scale_jittering=0.2,
        center_jettering=0.2,
        max_rot=0.5 * np.pi,
        subset_name=args.datasets,
        data_root=args.data_root,
    )
    val_dataset = HandDataset(
        data_split='test',
        train=False,
        subset_name=args.datasets,
        data_root=args.data_root,
    )
    print("Total train dataset size: {}".format(len(train_dataset)))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.train_batch,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.test_batch,
                                             shuffle=True,
                                             num_workers=args.workers,
                                             pin_memory=True)
    print("Total test dataset size: {}".format(len(val_dataset)))
    print("\nLOAD CHECKPOINT")
    if args.resume_seednet_pth:
        model.load_checkpoints(ckp_seednet=args.resume_seednet_pth)
    else:
        for m in model.modules():
            if isinstance(m, torch.nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)

    if args.evaluate_seednet_pth:
        model.load_checkpoints(ckp_seednet=args.evaluate_seednet_pth)
        validate(val_loader, model, criterion, args=args)
        return 0

    model = torch.nn.DataParallel(model)
    print("\nUSING {} GPUs".format(torch.cuda.device_count()))

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                args.lr_decay_step,
                                                gamma=args.gamma,
                                                last_epoch=args.start_epoch)

    for epoch in range(args.start_epoch, args.epochs + 1):
        print('\nEpoch: %d' % (epoch + 1))
        for i in range(len(optimizer.param_groups)):
            print('group %d lr:' % i, optimizer.param_groups[i]['lr'])
        #############  trian for on epoch  ###############
        train(
            train_loader,
            model,
            criterion,
            optimizer,
            args=args,
        )
        ##################################################
        acc_hm = best_acc
        if epoch >= 50 and epoch % 5 == 0:
            acc_hm = validate(val_loader, model, criterion, args=args)
        misc.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.module.upstream.seednet.state_dict(),
            },
            checkpoint=args.checkpoint,
            filename='{}.pth.tar'.format(args.saved_prefix),
            snapshot=args.snapshot,
            is_best=acc_hm > best_acc)
        if acc_hm > best_acc:
            best_acc = acc_hm
        scheduler.step()
    cprint('All Done', 'yellow', attrs=['bold'])
    return 0  # end of main