Ejemplo n.º 1
0
def check_model(args, loader, model):
    num_samples = 0
    model.forward_D = False
    model.eval()

    img_dir = makedir(args.output_dir, 'test' if args.use_gt_textures else 'test_patch')
    gt_img_dir = makedir(args.output_dir, 'test_real', args.save_gt_imgs)

    with torch.no_grad():
        for batch in loader:
            result = model[batch]
            imgs, imgs_pred = result.imgs, result.imgs_pred

            imgs_gt = imagenet_deprocess_batch(imgs)
            imgs_pred = imagenet_deprocess_batch(imgs_pred)
            for i in range(imgs_pred.size(0)):
                img_filename = '%04d.png' % num_samples

                if args.save_gt_imgs:
                    img_gt = imgs_gt[i].numpy().transpose(1, 2, 0)
                    img_gt_path = os.path.join(gt_img_dir, img_filename)
                    imsave(img_gt_path, img_gt)

                img_pred_np = imgs_pred[i].numpy().transpose(1, 2, 0)
                img_path = os.path.join(img_dir, img_filename)
                imsave(img_path, img_pred_np)

                num_samples += 1

            print('Saved %d images' % num_samples)
            if num_samples >= args.num_val_samples:
                break
Ejemplo n.º 2
0
 def write_images(self, t, imgs, imgs_pred, layout_one_hot,
                  layout_pred_one_hot, d_real_crops, d_fake_crops):
     writer = self.writer
     index = int(t / self.args.print_every)
     imgs_print = imagenet_deprocess_batch(imgs)
     writer.add_image(
         'img/real',
         torchvision.utils.make_grid(imgs_print,
                                     normalize=True,
                                     scale_each=True), index)
     if imgs_pred is not None:
         imgs_pred_print = imagenet_deprocess_batch(imgs_pred)
         writer.add_image(
             'img/pred',
             torchvision.utils.make_grid(imgs_pred_print,
                                         normalize=True,
                                         scale_each=True), index)
     if self.obj_discriminator is not None:
         d_real_crops_print = imagenet_deprocess_batch(d_real_crops)
         writer.add_image(
             'objs/d_real',
             torchvision.utils.make_grid(d_real_crops_print,
                                         normalize=True,
                                         scale_each=True), index)
         g_fake_crops_print = imagenet_deprocess_batch(d_fake_crops)
         writer.add_image(
             'objs/g_fake',
             torchvision.utils.make_grid(g_fake_crops_print,
                                         normalize=True,
                                         scale_each=True), index)
     layout_one_hot_3d = self.one_hot_to_rgb(layout_one_hot)
     writer.add_image(
         'img/layout',
         torchvision.utils.make_grid(layout_one_hot_3d.cpu().data,
                                     normalize=True,
                                     scale_each=True), index)
     layout_pred_one_hot_3d = self.one_hot_to_rgb(layout_pred_one_hot)
     writer.add_image(
         'img/layout_pred',
         torchvision.utils.make_grid(layout_pred_one_hot_3d.cpu().data,
                                     normalize=True,
                                     scale_each=True), index)
Ejemplo n.º 3
0
def check_model(args, loader, model):
    num_samples = 0
    model.forward_D = False
    model.eval()

    img_dir = makedir(
        args.output_dir,
        'test_noise' if args.use_gt_textures else 'test_noise_patch')

    crops_path = os.path.join(args.output_dir,
                              args.features_file_name[:-4] + "_crops.pt")
    print(crops_path)
    if os.path.isfile(crops_path):
        crops_dict = torch.load(crops_path)
    else:
        crops_dict = None
        print('No crops file !!!!!!!!!!!!!')

    image_size = 256
    use_gt_textures = args.use_gt_textures
    args.use_gt_textures = True
    with torch.no_grad():
        for _batch in loader:
            for noise_index in range(args.num_diff_noise):
                if noise_index > 0:
                    args.use_gt_textures = use_gt_textures
                batch = deepcopy(_batch)
                result = model[batch]
                imgs, imgs_pred = result.imgs, result.imgs_pred
                objs = result.objs
                change_indexes = result.change_indexes
                crop_indexes = result.crop_indexes
                boxes = result.boxes
                obj_to_img = result.obj_to_img

                imgs_pred = imagenet_deprocess_batch(imgs_pred)
                for i in range(imgs_pred.size(0)):
                    this_img_dir = makedir(img_dir, "%d" % (num_samples + i))
                    img_filename = '%04d.png' % noise_index

                    img_pred_np = imgs_pred[i].numpy().transpose(1, 2, 0)
                    img_path = os.path.join(this_img_dir, img_filename)
                    imsave(img_path, img_pred_np)

                    if args.save_layout:
                        # draw bbox and class
                        image = torch.ones(3, image_size, image_size)
                        image = transforms.ToPILImage()(image).convert("RGB")
                        draw = ImageDraw.Draw(image)
                        index = (obj_to_img == i).nonzero()[:, 0]
                        for ind in index:
                            box = boxes[ind]
                            cls = objs[ind]
                            color_style = 'normal'
                            draw = draw_box(
                                draw, box * image_size,
                                loader.dataset.ind_to_classes[cls + 1],
                                color_style)

                    # draw box of changed object and save used object patch
                    if change_indexes is not None:
                        if args.save_layout:
                            change_index = change_indexes[i]
                            box = boxes[change_index]
                            cls = objs[change_index]
                            color_style = 'special'
                            draw = draw_box(
                                draw, box * image_size,
                                loader.dataset.ind_to_classes[cls + 1],
                                color_style)

                        if args.save_crop and crops_dict is not None:
                            crop = crops_dict[cls][crop_indexes[i]]
                            crop = crop.numpy().transpose(1, 2, 0)
                            crop_path = os.path.join(
                                this_img_dir, "%04d_crop.png" % noise_index)
                            imsave(crop_path, crop)

                    if args.save_layout:
                        image.save(
                            os.path.join(this_img_dir,
                                         "%04d_layout.png" % noise_index))

            num_samples += imgs.shape[0]
            print('Saved %d images' % num_samples)
            if num_samples >= args.num_val_samples:
                break
Ejemplo n.º 4
0
def run_model(args, checkpoint, output_dir, loader=None):
    if args.save_graphs:
        from scene_generation.vis import draw_scene_graph
    dirname = os.path.dirname(args.checkpoint)
    features = None
    if not args.use_gt_textures:
        features_path = os.path.join(dirname, 'features_clustered_001.npy')
        print(features_path)
        if os.path.isfile(features_path):
            features = np.load(features_path, allow_pickle=True).item()
        else:
            raise ValueError('No features file')
    with torch.no_grad():
        vocab = checkpoint['model_kwargs']['vocab']
        model = build_model(args, checkpoint)
        if loader is None:
            loader = build_loader(args, checkpoint, vocab['is_panoptic'])
        accuracy_model = None
        if args.accuracy_model_path is not None and os.path.isfile(
                args.accuracy_model_path):
            accuracy_model = load_model(args.accuracy_model_path)

        img_dir = makedir(output_dir, 'images')
        graph_dir = makedir(output_dir, 'graphs', args.save_graphs)
        gt_img_dir = makedir(output_dir, 'images_gt', args.save_gt_imgs)
        layout_dir = makedir(output_dir, 'layouts', args.save_layout)

        img_idx = 0
        total_iou = 0
        total_boxes = 0
        r_05 = 0
        r_03 = 0
        corrects = 0
        real_objects_count = 0
        num_objs = model.num_objs
        colors = torch.randint(0, 256, [num_objs, 3]).float()
        for batch in loader:
            imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img, attributes = [
                x.cuda() for x in batch
            ]

            imgs_gt = imagenet_deprocess_batch(imgs)

            if args.use_gt_masks:
                masks_gt = masks
            else:
                masks_gt = None
            if args.use_gt_textures:
                all_features = None
            else:
                all_features = []
                for obj_name in objs:
                    obj_feature = features[obj_name.item()]
                    random_index = randint(0, obj_feature.shape[0] - 1)
                    feat = torch.from_numpy(obj_feature[random_index, :]).type(
                        torch.float32).cuda()
                    all_features.append(feat)
            if not args.use_gt_attr:
                attributes = torch.zeros_like(attributes)

            # Run the model with predicted masks
            model_out = model(imgs,
                              objs,
                              triples,
                              obj_to_img,
                              boxes_gt=boxes,
                              masks_gt=masks_gt,
                              attributes=attributes,
                              test_mode=True,
                              use_gt_box=args.use_gt_boxes,
                              features=all_features)
            imgs_pred, boxes_pred, masks_pred, _, layout, _ = model_out

            if accuracy_model is not None:
                if args.use_gt_boxes:
                    crops = crop_bbox_batch(imgs_pred, boxes, obj_to_img, 224)
                else:
                    crops = crop_bbox_batch(imgs_pred, boxes_pred, obj_to_img,
                                            224)

                outputs = accuracy_model(crops)
                if type(outputs) == tuple:
                    outputs, _ = outputs
                _, preds = torch.max(outputs, 1)

                # statistics
                for pred, label in zip(preds, objs):
                    if label.item() != 0:
                        real_objects_count += 1
                        corrects += 1 if pred.item() == label.item() else 0

            # Remove the __image__ object
            boxes_pred_no_image = []
            boxes_gt_no_image = []
            for o_index in range(len(obj_to_img)):
                if o_index < len(obj_to_img) - 1 and obj_to_img[
                        o_index] == obj_to_img[o_index + 1]:
                    boxes_pred_no_image.append(boxes_pred[o_index])
                    boxes_gt_no_image.append(boxes[o_index])
            boxes_pred_no_image = torch.stack(boxes_pred_no_image)
            boxes_gt_no_image = torch.stack(boxes_gt_no_image)

            iou, bigger_05, bigger_03 = jaccard(boxes_pred_no_image,
                                                boxes_gt_no_image)
            total_iou += iou
            r_05 += bigger_05
            r_03 += bigger_03
            total_boxes += boxes_pred_no_image.size(0)
            imgs_pred = imagenet_deprocess_batch(imgs_pred)

            obj_data = [objs, boxes_pred, masks_pred]
            _, obj_data = split_graph_batch(triples, obj_data, obj_to_img,
                                            triple_to_img)
            objs, boxes_pred, masks_pred = obj_data

            obj_data_gt = [boxes.data]
            if masks is not None:
                obj_data_gt.append(masks.data)
            triples, obj_data_gt = split_graph_batch(triples, obj_data_gt,
                                                     obj_to_img, triple_to_img)
            layouts_3d = one_hot_to_rgb(layout, colors, num_objs)
            for i in range(imgs_pred.size(0)):
                img_filename = '%04d.png' % img_idx
                if args.save_gt_imgs:
                    img_gt = imgs_gt[i].numpy().transpose(1, 2, 0)
                    img_gt_path = os.path.join(gt_img_dir, img_filename)
                    imsave(img_gt_path, img_gt)
                if args.save_layout:
                    layout_3d = layouts_3d[i].numpy().transpose(1, 2, 0)
                    layout_path = os.path.join(layout_dir, img_filename)
                    imsave(layout_path, layout_3d)

                img_pred_np = imgs_pred[i].numpy().transpose(1, 2, 0)
                img_path = os.path.join(img_dir, img_filename)
                imsave(img_path, img_pred_np)

                if args.save_graphs:
                    graph_img = draw_scene_graph(objs[i], triples[i], vocab)
                    graph_path = os.path.join(graph_dir, img_filename)
                    imsave(graph_path, graph_img)

                img_idx += 1

            print('Saved %d images' % img_idx)
        avg_iou = total_iou / total_boxes
        print('avg_iou {}'.format(avg_iou.item()))
        print('r0.5 {}'.format(r_05 / total_boxes))
        print('r0.3 {}'.format(r_03 / total_boxes))
        if accuracy_model is not None:
            print('Accuracy {}'.format(corrects / real_objects_count))
def run_model(args, checkpoint, output_dir, loader=None):
    device = torch.device("cuda:0")
    vocab = checkpoint['model_kwargs']['vocab']
    model = build_model(args, checkpoint)
    if loader is None:
        loader = build_loader(args, checkpoint)

    img_dir = makedir(output_dir, 'images')
    graph_dir = makedir(output_dir, 'graphs', args.save_graphs)
    gt_img_dir = makedir(output_dir, 'images_gt', args.save_gt_imgs)
    data_path = os.path.join(output_dir, 'data.pt')

    data = {
        'vocab': vocab,
        'objs': [],
        'masks_pred': [],
        'boxes_pred': [],
        'masks_gt': [],
        'boxes_gt': [],
        'filenames': [],
    }

    f = open("./" + output_dir + "/result_ids.txt", "w")

    img_idx = 0

    for batch in loader:
        masks = None
        batch = [tensor.to(device) for tensor in batch]
        masks = None
        if len(batch) == 6:
            imgs, objs, boxes, triples, obj_to_img, triple_to_img = batch
        elif len(batch) == 7:
            imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img = batch
        elif len(batch) == 12:
            imgs, objs, boxes, triples, obj_to_img, triple_to_img, \
            objs_r, boxes_r, triples_r, obj_to_img_r, triple_to_img_r, imgs_in = batch
        elif len(batch) == 13:
            imgs, objs, boxes, triples, obj_to_img, triple_to_img, attributes, \
            objs_r, boxes_r, triples_r, obj_to_img_r, triple_to_img_r, imgs_in = batch
        else:
            assert False
            #triple_pos = batch[-1]

        #print(objs, triples)
        imgs_gt = imagenet_deprocess_batch(imgs)
        boxes_gt = None
        masks_gt = None
        if args.use_gt_boxes:
            boxes_gt = boxes
        if args.use_gt_masks:
            masks_gt = masks
        #print(imgs_in.shape)
        imgs_in_ = imagenet_deprocess_batch(imgs_in[:, :3, :, :])

        #print(objs)
        #print(triples)
        #print("triple pos: ", triple_pos)
        img_in_ = np.array(imgs_in_[0].numpy().transpose(1, 2, 0))
        mask = np.concatenate([
            imgs_in[:, 3:, :, :].detach().cpu().numpy(),
            imgs_in[:, 3:, :, :].detach().cpu().numpy(),
            imgs_in[:, 3:, :, :].detach().cpu().numpy()
        ], 1)
        mask = np.transpose(mask, [0, 2, 3, 1])
        mask = np.squeeze(mask, 0)
        #print(mask.shape, np.max(mask), np.min(mask))
        img_masked = (1 -
                      mask / 255) * img_in_[:, :, :3] + mask * np.ones_like(
                          img_in_[:, :, :3])
        temp = img_masked[:, :, 2]
        img_masked[:, :, 2] = img_masked[:, :, 0]
        img_masked[:, :, 0] = temp
        #cv2.imwrite("./output_sameid/images_gt/" + str(img_idx).zfill(4) + "_masked.png", img_masked)
        #img_gt = imgs_gt[0].numpy().transpose(1, 2, 0)
        #graph_img = draw_scene_graph(objs, triples, vocab)
        #cv2.imshow('graph', graph_img)
        #cv2.imshow('img', cv2.resize(img_gt, (128, 128)))
        #cv2.imshow('img masked', cv2.resize(img_in, (128, 128)))
        #k = cv2.waitKey(0)

        if True:  #k == ord('c'):

            #change the id of a node
            #print("enter new obj id: ")
            #id_node = input()
            #print("you entered: ", id_node)
            #objs[0] = torch.tensor(np.int64(int(id_node)), dtype=torch.long)
            # change a relationship

            #print("enter new relationship id: ")
            #id_edge = input()
            #print("you entered: ", id_edge)

            #if triple_pos != -1:

            #  triples[triple_pos, 1] = torch.tensor(np.int64(int(id_edge)), dtype=torch.long)
            #else:
            #  print("no relationship found")

            img_filename = '%04d_gt.png' % img_idx
            if args.save_graphs:
                graph_img = draw_scene_graph(objs, triples, vocab)
                graph_path = os.path.join(graph_dir, img_filename)
                imsave(graph_path, graph_img)

            target_predicate = 15  #31
            source_predicate = 31  #15

            valid_triples = []

            #mode = 'reposition'
            #mode = 'auto'
            assert args.mode in ['auto', 'reposition', 'replace', 'remove']
            mode = args.mode

            if mode == 'replace':

                if boxes_gt[0, 2] - boxes_gt[0, 0] < 0.1 or boxes_gt[
                        0, 3] - boxes_gt[0, 1] < 0.15:
                    img_idx += 1
                    continue
                new_ids = change_id_constrained(objs[0], boxes_gt[0])

            elif mode == 'reposition':

                #triple_pos = -1
                #obj_pos = -1

                for j in range(triples.size(0)):
                    # if image not one of the objects and predicate is the type we want
                    if triples[j,0] != objs.size(0)-1 and triples[j,2] != objs.size(0)-1 \
                            and triples[j,1] == source_predicate:
                        valid_triples.append(j)

                new_ids = valid_triples  #change_relationship(triples[triple_pos, 1])

            elif mode == 'remove':

                id_removed = objs[0].item()
                box_removed = boxes_gt[0]

                has_other_instance = False
                for iii in range(objs_r.shape[0]):
                    if objs[0] == objs_r[iii]:
                        # we want an image that contains no more instances of the removed category for the user study
                        has_other_instance = True


                if has_other_instance or \
                  box_removed[3] - box_removed[1] < 0.2 or \
                  box_removed[2] - box_removed[0] < 0.2 or \
                  (box_removed[3] - box_removed[1] > 0.8 and box_removed[2] - box_removed[0] > 0.8):

                    img_idx += 1
                    continue

                objs = objs_r
                boxes = boxes_r
                triples = triples_r
                obj_to_img = obj_to_img_r
                triple_to_img = triple_to_img_r

                new_ids = [objs[0]]

            else:  # auto

                new_ids = [objs[0]]

            query_feats = None

            if args.with_image_query:
                img, box = query_image_by_semantic_id(new_ids, img_idx, loader)
                query_feats = model.forward_visual_feats(img, box)

                img_filename_query = '%04d_query.png' % (img_idx)
                img = imagenet_deprocess_batch(img)
                img_np = img[0].numpy().transpose(1, 2, 0)
                img_path = os.path.join(img_dir, img_filename_query)
                imsave(img_path, img_np)

            img_subid = 0

            for obj_new_id in new_ids:
                boxes_gt = None
                masks_gt = None
                if args.use_gt_boxes:
                    boxes_gt = boxes
                if args.use_gt_masks:
                    masks_gt = masks

                drop_box_idx = torch.ones_like(objs.unsqueeze(1),
                                               dtype=torch.float)
                drop_feat_idx = torch.ones_like(objs.unsqueeze(1),
                                                dtype=torch.float)

                if mode == 'reposition':
                    #if len(valid_triples) == 0:
                    #  continue

                    #print("obj_pos ", obj_pos, triple_pos)
                    triples_changed = triples.clone()
                    triple_to_img_changed = triple_to_img.clone()

                    triples_changed[obj_new_id, 1] = torch.tensor(
                        np.int64(int(target_predicate), dtype=torch.long))
                    subject_node = triples_changed[obj_new_id, 0]
                    object_node = triples_changed[obj_new_id, 2]

                    indexes = []
                    print("subject, object ", subject_node, object_node)

                    for t_index in range(triples_changed.size(0)):

                        if triples_changed[t_index, 1] == source_predicate and (triples_changed[t_index, 0] == subject_node  \
                              or triples_changed[t_index, 2] == object_node) and obj_new_id != t_index:
                            indexes.append(t_index)
                    if len(indexes) > 0:
                        triples_changed, triple_to_img_changed = remove_dub(
                            triples_changed, triple_to_img_changed, indexes)

                    img_gt_filename = '%04d_gt.png' % (img_idx)
                    img_pred_filename = '%04d_%d_64_norel_auto.png' % (
                        img_idx, img_subid)
                    img_filename_noised = '%04d_%d_64_noise_norel_auto.png' % (
                        img_idx, img_subid)

                    triples_ = triples_changed
                    triple_to_img_ = triple_to_img_changed

                    if not args.drop_obj_only:
                        drop_box_idx[subject_node] = 0
                    if not args.drop_subj_only:
                        drop_box_idx[object_node] = 0

                else:

                    objs[0] = torch.tensor(np.int64(int(obj_new_id)),
                                           dtype=torch.long)
                    #drop_box_idx[0] = 0
                    #drop_feat_idx =
                    obj_pos = -1

                    img_gt_filename = '%04d_%d_gt.png' % (img_idx, img_subid)
                    img_pred_filename = '%04d_%d_64.png' % (img_idx, img_subid)
                    img_filename_noised = '%04d_%d_64.png' % (img_idx,
                                                              img_subid)

                    triples_ = triples
                    triple_to_img_ = triple_to_img

                    subject_node = 0

                    if mode == 'replace':
                        drop_feat_idx[subject_node] = 0
                        # TODO with combined or pred box?

                    if mode == 'auto':
                        if not args.with_image_query:
                            drop_feat_idx[subject_node] = 0

                    # if mode is remove, do nothing
                #imgs = None
                triples_new = []
                for t in triples:
                    s, p, o = t
                    if p != 0:
                        triples_new.append(t)
                triples = torch.stack(triples_new, 0)
                objs[-1] = objs[-2]
                boxes[:, -1] = boxes[:, -2]
                attributes[:, -1] = attributes[:, -2]
                print(attributes.shape, objs.shape)
                model_out = model(imgs,
                                  objs,
                                  triples,
                                  obj_to_img,
                                  boxes_gt=boxes,
                                  masks_gt=masks_gt,
                                  attributes=attributes,
                                  gt_train=False,
                                  test_mode=False,
                                  use_gt_box=True,
                                  features=None,
                                  drop_box_idx=drop_box_idx,
                                  drop_feat_idx=drop_feat_idx,
                                  src_image=imgs_in)

                #model(objs, triples_, obj_to_img,
                #  boxes_gt=boxes_gt, masks_gt=masks_gt, src_image=imgs_in, mode=args.mode,
                #  query_feats=query_feats, drop_box_idx=drop_box_idx, drop_feat_idx=drop_feat_idx)

                imgs_pred, boxes_pred, masks_pred, _, _, _ = model_out

                # modify bboxes
                #boxes_combined = boxes_gt #combine_boxes(boxes_gt, boxes_pred)
                #model_out = model(objs, triples_, obj_to_img,
                #                  boxes_gt=boxes_combined, masks_gt=masks_gt, src_image=imgs_in)
                #imgs_pred, _, _, _, _ = model_out

                imgs_pred = imagenet_deprocess_batch(imgs_pred)

                #noised_srcs = imagenet_deprocess_batch(noised_srcs)

                obj_data = [objs, boxes_pred, masks_pred]
                _, obj_data = split_graph_batch(triples_, obj_data, obj_to_img,
                                                triple_to_img_)
                objs, boxes_pred, masks_pred = obj_data

                obj_data_gt = [boxes.data]
                if masks is not None:
                    obj_data_gt.append(masks.data)
                triples_, obj_data_gt = split_graph_batch(
                    triples_, obj_data_gt, obj_to_img, triple_to_img_)

                objs = torch.cat(objs)
                triples_ = torch.cat(triples_)

                boxes_gt, masks_gt = obj_data_gt[0], None
                if masks is not None:
                    masks_gt = obj_data_gt[1]

                for i in range(imgs_pred.size(0)):

                    if args.save_gt_imgs:
                        img_gt = imgs_gt[i].numpy().transpose(1, 2, 0)
                        img_gt_path = os.path.join(gt_img_dir, img_gt_filename)
                        imsave(img_gt_path, img_gt)

                    userStudy = False
                    # user study ----------------------------------------------------------------------
                    if mode == 'replace':

                        img_pred_filename = '%04d_%d.png' % \
                                            (img_idx, img_subid)

                        f.write(
                            str(img_idx) + "_" + str(img_subid) + " " +
                            vocab['object_idx_to_name'][objs[0].item()] + "\n")

                        img_pred_np = imgs_pred[i].numpy().transpose(1, 2, 0)
                        img_path = os.path.join(img_dir, img_pred_filename)
                        #imsave(img_path, img_pred_np)

                        if userStudy:
                            img_pred_np = cv2.resize(img_pred_np, (128, 128))
                            img_gt = imgs_gt[i].numpy().transpose(1, 2, 0)

                            img_gt = cv2.resize(img_gt, (128, 128))

                            wspace = np.zeros([img_pred_np.shape[0], 10, 3])
                            text = np.zeros(
                                [30, img_pred_np.shape[1] * 2 + 10, 3])

                            text = cv2.putText(text,
                                               "   Before            After",
                                               (17, 25),
                                               cv2.FONT_HERSHEY_SIMPLEX,
                                               0.5, (255, 255, 255),
                                               lineType=cv2.LINE_AA)

                            img_pred_gt = np.concatenate(
                                [img_gt, wspace, img_pred_np],
                                axis=1).astype('uint8')
                            img_pred_gt = np.concatenate(
                                [text, img_pred_gt], axis=0).astype('uint8')
                            imsave(img_path, img_pred_gt)
                        else:
                            imsave(img_path, img_pred_np)

                    elif mode == 'remove':

                        img_pred_filename = '%04d_%d.png' % \
                                            (img_idx, img_subid)

                        f.write(
                            str(img_idx) + "_" + str(img_subid) + " " +
                            vocab['object_idx_to_name'][id_removed] + "\n")

                        img_pred_np = imgs_pred[i].numpy().transpose(1, 2, 0)
                        img_path = os.path.join(img_dir, img_pred_filename)
                        #imsave(img_path, img_pred_np)

                        if userStudy:
                            img_pred_np = cv2.resize(img_pred_np, (128, 128))
                            img_gt = imgs_gt[i].numpy().transpose(1, 2, 0)

                            img_gt = cv2.resize(img_gt, (128, 128))

                            wspace = np.zeros([img_pred_np.shape[0], 90, 3])
                            text = np.zeros(
                                [30, img_pred_np.shape[1] + 2 * 90, 3])

                            text = cv2.putText(
                                text,
                                "Is there a " +
                                vocab['object_idx_to_name'][id_removed] +
                                " in the image?", (17, 20),
                                cv2.FONT_HERSHEY_SIMPLEX,
                                0.5, (255, 255, 255),
                                lineType=cv2.LINE_AA)

                            img_pred_gt = np.concatenate(
                                [wspace, img_pred_np, wspace],
                                axis=1).astype('uint8')
                            img_pred_gt = np.concatenate(
                                [text, img_pred_gt], axis=0).astype('uint8')
                            imsave(img_path, img_pred_gt)
                        else:
                            imsave(img_path, img_pred_np)

                    # ---------------------------------------------------------------------------------
                    else:
                        #print(vocab['pred_idx_to_name'][target_predicate])
                        img_pred_np = imgs_pred[i].numpy().transpose(1, 2, 0)
                        img_path = os.path.join(img_dir, img_pred_filename)
                        imsave(img_path, img_pred_np)
                    #noised_src_np = noised_srcs[i,:3,:,:].numpy().transpose(1, 2, 0)
                    #img_path_noised = os.path.join(img_dir, img_filename_noised)
                    #imsave(img_path_noised, noised_src_np)

                    data['objs'].append(objs[i].cpu().clone())
                    data['masks_pred'].append(masks_pred[i].cpu().clone())
                    data['boxes_pred'].append(boxes_pred[i].cpu().clone())
                    data['boxes_gt'].append(boxes_gt[i].cpu().clone())
                    data['filenames'].append(img_filename)

                    cur_masks_gt = None
                    if masks_gt is not None:
                        cur_masks_gt = masks_gt[i].cpu().clone()
                    data['masks_gt'].append(cur_masks_gt)
                    #print(objs[i], objs)
                    if args.save_graphs:
                        graph_img = draw_scene_graph(objs, triples_, vocab)
                        graph_path = os.path.join(graph_dir, img_pred_filename)
                        imsave(graph_path, graph_img)

                img_subid += 1

            img_idx += 1

            torch.save(data, data_path)
            print('Saved %d images' % img_idx)

    f.close()
Ejemplo n.º 6
0
def run_model(args, checkpoint, output_dir, loader=None):
    dirname = os.path.dirname(args.checkpoint)
    features = None
    if args.sample_features:
        features_path = os.path.join(dirname, 'features_clustered_001.npy')
        print(features_path)
        if os.path.isfile(features_path):
            features = np.load(features_path).item()
        else:
            raise ValueError('No features file')
    with torch.no_grad():
        vocab = checkpoint['model_kwargs']['vocab']
        model = build_model(args, checkpoint)
        if loader is None:
            loader = build_loader(args, checkpoint)

        img_dir = makedir(output_dir, 'images')
        graph_dir = makedir(output_dir, 'graphs', args.save_graphs)
        gt_img_dir = makedir(output_dir, 'images_gt', args.save_gt_imgs)
        layout_dir = makedir(output_dir, 'layouts', args.save_layout)

        img_idx = 0
        total_iou = 0
        total_boxes = 0
        r_05 = 0
        r_03 = 0
        num_objs = model.num_objs
        colors = torch.randint(0, 256, [num_objs, 3]).float()
        for batch in loader:
            imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img, attributes = [
                x.cuda() for x in batch
            ]

            imgs_gt = imagenet_deprocess_batch(imgs)
            masks_gt = None
            gt_train = False

            if args.use_gt_masks:
                masks_gt = masks
            if args.use_gt_textures:
                gt_train = True
            if not args.use_gt_attr:
                attributes = torch.zeros_like(attributes)

            if features is not None:
                all_features = []
                for obj_name in objs:
                    obj_feature = features[obj_name.item()]
                    random_index = randint(0, obj_feature.shape[0] - 1)
                    feat = torch.from_numpy(obj_feature[random_index, :]).type(
                        torch.float32).cuda()
                    all_features.append(feat)
            else:
                all_features = None
            # Run the model with predicted masks
            model_out = model(imgs,
                              objs,
                              triples,
                              obj_to_img,
                              boxes_gt=boxes,
                              masks_gt=masks_gt,
                              attributes=attributes,
                              gt_train=gt_train,
                              test_mode=True,
                              use_gt_box=args.use_gt_boxes,
                              features=all_features)
            imgs_pred, boxes_pred, masks_pred, _, layout, _ = model_out

            iou, bigger_05, bigger_03 = jaccard(boxes_pred, boxes)
            total_iou += iou
            r_05 += bigger_05
            r_03 += bigger_03
            total_boxes += boxes_pred.size(0)
            imgs_pred = imagenet_deprocess_batch(imgs_pred)

            obj_data = [objs, boxes_pred, masks_pred]
            _, obj_data = split_graph_batch(triples, obj_data, obj_to_img,
                                            triple_to_img)
            objs, boxes_pred, masks_pred = obj_data

            obj_data_gt = [boxes.data]
            if masks is not None:
                obj_data_gt.append(masks.data)
            triples, obj_data_gt = split_graph_batch(triples, obj_data_gt,
                                                     obj_to_img, triple_to_img)
            boxes_gt, masks_gt = obj_data_gt[0], None
            if masks is not None:
                masks_gt = obj_data_gt[1]
            layouts_3d = one_hot_to_rgb(layout, colors, num_objs)
            for i in range(imgs_pred.size(0)):
                img_filename = '%04d.png' % img_idx
                if args.save_gt_imgs:
                    img_gt = imgs_gt[i].numpy().transpose(1, 2, 0)
                    img_gt_path = os.path.join(gt_img_dir, img_filename)
                    imsave(img_gt_path, img_gt)
                if args.save_layout:
                    layout_3d = layouts_3d[i].numpy().transpose(1, 2, 0)
                    layout_path = os.path.join(layout_dir, img_filename)
                    imsave(layout_path, layout_3d)

                img_pred_np = imgs_pred[i].numpy().transpose(1, 2, 0)
                img_path = os.path.join(img_dir, img_filename)
                imsave(img_path, img_pred_np)

                cur_masks_gt = None
                if masks_gt is not None:
                    cur_masks_gt = masks_gt[i].cpu().clone()

                if args.save_graphs:
                    graph_img = draw_scene_graph(objs[i], triples[i], vocab)
                    graph_path = os.path.join(graph_dir, img_filename)
                    imsave(graph_path, graph_img)

                img_idx += 1

            print('Saved %d images' % img_idx)
        avg_iou = total_iou / total_boxes
        print(avg_iou)
        print('r0.5 {}'.format(r_05 / total_boxes))
        print('r0.3 {}'.format(r_03 / total_boxes))
def eval_model(args,
               model,
               loader,
               device,
               use_gt=False,
               use_feats=False,
               filter_box=False):
    all_losses = defaultdict(list)
    all_boxes = defaultdict(list)
    total_iou = []
    total_boxes = 0
    num_batches = 0
    num_samples = 0
    mae_per_image = []
    mae_roi_per_image = []
    roi_only_iou = []
    ssim_per_image = []
    ssim_rois = []
    rois = 0
    margin = 2

    ## Initializing the perceptual loss model
    lpips_model = models.PerceptualLoss(model='net-lin',
                                        net='alex',
                                        use_gpu=True)
    perceptual_error_image = []
    perceptual_error_roi = []
    # ---------------------------------------

    with torch.no_grad():
        for batch in tqdm.tqdm(loader):
            num_batches += 1
            # if num_batches > 10:
            #     break
            batch = [tensor.to(device) for tensor in batch]
            masks = None
            if len(batch) == 6:
                imgs, objs, boxes, triples, obj_to_img, triple_to_img = batch
            elif len(batch) == 7:
                imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img = batch
            elif len(batch) == 12:
                imgs, objs, boxes, triples, obj_to_img, triple_to_img, \
                objs_r, boxes_r, triples_r, obj_to_img_r, triple_to_img_r, imgs_in = batch
            elif len(batch) == 13:
                imgs, objs, boxes, triples, obj_to_img, triple_to_img, attributes, \
                objs_r, boxes_r, triples_r, obj_to_img_r, triple_to_img_r, imgs_in = batch
            else:
                assert False
            predicates = triples[:, 1]

            # #EVAL_ALL = True
            if EVAL_ALL:
                imgs, imgs_in, objs, boxes, triples, obj_to_img, \
                dropbox_indices, dropfeats_indices = process_batch(
                    imgs, imgs_in, objs, boxes, triples, obj_to_img, triple_to_img, device,
                    use_feats=use_feats, filter_box=filter_box)
            else:
                dropbox_indices = None
                dropfeats_indices = None
            #
            # if use_gt: # gt boxes
            #     model_out = model(objs, triples, obj_to_img, boxes_gt=boxes, masks_gt=masks, src_image=imgs_in,
            #                       drop_box_idx=None, drop_feat_idx=dropfeats_indices, mode='eval')
            # else:
            #     model_out = model(objs, triples, obj_to_img, boxes_gt=boxes, src_image=imgs_in,
            #                       drop_box_idx=dropbox_indices, drop_feats_idx=dropfeats_indices, mode='eval')

            masks_gt = None
            gt_train = False

            attributes = torch.zeros_like(attributes)

            all_features = None
            # Run the model with predicted masks
            model_out = model(imgs,
                              objs,
                              triples,
                              obj_to_img,
                              boxes_gt=boxes,
                              masks_gt=masks_gt,
                              attributes=attributes,
                              gt_train=gt_train,
                              test_mode=False,
                              use_gt_box=True,
                              features=all_features,
                              drop_box_idx=dropbox_indices,
                              drop_feat_idx=dropfeats_indices,
                              src_image=imgs_in)
            #imgs_pred, boxes_pred, masks_pred, _, layout, _ = model_out

            # OUTPUT
            imgs_pred, boxes_pred, masks_pred, predicate_scores, layout, _ = model_out
            # --------------------------------------------------------------------------------------------------------------
            #imgs_pred *= 3
            #print(imgs_pred.min(), imgs_pred.max())

            # Save all box predictions
            all_boxes['boxes_gt'].append(boxes)
            all_boxes['objs'].append(objs)
            all_boxes['boxes_pred'].append(boxes_pred)
            all_boxes['drop_targets'].append(dropbox_indices)

            # IoU over all
            total_iou.append(jaccard(boxes_pred,
                                     boxes).cpu().numpy())  #.detach()
            total_boxes += boxes_pred.size(0)

            # IoU over targets only
            pred_dropbox = boxes_pred[dropbox_indices.squeeze() == 0, :]
            gt_dropbox = boxes[dropbox_indices.squeeze() == 0, :]
            roi_only_iou.append(
                jaccard(pred_dropbox, gt_dropbox).detach().cpu().numpy())
            rois += pred_dropbox.size(0)
            # assert(pred_dropbox.size(0) == imgs.size(0))

            num_samples += imgs.shape[0]
            imgs = imagenet_deprocess_batch(imgs).float()
            imgs_pred = imagenet_deprocess_batch(imgs_pred).float()

            # Uncomment to plot images (for debugging purposes)
            #visualize_imgs_boxes(imgs, imgs_pred, boxes, boxes)

            # MAE per image
            mae_per_image.append(
                torch.mean(
                    torch.abs(imgs - imgs_pred).view(imgs.shape[0], -1),
                    1).cpu().numpy())

            for s in range(imgs.shape[0]):
                # get coordinates of target
                left, right, top, bottom = bbox_coordinates_with_margin(
                    boxes[s, :], margin, imgs)
                # calculate errors only in RoI one by one - good, i wanted to check this too since the errors were suspicious pheww
                mae_roi_per_image.append(
                    torch.mean(
                        torch.abs(imgs[s, :, top:bottom, left:right] -
                                  imgs_pred[s, :, top:bottom,
                                            left:right])).cpu().item())

                ssim_per_image.append(
                    pytorch_ssim.ssim(imgs[s:s + 1, :, :, :] / 255.0,
                                      imgs_pred[s:s + 1, :, :, :] / 255.0,
                                      window_size=3).cpu().item())
                ssim_rois.append(
                    pytorch_ssim.ssim(
                        imgs[s:s + 1, :, top:bottom, left:right] / 255.0,
                        imgs_pred[s:s + 1, :, top:bottom, left:right] / 255.0,
                        window_size=3).cpu().item())

                imgs_pred_norm = imgs_pred[
                    s:s +
                    1, :, :, :] / 127.5 - 1  # = util.im2tensor(imgs_pred[s:s+1, :, :, :].detach().cpu().numpy())
                imgs_gt_norm = imgs[
                    s:s +
                    1, :, :, :] / 127.5 - 1  # util.im2tensor(imgs[s:s+1, :, :, :].detach().cpu().numpy())

                #perceptual_error_roi.append(lpips_model.forward(imgs_pred_norm[:,:, top:bottom, left:right],
                #                                                  imgs_gt_norm[:,:, top:bottom, left:right]))

                #print(imgs_pred_norm.shape)
                perceptual_error_image.append(
                    lpips_model.forward(imgs_pred_norm,
                                        imgs_gt_norm).detach().cpu().numpy())

            if num_batches % PRINT_EVERY == 0:
                calculate_scores(mae_per_image, mae_roi_per_image, total_iou,
                                 roi_only_iou, ssim_per_image, ssim_rois,
                                 perceptual_error_image, perceptual_error_roi)

            if num_batches % SAVE_EVERY == 0:
                save_results(mae_per_image, mae_roi_per_image, total_iou,
                             roi_only_iou, ssim_per_image, ssim_rois,
                             perceptual_error_image, perceptual_error_roi,
                             all_boxes, num_batches)

    # mean_losses = {k: np.mean(v) for k, v in all_losses.items()}

    save_results(mae_per_image, mae_roi_per_image, total_iou, roi_only_iou,
                 ssim_per_image, ssim_rois, perceptual_error_image,
                 perceptual_error_roi, all_boxes, 'final')

    # masks_to_store = masks
    # if masks_to_store is not None:
    #     masks_to_store = masks_to_store.data.cpu().clone()

    # masks_pred_to_store = masks_pred
    # if masks_pred_to_store is not None:
    #     masks_pred_to_store = masks_pred_to_store.data.cpu().clone()

    # batch_data = {
    #     'objs': objs.detach().cpu().clone(),
    #     'boxes_gt': boxes.detach().cpu().clone(),
    #     'masks_gt': masks_to_store,
    #     'triples': triples.detach().cpu().clone(),
    #     'obj_to_img': obj_to_img.detach().cpu().clone(),
    #     'triple_to_img': triple_to_img.detach().cpu().clone(),
    #     'boxes_pred': boxes_pred.detach().cpu().clone(),
    #     'masks_pred': masks_pred_to_store
    # }
    # out = [mean_losses, samples, batch_data, avg_iou]
    # out = [mean_losses, mean_L1, avg_iou]

    return  # mae_per_image, mae_roi_per_image, total_iou, roi_only_iou