示例#1
0
    def getfile(self):
        """
        Loads input data
        """
        self.batch = next(self.data_loader)

        self.imgs, self.objs, self.boxes, self.triples, self.obj_to_img, self.triple_to_img, self.imgs_in = \
            [x.cuda() for x in self.batch]

        self.keep_box_idx = torch.ones_like(self.objs.unsqueeze(1),
                                            dtype=torch.float)
        self.keep_feat_idx = torch.ones_like(self.objs.unsqueeze(1),
                                             dtype=torch.float)
        self.keep_image_idx = torch.ones_like(self.objs.unsqueeze(1),
                                              dtype=torch.float)
        self.combine_gt_pred_box_idx = torch.zeros_like(self.objs)
        self.added_objs_idx = torch.zeros_like(self.objs.unsqueeze(1),
                                               dtype=torch.float)

        self.new_triples, self.new_objs = None, None

        image = imagenet_deprocess_batch(self.imgs)
        image = image[0].numpy().transpose(1, 2, 0).copy()

        self.image = image
        self.draw_input_image(new_image=True)
示例#2
0
def save_image_with_label(img_pred, img_gt, img_dir, filename, txt_str):
    # saves gt and generated image, concatenated
    # together with text label describing the change
    # used for easier visualization of results

    img_pred = imagenet_deprocess_batch(img_pred)
    img_gt = imagenet_deprocess_batch(img_gt)

    img_pred_np = img_pred[0].numpy().transpose(1, 2, 0)
    img_gt_np = img_gt[0].numpy().transpose(1, 2, 0)

    img_pred_np = cv2.resize(img_pred_np, (128, 128))
    img_gt_np = cv2.resize(img_gt_np, (128, 128))

    wspace = np.zeros([img_pred_np.shape[0], 10, 3])
    text = np.zeros([30, img_pred_np.shape[1] * 2 + 10, 3])
    text = cv2.putText(text, txt_str, (0,20), cv2.FONT_HERSHEY_SIMPLEX,
                     0.5, (255, 255, 255), lineType=cv2.LINE_AA)

    img_pred_gt = np.concatenate([img_gt_np, wspace, img_pred_np], axis=1).astype('uint8')
    img_pred_gt = np.concatenate([text, img_pred_gt], axis=0).astype('uint8')
    img_path = os.path.join(img_dir, filename)
    imsave(img_path, img_pred_gt)
示例#3
0
def eval_model(model,
               loader,
               device,
               vocab,
               use_gt_boxes=False,
               use_feats=False,
               filter_box=False):
    all_boxes = defaultdict(list)
    total_iou = []
    total_boxes = 0
    num_batches = 0
    num_samples = 0
    mae_per_image = []
    mae_roi_per_image = []
    roi_only_iou = []
    ssim_per_image = []
    ssim_rois = []
    rois = 0
    margin = 2

    ## Initializing the perceptual loss model
    lpips_model = models.PerceptualLoss(model='net-lin',
                                        net='alex',
                                        use_gpu=True)
    perceptual_error_image = []
    # ---------------------------------------

    img_idx = 0

    with torch.no_grad():
        for batch in tqdm.tqdm(loader):
            num_batches += 1
            # if num_batches > 10:
            #     break
            batch = [tensor.to(device) for tensor in batch]
            masks = None
            #len", len(batch))

            imgs, objs, boxes, triples, obj_to_img, triple_to_img, imgs_in = [
                b.to(device) for b in batch
            ]
            predicates = triples[:, 1]

            #EVAL_ALL = True
            if not args.generative:
                imgs, imgs_in, objs, boxes, triples, obj_to_img, \
                dropimage_indices, dropfeats_indices = [b.to(device) for b in process_batch(
                    imgs, imgs_in, objs, boxes, triples, obj_to_img, triple_to_img, device,
                    use_feats=use_feats, filter_box=filter_box)]

                dropbox_indices = dropimage_indices
            else:
                dropbox_indices = torch.ones_like(
                    objs.unsqueeze(1).float()).to(device)
                dropfeats_indices = torch.ones_like(
                    objs.unsqueeze(1).float()).to(device)
                dropimage_indices = torch.zeros_like(
                    objs.unsqueeze(1).float()).to(device)

            if imgs.shape[0] == 0:
                continue

            if args.visualize_graphs:
                # visualize scene graphs for debugging purposes
                visualize_scene_graphs(obj_to_img, objs, triples, vocab,
                                       device)

            if use_gt_boxes:
                model_out = model(
                    objs,
                    triples,
                    obj_to_img,
                    boxes_gt=boxes,
                    masks_gt=masks,
                    src_image=imgs_in,
                    keep_box_idx=torch.ones_like(dropimage_indices),
                    keep_feat_idx=dropfeats_indices,
                    keep_image_idx=dropimage_indices,
                    mode='eval')
            else:
                model_out = model(objs,
                                  triples,
                                  obj_to_img,
                                  boxes_gt=boxes,
                                  src_image=imgs_in,
                                  keep_box_idx=dropimage_indices,
                                  keep_feats_idx=dropfeats_indices,
                                  keep_image_idx=dropimage_indices,
                                  mode='eval')

            # OUTPUT
            imgs_pred, boxes_pred, masks_pred, _, _ = model_out
            # ----------------------------------------------------------------------------------------------------------

            # Save all box predictions
            all_boxes['boxes_gt'].append(boxes)
            all_boxes['objs'].append(objs)
            all_boxes['boxes_pred'].append(boxes_pred)
            all_boxes['drop_targets'].append(dropbox_indices)

            # IoU over all
            total_iou.append(jaccard(boxes_pred, boxes).detach().cpu().numpy())
            total_boxes += boxes_pred.size(0)

            # IoU over targets only
            pred_dropbox = boxes_pred[dropbox_indices.squeeze() == 0, :]
            gt_dropbox = boxes[dropbox_indices.squeeze() == 0, :]
            roi_only_iou.append(
                jaccard(pred_dropbox, gt_dropbox).detach().cpu().numpy())
            rois += pred_dropbox.size(0)

            num_samples += imgs.shape[0]
            imgs = imagenet_deprocess_batch(imgs).float()
            imgs_pred = imagenet_deprocess_batch(imgs_pred).float()

            if args.visualize_imgs_boxes:
                # visualize images with drawn boxes for debugging purposes
                visualize_imgs_boxes(imgs, imgs_pred, boxes, boxes_pred)

            if args.save_images:
                # save reconstructed images for later FID and Inception computation
                if args.save_gt_images:
                    # pass imgs as argument to additionally save gt images
                    save_images(imgs_pred, img_idx, imgs)
                else:
                    save_images(imgs_pred, img_idx)

            # MAE per image
            mae_per_image.append(
                torch.mean(
                    torch.abs(imgs - imgs_pred).view(imgs.shape[0], -1),
                    1).cpu().numpy())

            for s in range(imgs.shape[0]):
                # get coordinates of target
                left, right, top, bottom = bbox_coordinates_with_margin(
                    boxes[s, :], margin, imgs)
                if left > right or top > bottom:
                    continue

                # calculate errors only in RoI one by one
                mae_roi_per_image.append(
                    torch.mean(
                        torch.abs(imgs[s, :, top:bottom, left:right] -
                                  imgs_pred[s, :, top:bottom,
                                            left:right])).cpu().item())

                ssim_per_image.append(
                    pytorch_ssim.ssim(imgs[s:s + 1, :, :, :] / 255.0,
                                      imgs_pred[s:s + 1, :, :, :] / 255.0,
                                      window_size=3).cpu().item())
                ssim_rois.append(
                    pytorch_ssim.ssim(
                        imgs[s:s + 1, :, top:bottom, left:right] / 255.0,
                        imgs_pred[s:s + 1, :, top:bottom, left:right] / 255.0,
                        window_size=3).cpu().item())

                # normalize as expected from the LPIPS model
                imgs_pred_norm = imgs_pred[s:s + 1, :, :, :] / 127.5 - 1
                imgs_gt_norm = imgs[s:s + 1, :, :, :] / 127.5 - 1
                perceptual_error_image.append(
                    lpips_model.forward(imgs_pred_norm,
                                        imgs_gt_norm).detach().cpu().numpy())

            if num_batches % args.print_every == 0:
                calculate_scores(mae_per_image, mae_roi_per_image, total_iou,
                                 roi_only_iou, ssim_per_image, ssim_rois,
                                 perceptual_error_image)

            if num_batches % args.save_every == 0:
                save_results(mae_per_image, mae_roi_per_image, total_iou,
                             roi_only_iou, ssim_per_image, ssim_rois,
                             perceptual_error_image, all_boxes, num_batches)

            img_idx += 1

    calculate_scores(mae_per_image, mae_roi_per_image, total_iou, roi_only_iou,
                     ssim_per_image, ssim_rois, perceptual_error_image)
    save_results(mae_per_image, mae_roi_per_image, total_iou, roi_only_iou,
                 ssim_per_image, ssim_rois, perceptual_error_image, all_boxes,
                 'final')
示例#4
0
    def gen_image(self):
        """
        Generates an image, as indicated by the modified graph
        """
        if self.new_triples is not None:
            triples_ = self.new_triples
        else:
            triples_ = self.triples

        query_feats = None

        model_out = self.model(
            self.new_objs,
            triples_,
            None,
            boxes_gt=self.boxes,
            masks_gt=None,
            src_image=self.imgs_in,
            mode=self.mode,
            query_feats=query_feats,
            keep_box_idx=self.keep_box_idx,
            keep_feat_idx=self.keep_feat_idx,
            combine_gt_pred_box_idx=self.combine_gt_pred_box_idx,
            keep_image_idx=self.keep_image_idx,
            random_feats=args.random_feats,
            get_layout_boxes=True)

        imgs_pred, boxes_pred, masks_pred, noised_srcs, _, layout_boxes = model_out

        image = imagenet_deprocess_batch(imgs_pred)
        image = image[0].detach().numpy().transpose(1, 2, 0).copy()
        if args.update_input:
            self.image = image.copy()

        image = QtGui.QImage(image, image.shape[1], image.shape[0],
                             QtGui.QImage.Format_RGB888)

        im_pm = QtGui.QPixmap(image)
        self.ima.setPixmap(im_pm.scaled(200, 200))
        self.ima.setVisible(1)
        self.imCounter += 1

        if args.update_input:
            # reset everything so that the predicted image is now the input image for the next step
            self.imgs = imgs_pred.detach().clone()
            self.imgs_in = torch.cat(
                [self.imgs,
                 torch.zeros_like(self.imgs[:, 0:1, :, :])], 1)
            self.draw_input_image()
            self.boxes = layout_boxes.detach().clone()
            self.keep_box_idx = torch.ones_like(self.objs.unsqueeze(1),
                                                dtype=torch.float)
            self.keep_feat_idx = torch.ones_like(self.objs.unsqueeze(1),
                                                 dtype=torch.float)
            self.keep_image_idx = torch.ones_like(self.objs.unsqueeze(1),
                                                  dtype=torch.float)
            self.combine_gt_pred_box_idx = torch.zeros_like(self.objs)
        else:
            # input image is still the original one - don't reset anything
            # if an object is added for the first time, the GT/input box is still a dummy (set in add_triple)
            # in this case, we update the GT/input box, using the box predicted from SGN,
            # so that it can be used in future changes that rely on the GT/input box, e.g. replacement
            self.boxes = self.added_objs_idx * layout_boxes.detach().clone(
            ) + (1 - self.added_objs_idx) * self.boxes
            self.added_objs_idx = torch.zeros_like(self.objs.unsqueeze(1),
                                                   dtype=torch.float)
示例#5
0
def save_image_from_tensor(img, img_dir, filename):

    img = imagenet_deprocess_batch(img)
    img_np = img[0].numpy().transpose(1, 2, 0)
    img_path = os.path.join(img_dir, filename)
    imsave(img_path, img_np)
示例#6
0
def run_model(args, checkpoint, loader=None):

  output_dir = args.exp_dir
  model = build_model(args, checkpoint)
  if loader is None:
    loader = build_eval_loader(args, checkpoint, vocab_t)

  img_dir = makedir(output_dir, 'images_' + SPLIT)
  graph_json_dir = makedir(output_dir, 'graphs_json')

  f = open(output_dir + "/result_ids.txt", "w")

  img_idx = 0
  total_iou_all = []
  total_iou = get_def_dict()
  total_boxes = 0
  mae_per_image_all = []
  mae_per_image = get_def_dict()
  mae_roi_per_image_all = []
  mae_roi_per_image = get_def_dict()
  roi_only_iou_all = []
  roi_only_iou = get_def_dict()
  ssim_per_image_all = []
  ssim_per_image = get_def_dict()
  ssim_rois_all = []
  ssim_rois = get_def_dict()
  rois = 0
  margin = 2

  ## Initializing the perceptual loss model
  lpips_model = models.PerceptualLoss(model='net-lin', net='alex', use_gpu=True)
  perceptual_error_image_all = []
  perceptual_error_image = get_def_dict()
  perceptual_error_roi_all = []
  perceptual_error_roi = get_def_dict()

  for batch in loader:

    imgs, imgs_src, objs, objs_src, boxes, boxes_src, triples, triples_src, obj_to_img, \
        triple_to_img, imgs_in = [x.cuda() for x in batch]

    imgs_gt = imagenet_deprocess_batch(imgs_src)
    imgs_target_gt = imagenet_deprocess_batch(imgs)

    # Get mode from target scene - source scene, or image id, using sets
    graph_set_bef = Counter(tuple(row) for row in tripleToObjID(triples_src, objs_src))
    obj_set_bef = Counter([int(obj.cpu()) for obj in objs_src])
    graph_set_aft = Counter(tuple(row) for row in tripleToObjID(triples, objs))
    obj_set_aft = Counter([int(obj.cpu()) for obj in objs])

    if len(objs) > len(objs_src):
      mode = "addition"
      changes = graph_set_aft - graph_set_bef
      obj_ids = list(obj_set_aft - obj_set_bef)
      new_ids = (objs == obj_ids[0]).nonzero()
    elif len(objs) < len(objs_src):
      mode = "remove"
      changes = graph_set_bef - graph_set_aft
      obj_ids = list(obj_set_bef - obj_set_aft)
      new_ids_src = (objs_src == obj_ids[0]).nonzero()
      new_objs = [obj for obj in objs]
      new_objs.append(objs_src[new_ids_src[0]])
      objs = torch.tensor(new_objs).cuda()
      num_objs = len(objs)
      new_ids = [torch.tensor(num_objs-1)]
      new_boxes = [bbox for bbox in boxes]
      new_boxes.append(boxes_src[new_ids_src[0]][0])
      boxes = torch.stack(new_boxes)
      obj_to_img = torch.zeros(num_objs, dtype=objs.dtype, device=objs.device)
    elif torch.all(torch.eq(objs, objs_src)):
      mode = "reposition"
      changes = (graph_set_bef - graph_set_aft) + (graph_set_aft - graph_set_bef)
      idx_cnt = np.zeros((25,1))
      for [s,p,o] in list(changes):
        idx_cnt[s] += 1
        idx_cnt[o] += 1

      obj_ids = idx_cnt.argmax(0)
      id_src = (objs_src == obj_ids[0]).nonzero()
      box_src = boxes_src[id_src[0]]
      new_ids = (objs == obj_ids[0]).nonzero()
      boxes[new_ids[0]] = box_src

    elif len(objs) == len(objs_src):
      mode = "replace"
      changes = (graph_set_bef - graph_set_aft) + (graph_set_aft - graph_set_bef)
      obj_ids = [list(obj_set_bef - obj_set_aft)[0], list(obj_set_aft - obj_set_bef)[0]]
      new_ids = (objs == obj_ids[1]).nonzero()
    else:
      assert False

    new_ids = [int(new_id.cpu()) for new_id in new_ids]

    show_im = False
    if show_im:
      img_gt = imgs_gt[0].numpy().transpose(1, 2, 0)
      img_gt_target = imgs_target_gt[0].numpy().transpose(1, 2, 0)
      fig = plt.figure()
      fig.add_subplot(1, 2, 1)
      plt.imshow(img_gt)
      fig.add_subplot(1, 2, 2)
      plt.imshow(img_gt_target)
      plt.show(block=True)

    query_feats = None

    if args.with_query_image:
      img, box = query_image_by_semantic_id(new_ids, img_idx, loader)
      query_feats = model.forward_visual_feats(img, box)

      img_filename_query = '%04d_query.png' % (img_idx)
      img = imagenet_deprocess_batch(img)
      img_np = img[0].numpy().transpose(1, 2, 0).astype(np.uint8)
      img_path = os.path.join(img_dir, img_filename_query)
      imsave(img_path, img_np)


    img_gt_filename = '%04d_gt_src.png' % (img_idx)
    img_target_gt_filename = '%04d_gt_target.png' % (img_idx)
    img_pred_filename = '%04d_changed.png' % (img_idx)
    img_filename_noised = '%04d_noised.png' % (img_idx)

    triples_ = triples

    boxes_gt = boxes

    keep_box_idx = torch.ones_like(objs.unsqueeze(1), dtype=torch.float)
    keep_feat_idx = torch.ones_like(objs.unsqueeze(1), dtype=torch.float)
    keep_image_idx = torch.ones_like(objs.unsqueeze(1), dtype=torch.float)

    subject_node = new_ids[0]
    keep_image_idx[subject_node] = 0

    if mode == 'reposition':
      keep_box_idx[subject_node] = 0
    elif mode == "remove":
      keep_feat_idx[subject_node] = 0
    else:
      if mode == "replace":
        keep_feat_idx[subject_node] = 0
      if mode == 'auto_withfeats':
        keep_image_idx[subject_node] = 0

      if mode == 'auto_nofeats':
        if not args.with_query_image:
          keep_feat_idx[subject_node] = 0

    model_out = model(objs, triples_, obj_to_img,
        boxes_gt=boxes_gt, masks_gt=None, src_image=imgs_in, mode=mode,
        query_feats=query_feats, keep_box_idx=keep_box_idx, keep_feat_idx=keep_feat_idx,
        keep_image_idx=keep_image_idx)

    imgs_pred, boxes_pred_o, masks_pred, noised_srcs, _ = model_out

    imgs = imagenet_deprocess_batch(imgs).float()
    imgs_pred = imagenet_deprocess_batch(imgs_pred).float()

    #Metrics

    # IoU over all
    curr_iou = jaccard(boxes_pred_o, boxes).detach().cpu().numpy()
    total_iou_all.append(curr_iou)
    total_iou[mode].append(curr_iou)
    total_boxes += boxes_pred_o.size(0)

    # IoU over targets only
    pred_dropbox = boxes_pred_o[keep_box_idx.squeeze() == 0, :]
    gt_dropbox = boxes[keep_box_idx.squeeze() == 0, :]
    curr_iou_roi = jaccard(pred_dropbox, gt_dropbox).detach().cpu().numpy()
    roi_only_iou_all.append(curr_iou_roi)
    roi_only_iou[mode].append(curr_iou_roi)
    rois += pred_dropbox.size(0)

    # MAE per image
    curr_mae = torch.mean(
      torch.abs(imgs - imgs_pred).view(imgs.shape[0], -1), 1).cpu().numpy()
    mae_per_image[mode].append(curr_mae)
    mae_per_image_all.append(curr_mae)

    for s in range(imgs.shape[0]):
      # get coordinates of target
      left, right, top, bottom = bbox_coordinates_with_margin(boxes[s, :], margin, imgs)
      if left > right or top > bottom:
        continue
      # print("bboxes with margin: ", left, right, top, bottom)

      # calculate errors only in RoI one by one
      curr_mae_roi = torch.mean(
        torch.abs(imgs[s, :, top:bottom, left:right] - imgs_pred[s, :, top:bottom, left:right])).cpu().item()
      mae_roi_per_image[mode].append(curr_mae_roi)
      mae_roi_per_image_all.append(curr_mae_roi)

      curr_ssim = pytorch_ssim.ssim(imgs[s:s + 1, :, :, :] / 255.0,
                          imgs_pred[s:s + 1, :, :, :] / 255.0, window_size=3).cpu().item()
      ssim_per_image_all.append(curr_ssim)
      ssim_per_image[mode].append(curr_ssim)

      curr_ssim_roi = pytorch_ssim.ssim(imgs[s:s + 1, :, top:bottom, left:right] / 255.0,
                          imgs_pred[s:s + 1, :, top:bottom, left:right] / 255.0, window_size=3).cpu().item()
      ssim_rois_all.append(curr_ssim_roi)
      ssim_rois[mode].append(curr_ssim_roi)

      imgs_pred_norm = imgs_pred[s:s + 1, :, :, :] / 127.5 - 1
      imgs_gt_norm = imgs[s:s + 1, :, :, :] / 127.5 - 1

      curr_lpips = lpips_model.forward(imgs_pred_norm, imgs_gt_norm).detach().cpu().numpy()
      perceptual_error_image_all.append(curr_lpips)
      perceptual_error_image[mode].append(curr_lpips)

    for i in range(imgs_pred.size(0)):

      if args.save_imgs:
        img_gt = imgs_gt[i].numpy().transpose(1, 2, 0).astype(np.uint8)
        img_gt = cv2.resize(img_gt, (128, 128))
        img_gt_path = os.path.join(img_dir, img_gt_filename)
        imsave(img_gt_path, img_gt)

        img_gt_target = imgs_target_gt[i].numpy().transpose(1, 2, 0).astype(np.uint8)
        img_gt_target = cv2.resize(img_gt_target, (128, 128))
        img_gt_target_path = os.path.join(img_dir, img_target_gt_filename)
        imsave(img_gt_target_path, img_gt_target)

        noised_src_np = imagenet_deprocess_batch(noised_srcs[:, :3, :, :])
        noised_src_np = noised_src_np[i].numpy().transpose(1, 2, 0).astype(np.uint8)
        noised_src_np = cv2.resize(noised_src_np, (128, 128))
        img_path_noised = os.path.join(img_dir, img_filename_noised)
        imsave(img_path_noised, noised_src_np)

        img_pred_np = imgs_pred[i].numpy().transpose(1, 2, 0).astype(np.uint8)
        img_pred_np = cv2.resize(img_pred_np, (128, 128))
        img_path = os.path.join(img_dir, img_pred_filename)
        imsave(img_path, img_pred_np)

      save_graph_json(objs, triples, boxes, "after", graph_json_dir, img_idx)


    img_idx += 1

    if img_idx % print_every == 0:
      calculate_scores(mae_per_image_all, mae_roi_per_image_all, total_iou_all, roi_only_iou_all, ssim_per_image_all,
                       ssim_rois_all, perceptual_error_image_all, perceptual_error_roi_all)
      calculate_scores_modes(mae_per_image, mae_roi_per_image, total_iou, roi_only_iou, ssim_per_image, ssim_rois,
                       perceptual_error_image, perceptual_error_roi)

    print('Saved %d images' % img_idx)

  f.close()
示例#7
0
def check_model(args, t, loader, model):

  num_samples = 0
  all_losses = defaultdict(list)
  total_iou = 0
  total_boxes = 0
  with torch.no_grad():
    for batch in loader:
      batch = [tensor.cuda() for tensor in batch]
      masks = None
      imgs_src = None

      if args.dataset == "vg" or (args.dataset == "clevr" and not args.is_supervised):
        imgs, objs, boxes, triples, obj_to_img, triple_to_img, imgs_in = batch
      elif args.dataset == "clevr":
        imgs, imgs_src, objs, objs_src, boxes, boxes_src, triples, triples_src, obj_to_img, \
        triple_to_img, imgs_in = batch

      model_masks = masks

      model_out = model(objs, triples, obj_to_img, boxes_gt=boxes, masks_gt=model_masks,
                        src_image=imgs_in, imgs_src=imgs_src)
      imgs_pred, boxes_pred, masks_pred, _, _ = model_out

      skip_pixel_loss = False
      total_loss, losses = calculate_model_losses(
                                args, skip_pixel_loss, imgs, imgs_pred,
                                boxes, boxes_pred)

      total_iou += jaccard(boxes_pred, boxes)
      total_boxes += boxes_pred.size(0)

      for loss_name, loss_val in losses.items():
        all_losses[loss_name].append(loss_val)
      num_samples += imgs.size(0)
      if num_samples >= args.num_val_samples:
        break

    samples = {}
    samples['gt_img'] = imgs

    model_out = model(objs, triples, obj_to_img, boxes_gt=boxes, masks_gt=masks, src_image=imgs_in, imgs_src=imgs_src)
    samples['gt_box_gt_mask'] = model_out[0]

    model_out = model(objs, triples, obj_to_img, boxes_gt=boxes, src_image=imgs_in, imgs_src=imgs_src)
    samples['generated_img_gt_box'] = model_out[0]

    samples['masked_img'] = model_out[3][:,:3,:,:]

    for k, v in samples.items():
      samples[k] = imagenet_deprocess_batch(v)

    mean_losses = {k: np.mean(v) for k, v in all_losses.items()}
    avg_iou = total_iou / total_boxes

    masks_to_store = masks
    if masks_to_store is not None:
      masks_to_store = masks_to_store.data.cpu().clone()

    masks_pred_to_store = masks_pred
    if masks_pred_to_store is not None:
      masks_pred_to_store = masks_pred_to_store.data.cpu().clone()

  batch_data = {
    'objs': objs.detach().cpu().clone(),
    'boxes_gt': boxes.detach().cpu().clone(),
    'masks_gt': masks_to_store,
    'triples': triples.detach().cpu().clone(),
    'obj_to_img': obj_to_img.detach().cpu().clone(),
    'triple_to_img': triple_to_img.detach().cpu().clone(),
    'boxes_pred': boxes_pred.detach().cpu().clone(),
    'masks_pred': masks_pred_to_store
  }
  out = [mean_losses, samples, batch_data, avg_iou]

  return tuple(out)