Esempio n. 1
0
def main():
  opt = opts().parse()
  now = datetime.datetime.now()
  logger = Logger(opt.saveDir + '/logs_{}'.format(now.isoformat()))

  if opt.loadModel != 'none':
    model = torch.load(opt.loadModel).cuda()
  else:
    model = HourglassNet3D(opt.nStack, opt.nModules, opt.nFeats, opt.nRegModules).cuda()
  
  criterion = torch.nn.MSELoss().cuda()
  optimizer = torch.optim.RMSprop(model.parameters(), opt.LR, 
                                  alpha = ref.alpha, 
                                  eps = ref.epsilon, 
                                  weight_decay = ref.weightDecay, 
                                  momentum = ref.momentum)

  val_loader = torch.utils.data.DataLoader(
      H36M(opt, 'val'),
      batch_size = 1,
      shuffle = False,
      num_workers = int(ref.nThreads)
  )
  

  if opt.test:
    val(0, opt, val_loader, model, criterion)
    return

  train_loader = torch.utils.data.DataLoader(
      H36M(opt, 'train'),
      batch_size = opt.trainBatch,
      shuffle = True if opt.DEBUG == 0 else False,
      num_workers = int(ref.nThreads)
  )

  for epoch in range(1, opt.nEpochs + 1):
    loss_train, acc_train, mpjpe_train, loss3d_train = train(epoch, opt, train_loader, model, criterion, optimizer)
    logger.scalar_summary('loss_train', loss_train, epoch)
    logger.scalar_summary('acc_train', acc_train, epoch)
    logger.scalar_summary('mpjpe_train', mpjpe_train, epoch)
    logger.scalar_summary('loss3d_train', loss3d_train, epoch)
    if epoch % opt.valIntervals == 0:
      loss_val, acc_val, mpjpe_val, loss3d_val = val(epoch, opt, val_loader, model, criterion)
      logger.scalar_summary('loss_val', loss_val, epoch)
      logger.scalar_summary('acc_val', acc_val, epoch)
      logger.scalar_summary('mpjpe_val', mpjpe_val, epoch)
      logger.scalar_summary('loss3d_val', loss3d_val, epoch)
      torch.save(model, os.path.join(opt.saveDir, 'model_{}.pth'.format(epoch)))
      logger.write('{:8f} {:8f} {:8f} {:8f} {:8f} {:8f} {:8f} {:8f} \n'.format(loss_train, acc_train, mpjpe_train, loss3d_train, loss_val, acc_val, mpjpe_val, loss3d_val))
    else:
      logger.write('{:8f} {:8f} {:8f} {:8f} \n'.format(loss_train, acc_train, mpjpe_train, loss3d_train))
    adjust_learning_rate(optimizer, epoch, opt.dropLR, opt.LR)
  logger.close()
Esempio n. 2
0
    def __init__(self, opt, split):
        self.ratio3D = opt.ratio3D
        self.split = split
        self.dataset3D = H36M(opt, split)
        self.nImages3D = len(self.dataset3D)

        print('#Images3D {}'.format(self.nImages3D))
Esempio n. 3
0
def genInterpolations():
    model = torch.load(args.loadModel)
    num_samples = 100
    interpol_length = 11
    dataset = H36M(opt, 'val')
    for i in range(num_samples):
        ind1, ind2 = np.random.randint(0, len(dataset), 2)
        _, poseMap1 = dataset.__getitem__(ind1)
        _, poseMap2 = dataset.__getitem__(ind2)
        poseMap1 = poseMap1[None, ...]
        poseMap2 = poseMap2[None, ...]
        pts1 = torch.FloatTensor(getPreds(poseMap1))
        pts2 = torch.FloatTensor(getPreds(poseMap2))
        data1, data2 = Variable(pts1), Variable(pts2)
        if args.cuda:
            data1 = data1.cuda()
            data2 = data2.cuda()
        mu1, logvar1 = model.encode(data1.view(-1, 32))
        z1 = model.reparameterize(mu1, logvar1)
        mu2, logvar2 = model.encode(data2.view(-1, 32))
        z2 = model.reparameterize(mu2, logvar2)
        result = model.decode(z1)
        for j in range(1, interpol_length + 1):
            inter_z = z1 + j * (z2 - z1) / interpol_length
            interpol_pose = model.decode(inter_z)
            result = torch.cat((result, interpol_pose))
        # print(result.shape)
        save_image(makeSkel(result.data * 4, (255, 0, 0)),
                   '../exp/vae_pose/genInterpolations/interpol_' + str(i) +
                   '.png',
                   nrow=(interpol_length + 1) // 3)
Esempio n. 4
0
 def __init__(self, opt, split):
   self.ratio3D = opt.ratio3D
   self.split = split
   self.dataset3D = H36M(opt, split)
   self.dataset2D = MPII(opt, split, returnMeta = True)
   self.nImages2D = len(self.dataset2D)
   self.nImages3D = min(len(self.dataset3D), int(self.nImages2D * self.ratio3D))
   
   print('#Images2D {}, #Images3D {}'.format(self.nImages2D, self.nImages3D))
Esempio n. 5
0
def check_logic():
  preproc = nn.ModuleList([model.conv1_, model.bn1, model.relu, model.r1, model.maxpool, model.r4, model.r5 ])
  hg = model.hourglass[0]
  lower_hg = getEncoder(hg) 
  data_loader = torch.utils.data.DataLoader(
      H36M(opts, 'val'),
      batch_size = 1, 
      shuffle = False,
      num_workers = int(ref.nThreads)
  )
  for k, (input, target) in enumerate(data_loader):
      if(k>nSamples):
          break
      input_var = torch.autograd.Variable(input).float().cuda()
      for mod in preproc:
          input_var = mod(input_var)
      for mod in lower_hg:
          input_var = mod(input_var)    
      #decode  
      ups = input_var
      upper_hg = nn.ModuleList(getDecoder(hg))
      for mod in upper_hg:
          ups = mod(ups)
      Residual = model.Residual
      for j in range(nModules):
        ups = Residual[j](ups)
      lin_ = model.lin_
      ups = lin_[0](ups)
      tmpOut = model.tmpOut
      ups = tmpOut[0](ups)
      pred = eval.getPreds(ups.data.cpu().numpy())*4   
      gt = eval.getPreds(target.cpu().numpy()) * 4
      # init = getPreds(input.numpy()[:, 3:])
      debugger = Debugger()
      img = (input[0].numpy()[:3].transpose(1, 2, 0)*256).astype(np.uint8).copy()
      print(img.shape)
      debugger.addImg(img)
      debugger.addPoint2D(pred[0], (255, 0, 0))
      debugger.addPoint2D(gt[0], (0, 0, 255))
      # debugger.addPoint2D(init[0], (0, 255, 0))
      debugger.showAllImg(pause = True)
Esempio n. 6
0
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
    num_workers=args.workers, pin_memory=True, sampler=train_sampler)

val_loader = torch.utils.data.DataLoader(
    datasets.ImageFolder(valdir, transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])),
    batch_size=args.batch_size, shuffle=False,
    num_workers=args.workers, pin_memory=True)

"""
val_loader = torch.utils.data.DataLoader(H36M(args, 'val'),
                                         batch_size=1,
                                         shuffle=False,
                                         num_workers=int(ref.nThreads))

train_loader = torch.utils.data.DataLoader(
    Fusion(args, 'train'),
    batch_size=args.batch_size,
    shuffle=True if args.DEBUG == 0 else False,
    num_workers=int(ref.nThreads))

#now = datetime.datetime.now()
#logger = Logger(opt.saveDir + '/logs_{}'.format(now.isoformat()))


def train(epoch):