Beispiel #1
0
    def __init__(self, mano_path, detnet_path=None, iknet_path=None):
        super().__init__()
        self.para_init(mano_path)
        self.detnet = detnet(stacks=1)
        self.iknet = iknet(inc=84 * 3, depth=6, width=1024)

        self.model_init(detnet_path, iknet_path)
Beispiel #2
0
import torch
from manopth import manolayer
from model.detnet import detnet
from utils import func, bone, AIK, smoother
import numpy as np
import matplotlib.pyplot as plt
from utils import vis
from op_pso import PSO
import open3d
from model import shape_net
import os

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
_mano_root = 'mano/models'

module = detnet().to(device)
print('load model start')
check_point = torch.load('new_check_point/ckp_detnet_83.pth',
                         map_location=device)
model_state = module.state_dict()
state = {}
for k, v in check_point.items():
    if k in model_state:
        state[k] = v
    else:
        print(k, ' is NOT in current model')
model_state.update(state)
module.load_state_dict(model_state)
print('load model finished')

shape_model = shape_net.ShapeNet()
def main(args):
    for path in [args.checkpoint, args.outpath]:
        if not os.path.isdir(path):
            os.makedirs(path)

    misc.print_args(args)

    print("\nCREATE NETWORK")
    model = detnet()
    model.to(device)

    # define loss function (criterion) and optimizer
    criterion_det = losses.DetLoss(
        lambda_hm=100.,
        lambda_dm=1.,
        lambda_lm=10.,
    )
    criterion = {'det': criterion_det}
    optimizer = torch.optim.Adam([
        {
            'params': model.parameters(),
            'initial_lr': args.learning_rate
        },
    ],
                                 lr=args.learning_rate)

    test_set_dic = {}
    test_loader_dic = {}
    best_acc = {}
    auc_all = {}
    acc_hm_all = {}
    for test_set_name in args.datasets_test:
        if test_set_name in ['stb', 'rhd', 'do']:
            test_set_dic[test_set_name] = HandDataset(
                data_split='test',
                train=False,
                subset_name=test_set_name,
                data_root=args.data_root,
            )
        elif test_set_name == 'eo':
            test_set_dic[test_set_name] = EgoDexter(data_split='test',
                                                    data_root=args.data_root,
                                                    hand_side="right")
            print(test_set_dic[test_set_name])

        test_loader_dic[test_set_name] = torch.utils.data.DataLoader(
            test_set_dic[test_set_name],
            batch_size=args.test_batch,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True,
            drop_last=False)
        best_acc[test_set_name] = 0
        auc_all[test_set_name] = []
        acc_hm_all[test_set_name] = []

    total_test_set_size = 0
    for key, value in test_set_dic.items():
        total_test_set_size += len(value)
    print("Total test set size: {}".format(total_test_set_size))

    if args.resume or args.evaluate:
        print("\nLOAD CHECKPOINT")
        state_dict = torch.load(
            os.path.join(args.checkpoint,
                         'ckp_detnet_{}.pth'.format(args.evaluate_id)))
        # if args.clean:
        state_dict = misc.clean_state_dict(state_dict)

        model.load_state_dict(state_dict)
    else:
        for m in model.modules():
            if isinstance(m, torch.nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)

    if args.evaluate:
        for key, value in test_loader_dic.items():
            validate(value, model, criterion, key, args=args)
        return 0

    train_dataset = HandDataset(
        data_split='train',
        train=True,
        subset_name=args.datasets_train,
        data_root=args.data_root,
        scale_jittering=0.1,
        center_jettering=0.1,
        max_rot=0.5 * np.pi,
    )

    print("Total train dataset size: {}".format(len(train_dataset)))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.train_batch,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=False)

    # DataParallel so u can use multi GPUs
    model = torch.nn.DataParallel(model)
    print("\nUSING {} GPUs".format(torch.cuda.device_count()))

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                args.lr_decay_step,
                                                gamma=args.gamma,
                                                last_epoch=args.start_epoch)

    acc_hm = {}
    loss_all = {
        "lossH": [],
        "lossD": [],
        "lossL": [],
    }

    for epoch in range(args.start_epoch, args.epochs + 1):
        print('\nEpoch: %d' % (epoch + 1))
        for i in range(len(optimizer.param_groups)):
            print('group %d lr:' % i, optimizer.param_groups[i]['lr'])
        #############  trian for one epoch  ###############
        train(train_loader,
              model,
              criterion,
              optimizer,
              args=args,
              loss_all=loss_all)
        ##################################################
        auc = best_acc.copy()  # need to deepcopy it because it's a dict
        for key, value in test_loader_dic.items():
            auc[key], acc_hm[key] = validate(value,
                                             model,
                                             criterion,
                                             key,
                                             args=args)
            auc_all[key].append([epoch + 1, auc[key]])
            acc_hm_all[key].append([epoch + 1, acc_hm[key]])

        misc.save_checkpoint({
            'epoch': epoch + 1,
            'model': model,
        },
                             checkpoint=args.checkpoint,
                             filename='{}.pth'.format(args.saved_prefix),
                             snapshot=args.snapshot,
                             is_best=[auc, best_acc])

        for key, value in test_loader_dic.items():
            if auc[key] > best_acc[key]:
                best_acc[key] = auc[key]

        misc.out_loss_auc(loss_all, auc_all, acc_hm_all, outpath=args.outpath)

        scheduler.step()

    return 0  # end of main