Пример #1
0
def get_rel_score(args, t, loader, model):
    float_dtype = torch.FloatTensor
    long_dtype = torch.LongTensor
    num_samples = 0
    all_losses = defaultdict(list)
    total_iou = 0
    total_boxes = 0
    rel_score = 0

    with torch.no_grad():
        o_start = o_end = 0
        t_start = t_end = 0
        last_o_idx = last_t_idx = 0

        b = 0
        total_boxes = 0
        total_iou = 0
        for batch in loader:
            batch = [tensor.cuda() for tensor in batch]
            #batch = [tensor for tensor in batch]
            masks = None
            if len(batch) == 6:
                imgs, objs, boxes, triples, obj_to_img, triple_to_img = batch
            elif len(batch) == 8:
                imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img, triplet_masks = batch
            #elif len(batch) == 7:
            #  imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img = batch
            predicates = triples[:, 1]

            objs = objs.detach()
            triples = triples.detach()
            # Run the model as it has been run during training
            model_masks = masks
            model_out = model(objs,
                              triples,
                              obj_to_img,
                              boxes_gt=boxes,
                              masks_gt=model_masks)
            # imgs_pred, boxes_pred, masks_pred, predicate_scores = model_out
            imgs_pred, boxes_pred, masks_pred, objs_vec, layout, layout_boxes, layout_masks, obj_to_img, sg_context_pred, sg_context_pred_d, predicate_scores, obj_embeddings, pred_embeddings, triple_boxes_pred, triple_boxes_gt, triplet_masks_pred = model_out

            num_samples += imgs.size(0)
            if num_samples >= args.num_val_samples:
                break

            rel_score += relation_score(boxes_pred, boxes, masks_pred, masks,
                                        model.vocab)
            b += 1
            total_iou += jaccard(boxes_pred, boxes)
            total_boxes += boxes_pred.size(0)

        rel_score = rel_score / b
        avg_iou = total_iou / total_boxes
        return rel_score, avg_iou
Пример #2
0
def get_rel_score(args, t, loader, model):
    float_dtype = torch.FloatTensor
    long_dtype = torch.LongTensor
    num_samples = 0
    all_losses = defaultdict(list)
    total_iou = 0
    total_boxes = 0

    ###################
    if not os.path.isdir(args.output_dir):
        os.mkdir(args.output_dir)
        print('Created %s' % args.output_dir)

    img_dir = args.output_dir + '/img_dir'

    if not os.path.isdir(img_dir):
        os.mkdir(img_dir)
        print('Created %s' % img_dir)
    ##################
    rel_score = 0
    with torch.no_grad():
        o_start = o_end = 0
        t_start = t_end = 0
        last_o_idx = last_t_idx = 0

        b = 0
        total_boxes = 0
        total_iou = 0
        for batch in loader:
            #batch = [tensor.cuda() for tensor in batch]
            batch = [tensor for tensor in batch]
            masks = None
            if len(batch) == 6:
                imgs, objs, boxes, triples, obj_to_img, triple_to_img = batch
            elif len(batch) == 7:
                imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img = batch
            predicates = triples[:, 1]

            objs = objs.detach()
            triples = triples.detach()
            # Run the model as it has been run during training
            model_masks = masks
            model_out = model(objs,
                              triples,
                              obj_to_img,
                              boxes_gt=boxes,
                              masks_gt=model_masks)
            # imgs_pred, boxes_pred, masks_pred, predicate_scores = model_out
            imgs_pred, boxes_pred, masks_pred, objs_vec, layout, layout_boxes, layout_masks, obj_to_img, sg_context_pred, sg_context_pred_d, predicate_scores = model_out

            num_samples += imgs.size(0)
            if num_samples >= args.num_val_samples:
                break

            rel_score += relation_score(boxes_pred, boxes, masks_pred, masks,
                                        model.vocab)
            b += 1

            total_iou += jaccard(boxes_pred, boxes)
            total_boxes += boxes_pred.size(0)

        print(b)
        print(total_boxes)

        rel_score = rel_score / b
        avg_iou = total_iou / total_boxes

        # print('rel score:' rel_score)
        # print('avg iou:' avg_iou)

        return rel_score, avg_iou
Пример #3
0
def check_model(args,
                t,
                loader,
                model,
                logger=None,
                log_tag='',
                write_images=False):
    # float_dtype = torch.cuda.FloatTensor
    # long_dtype = torch.cuda.LongTensor
    float_dtype = torch.FloatTensor
    long_dtype = torch.LongTensor
    num_samples = 0
    all_losses = defaultdict(list)
    total_iou = 0
    total_boxes = 0

    ###################
    if not os.path.isdir(args.output_dir):
        os.mkdir(args.output_dir)
        print('Created %s' % args.output_dir)

    img_dir = args.output_dir + '/img_dir'

    if not os.path.isdir(img_dir):
        os.mkdir(img_dir)
        print('Created %s' % img_dir)
    ##################

    t = 0
    t1 = 0
    with torch.no_grad():
        o_start = o_end = 0
        t_start = t_end = 0
        last_o_idx = last_t_idx = 0
        for batch in loader:
            #batch = [tensor.cuda() for tensor in batch]
            batch = [tensor for tensor in batch]
            masks = None
            if len(batch) == 6:
                imgs, objs, boxes, triples, obj_to_img, triple_to_img = batch
            elif len(batch) == 7:
                imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img = batch
            predicates = triples[:, 1]

            ############################
            pdb.set_trace()
            ############################

            objs = objs.detach()
            triples = triples.detach()
            # Run the model as it has been run during training
            model_masks = masks
            model_out = model(objs,
                              triples,
                              obj_to_img,
                              boxes_gt=boxes,
                              masks_gt=model_masks)
            # imgs_pred, boxes_pred, masks_pred, predicate_scores = model_out
            imgs_pred, boxes_pred, masks_pred, objs_vec, layout, layout_boxes, layout_masks, obj_to_img, sg_context_pred, sg_context_pred_d, predicate_scores = model_out

            num_samples += imgs.size(0)
            if num_samples >= args.num_val_samples:
                break

            samples = {}
            samples['gt_img'] = imgs

            model_out = model(objs,
                              triples,
                              obj_to_img,
                              boxes_gt=boxes,
                              masks_gt=masks)
            samples['gt_box_gt_mask'] = model_out[0]

            model_out = model(objs, triples, obj_to_img, boxes_gt=boxes)
            samples['gt_box_pred_mask'] = model_out[0]

            ##############################################
            # import pdb
            # pdb.set_trace()
            # num_boxes=len(boxes)
            # model_out = model(objs, triples, obj_to_img, boxes_gt=scaled_boxes)
            # samples['gt_scaled_box_pred_mask'] = model_out[0]
            ##############################################

            model_out = model(objs, triples, obj_to_img)
            samples['pred_box_pred_mask'] = model_out[0]

            layout_preds = {}
            layout_preds['pred_boxes'] = model_out[5]
            layout_preds['pred_masks'] = model_out[6]

            for k, v in samples.items():
                samples[k] = imagenet_deprocess_batch(v)

            if write_images:
                #3. Log ground truth and predicted images
                with torch.no_grad():
                    gt_imgs = samples['gt_img'].detach()
                    p_gbox_pmsk_img = samples['gt_box_pred_mask'].detach()
                    p_test_imgs = samples['pred_box_pred_mask'].detach()

                    p_test_boxes = layout_preds['pred_boxes']
                    p_test_masks = layout_preds['pred_masks']

                np_gt_imgs = [
                    gt.cpu().numpy().transpose(1, 2, 0) for gt in gt_imgs
                ]
                np_gbox_pmsk_imgs = [
                    pred.cpu().numpy().transpose(1, 2, 0)
                    for pred in p_gbox_pmsk_img
                ]
                np_test_pred_imgs = [
                    pred.cpu().numpy().transpose(1, 2, 0)
                    for pred in p_test_imgs
                ]

                pred_layout_boxes = p_test_boxes
                pred_layout_masks = p_test_masks
                np_all_imgs = []

                # Overlay box on images
                pred_layout_boxes_t = pred_layout_boxes.detach()
                # overlaid_images = vis.overlay_boxes(np_test_pred_imgs, model.vocab, objs_vec, layout_boxes_t, obj_to_img, W=64, H=64)
                overlaid_images = vis.overlay_boxes(np_test_pred_imgs,
                                                    model.vocab,
                                                    objs_vec,
                                                    pred_layout_boxes_t,
                                                    obj_to_img,
                                                    W=64,
                                                    H=64)

                # # # draw the layout
                # layouts_gt = vis.debug_layout_mask(model.vocab, objs_vec, layout_boxes, layout_masks, obj_to_img, W=128, H=128)
                # layouts_pred = vis.debug_layout_mask(model.vocab, objs_vec, pred_layout_boxes, pred_layout_masks, obj_to_img, W=128, H=128)

                for gt_img, gtb_pm_img, pred_img, overlaid in zip(
                        np_gt_imgs, np_gbox_pmsk_imgs, np_test_pred_imgs,
                        overlaid_images):
                    # for gt_img, gtb_gtm_img, gtb_pm_img, pred_img, gt_layout_img, pred_layout_img, overlaid in zip(np_gt_imgs, np_pred_imgs, np_gbox_pmsk_imgs, np_test_pred_imgs, layouts_gt, layouts_pred, overlaid_images):
                    img_path = os.path.join(img_dir, '%06d_gt_img.png' % t)
                    imwrite(img_path, gt_img)

                    img_path = os.path.join(img_dir, '%06d_gtb_pm_img.png' % t)
                    imwrite(img_path, gtb_pm_img)

                    img_path = os.path.join(img_dir, '%06d_pred_img.png' % t)
                    imwrite(img_path, pred_img)

                    overlaid_path = os.path.join(img_dir,
                                                 '%06d_overlaid.png' % t)
                    imwrite(overlaid_path, overlaid)

                    t = t + 1

                total_iou += jaccard(boxes_pred, boxes)
                total_boxes += boxes_pred.size(0)

                ## Draw scene graph
                tot_obj = 0
                for b_t in range(imgs.size(0)):
                    sg_objs = objs[obj_to_img == b_t]
                    sg_rels = triples[triple_to_img == b_t]
                    sg_img = vis.draw_scene_graph_temp(sg_objs,
                                                       sg_rels,
                                                       tot_obj,
                                                       vocab=model.vocab)
                    sg_img_path = os.path.join(img_dir, '%06d_sg.png' % t1)
                    imwrite(sg_img_path, sg_img)

                    tot_obj = tot_obj + len(sg_objs)  #.size(0)
                    t1 = t1 + 1

                # for gt_img, gtb_gtm_img, gtb_pm_img, pred_img in zip(np_gt_imgs, np_pred_imgs, np_gbox_pmsk_imgs, np_test_pred_imgs):
                #   np_all_imgs.append((gt_img * 255.0).astype(np.uint8))
                #   np_all_imgs.append((gtb_gtm_img * 255.0).astype(np.uint8))
                #   np_all_imgs.append((gtb_pm_img * 255.0).astype(np.uint8))
                #   np_all_imgs.append((pred_img * 255.0).astype(np.uint8))

                # logger.image_summary(log_tag, np_all_imgs, t)
            #########################################################################

        # mean_losses = {k: np.mean(v) for k, v in all_losses.items()}
        # avg_iou = total_iou / total_boxes

        masks_to_store = masks
        if masks_to_store is not None:
            masks_to_store = masks_to_store.data.cpu().clone()

        masks_pred_to_store = masks_pred
        if masks_pred_to_store is not None:
            masks_pred_to_store = masks_pred_to_store.data.cpu().clone()

    batch_data = {
        'objs': objs.detach().cpu().clone(),
        'boxes_gt': boxes.detach().cpu().clone(),
        'masks_gt': masks_to_store,
        'triples': triples.detach().cpu().clone(),
        'obj_to_img': obj_to_img.detach().cpu().clone(),
        'triple_to_img': triple_to_img.detach().cpu().clone(),
        'boxes_pred': boxes_pred.detach().cpu().clone(),
        'masks_pred': masks_pred_to_store
    }
    #out = [mean_losses, samples, batch_data, avg_iou]
    out = [samples]

    ####################
    avg_iou = total_iou / total_boxes
    avg_iou
    # print('ravg iou:' avg_iou)
    ###################

    return tuple(out)
Пример #4
0
def check_model(args, t, loader, model, logger=None, log_tag='', write_images=False):
  float_dtype = torch.cuda.FloatTensor
  long_dtype = torch.cuda.LongTensor
  num_samples = 0
  all_losses = defaultdict(list)
  total_iou = 0
  total_boxes = 0
  with torch.no_grad():
    for batch in loader:
      batch = [tensor.cuda() for tensor in batch]
      masks = None
      if len(batch) == 6:
        imgs, objs, boxes, triples, obj_to_img, triple_to_img = batch
        triplet_masks = None
      elif len(batch) == 8:
        imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img, triplet_masks = batch
      #elif len(batch) == 7:
        #imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img = batch
      predicates = triples[:, 1] 

      # Run the model as it has been run during training
      model_masks = masks
      model_out = model(objs, triples, obj_to_img, boxes_gt=boxes, masks_gt=model_masks)
      # imgs_pred, boxes_pred, masks_pred, predicate_scores = model_out
      imgs_pred, boxes_pred, masks_pred, objs_vec, layout, layout_boxes, layout_masks, obj_to_img, sg_context_pred, sg_context_pred_d, predicate_scores, obj_embeddings, pred_embeddings, triplet_boxes_pred, triplet_boxes, triplet_masks_pred, boxes_pred_info, triplet_superboxes_pred = model_out

      # add additional information for GT boxes (hack to not change coco.py)
      boxes_info = None
      if args.use_bbox_info and boxes_pred_info is not None:
        boxes_info = add_bbox_info(boxes)
      # GT for triplet superbox
      triplet_superboxes = None
      if args.triplet_superbox_net and triplet_superboxes_pred is not None:
        # triplet_boxes = [ x1_0 y1_0 x1_1 y1_1 x2_0 y2_0 x2_1 y2_1]
        min_pts = triplet_boxes[:,:2]
        max_pts = triplet_boxes[:,6:8]
        triplet_superboxes = torch.cat([min_pts, max_pts], dim=1)

      # for layout model, we don't care about these
      #skip_pixel_loss = False
      #skip_perceptual_loss = False
      skip_pixel_loss = True 
      skip_perceptual_loss = True 

      # calculate all losses here
      total_loss, losses =  calculate_model_losses(
                                args, skip_pixel_loss, model, imgs, imgs_pred,
                                boxes, boxes_pred, masks, masks_pred,
                                boxes_info, boxes_pred_info,
                                predicates, predicate_scores,
                                triplet_boxes, triplet_boxes_pred, 
                                triplet_masks, triplet_masks_pred,
                                triplet_superboxes, triplet_superboxes_pred,
                                skip_perceptual_loss)

      losses['total_loss'] = total_loss.item()

      total_iou += jaccard(boxes_pred, boxes)
      total_boxes += boxes_pred.size(0)

      for loss_name, loss_val in losses.items():
        all_losses[loss_name].append(loss_val)
      num_samples += imgs.size(0)
      if num_samples >= args.num_val_samples:
        break

    
    samples = {}
    samples['gt_img'] = imgs

    #pdb.set_trace()
    #model_out = model(objs, triples, obj_to_img, boxes_gt=boxes, masks_gt=masks)
    #samples['gt_box_gt_mask'] = model_out[0]

    #model_out = model(objs, triples, obj_to_img, boxes_gt=boxes)
    #samples['gt_box_pred_mask'] = model_out[0]

    #model_out = model(objs, triples, obj_to_img)
    #samples['pred_box_pred_mask'] = model_out[0]
    
    #for k, v in samples.items():
    #  samples[k] = imagenet_deprocess_batch(v) 

    #  if logger is not None and write_images:
       #   #3. Log ground truth and predicted images
    #     with torch.no_grad():
    #       p_imgs = samples['gt_box_gt_mask'].detach() 
    #       gt_imgs = samples['gt_img'].detach() 
    #       p_gbox_pmsk_img = samples['gt_box_pred_mask'] 
    #       p_test_imgs = samples['pred_box_pred_mask'] 

    #     np_gt_imgs = [gt.cpu().numpy().transpose(1,2,0) for gt in gt_imgs]
    #     np_pred_imgs = [pred.cpu().numpy().transpose(1,2,0) for pred in p_imgs]
    #     np_gbox_pmsk_imgs = [pred.cpu().numpy().transpose(1,2,0) for pred in p_gbox_pmsk_img] 
    #     np_test_pred_imgs = [pred.cpu().numpy().transpose(1,2,0) for pred in p_test_imgs]  
    #     np_all_imgs = []
      
    #     for gt_img, gtb_gtm_img, gtb_pm_img, pred_img in zip(np_gt_imgs, np_pred_imgs, np_gbox_pmsk_imgs, np_test_pred_imgs):
    #       np_all_imgs.append((gt_img * 255.0).astype(np.uint8))
    #       np_all_imgs.append((gtb_gtm_img * 255.0).astype(np.uint8))
    #       np_all_imgs.append((gtb_pm_img * 255.0).astype(np.uint8)) 
    #       np_all_imgs.append((pred_img * 255.0).astype(np.uint8))  

    #     logger.image_summary(log_tag, np_all_imgs, t)
      ######################################################################### 

    mean_losses = {k: np.mean(v) for k, v in all_losses.items()}
    avg_iou = total_iou / total_boxes

    masks_to_store = masks
    if masks_to_store is not None:
      masks_to_store = masks_to_store.data.cpu().clone()

    masks_pred_to_store = masks_pred
    if masks_pred_to_store is not None:
      masks_pred_to_store = masks_pred_to_store.data.cpu().clone()

  batch_data = {
    'objs': objs.detach().cpu().clone(),
    'boxes_gt': boxes.detach().cpu().clone(), 
    'masks_gt': masks_to_store,
    'triples': triples.detach().cpu().clone(),
    'obj_to_img': obj_to_img.detach().cpu().clone(),
    'triple_to_img': triple_to_img.detach().cpu().clone(),
    'boxes_pred': boxes_pred.detach().cpu().clone(),
    'masks_pred': masks_pred_to_store
  }
  out = [mean_losses, samples, batch_data, avg_iou]

  return tuple(out)
Пример #5
0
def get_rel_score(args, t, loader, model):
    float_dtype = torch.FloatTensor
    long_dtype = torch.LongTensor
    num_samples = 0
    all_losses = defaultdict(list)
    total_iou = 0
    total_boxes = 0

    ###################
    if not os.path.isdir(args.output_dir):
        os.mkdir(args.output_dir)
        print('Created %s' % args.output_dir)

    img_dir = args.output_dir + '/img_dir'

    if not os.path.isdir(img_dir):
        os.mkdir(img_dir)
        print('Created %s' % img_dir)
    ##################
    rel_score = 0
    with torch.no_grad():
        o_start = o_end = 0
        t_start = t_end = 0
        last_o_idx = last_t_idx = 0

        b = 0
        total_boxes = 0
        total_iou = 0
        for batch in loader:
            batch = [tensor.cuda() for tensor in batch]
            #batch = [tensor for tensor in batch]
            masks = None
            if len(batch) == 6:
                imgs, objs, boxes, triples, obj_to_img, triple_to_img = batch
            elif len(batch) == 10:
                imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img, triplet_masks, triplet_contours, contours = batch
            #elif len(batch) == 8:
            #  imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img, triplet_masks = batch
            #elif len(batch) == 7:
            #  imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img = batch
            predicates = triples[:, 1]

            objs = objs.detach()
            triples = triples.detach()
            # Run the model as it has been run during training
            model_masks = masks
            model_out = model(objs,
                              triples,
                              obj_to_img,
                              boxes_gt=boxes,
                              masks_gt=model_masks)
            # imgs_pred, boxes_pred, masks_pred, predicate_scores = model_out
            #imgs_pred, boxes_pred, masks_pred, objs_vec, layout, layout_boxes, layout_masks, obj_to_img, sg_context_pred, sg_context_pred_d, predicate_scores, obj_embeddings, pred_embeddings, triple_boxes_pred, triple_boxes_gt = model_out
            #imgs_pred, boxes_pred, masks_pred, objs_vec, layout, layout_boxes, layout_masks, obj_to_img, sg_context_pred, sg_context_pred_d, predicate_scores, obj_embeddings, pred_embeddings, triple_boxes_pred, triple_boxes_gt, triplet_masks_pred = model_out
            imgs_pred, boxes_pred, masks_pred, objs_vec, layout, layout_boxes, layout_masks, obj_to_img, sg_context_pred, sg_context_pred_d, predicate_scores, obj_embeddings, pred_embeddings, triple_boxes_pred, triple_boxes_gt, triplet_masks_pred, triplet_contour_pred, contours_pred = model_out

            num_samples += imgs.size(0)
            if num_samples >= args.num_val_samples:
                break

            # if obj contours are predicted, derive bounding box from these;
            # if GT boxes are passed in, layout_masks are GT boxes
            if contours_pred is not None and boxes_pred is None:
                boxes_gt = boxes
                # get min-max bboxes, contour mean (not bbox mean), contour std dev (e.g. use predicted bboxes to estimate centers)
                boxes_pred, mean_pts_pred, std_pts_pred = min_max_bbox_fr_contours(
                    contours_pred)  # predicted
                #boxes, mean_pts, std_pts = min_max_bbox_fr_contours(contours) # GT

                # use geom center contour of contour to calculate relations score
                # (can't calculate geometric center of contour in sg2im.data.utils.determine_box_relation()
                #boxes_pred_mn = torch.cat([mean_pts_pred, mean_pts_pred], dim=1)
                #boxes_mn = torch.cat([mean_pts, mean_pts], dim=1)
                #pdb.set_trace()

                if 0:
                    # visualize first contour of each batch
                    import matplotlib.pyplot as plt
                    cc = contours.clone().view(-1, 12, 2)
                    cp = contours_pred.view(-1, 12, 2).clone().detach()
                    cp = cp.cpu().numpy()
                    #bb = new_boxes[0].view(2,2)
                    bb = boxes[0].view(2, 2)
                    bbp = boxes_pred.clone().detach()
                    #bbp = new_boxes_pred.clone().detach()
                    bbp = bbp[0].view(2, 2)
                    fig, ax = plt.subplots()
                    #ax.imshow(masks[0])
                    # without mask, origin will be LLHC
                    ax.scatter(cc[0, :, 0], cc[0, :, 1], linewidth=0.5)
                    ax.scatter(cp[0, :, 0], cp[0, :, 1], linewidth=0.5)
                    ax.scatter(bb[:, 0], bb[:, 1], linewidth=1.0, marker="x")
                    ax.scatter(bbp[:, 0], bbp[:, 1], linewidth=1.0, marker="x")
                    plt.show()
                    #pdb.set_trace()

            # use geometric center of contour l/r/t/b spatial relationships; use min-max box for surrounding/in
            #rel_score += relation_score_cont(boxes_pred_mn, boxes_mn, masks_pred, masks, boxes_pred, boxes, model.vocab) # no mask, faske boxes
            rel_score += relation_score(boxes_pred, boxes, masks_pred, masks,
                                        model.vocab)
            b += 1

            #total_iou += jaccard(new_boxes_pred, boxes)
            total_iou += jaccard(boxes_pred, boxes)
            total_boxes += boxes_pred.size(0)

        rel_score = rel_score / b
        avg_iou = total_iou / total_boxes

        # print('rel score:' rel_score)
        # print('avg iou:' avg_iou)
        return rel_score, avg_iou
Пример #6
0
def check_model(args, t, loader, model, log_tag='', write_images=False):

    if torch.cuda.is_available():
        float_dtype = torch.cuda.FloatTensor
        long_dtype = torch.cuda.LongTensor
    else:
        float_dtype = torch.FloatTensor
        long_dtype = torch.LongTensor

    ###################
    if not os.path.isdir(args.output_dir):
        os.mkdir(args.output_dir)
        print('Created %s' % args.output_dir)

    img_dir = args.output_dir + '/img_dir'

    if not os.path.isdir(img_dir):
        os.mkdir(img_dir)
        print('Created %s' % img_dir)
    ##################

    ## if specified load saved objectembeddings
    if args.coco_object_db_json is not None:
        pdb.set_trace()
        object_db = db_utils.read_fr_JSON(args.coco_object_db_json)
        triplet_db = None
    else:
        ## begin extract embedding data from model
        num_samples = 0
        all_losses = defaultdict(list)
        total_iou = 0
        total_boxes = 0
        t = 0
        # relationship (triplet) database
        triplet_db = dict()
        object_db = dict()

        # iterate over all batches of images
        with torch.no_grad():

            for batch in loader:

                if torch.cuda.is_available():
                    batch = [tensor.cuda() for tensor in batch]
                else:
                    batch = [tensor for tensor in batch]

                masks = None
                if len(batch) == 6:  # VG
                    imgs, objs, boxes, triples, obj_to_img, triple_to_img = batch
                elif len(batch) == 11:  # COCO
                    imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img, triplet_masks, extreme_points, cat_words, cat_ids = batch
                predicates = triples[:, 1]

                # Run the model as it has been run during training
                model_masks = masks
                model_out = model(objs,
                                  triples,
                                  obj_to_img,
                                  boxes_gt=boxes,
                                  masks_gt=model_masks)
                # imgs_pred, boxes_pred, masks_pred, predicate_scores = model_out
                #imgs_pred, boxes_pred, masks_pred, objs_vec, layout, layout_boxes, layout_masks, obj_to_img, sg_context_pred, sg_context_pred_d, predicate_scores, obj_embeddings, pred_embeddings, triple_boxes_pred, triple_boxes_gt, triplet_masks_pred, triplet_contours_pred = model_out
                imgs_pred, boxes_pred, masks_pred, objs_vec, layout, layout_boxes, layout_masks, obj_to_img, sg_context_pred, sg_context_pred_d, predicate_scores, obj_embeddings, pred_embeddings, triple_boxes_pred, triple_boxes_gt, triplet_masks_pred, boxes_pred_info, triplet_superboxes_pred = model_out
                # Run model without GT boxes to get predicted layout masks
                #model_out = model(objs, triples, obj_to_img)
                #layout_boxes, layout_masks = model_out[5], model_out[6]

                num_batch_samples = imgs.size(0)
                num_samples += num_batch_samples
                if num_samples >= args.num_train_samples:
                    break

                super_boxes = []

                # open file to record all triplets, per image, in a batch
                file_path = os.path.join(img_dir, 'all_batch_triplets.txt')
                f = open(file_path, 'w')
                ### embedding stuff below here ####
                for i in range(0, num_batch_samples):
                    print('Processing image', i + 1, 'of batch size',
                          args.batch_size)
                    f.write('---------- image ' + str(i) + '----------\n')

                    # process objects and get embeddings, per image
                    objs_index = np.where(obj_to_img == i)[
                        0]  # objects indices for image in batch
                    objs_img = objs[
                        objs_index]  # object class id labels for image
                    obj_names = np.array(model.vocab['object_idx_to_name'])[
                        objs_img]  # index with object class index
                    # scene graph embedding
                    obj_embeddings_img = obj_embeddings[objs_index]
                    # word embedding
                    obj_embeddings_word = cat_words[objs_index]
                    num_objs = len(objs_index)

                    for j in range(0, num_objs):
                        name = obj_names[j]
                        entry = {
                            'id': objs_img[j].tolist(),
                            'embed': obj_embeddings_img[j].tolist(),
                            'word_embed': obj_embeddings_word[j].tolist()
                        }
                        if name not in object_db:
                            object_db[name] = dict()
                            object_db[name]['objs'] = [entry]
                            object_db[name]['count'] = 1
                            #object_db[name]['word_embed'] = obj_embeddings_word[j].tolist() # word embedding
                        elif name in object_db:
                            object_db[name]['objs'] += [entry]
                            object_db[name]['count'] += 1
                        f.write('obj ' + str(i) + ': ' + name + '\n')

                    # test if db is serializable
                    #pdb.set_trace()
                    #import json
                    #jd = json.dumps(object_db)

# process all triples for in image
                    tr_index = np.where(triple_to_img.cpu().numpy() == i)
                    tr_img = triples[tr_index]
                    # 8 point triple boxes
                    np_triple_boxes_gt = np.array(triple_boxes_gt).astype(
                        float)
                    tr_img_boxes = np_triple_boxes_gt[tr_index]
                    assert len(tr_img) == len(tr_img_boxes)

                    # vocab['object_idx_to_name'], vocab['pred_idx_to_name']
                    # s,o: indices for "objs" array (yields 'object_idx' for 'object_idx_to_name')
                    # p: use this value as is (yields 'pred_idx' for 'pred_idx_to_name')
                    s, p, o = np.split(tr_img, 3, axis=1)

                    # iterate over all triplets in image to form (subject, predicat, object) tuples
                    relationship_data = []
                    num_triples = len(tr_img)

                    # need to iterate over all triples due to information that needs to be extracted per triple
                    for n in range(0, num_triples):
                        # tuple = (objs[obj_index], p, objs[subj_index])
                        subj_index = s[n]
                        subj = np.array(model.vocab['object_idx_to_name'])[
                            objs[subj_index]]
                        # object whitelist
                        if subj != 'person':
                            continue
                        pred = np.array(model.vocab['pred_idx_to_name'])[p[n]]
                        obj_index = o[n]
                        obj = np.array(
                            model.vocab['object_idx_to_name'])[objs[obj_index]]
                        triplet = tuple([subj, pred, obj])
                        relationship_data += [tuple([subj, pred, obj])]
                        #print(tuple([subj, pred, obj]))
                        #print('--------------------')
                        f.write('(' + db_utils.tuple_to_string(
                            tuple([subj, pred, obj])) + ')\n')

                        # GT bounding boxes: (x0, y0, x1, y1) format, in a [0, 1] coordinate system
                        # (from "boxes" (one for each object in "objs") using subj_index and obj_index)
                        subj_bbox = tr_img_boxes[n, 0:5]
                        obj_bbox = tr_img_boxes[n, 4:8]
                        #print(tuple([subj, pred, obj]), subj_bbox, obj_bbox)

                        # SG GCNN embeddings
                        subj_embed = obj_embeddings[subj_index].cpu().numpy(
                        ).tolist()
                        pred_embed = pred_embeddings[n].cpu().numpy().tolist()
                        obj_embed = obj_embeddings[obj_index].cpu().numpy(
                        ).tolist()
                        #pooled_embed = subj_embed + pred_embed + obj_embed

                        # add relationship to database
                        relationship = dict()
                        relationship['subject'] = subj
                        relationship['predicate'] = pred
                        relationship['object'] = obj
                        relationship['subject_bbox'] = subj_bbox.tolist(
                        )  #JSON can't serialize np.array()
                        relationship['object_bbox'] = obj_bbox.tolist()

                        # get super box
                        min_x = np.min([subj_bbox[0], obj_bbox[0]])
                        min_y = np.min([subj_bbox[1], obj_bbox[1]])
                        max_x = np.max([subj_bbox[2], obj_bbox[2]])
                        max_y = np.max([subj_bbox[3], obj_bbox[3]])
                        relationship['super_bbox'] = [
                            min_x, min_y, max_x, max_y
                        ]
                        super_boxes += [relationship['super_bbox']]
                        relationship['subject_embed'] = subj_embed
                        relationship['predicate_embed'] = pred_embed
                        relationship['object_embed'] = obj_embed
                        #relationship['embed'] = pooled_embed

                        if triplet not in triplet_db:
                            triplet_db[db_utils.tuple_to_string(triplet)] = [
                                relationship
                            ]
                        elif triplet in triplet_db:
                            triplet_db[db_utils.tuple_to_string(triplet)] += [
                                relationship
                            ]
#print('------- end of processing for image --------------------------')

                f.close()
                # measure IoU as a basic metric for bbox prediction
                total_iou += jaccard(boxes_pred, boxes)
                total_boxes += boxes_pred.size(0)
                ####### end single batch process #########
                print(
                    '------- single batch processing --------------------------'
                )

        # write object database to JSON file
        pdb.set_trace()
        db_utils.write_to_JSON(object_db, args.coco_object_db_json_write)
        # write triplet database to JSON file
        #db_utils.write_to_JSON(triplet_db, "coco_triplet_db.json")
    #####  end embedding extraction

    # analyze JSON database (stats, examples,etc)
    # also, compare word vs SG embedding
    analyze_object_db(object_db, analyze_word_embed=True)

    #batch_data = {
    #  'objs': objs.detach().cpu().clone(),
    #  'boxes_gt': boxes.detach().cpu().clone(),
    #  'masks_gt': masks_to_store,
    #  'triples': triples.detach().cpu().clone(),
    #  'obj_to_img': obj_to_img.detach().cpu().clone(),
    #  'triple_to_img': triple_to_img.detach().cpu().clone(),
    #  'boxes_pred': boxes_pred.detach().cpu().clone(),
    #  'masks_pred': masks_pred_to_store
    #}
    #out = [mean_losses, samples, batch_data, avg_iou]
    #out = [mean_losses, avg_iou]
    out = []

    ####################
    avg_iou = total_iou / total_boxes
    print('average bbox IoU = ', avg_iou.cpu().numpy())
    ###################

    return tuple(out)
Пример #7
0
def check_model(args, t, loader, model, log_tag='', write_images=False):

    if torch.cuda.is_available():
        float_dtype = torch.cuda.FloatTensor
        long_dtype = torch.cuda.LongTensor
    else:
        float_dtype = torch.FloatTensor
        long_dtype = torch.LongTensor

    num_samples = 0
    all_losses = defaultdict(list)
    total_iou = 0
    total_boxes = 0

    ###################
    if not os.path.isdir(args.output_dir):
        os.mkdir(args.output_dir)
        print('Created %s' % args.output_dir)

    img_dir = args.output_dir + '/img_dir'

    if not os.path.isdir(img_dir):
        os.mkdir(img_dir)
        print('Created %s' % img_dir)
    ##################

    t = 0
    # relationship (triplet) database
    triplet_db = dict()

    # iterate over all batches of images
    with torch.no_grad():
        for batch in loader:

            # TODO: HERE
            if torch.cuda.is_available():
                batch = [tensor.cuda() for tensor in batch]
            else:
                batch = [tensor for tensor in batch]

            masks = None
            if len(batch) == 6:  # VG
                imgs, objs, boxes, triples, obj_to_img, triple_to_img = batch
            #elif len(batch) == 8: # COCO
            #  imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img, triplet_masks = batch
            #elif len(batch) == 9: # COCO
            #  imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img, triplet_masks, triplet_contours = batch
            elif len(batch) == 10:  # COCO
                imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img, triplet_masks, triplet_contours, obj_contours = batch
            #elif len(batch) == 7:
            #  imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img = batch
            predicates = triples[:, 1]

            # Run the model as it has been run during training
            model_masks = masks
            model_out = model(objs,
                              triples,
                              obj_to_img,
                              boxes_gt=boxes,
                              masks_gt=model_masks)
            # imgs_pred, boxes_pred, masks_pred, predicate_scores = model_out
            imgs_pred, boxes_pred, masks_pred, objs_vec, layout, layout_boxes, layout_masks, obj_to_img, sg_context_pred, sg_context_pred_d, predicate_scores, obj_embeddings, pred_embeddings, triple_boxes_pred, triple_boxes_gt, triplet_masks_pred, triplet_contours_pred, obj_contours_pred = model_out
            #imgs_pred, boxes_pred, masks_pred, objs_vec, layout, layout_boxes, layout_masks, obj_to_img, sg_context_pred, sg_context_pred_d, predicate_scores, obj_embeddings, pred_embeddings, triple_boxes_pred, triple_boxes_gt, triplet_masks_pred, triplet_contours_pred = model_out
            #imgs_pred, boxes_pred, masks_pred, objs_vec, layout, layout_boxes, layout_masks, obj_to_img, sg_context_pred, sg_context_pred_d, predicate_scores, obj_embeddings, pred_embeddings, triple_boxes_pred, triple_boxes_gt, triplet_masks_pred = model_out
            # Run model without GT boxes to get predicted layout masks

            # use this when to get layout boxes/masks using predicted boxes
            #model_out = model(objs, triples, obj_to_img)
            #layout_boxes, layout_masks = model_out[5], model_out[6]

            # if obj contours are predicted, derive bounding box from these;
            # if GT boxes are passed in, layout_masks are GT boxes
            if obj_contours_pred is not None and boxes_pred is None:
                boxes_pred = min_max_bbox_fr_contours(obj_contours_pred)

                if 0:
                    import matplotlib.pyplot as plt
                    cc = obj_contours.clone().view(-1, 12, 2)
                    cp = obj_contours_pred.view(-1, 12, 2).clone().detach()
                    cp = cp.cpu().numpy()
                    bb = boxes[0].view(2, 2)
                    bbp = boxes_pred.clone().detach()
                    bbp = bbp[0].view(2, 2)
                    fig, ax = plt.subplots()
                    #ax.imshow(masks[0])
                    # without mask, origin will be LLHC
                    ax.scatter(cc[0, :, 0], cc[0, :, 1], linewidth=0.5)
                    ax.scatter(cp[0, :, 0], cp[0, :, 1], linewidth=0.5)
                    ax.scatter(bb[:, 0], bb[:, 1], linewidth=1.0, marker="x")
                    ax.scatter(bbp[:, 0], bbp[:, 1], linewidth=1.0, marker="x")
                    plt.show()
                    # display masks
                    masks_pred = masks_pred.detach()
                    np_masks_pred = [mask.cpu().numpy() for mask in masks_pred]
                    fig = plt.figure()
                    ax1 = fig.add_subplot(1, 2, 1)
                    ax1.imshow(masks[0])
                    ax2 = fig.add_subplot(1, 2, 2)
                    ax2.imshow(np_masks_pred[0])
                    plt.show()
                    pdb.set_trace()

            num_batch_samples = imgs.size(0)
            num_samples += num_batch_samples
            if num_samples >= args.num_val_samples:
                break

            super_boxes = []

            # open file to record all triplets, per image, in a batch
            file_path = os.path.join(img_dir, 'all_batch_triplets.txt')
            f = open(file_path, 'w')
            ### embedding stuff below here ####
            for i in range(0, num_batch_samples):
                print('Processing image', i + 1, 'of batch size',
                      args.batch_size)
                f.write('---------- image ' + str(i) + '----------\n')

                # from batch: objs, triples, triple_to_img, objs_to_img (need indices in that to select to tie triplets to image)
                # from model: obj_embed, pred_embed

                # find all triple indices for specific image
                # all triples for image i
                # TODO: clean up code so it is numpy() equivalent in all places
                tr_index = np.where(triple_to_img.cpu().numpy() == i)
                tr_img = triples.cpu().numpy()[tr_index, :]
                tr_img = np.squeeze(tr_img, axis=0)
                # 8 point triple boxes
                np_triple_boxes_gt = np.array(triple_boxes_gt).astype(float)
                tr_img_boxes = np_triple_boxes_gt[tr_index]
                assert len(tr_img) == len(tr_img_boxes)

                # vocab['object_idx_to_name'], vocab['pred_idx_to_name']
                # s,o: indices for "objs" array (yields 'object_idx' for 'object_idx_to_name')
                # p: use this value as is (yields 'pred_idx' for 'pred_idx_to_name')
                s, p, o = np.squeeze(np.split(tr_img, 3, axis=1))

                # iterate over all triplets in image to form (subject, predicat, object) tuples
                relationship_data = []
                num_triples = len(tr_img)

                # need to iterate over all triples due to information that needs to be extracted per triple
                for n in range(0, num_triples):
                    # tuple = (objs[obj_index], p, objs[subj_index])
                    subj_index = s[n]
                    subj = np.array(
                        model.vocab['object_idx_to_name'])[objs[subj_index]]
                    pred = np.array(model.vocab['pred_idx_to_name'])[p[n]]
                    obj_index = o[n]
                    obj = np.array(
                        model.vocab['object_idx_to_name'])[objs[obj_index]]
                    triplet = tuple([subj, pred, obj])
                    relationship_data += [tuple([subj, pred, obj])]
                    print(tuple([subj, pred, obj]))
                    #print('--------------------')
                    f.write(
                        '(' +
                        db_utils.tuple_to_string(tuple([subj, pred, obj])) +
                        ')\n')

                    # GT bounding boxes: (x0, y0, x1, y1) format, in a [0, 1] coordinate system
                    # (from "boxes" (one for each object in "objs") using subj_index and obj_index)
                    subj_bbox = tr_img_boxes[n, 0:5]
                    obj_bbox = tr_img_boxes[n, 4:8]
                    print(tuple([subj, pred, obj]), subj_bbox, obj_bbox)

                    # SG GCNN embeddings to be used for search (nth triplet corresponds to nth embedding)
                    #subj_embed = obj_embeddings[subj_index].numpy().tolist()
                    #pred_embed = pred_embeddings[n].numpy().tolist()
                    #obj_embed = obj_embeddings[obj_index].numpy().tolist()
                    subj_embed = obj_embeddings[subj_index].cpu().numpy(
                    ).tolist()
                    pred_embed = pred_embeddings[n].cpu().numpy().tolist()
                    obj_embed = obj_embeddings[obj_index].cpu().numpy().tolist(
                    )
                    pooled_embed = subj_embed + pred_embed + obj_embed

                    # add relationship to database
                    relationship = dict()
                    relationship['subject'] = subj
                    relationship['predicate'] = pred
                    relationship['object'] = obj
                    relationship['subject_bbox'] = subj_bbox.tolist(
                    )  #JSON can't serialize np.array()
                    relationship['object_bbox'] = obj_bbox.tolist()

                    # get super box
                    #min_x = np.min([tr_img_boxes[n][0], tr_img_boxes[n][4]])
                    #min_y = np.min([tr_img_boxes[n][1], tr_img_boxes[n][5]])
                    #max_x = np.max([tr_img_boxes[n][2], tr_img_boxes[n][6]])
                    #max_y = np.max([tr_img_boxes[n][3], tr_img_boxes[n][7]])
                    min_x = np.min([subj_bbox[0], obj_bbox[0]])
                    min_y = np.min([subj_bbox[1], obj_bbox[1]])
                    max_x = np.max([subj_bbox[2], obj_bbox[2]])
                    max_y = np.max([subj_bbox[3], obj_bbox[3]])
                    #print([min_x, min_y, max_x, max_y])
                    #print([_min_x, _min_y, _max_x, _max_y])
                    relationship['super_bbox'] = [min_x, min_y, max_x, max_y]
                    super_boxes += [relationship['super_bbox']]
                    #relationship['subject_embed'] = subj_embed
                    #relationship['predicate_embed'] = pred_embed
                    #relationship['object_embed'] = obj_embed
                    relationship['embed'] = pooled_embed

                    if triplet not in triplet_db:
                        triplet_db[db_utils.tuple_to_string(triplet)] = [
                            relationship
                        ]
                    elif triplet in triplet_db:
                        triplet_db[db_utils.tuple_to_string(triplet)] += [
                            relationship
                        ]
                    #pprint.pprint(triplet_db)
                    #pdb.set_trace()

                print('---------------------------------')
                #pprint.pprint(relationship_data)
                #pprint.pprint(triplet_db)  # printed per image iteration
                print(
                    '------- end of processing for image --------------------------'
                )

            ####### process batch images by visualizing triplets on all #########
            f.close()
            # measure IoU as a basic metric for bbox prediction
            total_iou += jaccard(boxes_pred, boxes)
            total_boxes += boxes_pred.size(0)

            # detach
            imgs = imgs.detach()
            triplet_masks = triplet_masks.detach()
            if triplet_masks_pred is not None:
                triplet_masks_pred = triplet_masks_pred.detach()
            else:
                triplet_masks_pred = triplet_masks
            boxes_pred = boxes_pred.detach()

            # deprocess (normalize) images
            samples = {}
            samples['gt_imgs'] = imgs

            for k, v in samples.items():
                samples[k] = imagenet_deprocess_batch(v)

            # GT images
            np_imgs = [gt.cpu().numpy().transpose(1, 2, 0) for gt in imgs]
            np_triplet_masks = [mask.cpu().numpy() for mask in triplet_masks]
            np_triplet_masks_pred = [
                mask.cpu().numpy() for mask in triplet_masks_pred
            ]
            # object masks
            np_masks_pred = [mask.cpu().numpy()
                             for mask in masks_pred]  # # objects
            np_masks = [mask.cpu().numpy()
                        for mask in model_masks]  # # objects
            np_layout_masks = [mask.cpu().numpy()
                               for mask in layout_masks]  # # objects

            # visualize predicted boxes/images
            # (output image is always 64x64 based upon how current model is trained)
            pred_overlaid_images = vis.overlay_boxes(np_imgs,
                                                     model.vocab,
                                                     objs_vec,
                                                     boxes_pred,
                                                     obj_to_img,
                                                     W=256,
                                                     H=256)
            # visualize predicted boxes/images

            # predicted layouts and bounding boxes (layout_boxes may be ground truth, layout_boxes = boxes_pred))
            layouts = vis.debug_layout_mask(model.vocab,
                                            objs,
                                            boxes_pred,
                                            layout_masks,
                                            obj_to_img,
                                            W=256,
                                            H=256)
            #layouts = vis.debug_layout_mask(model.vocab, objs, layout_boxes, layout_masks, obj_to_img, W=256, H=256)

            # visualize GT boxes/images
            #overlaid_images = vis.overlay_boxes(np_imgs, model.vocab, objs_vec, boxes, obj_to_img, W=64, H=64)
            overlaid_images = vis.overlay_boxes(np_imgs,
                                                model.vocab,
                                                objs_vec,
                                                boxes,
                                                obj_to_img,
                                                W=256,
                                                H=256)

            # triples to image
            # visualize suberboxes with object boxes underneath
            ##norm_overlaid_images = [i/255.0 for i in overlaid_images]
            ##sb_overlaid_images = vis.overlay_boxes(norm_overlaid_images, model.vocab, objs_vec, torch.tensor(super_boxes), triple_to_img, W=256, H=256, drawText=False, drawSuperbox=True)

            import matplotlib.pyplot as plt
            print("---- saving first GT image of batch -----")
            img_gt = np_imgs[0]
            imwrite('./test_GT_img_coco.png', img_gt)
            #plt.imshow(img_gt)  # can visualize [0-1] or [0-255] color scaling
            #plt.show()

            #print("---- saving first predicted triplet mask of batch -----")
            #gt_mask_np = np_triplet_masks[1]
            #plt.imshow(gt_mask_np)
            #plt.show()
            #pred_mask_np = np_triplet_masks_pred[1]
            #imwrite('./test_pred_overlay_mask_coco.png', img_np)
            #plt.imshow(pred_mask_np)
            #plt.show()

            print("---- saving first overlay image of batch -----")
            imwrite('./test_overlay_img_coco.png', overlaid_images[0])
            #plt.imshow(overlaid_images[0])
            #plt.show()

            print("---- saving first layout image of batch -----")
            imwrite('./test_layout_img_coco.png', layouts[0])
            #plt.imshow(layouts[0])
            #plt.show()

            # display GT / layout mask together
            #fig = plt.figure()
            #ax1 = fig.add_subplot(1,2,1)
            #ax1.imshow(overlaid_images[0])
            #ax2 = fig.add_subplot(1,2,2)
            #ax2.imshow(layouts[0])
            #plt.show()

            #print("---- saving first superbox overlay image of batch -----")
            #imwrite('./test_sb_overlay_img_coco.png', sb_overlaid_images[0])
            #plt.imshow(sb_overlaid_images[0])
            #plt.show()

            pdb.set_trace()
            # visualize predicted object contours with GT singleton mask
            c = 0
            #for np_img in np_imgs:
            for o in obj_contours:
                fig, ax = plt.subplots()
                ax.imshow(np_imgs[0])
                oc = obj_contours[c].view(12, 2) * 256.0
                ocp = obj_contours_pred[c].view(12, 2) * 256.0
                ax.scatter(oc[:, 0], oc[:, 1],
                           linewidth=0.5)  # order was switched in coco_cont.py
                ax.scatter(ocp[:, 0], ocp[:, 1],
                           linewidth=0.5)  # order was switched in coco_cont.py
                plt.show()
                #pdb.set_trace()
                c += 1

            print("---- saving batch images -----")
            if write_images:
                t = 0
                for gt_img, pred_overlaid_img, overlaid_img, layout_img in zip(
                        np_imgs, pred_overlaid_images, overlaid_images,
                        layouts):
                    #for gt_img, pred_overlaid_img, overlaid_img, sb_overlaid_img, layout_img in zip(np_imgs, pred_overlaid_images, overlaid_images, sb_overlaid_images, layouts):
                    img_path = os.path.join(img_dir, '%06d_gt_img.png' % t)
                    imwrite(img_path, gt_img)

                    img_path = os.path.join(img_dir, '%06d_pred_bbox.png' % t)
                    imwrite(img_path, pred_overlaid_img)

                    img_path = os.path.join(img_dir,
                                            '%06d_gt_bbox_img.png' % t)
                    imwrite(img_path, overlaid_img)

                    #img_path = os.path.join(img_dir, '%06d_gt_superbox_img.png' % t)
                    #imwrite(img_path, sb_overlaid_img)

                    img_path = os.path.join(img_dir, '%06d_layout.png' % t)
                    imwrite(img_path, layout_img)

                    t += 1

        # write database to JSON file
        db_utils.write_to_JSON(triplet_db, "coco_test_db.json")

        masks_to_store = masks
        if masks_to_store is not None:
            masks_to_store = masks_to_store.data.cpu().clone()

        masks_pred_to_store = masks_pred
        if masks_pred_to_store is not None:
            masks_pred_to_store = masks_pred_to_store.data.cpu().clone()

    #batch_data = {
    #  'objs': objs.detach().cpu().clone(),
    #  'boxes_gt': boxes.detach().cpu().clone(),
    #  'masks_gt': masks_to_store,
    #  'triples': triples.detach().cpu().clone(),
    #  'obj_to_img': obj_to_img.detach().cpu().clone(),
    #  'triple_to_img': triple_to_img.detach().cpu().clone(),
    #  'boxes_pred': boxes_pred.detach().cpu().clone(),
    #  'masks_pred': masks_pred_to_store
    #}
    #out = [mean_losses, samples, batch_data, avg_iou]
    #out = [mean_losses, avg_iou]
    out = []

    ####################
    avg_iou = total_iou / total_boxes
    print('average bbox IoU = ', avg_iou.cpu().numpy())
    ###################

    return tuple(out)