Ejemplo n.º 1
0
def main():
  opt = opts().parse()
  now = datetime.datetime.now()
  logger = Logger(opt.saveDir + '/logs_{}'.format(now.isoformat()))

  if opt.loadModel != 'none':
    model = torch.load(opt.loadModel).cuda()
  else:
    model = HourglassNet3D(opt.nStack, opt.nModules, opt.nFeats, opt.nRegModules).cuda()
  
  criterion = torch.nn.MSELoss().cuda()
  optimizer = torch.optim.RMSprop(model.parameters(), opt.LR, 
                                  alpha = ref.alpha, 
                                  eps = ref.epsilon, 
                                  weight_decay = ref.weightDecay, 
                                  momentum = ref.momentum)

  val_loader = torch.utils.data.DataLoader(
      H36M(opt, 'val'),
      batch_size = 1,
      shuffle = False,
      num_workers = int(ref.nThreads)
  )
  

  if opt.test:
    val(0, opt, val_loader, model, criterion)
    return

  train_loader = torch.utils.data.DataLoader(
      H36M(opt, 'train'),
      batch_size = opt.trainBatch,
      shuffle = True if opt.DEBUG == 0 else False,
      num_workers = int(ref.nThreads)
  )

  for epoch in range(1, opt.nEpochs + 1):
    loss_train, acc_train, mpjpe_train, loss3d_train = train(epoch, opt, train_loader, model, criterion, optimizer)
    logger.scalar_summary('loss_train', loss_train, epoch)
    logger.scalar_summary('acc_train', acc_train, epoch)
    logger.scalar_summary('mpjpe_train', mpjpe_train, epoch)
    logger.scalar_summary('loss3d_train', loss3d_train, epoch)
    if epoch % opt.valIntervals == 0:
      loss_val, acc_val, mpjpe_val, loss3d_val = val(epoch, opt, val_loader, model, criterion)
      logger.scalar_summary('loss_val', loss_val, epoch)
      logger.scalar_summary('acc_val', acc_val, epoch)
      logger.scalar_summary('mpjpe_val', mpjpe_val, epoch)
      logger.scalar_summary('loss3d_val', loss3d_val, epoch)
      torch.save(model, os.path.join(opt.saveDir, 'model_{}.pth'.format(epoch)))
      logger.write('{:8f} {:8f} {:8f} {:8f} {:8f} {:8f} {:8f} {:8f} \n'.format(loss_train, acc_train, mpjpe_train, loss3d_train, loss_val, acc_val, mpjpe_val, loss3d_val))
    else:
      logger.write('{:8f} {:8f} {:8f} {:8f} \n'.format(loss_train, acc_train, mpjpe_train, loss3d_train))
    adjust_learning_rate(optimizer, epoch, opt.dropLR, opt.LR)
  logger.close()
Ejemplo n.º 2
0
args.distributed = args.world_size > 1

if args.distributed:
    dist.init_process_group(backend=args.dist_backend,
                            init_method=args.dist_url,
                            world_size=args.world_size)

# create model
#if args.pretrained:
#    print("=> using pre-trained model '{}'".format(args.arch))
#    model = models.__dict__[args.arch](pretrained=True)
#else:
#    print("=> creating model '{}'".format(args.arch))
#    model = models.__dict__[args.arch]()
print('Creat model')
model = HourglassNet3D(args.nStack, args.nModules, args.nFeats,
                       args.nRegModules).cuda()
print(model)

if args.gpu is not None:
    model = model.cuda(args.gpu)
elif args.distributed:
    model.cuda()
    model = torch.nn.parallel.DistributedDataParallel(model)
else:
    if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
        model.features = torch.nn.DataParallel(model.features)
        model.cuda()
    else:
        model = torch.nn.DataParallel(model).cuda()

print('Loss function')
Ejemplo n.º 3
0
def main():
    opt = opts().parse()
    now = datetime.datetime.now()
    logger = Logger(opt.saveDir + '/logs_{}'.format(now.isoformat()))

    if opt.loadGModel != 'none':
        generator = torch.load(opt.loadGModel, map_location=lambda storage, loc:storage, pickle_module=pickle).cuda()
    else:
        generator = HourglassNet3D(opt.nStack, opt.nModules, opt.nFeats, opt.nRegModules, opt.nCamModules).cuda()

    if opt.loadDModel != 'none':
        discriminator = torch.load(opt.loadDModel, map_location=lambda storage, loc:storage, pickle_module=pickle).cuda()
    else:
        discriminator = Discriminator(3, opt.sizeLSTM).cuda()

    criterion = torch.nn.MSELoss().cuda()
    optimizer_G = torch.optim.adam(generator.parameters(), opt.ganLR, eps=ref.epsilon, weight_decay=ref.ganWeightDecay)
    optimizer_D = torch.optim.RMSprop(dis.parameters(), opt.ganLR)

    val_real_loader = torch.utils.data.DataLoader(
            TRUE(opt, 'val'), 
            batch_size = 1, 
            shuffle = False,
            num_workers = int(ref.nThreads)
            )

    val_fake_loader = torch.utils.data.DataLoader(
            Fusion(opt, 'val'), 
            batch_size = 1, 
            shuffle = False,
            num_workers = int(ref.nThreads)
            )

    if opt.test:
        val_gan(0, opt, val_real_loader, val_fake_loadre, generator, discriminator, criterion)
        return

    train_real_loader = torch.utils.data.DataLoader(
            TRUE(opt, 'train'), 
            batch_size = opt.trainBatch, 
            shuffle = True if opt.DEBUG == 0 else False,
            num_workers = int(ref.nThreads)
            )

    train_fake_loader = torch.utils.data.DataLoader(
            Fusion(opt, 'train'), 
            batch_size = opt.trainBatch, 
            shuffle = True if opt.DEBUG == 0 else False,
            num_workers = int(ref.nThreads)
            )

    for epoch in range(1, opt.nEpochs + 1):
        lossg_train, lossd_train, loss2d_train, acc_train, mpjpe_train = train_gan(epoch, opt, train_real_loader, train_fake_loader, generator, discriminator, criterion, optimizer_G, optimizer_D)
        logger.scalar_summary('lossg_train', lossg_train, epoch)
        logger.scalar_summary('lossd_train', lossd_train, epoch)
        logger.scalar_summary('loss2d_train', loss2d_train, epoch)
        logger.scalar_summary('acc_train', acc_train, epoch)
        logger.scalar_summary('mpjpe_train', mpjpe_train, epoch)
        if epoch % opt.valIntervals == 0:
            lossg_val, lossd_val, loss2d_val, acc_val, mpjpe_val = val_gan(epoch, opt, val_real_loader, val_fake_loader, generator, discriminator, criterion)
            logger.scalar_summary('lossg_val', lossg_val, epoch)
            logger.scalar_summary('lossd_val', lossd_val, epoch)
            logger.scalar_summary('loss2d_val', loss2d_val, epoch)
            logger.scalar_summary('acc_val', acc_val, epoch)
            logger.scalar_summary('mpjpe_val', mpjpe_val, epoch)
            torch.save(model, os.path.join(opt.saveDir, 'model_{}.pth'.format(epoch)))
            logger.write('{:8f} {:8f} {:8f} {:8f} {:8f} {:8f} {:8f} {:8f} {:8f} {:8f}\n'.format(lossg_train, lossd_train, loss2d_train, acc_train, mpjpe_train, lossg_val, lossd_val, loss2d_val, acc_val, mpjpe_val))
        else:
            logger.write('{:8f} {:8f} {:8f} {:8f} {:8f}\n'.format(lossg_train, lossd_train, loss2d_train, acc_train, mpjpe_train))

    logger.close()