Пример #1
0
    def __init__(self, args):
        print('Starting preparing ...')
        self.args = args

        # Program setting
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        cudnn.benchmark = True
        self.device = U.check_gpu(args.gpus)

        # Data Loader Setting
        if args.subset in ['cs', 'cv']:
            num_class = 60
        elif args.subset in ['csub', 'cset']:
            num_class = 120
        else:
            raise ValueError('Do NOT exist this subset: {}'.format(args.subset))
        data_shape = (3, args.max_frame, 25, 2)
        transform = transforms.Compose([
            Data_transform(args.data_transform), 
            Occlusion_part(args.occlusion_part), 
            Occlusion_time(args.occlusion_time), 
            Occlusion_block(args.occlusion_block), 
            Occlusion_rand(args.occlusion_rand, data_shape),
            Jittering_joint(args.jittering_joint, data_shape, sigma=args.sigma),
            Jittering_frame(args.jittering_frame, data_shape, sigma=args.sigma),
        ])
        self.train_loader = DataLoader(NTU('train', args.subset, data_shape, transform=transform),
                                       batch_size=args.batch_size, num_workers=2*len(args.gpus),
                                       pin_memory=True, shuffle=True, drop_last=True)
        self.eval_loader = DataLoader(NTU('eval', args.subset, data_shape, transform=transform),
                                      batch_size=args.batch_size, num_workers=2*len(args.gpus),
                                      pin_memory=True, shuffle=False, drop_last=False)
        if args.data_transform:
            data_shape = (9, args.max_frame, 25, 2)

        # Graph Setting
        graph = Graph(max_hop=args.gcn_kernel_size[1])
        A = torch.tensor(graph.A, dtype=torch.float32, requires_grad=False).to(self.device)

        # Model Setting
        self.model_name = str(args.config)+'_'+str(args.model_stream)+'s_RA-GCN_NTU'+args.subset
        self.model = RA_GCN(data_shape, num_class, A, args.drop_prob, args.gcn_kernel_size,
            args.model_stream, args.subset, args.pretrained).to(self.device)
        self.model = nn.DataParallel(self.model)

        # Optimizer Setting
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=args.learning_rate,
            momentum=0.9, weight_decay=0.0001, nesterov=True)

        # Loss Function Setting
        self.loss_func = nn.CrossEntropyLoss()

        # Mask Function Setting
        self.mask_func = Mask(args.model_stream, self.model.module)

        print('Successful!\n')
Пример #2
0
from src.nets import ST_GCN
import time
import torch
import numpy as np
from torch.backends import cudnn
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from src.graph import Graph
from src.utils import check_gpu

gcn_kernel_size = [5, 2]
graph = Graph(max_hop=gcn_kernel_size[1])
# device = check_gpu([0,1])
device = check_gpu([0])
A = torch.tensor(graph.A, dtype=torch.float32, requires_grad=False).to(device)

a = ST_GCN((9, 300, 25, 2), 120, A, 0.5, gcn_kernel_size)
checkpoint = torch.load('./models/baseline_NTUcset.pth.tar')

# a = ST_GCN((3,300,25,2), 60, A, 0.5, gcn_kernel_size)
# a = nn.DataParallel(a)

# a.load_state_dict(checkpoint['model'])
# a.module.load_state_dict(checkpoint)
# a.module.load_state_dict(checkpoint['model'])
a.load_state_dict(checkpoint['model'])

# stgcn.load_state_dict(checkpoint['model'])