def main(): global args, best_acc args = parser.parse_args() args.cuda = args.cuda and torch.cuda.is_available() torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) torch.cuda.set_device(3) global plotter plotter = VisdomLinePlotter(env_name=args.name) kwargs = {'num_workers': 20, 'pin_memory': True} if args.cuda else {} print('==>Preparing data...') base_path = "./data/handpose_data/" train_loader = torch.utils.data.DataLoader( SimpleImageLoader(base_path, train=True, transform=transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ColorJitter(0.05, 0.05, 0.05, 0.05), transforms.ToTensor(), Lighting(0.1, _imagenet_pca['eigval'], _imagenet_pca['eigvec']), # transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) ])), batch_size=args.batch_size, shuffle=True, drop_last = False, **kwargs) test_loader = torch.utils.data.DataLoader( SimpleImageLoader(base_path, False, transform=transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ])), batch_size=args.batch_size, drop_last=False, **kwargs) jnet = SimpleRegression() if args.cuda: jnet.cuda() if torch.cuda.device_count() > 1 and args.parallel: jnet = nn.DataParallel(jnet) # dataParallel # This flag allows you to enable the inbuilt cudnn auto-tuner to # find the best algorithm to use for your hardware. cudnn.benchmark = True # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("==> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] jnet.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("==> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: print("==> no checkpoint found at '{}'".format(args.resume)) criterion = torch.nn.MSELoss() optimizer = optim.SGD(jnet.parameters(), lr=args.lr, momentum=args.momentum) # if isinstance(jnet, nn.DataParallel): # optimizer = nn.DataParallel(optimizer,device_ids=[0,1]) for epoch in range(1, args.epochs + 1): # train for one epoch adjust_learning_rate(jnet, optimizer, epoch) train(train_loader, jnet, criterion, optimizer, epoch) # evaluate on validation set acc = test(test_loader, jnet, criterion, epoch) # remember best acc and save checkpoint is_best = acc > best_acc best_acc = max(acc, best_acc) save_checkpoint({ 'epoch': epoch + 1, 'state_dict': jnet.state_dict(), 'best_prec1': best_acc, 'optimizer': optimizer.state_dict(), }, is_best)
def main(): global args, best_acc args = parser.parse_args() args.cuda = args.cuda and torch.cuda.is_available() torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) global plotter plotter = VisdomLinePlotter(env_name=args.name) kwargs = {'num_workers': 60, 'pin_memory': True} if args.cuda else {} print('==>Preparing data...') base_path = "./data/handpose_data/" train_loader = torch.utils.data.DataLoader( SimpleImageLoader( base_path, train=True, transform=transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ColorJitter(0.05, 0.05, 0.05, 0.05), transforms.ToTensor(), Lighting(0.1, _imagenet_pca['eigval'], _imagenet_pca['eigvec']), # transforms.Normalize(mean=[0.4822, ], std=[0.1839, ]) ])), batch_size=args.batch_size, shuffle=True, drop_last=False, **kwargs) test_loader = torch.utils.data.DataLoader( SimpleImageLoader( base_path, False, transform=transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ColorJitter(0.05, 0.05, 0.05, 0.05), transforms.ToTensor(), Lighting(0.1, _imagenet_pca['eigval'], _imagenet_pca['eigvec']), # transforms.Normalize(mean=[0.4822, ], std=[0.1839, ]) ])), batch_size=args.batch_size, shuffle=True, drop_last=False, **kwargs) if args.net == "TCN": jnet = Tripletnet() elif (args.net == "SIMPLE"): jnet = SimpleRegression() else: jnet = JointsResNetEmbedSpatial() # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] jnet.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) if args.cuda: jnet.cuda() # This flag allows you to enable the inbuilt cudnn auto-tuner to # find the best algorithm to use for your hardware. cudnn.benchmark = True if args.net == "TCN": # MarginRankingLoss Evaluation of similarity loss function criterion = torch.nn.MarginRankingLoss(margin=args.margin) optimizer = optim.SGD(jnet.parameters(), lr=args.lr, momentum=args.momentum) else: criterion = torch.nn.MSELoss() optimizer = optim.SGD(jnet.parameters(), lr=args.lr, momentum=args.momentum) for epoch in range(1, args.epochs + 1): # train for one epoch train(train_loader, jnet, criterion, optimizer, epoch) # evaluate on validation set acc = test(test_loader, jnet, criterion, epoch) # remember best acc and save checkpoint is_best = acc > best_acc best_acc = max(acc, best_acc) save_checkpoint( { 'epoch': epoch + 1, 'state_dict': jnet.state_dict(), 'best_prec1': best_acc, }, is_best)