Exemple #1
0
def evaluate(args):
  assert torch.cuda.is_available(), 'CUDA is not available.'
  torch.backends.cudnn.enabled   = True
  torch.backends.cudnn.benchmark = True

  print ('The image is {:}'.format(args.image))
  print ('The model is {:}'.format(args.model))
  snapshot = Path(args.model)
  assert snapshot.exists(), 'The model path {:} does not exist'
  print ('The face bounding box is {:}'.format(args.face))
  assert len(args.face) == 4, 'Invalid face input : {:}'.format(args.face)
  snapshot = torch.load(snapshot)

  # General Data Argumentation
  mean_fill   = tuple( [int(x*255) for x in [0.485, 0.456, 0.406] ] )
  normalize   = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                      std=[0.229, 0.224, 0.225])

  param = snapshot['args']
  import pdb; pdb.set_trace()
  eval_transform  = transforms.Compose([transforms.PreCrop(param.pre_crop_expand), transforms.TrainScale2WH((param.crop_width, param.crop_height)), transforms.ToTensor(), normalize])
  model_config = load_configure(param.model_config, None)
  dataset = Dataset(eval_transform, param.sigma, model_config.downsample, param.heatmap_type, param.data_indicator)
  dataset.reset(param.num_pts)
  
  net = obtain_model(model_config, param.num_pts + 1)
  net = net.cuda()
  weights = remove_module_dict(snapshot['detector'])
  net.load_state_dict(weights)
  print ('Prepare input data')
  [image, _, _, _, _, _, cropped_size], meta = dataset.prepare_input(args.image, args.face)
  inputs = image.unsqueeze(0).cuda()
  # network forward
  with torch.no_grad():
    batch_heatmaps, batch_locs, batch_scos = net(inputs)
  # obtain the locations on the image in the orignial size
  cpu = torch.device('cpu')
  np_batch_locs, np_batch_scos, cropped_size = batch_locs.to(cpu).numpy(), batch_scos.to(cpu).numpy(), cropped_size.numpy()
  locations, scores = np_batch_locs[0,:-1,:], np.expand_dims(np_batch_scos[0,:-1], -1)

  scale_h, scale_w = cropped_size[0] * 1. / inputs.size(-2) , cropped_size[1] * 1. / inputs.size(-1)

  locations[:, 0], locations[:, 1] = locations[:, 0] * scale_w + cropped_size[2], locations[:, 1] * scale_h + cropped_size[3]
  prediction = np.concatenate((locations, scores), axis=1).transpose(1,0)

  print ('the coordinates for {:} facial landmarks:'.format(param.num_pts))
  for i in range(param.num_pts):
    point = prediction[:, i]
    print ('the {:02d}/{:02d}-th point : ({:.1f}, {:.1f}), score = {:.2f}'.format(i, param.num_pts, float(point[0]), float(point[1]), float(point[2])))

  if args.save:
    resize = 512
    image = draw_image_by_points(args.image, prediction, 2, (255, 0, 0), args.face, resize)
    image.save(args.save)
    print ('save the visualization results into {:}'.format(args.save))
  else:
    print ('ignore the visualization procedure')
def evaluate(args):
  assert torch.cuda.is_available(), 'CUDA is not available.'
  torch.backends.cudnn.enabled   = True
  torch.backends.cudnn.benchmark = True

  print ('The image is {:}'.format(args.image))
  print ('The model is {:}'.format(args.model))
  snapshot = Path(args.model)
  assert snapshot.exists(), 'The model path {:} does not exist'
  print ('The face bounding box is {:}'.format(args.face))
  assert len(args.face) == 4, 'Invalid face input : {:}'.format(args.face)
  snapshot = torch.load(snapshot)

  # General Data Argumentation
  mean_fill   = tuple( [int(x*255) for x in [0.485, 0.456, 0.406] ] )
  normalize   = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                      std=[0.229, 0.224, 0.225])

  param = snapshot['args']
  eval_transform  = transforms.Compose([transforms.PreCrop(param.pre_crop_expand), transforms.TrainScale2WH((param.crop_width, param.crop_height)),  transforms.ToTensor(), normalize])
  model_config = load_configure(param.model_config, None)
  dataset = Dataset(eval_transform, param.sigma, model_config.downsample, param.heatmap_type, param.data_indicator)
  dataset.reset(param.num_pts)
  
  net = obtain_model(model_config, param.num_pts + 1)
  net = net.cuda()
  weights = remove_module_dict(snapshot['state_dict'])
  net.load_state_dict(weights)
  print ('Prepare input data')
  [image, _, _, _, _, _, cropped_size], meta = dataset.prepare_input(args.image, args.face)
  inputs = image.unsqueeze(0).cuda()
  # network forward
  with torch.no_grad():
    batch_heatmaps, batch_locs, batch_scos = net(inputs)
  # obtain the locations on the image in the orignial size
  cpu = torch.device('cpu')
  np_batch_locs, np_batch_scos, cropped_size = batch_locs.to(cpu).numpy(), batch_scos.to(cpu).numpy(), cropped_size.numpy()
  locations, scores = np_batch_locs[0,:-1,:], np.expand_dims(np_batch_scos[0,:-1], -1)

  scale_h, scale_w = cropped_size[0] * 1. / inputs.size(-2) , cropped_size[1] * 1. / inputs.size(-1)

  locations[:, 0], locations[:, 1] = locations[:, 0] * scale_w + cropped_size[2], locations[:, 1] * scale_h + cropped_size[3]
  prediction = np.concatenate((locations, scores), axis=1).transpose(1,0)

  print ('the coordinates for {:} facial landmarks:'.format(param.num_pts))
  for i in range(param.num_pts):
    point = prediction[:, i]
    print ('the {:02d}/{:02d}-th point : ({:.1f}, {:.1f}), score = {:.2f}'.format(i, param.num_pts, float(point[0]), float(point[1]), float(point[2])))

  if args.save:
    resize = 512
    image = draw_image_by_points(args.image, prediction, 2, (255, 0, 0), args.face, resize)
    image.save(args.save)
    print ('save the visualization results into {:}'.format(args.save))
  else:
    print ('ignore the visualization procedure')
def crop_style(list_file, num_pts, save_dir):
    #style = 'Original'
    #save_dir = 'cache/{}'.format(style)
    print('crop face images into {}'.format(save_dir))
    if not osp.isdir(save_dir): os.makedirs(save_dir)
    transform = transforms.Compose(
        [transforms.PreCrop(0.2),
         transforms.TrainScale2WH((256, 256))])
    data = GenDataset(transform, 1, 8, 'gaussian', 'test')
    data.load_list(list_file, num_pts, True)
    #loader = torch.utils.data.DataLoader(data, batch_size=1, shuffle=False, num_workers=12, pin_memory=True)
    for i, tempx in enumerate(data):
        image = tempx[0]
        #points = tempx[3]
        basename = osp.basename(data.datas[i])
        save_name = osp.join(save_dir, basename)
        image.save(save_name)
        if i % PRINT_GAP == 0:
            print('--->>> process the {:4d}/{:4d}-th image'.format(
                i, len(data)))
def main(args):
  assert torch.cuda.is_available(), 'CUDA is not available.'
  torch.backends.cudnn.enabled   = True
  torch.backends.cudnn.benchmark = True
  prepare_seed(args.rand_seed)

  logstr = 'seed-{:}-time-{:}'.format(args.rand_seed, time_for_file())
  logger = Logger(args.save_path, logstr)
  logger.log('Main Function with logger : {:}'.format(logger))
  logger.log('Arguments : -------------------------------')
  for name, value in args._get_kwargs():
    logger.log('{:16} : {:}'.format(name, value))
  logger.log("Python  version : {}".format(sys.version.replace('\n', ' ')))
  logger.log("Pillow  version : {}".format(PIL.__version__))
  logger.log("PyTorch version : {}".format(torch.__version__))
  logger.log("cuDNN   version : {}".format(torch.backends.cudnn.version()))

  # General Data Argumentation
  mean_fill   = tuple( [int(x*255) for x in [0.485, 0.456, 0.406] ] )
  normalize   = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                      std=[0.229, 0.224, 0.225])
  assert args.arg_flip == False, 'The flip is : {}, rotate is {}'.format(args.arg_flip, args.rotate_max)
  train_transform  = [transforms.PreCrop(args.pre_crop_expand)]
  train_transform += [transforms.TrainScale2WH((args.crop_width, args.crop_height))]
  train_transform += [transforms.AugScale(args.scale_prob, args.scale_min, args.scale_max)]
  #if args.arg_flip:
  #  train_transform += [transforms.AugHorizontalFlip()]
  if args.rotate_max:
    train_transform += [transforms.AugRotate(args.rotate_max)]
  train_transform += [transforms.AugCrop(args.crop_width, args.crop_height, args.crop_perturb_max, mean_fill)]
  train_transform += [transforms.ToTensor(), normalize]
  train_transform  = transforms.Compose( train_transform )

  eval_transform  = transforms.Compose([transforms.PreCrop(args.pre_crop_expand), transforms.TrainScale2WH((args.crop_width, args.crop_height)),  transforms.ToTensor(), normalize])
  assert (args.scale_min+args.scale_max) / 2 == args.scale_eval, 'The scale is not ok : {},{} vs {}'.format(args.scale_min, args.scale_max, args.scale_eval)
  
  # Model Configure Load
  model_config = load_configure(args.model_config, logger)
  args.sigma   = args.sigma * args.scale_eval
  logger.log('Real Sigma : {:}'.format(args.sigma))

  # Training Dataset
  train_data   = Dataset(train_transform, args.sigma, model_config.downsample, args.heatmap_type, args.data_indicator)
  train_data.load_list(args.train_lists, args.num_pts, True)
  train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)


  # Evaluation Dataloader
  eval_loaders = []
  if args.eval_vlists is not None:
    for eval_vlist in args.eval_vlists:
      eval_vdata = Dataset(eval_transform, args.sigma, model_config.downsample, args.heatmap_type, args.data_indicator)
      eval_vdata.load_list(eval_vlist, args.num_pts, True)
      eval_vloader = torch.utils.data.DataLoader(eval_vdata, batch_size=args.batch_size, shuffle=False,
                                                 num_workers=args.workers, pin_memory=True)
      eval_loaders.append((eval_vloader, True))

  if args.eval_ilists is not None:
    for eval_ilist in args.eval_ilists:
      eval_idata = Dataset(eval_transform, args.sigma, model_config.downsample, args.heatmap_type, args.data_indicator)
      eval_idata.load_list(eval_ilist, args.num_pts, True)
      eval_iloader = torch.utils.data.DataLoader(eval_idata, batch_size=args.batch_size, shuffle=False,
                                                 num_workers=args.workers, pin_memory=True)
      eval_loaders.append((eval_iloader, False))

  # Define network
  logger.log('configure : {:}'.format(model_config))
  net = obtain_model(model_config, args.num_pts + 1)
  assert model_config.downsample == net.downsample, 'downsample is not correct : {} vs {}'.format(model_config.downsample, net.downsample)
  logger.log("=> network :\n {}".format(net))

  logger.log('Training-data : {:}'.format(train_data))
  for i, eval_loader in enumerate(eval_loaders):
    eval_loader, is_video = eval_loader
    logger.log('The [{:2d}/{:2d}]-th testing-data [{:}] = {:}'.format(i, len(eval_loaders), 'video' if is_video else 'image', eval_loader.dataset))
    
  logger.log('arguments : {:}'.format(args))

  opt_config = load_configure(args.opt_config, logger)

  if hasattr(net, 'specify_parameter'):
    net_param_dict = net.specify_parameter(opt_config.LR, opt_config.Decay)
  else:
    net_param_dict = net.parameters()

  optimizer, scheduler, criterion = obtain_optimizer(net_param_dict, opt_config, logger)
  logger.log('criterion : {:}'.format(criterion))
  net, criterion = net.cuda(), criterion.cuda()
  net = torch.nn.DataParallel(net)

  last_info = logger.last_info()
  if last_info.exists():
    logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
    last_info = torch.load(last_info)
    start_epoch = last_info['epoch'] + 1
    checkpoint  = torch.load(last_info['last_checkpoint'])
    assert last_info['epoch'] == checkpoint['epoch'], 'Last-Info is not right {:} vs {:}'.format(last_info, checkpoint['epoch'])
    net.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    scheduler.load_state_dict(checkpoint['scheduler'])
    logger.log("=> load-ok checkpoint '{:}' (epoch {:}) done" .format(logger.last_info(), checkpoint['epoch']))
  else:
    logger.log("=> do not find the last-info file : {:}".format(last_info))
    start_epoch = 0


  if args.eval_once:
    logger.log("=> only evaluate the model once")
    eval_results = eval_all(args, eval_loaders, net, criterion, 'eval-once', logger, opt_config)
    logger.close() ; return


  # Main Training and Evaluation Loop
  start_time = time.time()
  epoch_time = AverageMeter()
  for epoch in range(start_epoch, opt_config.epochs):

    scheduler.step()
    need_time = convert_secs2time(epoch_time.avg * (opt_config.epochs-epoch), True)
    epoch_str = 'epoch-{:03d}-{:03d}'.format(epoch, opt_config.epochs)
    LRs       = scheduler.get_lr()
    logger.log('\n==>>{:s} [{:s}], [{:s}], LR : [{:.5f} ~ {:.5f}], Config : {:}'.format(time_string(), epoch_str, need_time, min(LRs), max(LRs), opt_config))

    # train for one epoch
    train_loss, train_nme = train(args, train_loader, net, criterion, optimizer, epoch_str, logger, opt_config)
    # log the results    
    logger.log('==>>{:s} Train [{:}] Average Loss = {:.6f}, NME = {:.2f}'.format(time_string(), epoch_str, train_loss, train_nme*100))

    # remember best prec@1 and save checkpoint
    save_path = save_checkpoint({
          'epoch': epoch,
          'args' : deepcopy(args),
          'arch' : model_config.arch,
          'state_dict': net.state_dict(),
          'scheduler' : scheduler.state_dict(),
          'optimizer' : optimizer.state_dict(),
          }, logger.path('model') / '{:}-{:}.pth'.format(model_config.arch, epoch_str), logger)

    last_info = save_checkpoint({
          'epoch': epoch,
          'last_checkpoint': save_path,
          }, logger.last_info(), logger)

    eval_results = eval_all(args, eval_loaders, net, criterion, epoch_str, logger, opt_config)
    
    # measure elapsed time
    epoch_time.update(time.time() - start_time)
    start_time = time.time()

  logger.close()
Exemple #5
0
snapshot = Path(model_path)
assert snapshot.exists(), 'The model path {:} does not exist'
snapshot = torch.load(snapshot)

mean_fill = tuple([int(x * 255) for x in [0.485, 0.456, 0.406]])
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

param = snapshot['args']
eval_transform = transforms.Compose([
    transforms.PreCrop(param.pre_crop_expand),
    transforms.TrainScale2WH((param.crop_width, param.crop_height)),
    transforms.ToTensor(), normalize
])
model_config = load_configure(param.model_config, None)
dataset = Dataset(eval_transform, param.sigma, model_config.downsample,
                  param.heatmap_type, param.data_indicator)
dataset.reset(param.num_pts)

net = obtain_model(model_config, param.num_pts + 1)
net = net.cuda()
#import pdb; pdb.set_trace()
try:
    weights = remove_module_dict(snapshot['detector'])
except:
    weights = remove_module_dict(snapshot['state_dict'])
net.load_state_dict(weights)


def evaluate(args):
    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled = True
Exemple #6
0
def evaluate(args):
    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    model_name = os.path.split(args.model)[-1]
    onnx_name = os.path.splitext(model_name)[0] + ".onnx"

    print('The model is {:}'.format(args.model))
    print('Model name is {:} \nOutput onnx file is {:}'.format(
        model_name, onnx_name))

    snapshot = Path(args.model)
    assert snapshot.exists(), 'The model does not exist {:}'
    #print('Output onnx file is {:}'.format(onnx_name))
    snapshot = torch.load(snapshot)

    # General Data Argumentation
    mean_fill = tuple([int(x * 255) for x in [0.485, 0.456, 0.406]])
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    param = snapshot['args']
    print(param)

    eval_transform = transforms.Compose([
        transforms.PreCrop(param.pre_crop_expand),
        transforms.TrainScale2WH((param.crop_width, param.crop_height)),
        transforms.ToTensor(), normalize
    ])
    model_config = load_configure(param.model_config, None)
    print(model_config)

    dataset = Dataset(eval_transform, param.sigma, model_config.downsample,
                      param.heatmap_type, param.data_indicator)
    dataset.reset(param.num_pts)

    net = obtain_model(model_config, param.num_pts + 1)
    net = net
    weights = remove_module_dict(snapshot['state_dict'])

    nu_weights = {}
    for key, val in weights.items():
        nu_weights[key.split('detector.')[-1]] = val
        print(key.split('detector.')[-1])
    weights = nu_weights

    net.load_state_dict(weights)

    input_name = ['image_in']
    output_name = ['locs', 'scors', 'crap']

    im = cv2.imread('Menpo51220/val/0000018.jpg')

    imshape = im.shape
    face = [0, 0, imshape[0], imshape[1]]
    [image, _, _, _, _, _,
     cropped_size], meta = dataset.prepare_input('Menpo51220/val/0000018.jpg',
                                                 face)
    dummy_input = torch.randn(1,
                              3,
                              256,
                              256,
                              requires_grad=True,
                              dtype=torch.float32)
    input(dummy_input.dtype)
    #input('imcrap')

    inputs = image.unsqueeze(0)
    out_in = inputs.data.numpy()
    with open('pick.pick', 'wb') as crap:
        pickle.dump(out_in, crap)

    with torch.no_grad():
        batch_locs, batch_scos, heatmap = net(inputs)
        torch.onnx.export(net.cuda(),
                          dummy_input.cuda(),
                          onnx_name,
                          verbose=True,
                          input_names=input_name,
                          output_names=output_name,
                          export_params=True)
        print(batch_locs)
        print(batch_scos)
        print(heatmap)
    cpu = torch.device('cpu')
    np_batch_locs, np_batch_scos, cropped_size = batch_locs.to(
        cpu).numpy(), batch_scos.to(cpu).numpy(), cropped_size.numpy()
    locations = np_batch_locs[:-1, :]
    scores = np.expand_dims(np_batch_scos[:-1], -1)

    scale_h, scale_w = cropped_size[0] * 1. / inputs.size(
        -2), cropped_size[1] * 1. / inputs.size(-1)

    locations[:,
              0], locations[:,
                            1] = locations[:,
                                           0] * scale_w, locations[:,
                                                                   1] * scale_h
    prediction = np.concatenate((locations, scores), axis=1).transpose(1, 0)

    pred_pts = np.transpose(prediction, [1, 0])
    pred_pts = pred_pts[:, :-1]
    #print(pred_pts)
    sim = draw_pts(im, pred_pts=pred_pts, get_l1e=False)
    cv2.imwrite('py_0.jpg', sim)
Exemple #7
0
def train(args):
  
  assert torch.cuda.is_available(), 'CUDA is not available.'
  torch.backends.cudnn.enabled   = True
  torch.backends.cudnn.benchmark = True
  
  
  print('Arguments : -------------------------------')
  for name, value in args._get_kwargs():
    print('{:16} : {:}'.format(name, value))
    
    
  # Data Augmentation    
  mean_fill   = tuple( [int(x*255) for x in [0.485, 0.456, 0.406] ] )
  normalize   = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                      std=[0.229, 0.224, 0.225])


  # train_transform = [transforms.AugTransBbox(1, 0.5)]
  train_transform = [transforms.PreCrop(args.pre_crop_expand)]
  
  train_transform += [transforms.TrainScale2WH((1024, 1024))]
  train_transform += [transforms.AugHorizontalFlip(args.flip_prob)]

  train_transform += [transforms.ToTensor()]
  train_transform  = transforms.Compose( train_transform )


  # Training datasets
  train_data = GeneralDataset(args.num_pts, train_transform, args.train_lists)
  train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)

  
    
  net = Model(args.num_pts)
  
  # print(len(net.children()))
  #for m in net.children():
  #  print(type(m))
  criterion = wing_loss(args) 

  optimizer = torch.optim.SGD(net.parameters(), lr=args.LR, momentum=args.momentum,
                          weight_decay=args.decay, nesterov=args.nesterov)
    
  scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.schedule, gamma=args.gamma)
  
  net = net.cuda()
  net = torch.nn.DataParallel(net)
    
    
    
  print('--------------', len(train_loader))
  for epoch in range(3):
    break
    for i , (inputs, target, mask, cropped_size) in enumerate(train_loader):
        
        
      target = target.squeeze(1)
      inputs = inputs.cuda()
      target = target.cuda()
      mask = mask.cuda()
      prediction = net(inputs)            
      loss = criterion(prediction, target, mask) 
      
      nums_img = inputs.size()[0]
      for j in range(nums_img): 
        temp_img = inputs[j].permute(1,2,0)
        temp_img = temp_img.mul(255).numpy()
        temp_img = cv2.cvtColor(temp_img, cv2.COLOR_RGB2BGR)

        pts = []
        for d in range(args.num_pts):
          pts.append((target[j][0][2*d].item(), target[j][0][2*d+1].item()))
        bbox = [int(index[0].item())  for index in meta]
        #print(pts)
        draw_points(temp_img, pts, (0, 255, 255))
        #draw_points(temp_img, [(bbox[0],bbox[1])], (0, 0, 255))
        #draw_points(temp_img, [(bbox[2],bbox[3])], (0, 0, 255))
        cv2.rectangle(temp_img,(bbox[0],bbox[1]),(bbox[2],bbox[3]),(0,255,0),4)
        cv2.imwrite('{}-{}-{}.jpg'.format(epoch,i,j), temp_img)
        # assert 1==0
      #if i > 5:
      #  break
  for a, v in enumerate(train_data.data_value):
   
    image = cv2.imread(v['image_path'])
    meta = v['meta']
    # bbox = v['bbox']
    bbox = v['meta'].get_box()
    pts = []
    for d in range(args.num_pts):
      pts.append((meta.points[0, d], meta.points[1, d]))
    draw_points(image, pts, (0, 255, 255))
    cv2.rectangle(image,(int(bbox[0]), int(bbox[1])),(int(bbox[2]), int(bbox[3])),(0,255,0),4)
    cv2.imwrite('ori_{}.jpg'.format(a), image)
def main(args):
    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    prepare_seed(args.rand_seed)

    logstr = 'seed-{:}-time-{:}'.format(args.rand_seed, time_for_file())
    logger = Logger(args.save_path, logstr)
    logger.log('Main Function with logger : {:}'.format(logger))
    logger.log('Arguments : -------------------------------')
    for name, value in args._get_kwargs():
        logger.log('{:16} : {:}'.format(name, value))
    logger.log("Python  version : {}".format(sys.version.replace('\n', ' ')))
    logger.log("Pillow  version : {}".format(PIL.__version__))
    logger.log("PyTorch version : {}".format(torch.__version__))
    logger.log("cuDNN   version : {}".format(torch.backends.cudnn.version()))

    # General Data Argumentation
    mean_fill = tuple([int(x * 255) for x in [0.485, 0.456, 0.406]])
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    assert args.arg_flip == False, 'The flip is : {}, rotate is {}'.format(
        args.arg_flip, args.rotate_max)
    train_transform = [transforms.PreCrop(args.pre_crop_expand)]
    train_transform += [
        transforms.TrainScale2WH((args.crop_width, args.crop_height))
    ]
    train_transform += [
        transforms.AugScale(args.scale_prob, args.scale_min, args.scale_max)
    ]
    #if args.arg_flip:
    #  train_transform += [transforms.AugHorizontalFlip()]
    if args.rotate_max:
        train_transform += [transforms.AugRotate(args.rotate_max)]
    train_transform += [
        transforms.AugCrop(args.crop_width, args.crop_height,
                           args.crop_perturb_max, mean_fill)
    ]
    train_transform += [transforms.ToTensor(), normalize]
    train_transform = transforms.Compose(train_transform)

    eval_transform = transforms.Compose([
        transforms.PreCrop(args.pre_crop_expand),
        transforms.TrainScale2WH((args.crop_width, args.crop_height)),
        transforms.ToTensor(), normalize
    ])
    assert (
        args.scale_min + args.scale_max
    ) / 2 == args.scale_eval, 'The scale is not ok : {},{} vs {}'.format(
        args.scale_min, args.scale_max, args.scale_eval)

    # Model Configure Load
    model_config = load_configure(args.model_config, logger)
    args.sigma = args.sigma * args.scale_eval
    logger.log('Real Sigma : {:}'.format(args.sigma))

    # Training Dataset
    train_data = VDataset(train_transform, args.sigma, model_config.downsample,
                          args.heatmap_type, args.data_indicator,
                          args.video_parser)
    train_data.load_list(args.train_lists, args.num_pts, True)
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    # Evaluation Dataloader
    eval_loaders = []
    if args.eval_vlists is not None:
        for eval_vlist in args.eval_vlists:
            eval_vdata = IDataset(eval_transform, args.sigma,
                                  model_config.downsample, args.heatmap_type,
                                  args.data_indicator)
            eval_vdata.load_list(eval_vlist, args.num_pts, True)
            eval_vloader = torch.utils.data.DataLoader(
                eval_vdata,
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=args.workers,
                pin_memory=True)
            eval_loaders.append((eval_vloader, True))

    if args.eval_ilists is not None:
        for eval_ilist in args.eval_ilists:
            eval_idata = IDataset(eval_transform, args.sigma,
                                  model_config.downsample, args.heatmap_type,
                                  args.data_indicator)
            eval_idata.load_list(eval_ilist, args.num_pts, True)
            eval_iloader = torch.utils.data.DataLoader(
                eval_idata,
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=args.workers,
                pin_memory=True)
            eval_loaders.append((eval_iloader, False))

    # Define network
    lk_config = load_configure(args.lk_config, logger)
    logger.log('model configure : {:}'.format(model_config))
    logger.log('LK configure : {:}'.format(lk_config))
    net = obtain_model(model_config, lk_config, args.num_pts + 1)
    assert model_config.downsample == net.downsample, 'downsample is not correct : {} vs {}'.format(
        model_config.downsample, net.downsample)
    logger.log("=> network :\n {}".format(net))

    logger.log('Training-data : {:}'.format(train_data))
    for i, eval_loader in enumerate(eval_loaders):
        eval_loader, is_video = eval_loader
        logger.log('The [{:2d}/{:2d}]-th testing-data [{:}] = {:}'.format(
            i, len(eval_loaders), 'video' if is_video else 'image',
            eval_loader.dataset))

    logger.log('arguments : {:}'.format(args))

    opt_config = load_configure(args.opt_config, logger)

    if hasattr(net, 'specify_parameter'):
        net_param_dict = net.specify_parameter(opt_config.LR, opt_config.Decay)
    else:
        net_param_dict = net.parameters()

    optimizer, scheduler, criterion = obtain_optimizer(net_param_dict,
                                                       opt_config, logger)
    logger.log('criterion : {:}'.format(criterion))
    net, criterion = net.cuda(), criterion.cuda()
    net = torch.nn.DataParallel(net)

    last_info = logger.last_info()
    if last_info.exists():
        logger.log("=> loading checkpoint of the last-info '{:}' start".format(
            last_info))
        last_info = torch.load(last_info)
        start_epoch = last_info['epoch'] + 1
        checkpoint = torch.load(last_info['last_checkpoint'])
        assert last_info['epoch'] == checkpoint[
            'epoch'], 'Last-Info is not right {:} vs {:}'.format(
                last_info, checkpoint['epoch'])
        net.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        logger.log("=> load-ok checkpoint '{:}' (epoch {:}) done".format(
            logger.last_info(), checkpoint['epoch']))
    elif args.init_model is not None:
        init_model = Path(args.init_model)
        assert init_model.exists(), 'init-model {:} does not exist'.format(
            init_model)
        checkpoint = torch.load(init_model)
        checkpoint = remove_module_dict(checkpoint['state_dict'], True)
        net.module.detector.load_state_dict(checkpoint)
        logger.log("=> initialize the detector : {:}".format(init_model))
        start_epoch = 0
    else:
        logger.log("=> do not find the last-info file : {:}".format(last_info))
        start_epoch = 0

    detector = torch.nn.DataParallel(net.module.detector)

    eval_results = eval_all(args, eval_loaders, detector, criterion,
                            'start-eval', logger, opt_config)
    if args.eval_once:
        logger.log("=> only evaluate the model once")
        logger.close()
        return

    # Main Training and Evaluation Loop
    start_time = time.time()
    epoch_time = AverageMeter()
    for epoch in range(start_epoch, opt_config.epochs):

        scheduler.step()
        need_time = convert_secs2time(
            epoch_time.avg * (opt_config.epochs - epoch), True)
        epoch_str = 'epoch-{:03d}-{:03d}'.format(epoch, opt_config.epochs)
        LRs = scheduler.get_lr()
        logger.log(
            '\n==>>{:s} [{:s}], [{:s}], LR : [{:.5f} ~ {:.5f}], Config : {:}'.
            format(time_string(), epoch_str, need_time, min(LRs), max(LRs),
                   opt_config))

        # train for one epoch
        train_loss = train(args, train_loader, net, criterion, optimizer,
                           epoch_str, logger, opt_config, lk_config,
                           epoch >= lk_config.start)
        # log the results
        logger.log('==>>{:s} Train [{:}] Average Loss = {:.6f}'.format(
            time_string(), epoch_str, train_loss))

        # remember best prec@1 and save checkpoint
        save_path = save_checkpoint(
            {
                'epoch': epoch,
                'args': deepcopy(args),
                'arch': model_config.arch,
                'state_dict': net.state_dict(),
                'detector': detector.state_dict(),
                'scheduler': scheduler.state_dict(),
                'optimizer': optimizer.state_dict(),
            },
            logger.path('model') /
            '{:}-{:}.pth'.format(model_config.arch, epoch_str), logger)

        last_info = save_checkpoint(
            {
                'epoch': epoch,
                'last_checkpoint': save_path,
            }, logger.last_info(), logger)

        eval_results = eval_all(args, eval_loaders, detector, criterion,
                                epoch_str, logger, opt_config)

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

    logger.close()
Exemple #9
0
def train(args):

    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True

    tfboard_writer = SummaryWriter()
    logname = '{}'.format(datetime.datetime.now().strftime('%Y-%m-%d-%H:%M'))
    logger = Logger(args.save_path, logname)
    logger.log('Arguments : -------------------------------')
    for name, value in args._get_kwargs():
        logger.log('{:16} : {:}'.format(name, value))

    # Data Augmentation
    mean_fill = tuple([int(x * 255) for x in [0.485, 0.456, 0.406]])
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_transform = [
        transforms.AugTransBbox(args.transbbox_prob, args.transbbox_percent)
    ]
    train_transform += [transforms.PreCrop(args.pre_crop_expand)]
    train_transform += [
        transforms.TrainScale2WH((args.crop_width, args.crop_height))
    ]
    #train_transform += [transforms.AugHorizontalFlip(args.flip_prob)]
    #train_transform += [transforms.AugScale(args.scale_prob, args.scale_min, args.scale_max)]
    #train_transform += [transforms.AugCrop(args.crop_width, args.crop_height, args.crop_perturb_max, mean_fill)]
    if args.rotate_max:
        train_transform += [transforms.AugRotate(args.rotate_max)]
    train_transform += [
        transforms.AugGaussianBlur(args.gaussianblur_prob,
                                   args.gaussianblur_kernel_size,
                                   args.gaussianblur_sigma)
    ]
    train_transform += [transforms.ToTensor(), normalize]
    train_transform = transforms.Compose(train_transform)

    eval_transform = transforms.Compose([
        transforms.PreCrop(args.pre_crop_expand),
        transforms.TrainScale2WH((args.crop_width, args.crop_height)),
        transforms.ToTensor(), normalize
    ])

    # Training datasets
    train_data = GeneralDataset(args.num_pts, train_transform,
                                args.train_lists)
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    # Evaluation Dataloader
    eval_loaders = []

    for eval_ilist in args.eval_lists:
        eval_idata = GeneralDataset(args.num_pts, eval_transform, eval_ilist)
        eval_iloader = torch.utils.data.DataLoader(eval_idata,
                                                   batch_size=args.batch_size,
                                                   shuffle=False,
                                                   num_workers=args.workers,
                                                   pin_memory=True)
        eval_loaders.append(eval_iloader)

    net = Model(args.num_pts)

    logger.log("=> network :\n {}".format(net))
    logger.log('arguments : {:}'.format(args))

    optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                       net.parameters()),
                                lr=args.LR,
                                momentum=args.momentum,
                                weight_decay=args.decay,
                                nesterov=args.nesterov)

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=args.schedule,
                                                     gamma=args.gamma)

    criterion = wing_loss(args)
    # criterion = torch.nn.MSELoss(reduce=True)

    net = net.cuda()
    criterion = criterion.cuda()
    net = torch.nn.DataParallel(net)

    last_info = logger.last_info()
    if last_info.exists():
        logger.log("=> loading checkpoint of the last-info '{:}' start".format(
            last_info))
        last_info = torch.load(last_info)
        start_epoch = last_info['epoch'] + 1
        checkpoint = torch.load(last_info['last_checkpoint'])
        assert last_info['epoch'] == checkpoint[
            'epoch'], 'Last-Info is not right {:} vs {:}'.format(
                last_info, checkpoint['epoch'])
        net.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        logger.log("=> load-ok checkpoint '{:}' (epoch {:}) done".format(
            logger.last_info(), checkpoint['epoch']))
    else:
        logger.log("=> do not find the last-info file : {:}".format(last_info))
        start_epoch = 0

    for epoch in range(start_epoch, args.epochs):
        scheduler.step()

        net.train()

        # train
        img_prediction = []
        img_target = []
        train_losses = AverageMeter()
        for i, (inputs, target) in enumerate(train_loader):

            target = target.squeeze(1)
            inputs = inputs.cuda()
            target = target.cuda()
            #print(inputs.size())
            #ssert 1==0

            prediction = net(inputs)

            loss = criterion(prediction, target)
            train_losses.update(loss.item(), inputs.size(0))

            prediction = prediction.detach().to(torch.device('cpu')).numpy()
            target = target.detach().to(torch.device('cpu')).numpy()

            for idx in range(inputs.size()[0]):
                img_prediction.append(prediction[idx, :])
                img_target.append(target[idx, :])

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i % args.print_freq == 0 or i + 1 == len(train_loader):
                logger.log(
                    '[train Info]: [epoch-{}-{}][{:04d}/{:04d}][Loss:{:.2f}]'.
                    format(epoch, args.epochs, i, len(train_loader),
                           loss.item()))

        train_nme = compute_nme(args.num_pts, img_prediction, img_target)
        logger.log('epoch {:02d} completed!'.format(epoch))
        logger.log(
            '[train Info]: [epoch-{}-{}][Avg Loss:{:.6f}][NME:{:.2f}]'.format(
                epoch, args.epochs, train_losses.avg, train_nme * 100))
        tfboard_writer.add_scalar('Average Loss', train_losses.avg, epoch)
        tfboard_writer.add_scalar('NME', train_nme * 100,
                                  epoch)  # traing data nme

        # save checkpoint
        filename = 'epoch-{}-{}.pth'.format(epoch, args.epochs)
        save_path = logger.path('model') / filename
        torch.save(
            {
                'epoch': epoch,
                'args': deepcopy(args),
                'state_dict': net.state_dict(),
                'scheduler': scheduler.state_dict(),
                'optimizer': optimizer.state_dict(),
            },
            logger.path('model') / filename)
        logger.log('save checkpoint into {}'.format(filename))
        last_info = torch.save({
            'epoch': epoch,
            'last_checkpoint': save_path
        }, logger.last_info())

        # eval
        logger.log('Basic-Eval-All evaluates {} dataset'.format(
            len(eval_loaders)))

        for i, loader in enumerate(eval_loaders):

            eval_losses = AverageMeter()
            eval_prediction = []
            eval_target = []
            with torch.no_grad():
                net.eval()
                for i_batch, (inputs, target) in enumerate(loader):

                    target = target.squeeze(1)
                    inputs = inputs.cuda()
                    target = target.cuda()
                    prediction = net(inputs)
                    loss = criterion(prediction, target)
                    eval_losses.update(loss.item(), inputs.size(0))

                    prediction = prediction.detach().to(
                        torch.device('cpu')).numpy()
                    target = target.detach().to(torch.device('cpu')).numpy()

                    for idx in range(inputs.size()[0]):
                        eval_prediction.append(prediction[idx, :])
                        eval_target.append(target[idx, :])
                    if i_batch % args.print_freq == 0 or i + 1 == len(loader):
                        logger.log(
                            '[Eval Info]: [epoch-{}-{}][{:04d}/{:04d}][Loss:{:.2f}]'
                            .format(epoch, args.epochs, i, len(loader),
                                    loss.item()))

            eval_nme = compute_nme(args.num_pts, eval_prediction, eval_target)
            logger.log(
                '[Eval Info]: [evaluate the {}/{}-th dataset][epoch-{}-{}][Avg Loss:{:.6f}][NME:{:.2f}]'
                .format(i, len(eval_loaders), epoch, args.epochs,
                        eval_losses.avg, eval_nme * 100))
            tfboard_writer.add_scalar('eval_nme/{}'.format(i), eval_nme * 100,
                                      epoch)

    logger.close()
def evaluate(args):
    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True

    print('The image is {:}'.format(args.image))
    print('The model is {:}'.format(args.model))
    snapshot = Path(args.model)
    assert snapshot.exists(), 'The model path {:} does not exist'
    snapshot = torch.load(snapshot)

    # General Data Argumentation
    mean_fill = tuple([int(x * 255) for x in [0.485, 0.456, 0.406]])
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    param = snapshot['args']
    eval_transform = transforms.Compose([
        transforms.PreCrop(param.pre_crop_expand),
        transforms.TrainScale2WH((param.crop_width, param.crop_height)),
        transforms.ToTensor(), normalize
    ])
    model_config = load_configure(param.model_config, None)
    dataset = Dataset(eval_transform, param.sigma, model_config.downsample,
                      param.heatmap_type, param.data_indicator)
    dataset.reset(param.num_pts)

    net = obtain_model(model_config, param.num_pts + 1)
    net = net.cuda()
    weights = remove_module_dict(snapshot['state_dict'])
    nu_weights = {}
    for key, val in weights.items():
        nu_weights[key.split('detector.')[-1]] = val
        print(key.split('detector.')[-1])
    weights = nu_weights
    net.load_state_dict(weights)
    print('Prepare input data')
    l1 = []
    record_writer = Collection_engine.produce_generator()
    total_images = len(images)
    for im_ind, aimage in enumerate(images):
        progressbar(im_ind, total_images)

        pts_name = os.path.splitext(aimage)[0] + '.pts'
        pts_full = _pts_path_ + pts_name
        gtpts = get_pts(pts_full, 90)
        aim = _image_path + aimage
        args.image = aim
        im = cv2.imread(aim)
        imshape = im.shape
        args.face = [0, 0, imshape[0], imshape[1]]
        [image, _, _, _, _, _,
         cropped_size], meta = dataset.prepare_input(args.image, args.face)
        inputs = image.unsqueeze(0).cuda()
        # network forward
        with torch.no_grad():
            batch_heatmaps, batch_locs, batch_scos = net(inputs)
        # obtain the locations on the image in the orignial size
        cpu = torch.device('cpu')
        np_batch_locs, np_batch_scos, cropped_size = batch_locs.to(
            cpu).numpy(), batch_scos.to(cpu).numpy(), cropped_size.numpy()
        locations, scores = np_batch_locs[0, :-1, :], np.expand_dims(
            np_batch_scos[0, :-1], -1)

        scale_h, scale_w = cropped_size[0] * 1. / inputs.size(
            -2), cropped_size[1] * 1. / inputs.size(-1)

        locations[:,
                  0], locations[:,
                                1] = locations[:, 0] * scale_w + cropped_size[
                                    2], locations[:,
                                                  1] * scale_h + cropped_size[3]
        prediction = np.concatenate((locations, scores),
                                    axis=1).transpose(1, 0)

        #print ('the coordinates for {:} facial landmarks:'.format(param.num_pts))
        for i in range(param.num_pts):
            point = prediction[:, i]
            #print ('the {:02d}/{:02d}-th point : ({:.1f}, {:.1f}), score = {:.2f}'.format(i, param.num_pts, float(point[0]), float(point[1]), float(point[2])))

        if args.save:
            args.save = _output_path + aimage
            resize = 512
            #image = draw_image_by_points(args.image, prediction, 2, (255, 0, 0), args.face, resize)
            #sim, l1e =draw_pts(im, gt_pts=gtpts, pred_pts=prediction, get_l1e=True)
            #print(np.mean(l1e))
            #l1.append(np.mean(l1e))
            pred_pts = np.transpose(prediction, [1, 0])
            pred_pts = pred_pts[:, :-1]
            record_writer.consume_data(im,
                                       gt_pts=gtpts,
                                       pred_pts=pred_pts,
                                       name=aimage)
            #cv2.imwrite(_output_path+aimage, sim)
            #image.save(args.save)
            #print ('save the visualization results into {:}'.format(args.save))
        else:
            print('ignore the visualization procedure')

    record_writer.post_process()
    record_writer.generate_output(output_path=_output_path,
                                  epochs=50,
                                  name='Supervision By Registration')
    os.makedirs(experiment_path)

if args.run:
    # Load params
    params = pickle.load(open(model_path + '/params.p', 'rb'))

    # Load model
    state_dict = torch.load('{}/model/{}_state_dict_best.pth'.format(
        model_path, params['model']),
                            map_location=lambda storage, loc: storage)
    model = load_model(params['model'], params)
    model.load_state_dict(state_dict)

    # Load ground-truth states from test set
    test_loader = DataLoader(GeneralDataset(params['dataset'],
                                            train=False,
                                            normalize_data=params['normalize'],
                                            subsample=params['subsample']),
                             batch_size=args.n_samples,
                             shuffle=args.shuffle)
    data, macro_intents = next(iter(test_loader))
    data, macro_intents = data.transpose(0, 1), macro_intents.transpose(0, 1)

    # Sample trajectories
    samples, macro_samples = model.sample(data,
                                          macro_intents,
                                          burn_in=args.burn_in)

    # Save samples
    samples = samples.detach().numpy()
    pickle.dump(samples,
                open(experiment_path + '/samples.p', 'wb'),
    model.load_state_dict(state_dict)
else:
    printlog('{:03d} {} {}'.format(args.trial, args.model, args.dataset))
    printlog(model.params_str)
    printlog(
        'start_lr {} | min_lr {} | subsample {} | batch_size {} | seed {}'.
        format(args.start_lr, args.min_lr, args.subsample, args.batch_size,
               args.seed))
    printlog('n_params: {:,}'.format(params['total_params']))
    printlog('best_loss:')
printlog('############################################################')

# Dataset loaders
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
train_loader = DataLoader(GeneralDataset(args.dataset,
                                         train=True,
                                         normalize_data=args.normalize,
                                         subsample=args.subsample),
                          batch_size=batch_size,
                          shuffle=True,
                          **kwargs)
test_loader = DataLoader(GeneralDataset(args.dataset,
                                        train=False,
                                        normalize_data=args.normalize,
                                        subsample=args.subsample),
                         batch_size=batch_size,
                         shuffle=True,
                         **kwargs)

############################# TRAIN LOOP #############################

best_test_loss = 0
# Create save destination
save_path = 'datasets/{}/data/examples'.format(args.dataset)
if not os.path.exists(save_path):
    os.makedirs(save_path)

# Set params
params = {
    'dataset' : args.dataset,
    'normalize' : True,
    'n_samples' : args.n_samples,
    'burn_in' : 0,
    'genMacro' : True
}   

# Load ground-truth states from test set
test_loader = DataLoader(
    GeneralDataset(params['dataset'], train=False, normalize_data=params['normalize'], subsample=1), 
    batch_size=args.n_samples, shuffle=args.shuffle)
data, macro_intents = next(iter(test_loader))
data, macro_intents = data.detach().numpy(), macro_intents.detach().numpy()

# Get dataset plot function
dataset = import_module('datasets.{}'.format(params['dataset']))
plot_func = dataset.animate if args.animate else dataset.display

for k in range(args.n_samples):
    print('Sample {:02d}'.format(k))
    save_file = '{}/{:02d}'.format(save_path, k)
    plot_func(data[k], macro_intents[k], params=params, save_file=save_file)
    
def evaluate(args):
    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True

    print('The model is {:}'.format(args.model))
    snapshot = Path(args.model)
    assert snapshot.exists(), 'The model path {:} does not exist'
    snapshot = torch.load(snapshot)

    # General Data Argumentation
    mean_fill = tuple([int(x * 255) for x in [0.485, 0.456, 0.406]])
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    param = snapshot['args']
    eval_transform = transforms.Compose([
        transforms.PreCrop(param.pre_crop_expand),
        transforms.TrainScale2WH((param.crop_width, param.crop_height)),
        transforms.ToTensor(), normalize
    ])
    model_config = load_configure(param.model_config, None)
    dataset = Dataset(eval_transform, param.sigma, model_config.downsample,
                      param.heatmap_type, param.data_indicator)
    dataset.reset(param.num_pts)

    net = obtain_model(model_config, param.num_pts + 1)
    net = net.cuda()
    weights = remove_module_dict(snapshot['state_dict'])
    nu_weights = {}
    for key, val in weights.items():
        nu_weights[key.split('detector.')[-1]] = val
        print(key.split('detector.')[-1])
    weights = nu_weights
    net.load_state_dict(weights)

    print('Prepare input data')
    images = os.listdir(args.image_path)
    images = natsort.natsorted(images)
    total_images = len(images)

    for im_ind, aimage in enumerate(images):
        progressbar(im_ind, total_images)
        #aim = os.path.join(args.image_path, aimage)
        aim = '0.jpg'
        args.image = aim
        im = cv2.imread(aim)
        imshape = im.shape
        print(imshape)
        input('crap12')
        args.face = [0, 0, imshape[0], imshape[1]]
        [image, _, _, _, _, _,
         cropped_size], meta = dataset.prepare_input(args.image, args.face)
        inputs = image.unsqueeze(0).cuda()
        scale_h, scale_w = cropped_size[0] * 1. / inputs.size(
            -2), cropped_size[1] * 1. / inputs.size(-1)
        print(inputs.size(-2))
        print(inputs.size(-1))
        print(scale_w.data.numpy())
        print(scale_h.data.numpy())
        print(cropped_size.data.numpy())
        input('crap')

        # network forward
        with torch.no_grad():
            batch_locs, batch_scos = net(inputs)
            c_im = np.expand_dims(image.data.numpy(), 0)
            c_locs, c_scors = rep.run(c_im)
        # obtain the locations on the image in the orignial size
        cpu = torch.device('cpu')
        np_batch_locs, np_batch_scos, cropped_size = batch_locs.to(
            cpu).numpy(), batch_scos.to(cpu).numpy(), cropped_size.numpy()
        locations, scores = np_batch_locs[0, :-1, :], np.expand_dims(
            np_batch_scos[0, :-1], -1)


        locations[:, 0], locations[:, 1] = locations[:, 0] * scale_w + cropped_size[2], locations[:, 1] * scale_h + \
                                           cropped_size[3]
        prediction = np.concatenate((locations, scores),
                                    axis=1).transpose(1, 0)

        c_locations = c_locs[0, :-1, :]
        c_locations[:, 0], c_locations[:, 1] = c_locations[:, 0] * scale_w + cropped_size[2], c_locations[:, 1] * scale_h + \
                                           cropped_size[3]
        c_scores = np.expand_dims(c_scors[0, :-1], -1)

        c_pred_pts = np.concatenate((c_locations, c_scores),
                                    axis=1).transpose(1, 0)

        pred_pts = np.transpose(prediction, [1, 0])
        pred_pts = pred_pts[:, :-1]

        c_pred_pts = np.transpose(c_pred_pts, [1, 0])
        c_pred_pts = c_pred_pts[:, :-1]
        print(c_scors, '\n\n\n')
        print(np_batch_scos)
        print(c_scors - np_batch_scos)

        if args.save:
            json_file = os.path.splitext(aimage)[0] + '.jpg'
            save_path = os.path.join(args.save, 'caf' + json_file)
            save_path2 = os.path.join(args.save, 'py_' + json_file)

            sim2 = draw_pts(im, pred_pts=pred_pts, get_l1e=False)
            sim = draw_pts(im, pred_pts=c_pred_pts, get_l1e=False)
            #print(pred_pts)
            cv2.imwrite(save_path, sim)
            cv2.imwrite(save_path2, sim2)
            input('save1')
            # image.save(args.save)
            # print ('save the visualization results into {:}'.format(args.save))

        else:
            print('ignore the visualization procedure')