def main(): opt = opts().parse() now = datetime.datetime.now() logger = Logger(opt.saveDir + '/logs_{}'.format(now.isoformat())) if opt.loadModel != 'none': model = torch.load(opt.loadModel).cuda() else: model = HourglassNet3D(opt.nStack, opt.nModules, opt.nFeats, opt.nRegModules).cuda() criterion = torch.nn.MSELoss().cuda() optimizer = torch.optim.RMSprop(model.parameters(), opt.LR, alpha = ref.alpha, eps = ref.epsilon, weight_decay = ref.weightDecay, momentum = ref.momentum) val_loader = torch.utils.data.DataLoader( H36M(opt, 'val'), batch_size = 1, shuffle = False, num_workers = int(ref.nThreads) ) if opt.test: val(0, opt, val_loader, model, criterion) return train_loader = torch.utils.data.DataLoader( H36M(opt, 'train'), batch_size = opt.trainBatch, shuffle = True if opt.DEBUG == 0 else False, num_workers = int(ref.nThreads) ) for epoch in range(1, opt.nEpochs + 1): loss_train, acc_train, mpjpe_train, loss3d_train = train(epoch, opt, train_loader, model, criterion, optimizer) logger.scalar_summary('loss_train', loss_train, epoch) logger.scalar_summary('acc_train', acc_train, epoch) logger.scalar_summary('mpjpe_train', mpjpe_train, epoch) logger.scalar_summary('loss3d_train', loss3d_train, epoch) if epoch % opt.valIntervals == 0: loss_val, acc_val, mpjpe_val, loss3d_val = val(epoch, opt, val_loader, model, criterion) logger.scalar_summary('loss_val', loss_val, epoch) logger.scalar_summary('acc_val', acc_val, epoch) logger.scalar_summary('mpjpe_val', mpjpe_val, epoch) logger.scalar_summary('loss3d_val', loss3d_val, epoch) torch.save(model, os.path.join(opt.saveDir, 'model_{}.pth'.format(epoch))) logger.write('{:8f} {:8f} {:8f} {:8f} {:8f} {:8f} {:8f} {:8f} \n'.format(loss_train, acc_train, mpjpe_train, loss3d_train, loss_val, acc_val, mpjpe_val, loss3d_val)) else: logger.write('{:8f} {:8f} {:8f} {:8f} \n'.format(loss_train, acc_train, mpjpe_train, loss3d_train)) adjust_learning_rate(optimizer, epoch, opt.dropLR, opt.LR) logger.close()
args.distributed = args.world_size > 1 if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size) # create model #if args.pretrained: # print("=> using pre-trained model '{}'".format(args.arch)) # model = models.__dict__[args.arch](pretrained=True) #else: # print("=> creating model '{}'".format(args.arch)) # model = models.__dict__[args.arch]() print('Creat model') model = HourglassNet3D(args.nStack, args.nModules, args.nFeats, args.nRegModules).cuda() print(model) if args.gpu is not None: model = model.cuda(args.gpu) elif args.distributed: model.cuda() model = torch.nn.parallel.DistributedDataParallel(model) else: if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): model.features = torch.nn.DataParallel(model.features) model.cuda() else: model = torch.nn.DataParallel(model).cuda() print('Loss function')
def main(): opt = opts().parse() now = datetime.datetime.now() logger = Logger(opt.saveDir + '/logs_{}'.format(now.isoformat())) if opt.loadGModel != 'none': generator = torch.load(opt.loadGModel, map_location=lambda storage, loc:storage, pickle_module=pickle).cuda() else: generator = HourglassNet3D(opt.nStack, opt.nModules, opt.nFeats, opt.nRegModules, opt.nCamModules).cuda() if opt.loadDModel != 'none': discriminator = torch.load(opt.loadDModel, map_location=lambda storage, loc:storage, pickle_module=pickle).cuda() else: discriminator = Discriminator(3, opt.sizeLSTM).cuda() criterion = torch.nn.MSELoss().cuda() optimizer_G = torch.optim.adam(generator.parameters(), opt.ganLR, eps=ref.epsilon, weight_decay=ref.ganWeightDecay) optimizer_D = torch.optim.RMSprop(dis.parameters(), opt.ganLR) val_real_loader = torch.utils.data.DataLoader( TRUE(opt, 'val'), batch_size = 1, shuffle = False, num_workers = int(ref.nThreads) ) val_fake_loader = torch.utils.data.DataLoader( Fusion(opt, 'val'), batch_size = 1, shuffle = False, num_workers = int(ref.nThreads) ) if opt.test: val_gan(0, opt, val_real_loader, val_fake_loadre, generator, discriminator, criterion) return train_real_loader = torch.utils.data.DataLoader( TRUE(opt, 'train'), batch_size = opt.trainBatch, shuffle = True if opt.DEBUG == 0 else False, num_workers = int(ref.nThreads) ) train_fake_loader = torch.utils.data.DataLoader( Fusion(opt, 'train'), batch_size = opt.trainBatch, shuffle = True if opt.DEBUG == 0 else False, num_workers = int(ref.nThreads) ) for epoch in range(1, opt.nEpochs + 1): lossg_train, lossd_train, loss2d_train, acc_train, mpjpe_train = train_gan(epoch, opt, train_real_loader, train_fake_loader, generator, discriminator, criterion, optimizer_G, optimizer_D) logger.scalar_summary('lossg_train', lossg_train, epoch) logger.scalar_summary('lossd_train', lossd_train, epoch) logger.scalar_summary('loss2d_train', loss2d_train, epoch) logger.scalar_summary('acc_train', acc_train, epoch) logger.scalar_summary('mpjpe_train', mpjpe_train, epoch) if epoch % opt.valIntervals == 0: lossg_val, lossd_val, loss2d_val, acc_val, mpjpe_val = val_gan(epoch, opt, val_real_loader, val_fake_loader, generator, discriminator, criterion) logger.scalar_summary('lossg_val', lossg_val, epoch) logger.scalar_summary('lossd_val', lossd_val, epoch) logger.scalar_summary('loss2d_val', loss2d_val, epoch) logger.scalar_summary('acc_val', acc_val, epoch) logger.scalar_summary('mpjpe_val', mpjpe_val, epoch) torch.save(model, os.path.join(opt.saveDir, 'model_{}.pth'.format(epoch))) logger.write('{:8f} {:8f} {:8f} {:8f} {:8f} {:8f} {:8f} {:8f} {:8f} {:8f}\n'.format(lossg_train, lossd_train, loss2d_train, acc_train, mpjpe_train, lossg_val, lossd_val, loss2d_val, acc_val, mpjpe_val)) else: logger.write('{:8f} {:8f} {:8f} {:8f} {:8f}\n'.format(lossg_train, lossd_train, loss2d_train, acc_train, mpjpe_train)) logger.close()