def __init__(self, mano_path, detnet_path=None, iknet_path=None): super().__init__() self.para_init(mano_path) self.detnet = detnet(stacks=1) self.iknet = iknet(inc=84 * 3, depth=6, width=1024) self.model_init(detnet_path, iknet_path)
import torch from manopth import manolayer from model.detnet import detnet from utils import func, bone, AIK, smoother import numpy as np import matplotlib.pyplot as plt from utils import vis from op_pso import PSO import open3d from model import shape_net import os device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') _mano_root = 'mano/models' module = detnet().to(device) print('load model start') check_point = torch.load('new_check_point/ckp_detnet_83.pth', map_location=device) model_state = module.state_dict() state = {} for k, v in check_point.items(): if k in model_state: state[k] = v else: print(k, ' is NOT in current model') model_state.update(state) module.load_state_dict(model_state) print('load model finished') shape_model = shape_net.ShapeNet()
def main(args): for path in [args.checkpoint, args.outpath]: if not os.path.isdir(path): os.makedirs(path) misc.print_args(args) print("\nCREATE NETWORK") model = detnet() model.to(device) # define loss function (criterion) and optimizer criterion_det = losses.DetLoss( lambda_hm=100., lambda_dm=1., lambda_lm=10., ) criterion = {'det': criterion_det} optimizer = torch.optim.Adam([ { 'params': model.parameters(), 'initial_lr': args.learning_rate }, ], lr=args.learning_rate) test_set_dic = {} test_loader_dic = {} best_acc = {} auc_all = {} acc_hm_all = {} for test_set_name in args.datasets_test: if test_set_name in ['stb', 'rhd', 'do']: test_set_dic[test_set_name] = HandDataset( data_split='test', train=False, subset_name=test_set_name, data_root=args.data_root, ) elif test_set_name == 'eo': test_set_dic[test_set_name] = EgoDexter(data_split='test', data_root=args.data_root, hand_side="right") print(test_set_dic[test_set_name]) test_loader_dic[test_set_name] = torch.utils.data.DataLoader( test_set_dic[test_set_name], batch_size=args.test_batch, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=False) best_acc[test_set_name] = 0 auc_all[test_set_name] = [] acc_hm_all[test_set_name] = [] total_test_set_size = 0 for key, value in test_set_dic.items(): total_test_set_size += len(value) print("Total test set size: {}".format(total_test_set_size)) if args.resume or args.evaluate: print("\nLOAD CHECKPOINT") state_dict = torch.load( os.path.join(args.checkpoint, 'ckp_detnet_{}.pth'.format(args.evaluate_id))) # if args.clean: state_dict = misc.clean_state_dict(state_dict) model.load_state_dict(state_dict) else: for m in model.modules(): if isinstance(m, torch.nn.Conv2d): torch.nn.init.kaiming_normal_(m.weight) if args.evaluate: for key, value in test_loader_dic.items(): validate(value, model, criterion, key, args=args) return 0 train_dataset = HandDataset( data_split='train', train=True, subset_name=args.datasets_train, data_root=args.data_root, scale_jittering=0.1, center_jettering=0.1, max_rot=0.5 * np.pi, ) print("Total train dataset size: {}".format(len(train_dataset))) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=False) # DataParallel so u can use multi GPUs model = torch.nn.DataParallel(model) print("\nUSING {} GPUs".format(torch.cuda.device_count())) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_decay_step, gamma=args.gamma, last_epoch=args.start_epoch) acc_hm = {} loss_all = { "lossH": [], "lossD": [], "lossL": [], } for epoch in range(args.start_epoch, args.epochs + 1): print('\nEpoch: %d' % (epoch + 1)) for i in range(len(optimizer.param_groups)): print('group %d lr:' % i, optimizer.param_groups[i]['lr']) ############# trian for one epoch ############### train(train_loader, model, criterion, optimizer, args=args, loss_all=loss_all) ################################################## auc = best_acc.copy() # need to deepcopy it because it's a dict for key, value in test_loader_dic.items(): auc[key], acc_hm[key] = validate(value, model, criterion, key, args=args) auc_all[key].append([epoch + 1, auc[key]]) acc_hm_all[key].append([epoch + 1, acc_hm[key]]) misc.save_checkpoint({ 'epoch': epoch + 1, 'model': model, }, checkpoint=args.checkpoint, filename='{}.pth'.format(args.saved_prefix), snapshot=args.snapshot, is_best=[auc, best_acc]) for key, value in test_loader_dic.items(): if auc[key] > best_acc[key]: best_acc[key] = auc[key] misc.out_loss_auc(loss_all, auc_all, acc_hm_all, outpath=args.outpath) scheduler.step() return 0 # end of main