示例#1
0
def main(args):
    if not os.path.isfile(args.checkpoint):
        print('ERROR: Checkpoint file "%s" not found' % args.checkpoint)
        print('Maybe you forgot to download pretraind models? Try running:')
        print('bash scripts/download_models.sh')
        return

    if args.device == 'cpu':
        device = torch.device('cpu')
    elif args.device == 'gpu':
        device = torch.device('cuda:%s'%args.gpu_ids[0])
        if not torch.cuda.is_available():
            print('WARNING: CUDA not available; falling back to CPU')
            device = torch.device('cpu')
    else:
        device = torch.device('cuda:{gpu}'.format(gpu=args.device))
        if not torch.cuda.is_available():
            print('WARNING: CUDA not available; falling back to CPU')
            device = torch.device('cpu')

    # Load the model, with a bit of care in case there are no GPUs
    map_location = 'cpu' if device == torch.device('cpu') else device
    print("loading: %s"%args.checkpoint)
    checkpoint = torch.load(args.checkpoint, map_location='cpu')

    # Model
    args.skip_graph_model = 1
    args.skip_generation = 0
    model = MetaGeneratorModel(args, device)
    model.load_state_dict(checkpoint['model_state'], strict=True)
    model.eval()
    model.to(device)
    df = pd.read_csv(os.path.join(args.base_dir, "results_objs.csv"))
    run_args = json.load(open(os.path.join(args.base_dir, "run_args.json"), "r"))
    vocab = run_args["vocab"]

    gen_run_args = json.load(open(os.path.join(os.path.dirname(args.checkpoint), "run_args.json"), "r"))
    gen_vocab = gen_run_args["vocab"]

    with torch.no_grad():
        for i, row in df.iterrows():

            # bbox[:, 2] = bbox[:, 0] + bbox[:, 2]
            # bbox[:, 3] = bbox[:, 1] + bbox[:, 3]
            # bbox = np.concatenate([bbox, [[-0.6, -0.6, 0.5, 0.5]]], axis=0)
            bbox = np.array(eval(row['predicted_boxes']))
            bbox = torch.FloatTensor(bbox).unsqueeze(0)

            object_class = eval(row['class'])
            labels = torch.LongTensor([gen_vocab["object_name_to_idx"][c] for c in object_class if c != "__image__"]).unsqueeze(0).unsqueeze(-1)

            model_out = model.layout_to_image_model.forward(None, labels, bbox, None, test_mode=True)
            img_pred = model_out
            image = deprocess_batch(img_pred, deprocess_func=decode_image)[0]
            image = np.transpose(image.cpu().numpy(), [1, 2, 0])
            Image.fromarray(image).save(os.path.join(args.base_dir, 'samples_roei', os.path.basename(row['image_id'])))
def draw_datasets(samples, output_dir, deprocess_func, image_ids):
    for k, v in samples.items():
        samples[k] = np.transpose(deprocess_batch(v, deprocess_func=deprocess_func).cpu().numpy(), [0, 2, 3, 1])
    for k, v in samples.items():
        # Set the output path
        if k == 'gt_img':
            path = os.path.join(output_dir, "gt")
        else:
            path = os.path.join(output_dir, "generation", k)

        os.makedirs(path, exist_ok=True)
        for i in range(v.shape[0]):
            RGB_img_i = cv2.cvtColor(v[i], cv2.COLOR_BGR2RGB)
            cv2.imwrite("{}/{}.jpg".format(path, image_ids[i]), RGB_img_i)
示例#3
0
def check_model(args,
                loader,
                model,
                gans_model,
                inception_score,
                use_gt=True,
                full_test=False):
    model.eval()
    num_samples = 0
    all_losses = defaultdict(list)
    total_iou = 0.
    total_iou_masks = 0.
    total_iou_05 = 0.
    total_iou_03 = 0.
    total_boxes = 0.
    inception_score.clean()
    image_df = {
        'image_id': [],
        'avg_iou': [],
        'iou03': [],
        'iou05': [],
        "predicted_boxes": [],
        "gt_boxes": [],
        "number_of_objects": [],
        "class": []
    }
    with torch.no_grad():
        for batch in loader:
            try:
                batch = batch_to(batch)
                imgs, objs, boxes, triplets, _, triplet_type, masks, image_ids = batch

                # Run the model as it has been run during training
                if use_gt:
                    model_out = model(objs,
                                      triplets,
                                      triplet_type,
                                      boxes_gt=boxes,
                                      masks_gt=masks,
                                      test_mode=True)
                else:
                    model_out = model(objs,
                                      triplets,
                                      triplet_type,
                                      test_mode=True)
                imgs_pred, boxes_pred, masks_pred = model_out
                G_losses = gans_model(batch,
                                      model_out,
                                      mode='compute_generator_loss')

                if boxes_pred is not None:
                    boxes_pred = torch.clamp(boxes_pred, 0., 1.)
                if imgs_pred is not None:
                    inception_score(imgs_pred)

                if not args.skip_graph_model:
                    image_df['image_id'].extend(image_ids)

                    for i in range(boxes.size(0)):
                        # masks_sample = masks[i]
                        # masks_pred_sample = masks_pred[i]
                        boxes_sample = boxes[i]
                        boxes_pred_sample = boxes_pred[i]
                        boxes_pred_sample, boxes_sample = \
                            remove_dummies_and_padding(boxes_sample, objs[i], args.vocab,
                                                       [boxes_pred_sample, boxes_sample])
                        iou, iou05, iou03 = jaccard(boxes_pred_sample,
                                                    boxes_sample)
                        # iou_masks = jaccard_masks(masks_pred_sample, masks_sample)
                        total_iou += iou.sum()
                        # total_iou_masks += iou_masks.sum()
                        total_iou_05 += iou05.sum()
                        total_iou_03 += iou03.sum()
                        total_boxes += float(iou.shape[0])

                        image_df['avg_iou'].append(np.mean(iou))
                        image_df['iou03'].append(np.mean(iou03))
                        image_df['iou05'].append(np.mean(iou03))
                        image_df['predicted_boxes'].append(
                            str(boxes_pred_sample.cpu().numpy().tolist()))
                        image_df['gt_boxes'].append(
                            str(boxes_sample.cpu().numpy().tolist()))
                        image_df["number_of_objects"].append(len(objs[i]))
                        if objs.shape[-1] == 1:
                            image_df["class"].append(
                                str([
                                    args.vocab["object_idx_to_name"][obj_index]
                                    for obj_index in objs[i]
                                ]))
                        else:
                            image_df["class"].append(
                                str([
                                    args.vocab["reverse_attributes"]['shape'][
                                        str(int(objs[i][obj_index][2]))]
                                    for obj_index in range(objs[i].shape[0])
                                ]))

                for loss_name, loss_val in G_losses.items():
                    all_losses[loss_name].append(loss_val)

                num_samples += imgs.size(0)
                if not full_test and args.num_val_samples and num_samples >= args.num_val_samples:
                    break
            except Exception as e:
                print("Error in {}".format(str(e)))

        samples = {}
        if not args.skip_generation and not args.skip_graph_model:
            samples['pred_box_pred_mask'] = model(objs,
                                                  triplets,
                                                  triplet_type,
                                                  test_mode=True)[0]
            samples['pred_box_gt_mask'] = model(objs,
                                                triplets,
                                                triplet_type,
                                                masks_gt=masks,
                                                test_mode=True)[0]

        if not args.skip_generation:
            samples['gt_img'] = imgs
            samples['gt_box_gt_mask'] = \
                model(objs, triplets, triplet_type, boxes_gt=boxes, masks_gt=masks, test_mode=True)[0]
            samples['gt_box_pred_mask'] = model(objs,
                                                triplets,
                                                triplet_type,
                                                boxes_gt=boxes,
                                                test_mode=True)[0]

            for k, v in samples.items():
                samples[k] = np.transpose(
                    deprocess_batch(
                        v, deprocess_func=args.deprocess_func).cpu().numpy(),
                    [0, 2, 3, 1])

        mean_losses = {
            k: torch.stack(v).mean()
            for k, v in all_losses.items() if k != 'bbox_pred_all'
        }
        if not args.skip_graph_model:
            mean_losses.update({
                'avg_iou': total_iou / total_boxes,
                'total_iou_05': total_iou_05 / total_boxes,
                'total_iou_03': total_iou_03 / total_boxes
            })
            mean_losses.update({'inception_mean': 0.0})
            mean_losses.update({'inception_std': 0.0})

        if not args.skip_generation:
            inception_mean, inception_std = inception_score.compute_score(
                splits=5)
            mean_losses.update({'inception_mean': inception_mean})
            mean_losses.update({'inception_std': inception_std})

    model.train()
    return mean_losses, samples, pd.DataFrame.from_dict(image_df)
示例#4
0
if __name__ == "__main__":
    dset = CocoSceneGraphDataset(
        image_dir="/home/roeiherz/Datasets/MSCoco/images/val2017",
        instances_json=
        "/home/roeiherz/Datasets/MSCoco/annotations/instances_val2017.json",
        stuff_json=
        "/home/roeiherz/Datasets/MSCoco/annotations/stuff_val2017.json",
        stuff_only=True,
        image_size=(256, 256),
        normalize_images=True,
        max_samples=None,
        include_relationships=True,
        min_object_size=0.02,
        min_objects=3,
        max_objects=8,
        include_other=False,
        instance_whitelist=None,
        stuff_whitelist=None)
    idx = 100
    item = dset[idx]
    image, objs, boxes, triplets = item
    image = deprocess_batch(torch.unsqueeze(image, 0),
                            deprocess_func=decode_image)[0]
    cv2.imwrite('img.png', np.transpose(image.cpu().numpy(), [1, 2, 0]))
    objs_text = [
        dset.vocab['object_idx_to_name'][k]
        for k in objs['objects'].cpu().numpy()
    ]
    draw_item(item, image_size=dset.image_size, text=objs_text)
    print(item)