Exemplo n.º 1
0
def evaluate(opt, dloader, model, epoch=0, vis=None, use_saved_file=False):
    # Visualizer
    opt.save_visuals = True
    if vis is None:
        if hasattr(opt, 'save_visuals') and opt.save_visuals:
            vis = Visualizer(os.path.join(opt.ckpt_path, 'test_log'))
        else:
            opt.save_visuals = False

    model.setup(is_train=False)
    metric = utils.Metrics()
    results = {}

    for step, data in enumerate(dloader):
        input, output, input_unocc, output_unocc = data

        dec_output, latent, nelbo = model.test(input, output)

        # results with partial occlusion in the TOP:
        crop_size_1 = opt.crop_size[1]
        output_eval = torch.cat([input_unocc, output], dim=1)[:, :, :,
                                                              -crop_size_1:]
        rec_pred_eval = dec_output[:, :, :, -crop_size_1:]
        metric.update(output_eval, rec_pred_eval)

        if (step + 1) % opt.log_every == 0:
            print('{}/{}'.format(step + 1, len(dloader)))
            if opt.save_visuals:
                vis.add_images(model.get_visuals(), step, prefix='test_val')

    # BCE, MSE
    results.update(metric.get_scores())

    return results
Exemplo n.º 2
0
def evaluate(opt, dloader, model, use_saved_file=False):
  # Visualizer
  if hasattr(opt, 'save_visuals') and opt.save_visuals:
    vis = Visualizer(os.path.join(opt.ckpt_path, 'tb_test'))
  else:
    opt.save_visuals = False

  model.setup(is_train=False)
  metric = utils.Metrics()
  results = {}

  if hasattr(opt, 'save_all_results') and opt.save_all_results:
    save_dir = os.path.join(opt.ckpt_path, 'results')
    os.makedirs(save_dir, exist_ok=True)
  else:
    opt.save_all_results = False

  # Hacky
  is_bouncing_balls = ('bouncing_balls' in opt.dset_name) and opt.n_components == 4
  if is_bouncing_balls:
    dloader.dataset.return_positions = True
    saved_positions = os.path.join(opt.ckpt_path, 'positions.npy') if use_saved_file else ''
    velocity_metric = utils.VelocityMetrics(saved_positions)

  count = 0
  for step, data in enumerate(dloader):
    if not is_bouncing_balls:
      input, gt = data
    else:
      input, gt, positions = data
    output, latent = model.test(input, gt)
    pred = output[:, opt.n_frames_input:, ...]
    metric.update(gt, pred)

    if opt.save_all_results:
      gt = np.concatenate([input.numpy(), gt.numpy()], axis=1)
      prediction = utils.to_numpy(output)
      count = save_images(prediction, gt, latent, save_dir, count)

    if is_bouncing_balls:
      # Calculate position and velocity from pose
      pose = latent['pose'].data.cpu()
      velocity_metric.update(positions, pose, opt.n_frames_input)

    if (step + 1) % opt.log_every == 0:
      print('{}/{}'.format(step + 1, len(dloader)))
      if opt.save_visuals:
        vis.add_images(model.get_visuals(), step, prefix='test')

  # BCE, MSE
  results.update(metric.get_scores())

  if is_bouncing_balls:
    # Don't break the original code
    dloader.dataset.return_positions = False
    results.update(velocity_metric.get_scores())

  return results