def __init__(self, args): print('Starting preparing ...') self.args = args # Program setting np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) cudnn.benchmark = True self.device = U.check_gpu(args.gpus) # Data Loader Setting if args.subset in ['cs', 'cv']: num_class = 60 elif args.subset in ['csub', 'cset']: num_class = 120 else: raise ValueError('Do NOT exist this subset: {}'.format(args.subset)) data_shape = (3, args.max_frame, 25, 2) transform = transforms.Compose([ Data_transform(args.data_transform), Occlusion_part(args.occlusion_part), Occlusion_time(args.occlusion_time), Occlusion_block(args.occlusion_block), Occlusion_rand(args.occlusion_rand, data_shape), Jittering_joint(args.jittering_joint, data_shape, sigma=args.sigma), Jittering_frame(args.jittering_frame, data_shape, sigma=args.sigma), ]) self.train_loader = DataLoader(NTU('train', args.subset, data_shape, transform=transform), batch_size=args.batch_size, num_workers=2*len(args.gpus), pin_memory=True, shuffle=True, drop_last=True) self.eval_loader = DataLoader(NTU('eval', args.subset, data_shape, transform=transform), batch_size=args.batch_size, num_workers=2*len(args.gpus), pin_memory=True, shuffle=False, drop_last=False) if args.data_transform: data_shape = (9, args.max_frame, 25, 2) # Graph Setting graph = Graph(max_hop=args.gcn_kernel_size[1]) A = torch.tensor(graph.A, dtype=torch.float32, requires_grad=False).to(self.device) # Model Setting self.model_name = str(args.config)+'_'+str(args.model_stream)+'s_RA-GCN_NTU'+args.subset self.model = RA_GCN(data_shape, num_class, A, args.drop_prob, args.gcn_kernel_size, args.model_stream, args.subset, args.pretrained).to(self.device) self.model = nn.DataParallel(self.model) # Optimizer Setting self.optimizer = torch.optim.SGD(self.model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=0.0001, nesterov=True) # Loss Function Setting self.loss_func = nn.CrossEntropyLoss() # Mask Function Setting self.mask_func = Mask(args.model_stream, self.model.module) print('Successful!\n')
from src.nets import ST_GCN import time import torch import numpy as np from torch.backends import cudnn from torch import nn from torch.nn import functional as F from torch.utils.data import DataLoader from torchvision import transforms from src.graph import Graph from src.utils import check_gpu gcn_kernel_size = [5, 2] graph = Graph(max_hop=gcn_kernel_size[1]) # device = check_gpu([0,1]) device = check_gpu([0]) A = torch.tensor(graph.A, dtype=torch.float32, requires_grad=False).to(device) a = ST_GCN((9, 300, 25, 2), 120, A, 0.5, gcn_kernel_size) checkpoint = torch.load('./models/baseline_NTUcset.pth.tar') # a = ST_GCN((3,300,25,2), 60, A, 0.5, gcn_kernel_size) # a = nn.DataParallel(a) # a.load_state_dict(checkpoint['model']) # a.module.load_state_dict(checkpoint) # a.module.load_state_dict(checkpoint['model']) a.load_state_dict(checkpoint['model']) # stgcn.load_state_dict(checkpoint['model'])