def getfile(self): """ Loads input data """ self.batch = next(self.data_loader) self.imgs, self.objs, self.boxes, self.triples, self.obj_to_img, self.triple_to_img, self.imgs_in = \ [x.cuda() for x in self.batch] self.keep_box_idx = torch.ones_like(self.objs.unsqueeze(1), dtype=torch.float) self.keep_feat_idx = torch.ones_like(self.objs.unsqueeze(1), dtype=torch.float) self.keep_image_idx = torch.ones_like(self.objs.unsqueeze(1), dtype=torch.float) self.combine_gt_pred_box_idx = torch.zeros_like(self.objs) self.added_objs_idx = torch.zeros_like(self.objs.unsqueeze(1), dtype=torch.float) self.new_triples, self.new_objs = None, None image = imagenet_deprocess_batch(self.imgs) image = image[0].numpy().transpose(1, 2, 0).copy() self.image = image self.draw_input_image(new_image=True)
def save_image_with_label(img_pred, img_gt, img_dir, filename, txt_str): # saves gt and generated image, concatenated # together with text label describing the change # used for easier visualization of results img_pred = imagenet_deprocess_batch(img_pred) img_gt = imagenet_deprocess_batch(img_gt) img_pred_np = img_pred[0].numpy().transpose(1, 2, 0) img_gt_np = img_gt[0].numpy().transpose(1, 2, 0) img_pred_np = cv2.resize(img_pred_np, (128, 128)) img_gt_np = cv2.resize(img_gt_np, (128, 128)) wspace = np.zeros([img_pred_np.shape[0], 10, 3]) text = np.zeros([30, img_pred_np.shape[1] * 2 + 10, 3]) text = cv2.putText(text, txt_str, (0,20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), lineType=cv2.LINE_AA) img_pred_gt = np.concatenate([img_gt_np, wspace, img_pred_np], axis=1).astype('uint8') img_pred_gt = np.concatenate([text, img_pred_gt], axis=0).astype('uint8') img_path = os.path.join(img_dir, filename) imsave(img_path, img_pred_gt)
def eval_model(model, loader, device, vocab, use_gt_boxes=False, use_feats=False, filter_box=False): all_boxes = defaultdict(list) total_iou = [] total_boxes = 0 num_batches = 0 num_samples = 0 mae_per_image = [] mae_roi_per_image = [] roi_only_iou = [] ssim_per_image = [] ssim_rois = [] rois = 0 margin = 2 ## Initializing the perceptual loss model lpips_model = models.PerceptualLoss(model='net-lin', net='alex', use_gpu=True) perceptual_error_image = [] # --------------------------------------- img_idx = 0 with torch.no_grad(): for batch in tqdm.tqdm(loader): num_batches += 1 # if num_batches > 10: # break batch = [tensor.to(device) for tensor in batch] masks = None #len", len(batch)) imgs, objs, boxes, triples, obj_to_img, triple_to_img, imgs_in = [ b.to(device) for b in batch ] predicates = triples[:, 1] #EVAL_ALL = True if not args.generative: imgs, imgs_in, objs, boxes, triples, obj_to_img, \ dropimage_indices, dropfeats_indices = [b.to(device) for b in process_batch( imgs, imgs_in, objs, boxes, triples, obj_to_img, triple_to_img, device, use_feats=use_feats, filter_box=filter_box)] dropbox_indices = dropimage_indices else: dropbox_indices = torch.ones_like( objs.unsqueeze(1).float()).to(device) dropfeats_indices = torch.ones_like( objs.unsqueeze(1).float()).to(device) dropimage_indices = torch.zeros_like( objs.unsqueeze(1).float()).to(device) if imgs.shape[0] == 0: continue if args.visualize_graphs: # visualize scene graphs for debugging purposes visualize_scene_graphs(obj_to_img, objs, triples, vocab, device) if use_gt_boxes: model_out = model( objs, triples, obj_to_img, boxes_gt=boxes, masks_gt=masks, src_image=imgs_in, keep_box_idx=torch.ones_like(dropimage_indices), keep_feat_idx=dropfeats_indices, keep_image_idx=dropimage_indices, mode='eval') else: model_out = model(objs, triples, obj_to_img, boxes_gt=boxes, src_image=imgs_in, keep_box_idx=dropimage_indices, keep_feats_idx=dropfeats_indices, keep_image_idx=dropimage_indices, mode='eval') # OUTPUT imgs_pred, boxes_pred, masks_pred, _, _ = model_out # ---------------------------------------------------------------------------------------------------------- # Save all box predictions all_boxes['boxes_gt'].append(boxes) all_boxes['objs'].append(objs) all_boxes['boxes_pred'].append(boxes_pred) all_boxes['drop_targets'].append(dropbox_indices) # IoU over all total_iou.append(jaccard(boxes_pred, boxes).detach().cpu().numpy()) total_boxes += boxes_pred.size(0) # IoU over targets only pred_dropbox = boxes_pred[dropbox_indices.squeeze() == 0, :] gt_dropbox = boxes[dropbox_indices.squeeze() == 0, :] roi_only_iou.append( jaccard(pred_dropbox, gt_dropbox).detach().cpu().numpy()) rois += pred_dropbox.size(0) num_samples += imgs.shape[0] imgs = imagenet_deprocess_batch(imgs).float() imgs_pred = imagenet_deprocess_batch(imgs_pred).float() if args.visualize_imgs_boxes: # visualize images with drawn boxes for debugging purposes visualize_imgs_boxes(imgs, imgs_pred, boxes, boxes_pred) if args.save_images: # save reconstructed images for later FID and Inception computation if args.save_gt_images: # pass imgs as argument to additionally save gt images save_images(imgs_pred, img_idx, imgs) else: save_images(imgs_pred, img_idx) # MAE per image mae_per_image.append( torch.mean( torch.abs(imgs - imgs_pred).view(imgs.shape[0], -1), 1).cpu().numpy()) for s in range(imgs.shape[0]): # get coordinates of target left, right, top, bottom = bbox_coordinates_with_margin( boxes[s, :], margin, imgs) if left > right or top > bottom: continue # calculate errors only in RoI one by one mae_roi_per_image.append( torch.mean( torch.abs(imgs[s, :, top:bottom, left:right] - imgs_pred[s, :, top:bottom, left:right])).cpu().item()) ssim_per_image.append( pytorch_ssim.ssim(imgs[s:s + 1, :, :, :] / 255.0, imgs_pred[s:s + 1, :, :, :] / 255.0, window_size=3).cpu().item()) ssim_rois.append( pytorch_ssim.ssim( imgs[s:s + 1, :, top:bottom, left:right] / 255.0, imgs_pred[s:s + 1, :, top:bottom, left:right] / 255.0, window_size=3).cpu().item()) # normalize as expected from the LPIPS model imgs_pred_norm = imgs_pred[s:s + 1, :, :, :] / 127.5 - 1 imgs_gt_norm = imgs[s:s + 1, :, :, :] / 127.5 - 1 perceptual_error_image.append( lpips_model.forward(imgs_pred_norm, imgs_gt_norm).detach().cpu().numpy()) if num_batches % args.print_every == 0: calculate_scores(mae_per_image, mae_roi_per_image, total_iou, roi_only_iou, ssim_per_image, ssim_rois, perceptual_error_image) if num_batches % args.save_every == 0: save_results(mae_per_image, mae_roi_per_image, total_iou, roi_only_iou, ssim_per_image, ssim_rois, perceptual_error_image, all_boxes, num_batches) img_idx += 1 calculate_scores(mae_per_image, mae_roi_per_image, total_iou, roi_only_iou, ssim_per_image, ssim_rois, perceptual_error_image) save_results(mae_per_image, mae_roi_per_image, total_iou, roi_only_iou, ssim_per_image, ssim_rois, perceptual_error_image, all_boxes, 'final')
def gen_image(self): """ Generates an image, as indicated by the modified graph """ if self.new_triples is not None: triples_ = self.new_triples else: triples_ = self.triples query_feats = None model_out = self.model( self.new_objs, triples_, None, boxes_gt=self.boxes, masks_gt=None, src_image=self.imgs_in, mode=self.mode, query_feats=query_feats, keep_box_idx=self.keep_box_idx, keep_feat_idx=self.keep_feat_idx, combine_gt_pred_box_idx=self.combine_gt_pred_box_idx, keep_image_idx=self.keep_image_idx, random_feats=args.random_feats, get_layout_boxes=True) imgs_pred, boxes_pred, masks_pred, noised_srcs, _, layout_boxes = model_out image = imagenet_deprocess_batch(imgs_pred) image = image[0].detach().numpy().transpose(1, 2, 0).copy() if args.update_input: self.image = image.copy() image = QtGui.QImage(image, image.shape[1], image.shape[0], QtGui.QImage.Format_RGB888) im_pm = QtGui.QPixmap(image) self.ima.setPixmap(im_pm.scaled(200, 200)) self.ima.setVisible(1) self.imCounter += 1 if args.update_input: # reset everything so that the predicted image is now the input image for the next step self.imgs = imgs_pred.detach().clone() self.imgs_in = torch.cat( [self.imgs, torch.zeros_like(self.imgs[:, 0:1, :, :])], 1) self.draw_input_image() self.boxes = layout_boxes.detach().clone() self.keep_box_idx = torch.ones_like(self.objs.unsqueeze(1), dtype=torch.float) self.keep_feat_idx = torch.ones_like(self.objs.unsqueeze(1), dtype=torch.float) self.keep_image_idx = torch.ones_like(self.objs.unsqueeze(1), dtype=torch.float) self.combine_gt_pred_box_idx = torch.zeros_like(self.objs) else: # input image is still the original one - don't reset anything # if an object is added for the first time, the GT/input box is still a dummy (set in add_triple) # in this case, we update the GT/input box, using the box predicted from SGN, # so that it can be used in future changes that rely on the GT/input box, e.g. replacement self.boxes = self.added_objs_idx * layout_boxes.detach().clone( ) + (1 - self.added_objs_idx) * self.boxes self.added_objs_idx = torch.zeros_like(self.objs.unsqueeze(1), dtype=torch.float)
def save_image_from_tensor(img, img_dir, filename): img = imagenet_deprocess_batch(img) img_np = img[0].numpy().transpose(1, 2, 0) img_path = os.path.join(img_dir, filename) imsave(img_path, img_np)
def run_model(args, checkpoint, loader=None): output_dir = args.exp_dir model = build_model(args, checkpoint) if loader is None: loader = build_eval_loader(args, checkpoint, vocab_t) img_dir = makedir(output_dir, 'images_' + SPLIT) graph_json_dir = makedir(output_dir, 'graphs_json') f = open(output_dir + "/result_ids.txt", "w") img_idx = 0 total_iou_all = [] total_iou = get_def_dict() total_boxes = 0 mae_per_image_all = [] mae_per_image = get_def_dict() mae_roi_per_image_all = [] mae_roi_per_image = get_def_dict() roi_only_iou_all = [] roi_only_iou = get_def_dict() ssim_per_image_all = [] ssim_per_image = get_def_dict() ssim_rois_all = [] ssim_rois = get_def_dict() rois = 0 margin = 2 ## Initializing the perceptual loss model lpips_model = models.PerceptualLoss(model='net-lin', net='alex', use_gpu=True) perceptual_error_image_all = [] perceptual_error_image = get_def_dict() perceptual_error_roi_all = [] perceptual_error_roi = get_def_dict() for batch in loader: imgs, imgs_src, objs, objs_src, boxes, boxes_src, triples, triples_src, obj_to_img, \ triple_to_img, imgs_in = [x.cuda() for x in batch] imgs_gt = imagenet_deprocess_batch(imgs_src) imgs_target_gt = imagenet_deprocess_batch(imgs) # Get mode from target scene - source scene, or image id, using sets graph_set_bef = Counter(tuple(row) for row in tripleToObjID(triples_src, objs_src)) obj_set_bef = Counter([int(obj.cpu()) for obj in objs_src]) graph_set_aft = Counter(tuple(row) for row in tripleToObjID(triples, objs)) obj_set_aft = Counter([int(obj.cpu()) for obj in objs]) if len(objs) > len(objs_src): mode = "addition" changes = graph_set_aft - graph_set_bef obj_ids = list(obj_set_aft - obj_set_bef) new_ids = (objs == obj_ids[0]).nonzero() elif len(objs) < len(objs_src): mode = "remove" changes = graph_set_bef - graph_set_aft obj_ids = list(obj_set_bef - obj_set_aft) new_ids_src = (objs_src == obj_ids[0]).nonzero() new_objs = [obj for obj in objs] new_objs.append(objs_src[new_ids_src[0]]) objs = torch.tensor(new_objs).cuda() num_objs = len(objs) new_ids = [torch.tensor(num_objs-1)] new_boxes = [bbox for bbox in boxes] new_boxes.append(boxes_src[new_ids_src[0]][0]) boxes = torch.stack(new_boxes) obj_to_img = torch.zeros(num_objs, dtype=objs.dtype, device=objs.device) elif torch.all(torch.eq(objs, objs_src)): mode = "reposition" changes = (graph_set_bef - graph_set_aft) + (graph_set_aft - graph_set_bef) idx_cnt = np.zeros((25,1)) for [s,p,o] in list(changes): idx_cnt[s] += 1 idx_cnt[o] += 1 obj_ids = idx_cnt.argmax(0) id_src = (objs_src == obj_ids[0]).nonzero() box_src = boxes_src[id_src[0]] new_ids = (objs == obj_ids[0]).nonzero() boxes[new_ids[0]] = box_src elif len(objs) == len(objs_src): mode = "replace" changes = (graph_set_bef - graph_set_aft) + (graph_set_aft - graph_set_bef) obj_ids = [list(obj_set_bef - obj_set_aft)[0], list(obj_set_aft - obj_set_bef)[0]] new_ids = (objs == obj_ids[1]).nonzero() else: assert False new_ids = [int(new_id.cpu()) for new_id in new_ids] show_im = False if show_im: img_gt = imgs_gt[0].numpy().transpose(1, 2, 0) img_gt_target = imgs_target_gt[0].numpy().transpose(1, 2, 0) fig = plt.figure() fig.add_subplot(1, 2, 1) plt.imshow(img_gt) fig.add_subplot(1, 2, 2) plt.imshow(img_gt_target) plt.show(block=True) query_feats = None if args.with_query_image: img, box = query_image_by_semantic_id(new_ids, img_idx, loader) query_feats = model.forward_visual_feats(img, box) img_filename_query = '%04d_query.png' % (img_idx) img = imagenet_deprocess_batch(img) img_np = img[0].numpy().transpose(1, 2, 0).astype(np.uint8) img_path = os.path.join(img_dir, img_filename_query) imsave(img_path, img_np) img_gt_filename = '%04d_gt_src.png' % (img_idx) img_target_gt_filename = '%04d_gt_target.png' % (img_idx) img_pred_filename = '%04d_changed.png' % (img_idx) img_filename_noised = '%04d_noised.png' % (img_idx) triples_ = triples boxes_gt = boxes keep_box_idx = torch.ones_like(objs.unsqueeze(1), dtype=torch.float) keep_feat_idx = torch.ones_like(objs.unsqueeze(1), dtype=torch.float) keep_image_idx = torch.ones_like(objs.unsqueeze(1), dtype=torch.float) subject_node = new_ids[0] keep_image_idx[subject_node] = 0 if mode == 'reposition': keep_box_idx[subject_node] = 0 elif mode == "remove": keep_feat_idx[subject_node] = 0 else: if mode == "replace": keep_feat_idx[subject_node] = 0 if mode == 'auto_withfeats': keep_image_idx[subject_node] = 0 if mode == 'auto_nofeats': if not args.with_query_image: keep_feat_idx[subject_node] = 0 model_out = model(objs, triples_, obj_to_img, boxes_gt=boxes_gt, masks_gt=None, src_image=imgs_in, mode=mode, query_feats=query_feats, keep_box_idx=keep_box_idx, keep_feat_idx=keep_feat_idx, keep_image_idx=keep_image_idx) imgs_pred, boxes_pred_o, masks_pred, noised_srcs, _ = model_out imgs = imagenet_deprocess_batch(imgs).float() imgs_pred = imagenet_deprocess_batch(imgs_pred).float() #Metrics # IoU over all curr_iou = jaccard(boxes_pred_o, boxes).detach().cpu().numpy() total_iou_all.append(curr_iou) total_iou[mode].append(curr_iou) total_boxes += boxes_pred_o.size(0) # IoU over targets only pred_dropbox = boxes_pred_o[keep_box_idx.squeeze() == 0, :] gt_dropbox = boxes[keep_box_idx.squeeze() == 0, :] curr_iou_roi = jaccard(pred_dropbox, gt_dropbox).detach().cpu().numpy() roi_only_iou_all.append(curr_iou_roi) roi_only_iou[mode].append(curr_iou_roi) rois += pred_dropbox.size(0) # MAE per image curr_mae = torch.mean( torch.abs(imgs - imgs_pred).view(imgs.shape[0], -1), 1).cpu().numpy() mae_per_image[mode].append(curr_mae) mae_per_image_all.append(curr_mae) for s in range(imgs.shape[0]): # get coordinates of target left, right, top, bottom = bbox_coordinates_with_margin(boxes[s, :], margin, imgs) if left > right or top > bottom: continue # print("bboxes with margin: ", left, right, top, bottom) # calculate errors only in RoI one by one curr_mae_roi = torch.mean( torch.abs(imgs[s, :, top:bottom, left:right] - imgs_pred[s, :, top:bottom, left:right])).cpu().item() mae_roi_per_image[mode].append(curr_mae_roi) mae_roi_per_image_all.append(curr_mae_roi) curr_ssim = pytorch_ssim.ssim(imgs[s:s + 1, :, :, :] / 255.0, imgs_pred[s:s + 1, :, :, :] / 255.0, window_size=3).cpu().item() ssim_per_image_all.append(curr_ssim) ssim_per_image[mode].append(curr_ssim) curr_ssim_roi = pytorch_ssim.ssim(imgs[s:s + 1, :, top:bottom, left:right] / 255.0, imgs_pred[s:s + 1, :, top:bottom, left:right] / 255.0, window_size=3).cpu().item() ssim_rois_all.append(curr_ssim_roi) ssim_rois[mode].append(curr_ssim_roi) imgs_pred_norm = imgs_pred[s:s + 1, :, :, :] / 127.5 - 1 imgs_gt_norm = imgs[s:s + 1, :, :, :] / 127.5 - 1 curr_lpips = lpips_model.forward(imgs_pred_norm, imgs_gt_norm).detach().cpu().numpy() perceptual_error_image_all.append(curr_lpips) perceptual_error_image[mode].append(curr_lpips) for i in range(imgs_pred.size(0)): if args.save_imgs: img_gt = imgs_gt[i].numpy().transpose(1, 2, 0).astype(np.uint8) img_gt = cv2.resize(img_gt, (128, 128)) img_gt_path = os.path.join(img_dir, img_gt_filename) imsave(img_gt_path, img_gt) img_gt_target = imgs_target_gt[i].numpy().transpose(1, 2, 0).astype(np.uint8) img_gt_target = cv2.resize(img_gt_target, (128, 128)) img_gt_target_path = os.path.join(img_dir, img_target_gt_filename) imsave(img_gt_target_path, img_gt_target) noised_src_np = imagenet_deprocess_batch(noised_srcs[:, :3, :, :]) noised_src_np = noised_src_np[i].numpy().transpose(1, 2, 0).astype(np.uint8) noised_src_np = cv2.resize(noised_src_np, (128, 128)) img_path_noised = os.path.join(img_dir, img_filename_noised) imsave(img_path_noised, noised_src_np) img_pred_np = imgs_pred[i].numpy().transpose(1, 2, 0).astype(np.uint8) img_pred_np = cv2.resize(img_pred_np, (128, 128)) img_path = os.path.join(img_dir, img_pred_filename) imsave(img_path, img_pred_np) save_graph_json(objs, triples, boxes, "after", graph_json_dir, img_idx) img_idx += 1 if img_idx % print_every == 0: calculate_scores(mae_per_image_all, mae_roi_per_image_all, total_iou_all, roi_only_iou_all, ssim_per_image_all, ssim_rois_all, perceptual_error_image_all, perceptual_error_roi_all) calculate_scores_modes(mae_per_image, mae_roi_per_image, total_iou, roi_only_iou, ssim_per_image, ssim_rois, perceptual_error_image, perceptual_error_roi) print('Saved %d images' % img_idx) f.close()
def check_model(args, t, loader, model): num_samples = 0 all_losses = defaultdict(list) total_iou = 0 total_boxes = 0 with torch.no_grad(): for batch in loader: batch = [tensor.cuda() for tensor in batch] masks = None imgs_src = None if args.dataset == "vg" or (args.dataset == "clevr" and not args.is_supervised): imgs, objs, boxes, triples, obj_to_img, triple_to_img, imgs_in = batch elif args.dataset == "clevr": imgs, imgs_src, objs, objs_src, boxes, boxes_src, triples, triples_src, obj_to_img, \ triple_to_img, imgs_in = batch model_masks = masks model_out = model(objs, triples, obj_to_img, boxes_gt=boxes, masks_gt=model_masks, src_image=imgs_in, imgs_src=imgs_src) imgs_pred, boxes_pred, masks_pred, _, _ = model_out skip_pixel_loss = False total_loss, losses = calculate_model_losses( args, skip_pixel_loss, imgs, imgs_pred, boxes, boxes_pred) total_iou += jaccard(boxes_pred, boxes) total_boxes += boxes_pred.size(0) for loss_name, loss_val in losses.items(): all_losses[loss_name].append(loss_val) num_samples += imgs.size(0) if num_samples >= args.num_val_samples: break samples = {} samples['gt_img'] = imgs model_out = model(objs, triples, obj_to_img, boxes_gt=boxes, masks_gt=masks, src_image=imgs_in, imgs_src=imgs_src) samples['gt_box_gt_mask'] = model_out[0] model_out = model(objs, triples, obj_to_img, boxes_gt=boxes, src_image=imgs_in, imgs_src=imgs_src) samples['generated_img_gt_box'] = model_out[0] samples['masked_img'] = model_out[3][:,:3,:,:] for k, v in samples.items(): samples[k] = imagenet_deprocess_batch(v) mean_losses = {k: np.mean(v) for k, v in all_losses.items()} avg_iou = total_iou / total_boxes masks_to_store = masks if masks_to_store is not None: masks_to_store = masks_to_store.data.cpu().clone() masks_pred_to_store = masks_pred if masks_pred_to_store is not None: masks_pred_to_store = masks_pred_to_store.data.cpu().clone() batch_data = { 'objs': objs.detach().cpu().clone(), 'boxes_gt': boxes.detach().cpu().clone(), 'masks_gt': masks_to_store, 'triples': triples.detach().cpu().clone(), 'obj_to_img': obj_to_img.detach().cpu().clone(), 'triple_to_img': triple_to_img.detach().cpu().clone(), 'boxes_pred': boxes_pred.detach().cpu().clone(), 'masks_pred': masks_pred_to_store } out = [mean_losses, samples, batch_data, avg_iou] return tuple(out)