def save_class_images(dataset, tgt_image_path, extension=".jpg"):
    db = {"cids":[], "cluster":[], "gtbboxid":[], "classid":[], "imageid":[], "difficult":[], "type":[], "size":[], "bbox":[]}

    for lbl, label_image in tqdm(dataset.gt_images_per_classid.items()):
        # create the file name to be used with cirtorch.datasets.datahelpers.cid2filename and their dataloader
        cid = "lbl{label:05d}{box_type}".format(label = lbl, box_type="CL")
        file_name = cid2filename(cid, prefix=tgt_image_path)

        # save the image
        image_path, _ = os.path.split(file_name)
        mkdir(image_path)
        if extension:
            label_image.save("{}{}".format(file_name, extension))
        else:
            # cirtorch uses files with empty extension for training for some reason, need to support that
            label_image.save("{}".format(file_name), format="jpeg")

        width, height = label_image.size
        box = [0, 0, width, height]  # format (x1,y1,x2,y2)

        # add to the db structure
        db["cids"].append(cid)
        db["cluster"].append(lbl)  # use labels as clusters not to sample negatives from the same object
        db["classid"].append(lbl)
        db["gtbboxid"].append(None)
        db["imageid"].append(None)
        db["difficult"].append(None)
        db["type"].append("classimage")
        db["size"].append(label_image.size)
        db["bbox"].append(box)  # format (x1,y1,x2,y2)

    return db
def main():
    args = parse_opts()
    set_random_seed(args.random_seed)

    logger_name = "retrieval_data"
    retrieval_dataset_name_suffix = "-retrieval"
    logger = setup_logger(logger_name, None)
    data_path = get_data_path()

    script_path = os.path.expanduser(os.path.dirname(os.path.abspath(__file__)))
    target_path = os.path.join(script_path, "cnnimageretrieval-pytorch", "data")
    mkdir(target_path)

    dataset_train = build_dataset_by_name(data_path, args.dataset_train,
                                        eval_scale=args.dataset_train_scale,
                                        logger_prefix=logger_name)
    retrieval_dataset_train_name = dataset_train.get_name() + retrieval_dataset_name_suffix

    dataset_val = build_dataset_by_name(data_path, args.dataset_val,
                                        eval_scale=args.dataset_val_scale,
                                        logger_prefix=logger_name)
    retrieval_dataset_val_name = dataset_val.get_name() + retrieval_dataset_name_suffix


    datasets_test = []
    retrieval_dataset_test_names = []
    if args.datasets_test:
        if len(args.datasets_test_scale) == 1:
            datasets_test_scale = args.datasets_test_scale * len(args.datasets_test)
        else:
            datasets_test_scale = args.datasets_test_scale
        assert len(args.datasets_test) == len(datasets_test_scale), "Arg datasets-test-scale should be of len 1 or of len equal to the len of datasets-test"

        for dataset_name, scale in zip(args.datasets_test, datasets_test_scale):
            dataset = build_dataset_by_name(data_path, dataset_name,
                                            eval_scale=scale,
                                            logger_prefix=logger_name)
            retrieval_dataset_test_names.append(dataset.get_name() + retrieval_dataset_name_suffix)
            datasets_test.append(dataset)

    # create dataset
    if args.num_random_crops_per_image > 0:
        crop_suffix = f"-rndCropPerImage{args.num_random_crops_per_image}"
        retrieval_dataset_train_name = retrieval_dataset_train_name + crop_suffix
        retrieval_dataset_val_name = retrieval_dataset_val_name + crop_suffix
        retrieval_dataset_test_names = [name + crop_suffix for name in retrieval_dataset_test_names]

    prepare_dataset(target_path,
                    retrieval_dataset_train_name, dataset_train,
                    retrieval_dataset_val_name, dataset_val,
                    args.iou_pos_threshold, args.iou_neg_threshold,
                    args.num_queries_image_to_image,
                    logger,
                    retrieval_dataset_test_names=retrieval_dataset_test_names, datasets_test=datasets_test,
                    num_random_crops_per_image=args.num_random_crops_per_image)
Ejemplo n.º 3
0
def init_logger(args, logger_prefix="detector-retrieval"):
    if args.output_dir:
        mkdir(args.output_dir)

    logger = setup_logger(logger_prefix, args.output_dir if args.output_dir else None)

    if args.config_file:
        with open(args.config_file, "r") as cf:
            config_str = "\n" + cf.read()
            logger.info(config_str)
        logger.info("Running with the default eval section:\n{}".format(cfg.eval))
    else:
        logger.info("Launched with no OS2D config file")
    logger.info("Running args:\n{}".format(args))
    return logger
Ejemplo n.º 4
0
def init_logger(cfg, config_file):
    output_dir = cfg.output.path
    if output_dir:
        mkdir(output_dir)

    logger = setup_logger("OS2D",
                          output_dir if cfg.output.save_log_to_file else None)

    if config_file:
        logger.info("Loaded configuration file {}".format(config_file))
        with open(config_file, "r") as cf:
            config_str = "\n" + cf.read()
            logger.info(config_str)
    else:
        logger.info("Config file was not provided")

    logger.info("Running with config:\n{}".format(cfg))

    # save config file only when training (to run multiple evaluations in the same folder)
    if output_dir and cfg.train.do_training:
        output_config_path = os.path.join(output_dir, "config.yml")
        logger.info("Saving config into: {}".format(output_config_path))
        # save overloaded model config in the output directory
        save_config(cfg, output_config_path)
Ejemplo n.º 5
0
def build_instre_dataset(data_path,
                         name,
                         eval_scale=None,
                         cache_images=False,
                         no_image_reading=False,
                         logger_prefix="OS2D"):
    logger = logging.getLogger(f"{logger_prefix}.dataset")
    logger.info(
        "Preparing the INSTRE dataset: version {0}, eval scale {1}, image caching {2}"
        .format(name, eval_scale, cache_images))
    # INSTRE dataset was downloaded from here: ftp://ftp.irisa.fr/local/texmex/corpus/instre/instre.tar.gz
    # Splits by Iscen et al. (2016) were downloaded from here: ftp://ftp.irisa.fr/local/texmex/corpus/instre/gnd_instre.mat

    image_size = 1000
    import scipy.io as sio
    dataset_path = os.path.join(data_path, "instre")
    annotation_file = os.path.join(dataset_path, "gnd_instre.mat")
    annotation_data = sio.loadmat(annotation_file)
    # annotation_data["qimlist"][0] - 1250 queries - each in annotation_data["qimlist"][0][i][0] file, root - os.path.join(data_path, "instre")
    # annotation_data["imlist"][0] - 27293 database images - each in annotation_data["imlist"][0][i][0] file, root - os.path.join(data_path, "instre")
    # annotation_data["gnd"][0] - 1250 annotations for all queries:
    #   annotation_data["gnd"][0][i][0] - indices of positives in annotation_data["imlist"][0] (WARNING - 1-based)
    #   annotation_data["gnd"][0][i][1] - bbox of the query object, one of the boxes from ent of *.txt
    #   images in subsets INSTRE-S1 and INSTRE-S2 contain exactly one object
    #   images in the subset INSTRE-M contain two objects each

    image_path = dataset_path
    gt_path = os.path.join(dataset_path, "classes")
    gt_image_path = os.path.join(gt_path, "images")
    mkdir(gt_image_path)

    classdatafile = os.path.join(gt_path, "instre.csv")
    if not os.path.isfile(classdatafile):
        logger.info(
            f"Did not find data file {classdatafile}, creating it from INSTRE source data"
        )
        # create the annotation file from the raw dataset
        annotation_data["qimlist"] = annotation_data["qimlist"].flatten()
        annotation_data["imlist"] = annotation_data["imlist"].flatten()
        annotation_data["gnd"] = annotation_data["gnd"].flatten()
        num_classes = len(annotation_data["qimlist"])
        gtboxframe = []  # will be creating dataframe from a list of dicts
        for i_class in range(num_classes):
            query_image_path_original = str(
                annotation_data["qimlist"][i_class][0])
            if query_image_path_original.split("/")[0].lower() == "instre-m":
                # Query boxes from subset "INSTRE-M" contain both objects, so it is not clear how to use them
                logger.info(
                    f"Skipping query {i_class}: {query_image_path_original}")
                continue
            logger.info(f"Adding query {i_class}: {query_image_path_original}")
            query_bbox = annotation_data["gnd"][i_class][1].flatten()
            query_positives = annotation_data["gnd"][i_class][0].flatten(
            ) - 1  # "-1" because of the original MATLAB indexing

            classid = i_class
            classfilename = f"{i_class:05d}_{'_'.join(query_image_path_original.split('/'))}"

            if not os.path.isfile(classfilename):
                query_img = read_image(
                    os.path.join(dataset_path, query_image_path_original))
                query_img_cropped_box = query_img.crop(query_bbox)
                query_img_cropped_box.save(
                    os.path.join(gt_image_path, classfilename))

            def convert_the_box_from_xywh(box, imsize):
                lx = float(box[0]) / imsize.w
                ty = float(box[1]) / imsize.h
                rx = lx + float(box[2]) / imsize.w
                by = ty + float(box[3]) / imsize.h
                return lx, ty, rx, by

            def read_boxes_from(file_with_boxes):
                with open(file_with_boxes, "r") as fo:
                    lines = fo.readlines()
                boxes = [[int(s) for s in line.split(" ")] for line in lines
                         if line]
                return boxes

            def get_box_file_for_image_file(image_filename):
                return image_filename.split(".")[0] + ".txt"

            def get_the_boxes(image_filename):
                file_with_boxes = os.path.join(
                    image_path, get_box_file_for_image_file(image_filename))
                # get image size - recompute boxes
                boxes = read_boxes_from(file_with_boxes)
                img = read_image(os.path.join(image_path, image_filename))
                imsize = FeatureMapSize(img=img)
                # choose the correct box if have two of them
                # From INSTRE documentation:
                # Specially, for each tuple-class in INSTRE-M, there are two corresponding object classes in INSTRE-S1.
                # In each annotation file for a INSTRE-M image, the first line records the object labeled as [a] in INSTRE-S1
                # and the second line records the object labeled as [b] in INSTRE-S1.
                #
                # CAUTION! the matlab file has boxes in x1, y1, x2, y2, but the .txt files in x, y, w, h
                query_path_split = query_image_path_original.split("/")
                image_filename_split = image_filename.split("/")
                if query_path_split[0].lower(
                ) == "instre-s1" and image_filename_split[0].lower(
                ) == "instre-m":
                    assert len(
                        boxes
                    ) == 2, f"INSTRE-M images should have exactly two boxes, but have {boxes}"
                    assert query_path_split[1][2] in ["a", "b"]
                    i_box = 0 if query_path_split[1][2] == "a" else 1
                    boxes = [convert_the_box_from_xywh(boxes[i_box], imsize)]
                elif query_path_split[0].lower() == "instre-s1" and image_filename_split[0].lower() == "instre-s1" or \
                        query_path_split[0].lower() == "instre-s2" and image_filename_split[0].lower() == "instre-s2":
                    boxes = [
                        convert_the_box_from_xywh(box, imsize) for box in boxes
                    ]
                else:
                    raise RuntimeError(
                        f"Should not be happening, query {query_image_path_original}, image {image_filename}, boxes {boxes}"
                    )
                return boxes

            for image_id in query_positives:
                # add one bbox to the annotation
                #     required_columns = ["imageid", "imagefilename", "classid", "classfilename", "gtbboxid", "difficult", "lx", "ty", "rx", "by"]
                image_file_name = str(annotation_data["imlist"][image_id][0])
                boxes = get_the_boxes(image_file_name)
                for box in boxes:
                    item = OrderedDict()
                    item["gtbboxid"] = len(gtboxframe)
                    item["classid"] = classid
                    item["classfilename"] = classfilename
                    item["imageid"] = image_id
                    assert annotation_data["imlist"][image_id].size == 1
                    item["imagefilename"] = image_file_name
                    item["difficult"] = 0
                    item["lx"], item["ty"], item["rx"], item["by"] = box
                    gtboxframe.append(item)

        gtboxframe = pd.DataFrame(gtboxframe)
        gtboxframe.to_csv(classdatafile)

    gtboxframe = read_annotation_file(classdatafile)

    # get these automatically from gtboxframe
    image_ids = None
    image_file_names = None

    # define a subset split (using closure)
    subset_name = name.lower()
    assert subset_name.startswith("instre"), ""
    subset_name = subset_name[len("instre"):]
    subsets = [
        "all", "s1-train", "s1-val", "s1-test", "s2-train", "s2-val", "s2-test"
    ]
    found_subset = False
    for subset in subsets:
        if subset_name == "-" + subset:
            found_subset = subset
            break
    assert found_subset, "Could not identify subset {}".format(subset_name)

    if subset == "all":
        pass
    elif subset in ["s1-train", "s1-val", "s1-test"]:
        gtboxframe = gtboxframe[gtboxframe.classfilename.str.contains(
            "INSTRE-S1")]
        classes = gtboxframe.classfilename.drop_duplicates()
        if subset == "s1-train":
            classes = classes[:len(classes) * 75 // 100]  # first 75%
        elif subset == "s1-test":
            classes = classes[len(classes) * 8 // 10:]  # last 20%
        else:  # "s1-val"
            classes = classes[len(classes) * 75 // 100:len(classes) * 8 //
                              10]  # 5%
        gtboxframe = gtboxframe[gtboxframe.classfilename.isin(classes)]
    elif subset in ["s2-train", "s2-val", "s2-test"]:
        gtboxframe = gtboxframe[gtboxframe.classfilename.str.contains(
            "INSTRE-S2")]
        classes = gtboxframe.classfilename.drop_duplicates()
        if subset == "s2-train":
            classes = classes[:len(classes) * 75 // 100]  # first 75%
        elif subset == "s2-test":
            classes = classes[len(classes) * 8 // 10:]  # last 20%
        else:  # "s2-val"
            classes = classes[len(classes) * 75 // 100:len(classes) * 8 //
                              10]  # 5%
        gtboxframe = gtboxframe[gtboxframe.classfilename.isin(classes)]
    else:
        raise (RuntimeError("Unknown subset {0}".format(subset)))

    dataset = DatasetOneShotDetection(gtboxframe,
                                      gt_image_path,
                                      image_path,
                                      name,
                                      image_size,
                                      eval_scale,
                                      image_ids=image_ids,
                                      image_file_names=image_file_names,
                                      cache_images=cache_images,
                                      no_image_reading=no_image_reading,
                                      logger_prefix=logger_prefix)
    return dataset
Ejemplo n.º 6
0
def build_imagenet_test_episodes(subset_name, data_path, logger):
    episode_id = int(subset_name.split('-')[-1])
    epi_data_name = "epi_inloc_in_domain_1_5_10_500"
    image_size = 1000

    dataset_path = os.path.join(data_path, "ImageNet-RepMet")
    roidb_path = os.path.join(dataset_path, "RepMet_CVPR2019_data", "data",
                              "Imagenet_LOC", "voc_inloc_roidb.pkl")
    with open(roidb_path, 'rb') as fid:
        roidb = pickle.load(fid, encoding='latin1')
    episodes_path = os.path.join(dataset_path, "RepMet_CVPR2019_data", "data",
                                 "Imagenet_LOC", "episodes",
                                 f"{epi_data_name}.pkl")
    with open(episodes_path, 'rb') as fid:
        episode_data = pickle.load(fid, encoding='latin1')

    logger.info(f"Extracting episode {episode_id} out of {len(episode_data)}")
    episode = episode_data[episode_id]

    dataset_image_path = os.path.join(data_path, "ImageNet-RepMet", "ILSVRC")

    SWAP_IMG_PATH_SRC = "/dccstor/leonidka1/data/imagenet/ILSVRC/"

    def _get_image_path(image_path):
        image_path = image_path.replace(SWAP_IMG_PATH_SRC, "")
        return image_path

    # episode["epi_cats"] - list of class ids
    # episode["query_images"] - list of path to the episode images
    # episode["epi_cats_names"] - list of names of the episode classes
    # episode["train_boxes"] - list of box data about class boxes

    num_classes = len(episode["epi_cats"])

    gt_path = os.path.join(dataset_path, epi_data_name)
    gt_path = os.path.join(gt_path, f"classes_episode_{episode_id}")
    gt_image_path = os.path.join(gt_path, "images")
    mkdir(gt_image_path)
    classdatafile = os.path.join(
        gt_path, f"classes_{epi_data_name}_episode_{episode_id}.csv")
    if not os.path.isfile(classdatafile):
        logger.info(
            f"Did not find data file {classdatafile}, creating it from the RepMet source data"
        )
        # create the annotation file from the raw dataset
        gtboxframe = []  # will be creating dataframe from a list of dicts

        gt_filename_by_id = {}
        for i_class in range(len(episode["train_boxes"])):
            train_boxes_data = episode["train_boxes"][i_class]
            class_id = train_boxes_data[0]
            assert class_id in episode[
                "epi_cats"], f"class_id={class_id} should be listed in episode['epi_cats']={episode['epi_cats']}"

            query_image_path_original = _get_image_path(train_boxes_data[2])
            query_bbox = train_boxes_data[3]
            query_bbox = query_bbox.flatten()

            classfilename = f"{class_id:05d}_{'_'.join(query_image_path_original.split('/'))}"

            if class_id not in gt_filename_by_id:
                logger.info(
                    f"Adding query #{len(gt_filename_by_id)} - {class_id}: {query_image_path_original}"
                )

                if not os.path.isfile(classfilename) or True:
                    query_img = read_image(
                        os.path.join(dataset_image_path,
                                     query_image_path_original))
                    query_img_cropped_box = query_img.crop(query_bbox)
                    query_img_cropped_box.save(
                        os.path.join(gt_image_path, classfilename))

                gt_filename_by_id[class_id] = classfilename
            else:
                logger.info(
                    f"WARNING: class {class_id} has multiple entries in GT image {query_image_path_original}, using the first box as GT"
                )

        for class_id in episode["epi_cats"]:
            if class_id not in gt_filename_by_id:
                logger.info(
                    f"WARNING: ground truth for class {class_id} not found in episode {episode_id}"
                )

        def convert_the_box_to_relative(box, imsize):
            lx = float(box[0]) / imsize.w
            ty = float(box[1]) / imsize.h
            rx = float(box[2]) / imsize.w
            by = float(box[3]) / imsize.h
            return lx, ty, rx, by

        def find_image_path_in_roidb(image_file_name, roidb):
            for i_image, im_data in enumerate(roidb["roidb"]):
                if im_data["flipped"]:
                    raise RuntimeError(
                        f"Image {i_image} data {im_data} has flipped flag on")
                if im_data["image"] == image_file_name:
                    return i_image
            return None

        for image_file_name in episode["query_images"]:
            # add one bbox to the annotation
            #     required_columns = ["imageid", "imagefilename", "classid", "classfilename", "gtbboxid", "difficult", "lx", "ty", "rx", "by"]
            image_id = find_image_path_in_roidb(image_file_name, roidb)
            im_data = roidb["roidb"][image_id]
            image_file_name = _get_image_path(image_file_name)

            imsize = FeatureMapSize(w=int(im_data["width"]),
                                    h=int(im_data["height"]))

            boxes_xyxy = im_data["boxes"]
            classes = im_data["gt_classes"]

            for box, class_id in zip(boxes_xyxy, classes):
                if class_id in gt_filename_by_id:
                    item = OrderedDict()
                    item["imageid"] = int(image_id)
                    item["imagefilename"] = image_file_name
                    item["classid"] = int(class_id)
                    item["classfilename"] = gt_filename_by_id[class_id]
                    item["gtbboxid"] = len(gtboxframe)
                    item["difficult"] = 0
                    item["lx"], item["ty"], item["rx"], item[
                        "by"] = convert_the_box_to_relative(box, imsize)
                    gtboxframe.append(item)

        gtboxframe = pd.DataFrame(gtboxframe)
        gtboxframe.to_csv(classdatafile)

    gtboxframe = pd.read_csv(classdatafile)

    return gtboxframe, gt_image_path, dataset_image_path, image_size
def save_cropped_boxes(dataset, tgt_image_path, extension=".jpg", num_random_crops_per_image=0):
    # crop all the boxes
    db = {"cids":[], "cluster":[], "gtbboxid":[], "classid":[], "imageid":[], "difficult":[], "type":[], "size":[], "bbox":[]}

    for image_id in tqdm(dataset.image_ids):
        img = dataset._get_dataset_image_by_id(image_id)
        boxes = dataset.get_image_annotation_for_imageid(image_id)

        assert boxes.has_field("labels"), "GT boxes need a field 'labels'"
        # remove all fields except "labels" and "difficult"
        for f in boxes.fields():
            if f not in ["labels", "difficult"]:
                boxes.remove_field(f)
        if not boxes.has_field("difficult"):
            boxes.add_field("difficult", torch.zeros(len(boxes), dtype=torch.bool))

        num_gt_boxes = len(boxes)
        im_size = FeatureMapSize(img=img)
        assert im_size == boxes.image_size

        eval_scale = dataset.get_eval_scale()

        # sample random boxes if needed
        if num_random_crops_per_image > 0:
            boxes_random = torch.rand(num_random_crops_per_image, 4)
            x1 = torch.min(boxes_random[:, 0], boxes_random[:, 2]) * im_size.w
            x2 = torch.max(boxes_random[:, 0], boxes_random[:, 2]) * im_size.w
            y1 = torch.min(boxes_random[:, 1], boxes_random[:, 3]) * im_size.h
            y2 = torch.max(boxes_random[:, 1], boxes_random[:, 3]) * im_size.h
            boxes_random = torch.stack([x1, y1, x2, y2], 1).floor()

            # crop boxes that are too small
            min_size = 10.0 / eval_scale * max(im_size.w, im_size.h)
            mask_bad_boxes = (boxes_random[:,0] + min_size > boxes_random[:,2]) | (boxes_random[:,1] + min_size > boxes_random[:,3])
            good_boxes = torch.nonzero(~mask_bad_boxes).view(-1)
            boxes_random = boxes_random[good_boxes]

            boxes_random = BoxList(boxes_random, im_size, mode="xyxy")

            boxes_random.add_field("labels", torch.full([len(boxes_random)], -1, dtype=torch.long))
            boxes_random.add_field("difficult", torch.zeros(len(boxes_random), dtype=torch.bool))
            boxes = cat_boxlist([boxes, boxes_random])

        if boxes is not None:
            for i_box in range(len(boxes)):
                # box format: left, top, right, bottom
                box = boxes[i_box].bbox_xyxy.view(-1)
                box = [b.item() for b in box]
                cropped_img = img.crop(box)

                if i_box < num_gt_boxes:
                    lbl = boxes[i_box].get_field("labels").item()
                    dif_flag = boxes[i_box].get_field("difficult").item()
                    box_id = i_box
                    box_type = "GT"
                else:
                    lbl = -1
                    dif_flag = 0
                    box_id = i_box
                    box_type = "RN"

                # create the file name to be used with cirtorch.datasets.datahelpers.cid2filename and their dataloader
                cid = "box{box_id:05d}_lbl{label:05d}_dif{dif:01d}_im{image_id:05d}{box_type}".format(box_id=box_id, image_id = image_id, label = lbl, dif = dif_flag, box_type=box_type)
                file_name = cid2filename(cid, prefix=tgt_image_path)

                # save the image
                image_path, _ = os.path.split(file_name)
                mkdir(image_path)
                if extension:
                    cropped_img.save("{}{}".format(file_name, extension))
                else:
                    # cirtorch uses files with empty extension for training for some reason, need to support that
                    cropped_img.save("{}".format(file_name), format="jpeg")

                # add to the db structure
                db["cids"].append(cid)
                db["cluster"].append(lbl)  # use labels as clusters not to sample negatives from the same object
                db["classid"].append(lbl)
                db["gtbboxid"].append(box_id)
                db["imageid"].append(image_id)
                db["difficult"].append(dif_flag)
                if i_box < num_gt_boxes:
                    db["type"].append("gtproposal")
                else:
                    db["type"].append("randomcrop")
                db["size"].append(cropped_img.size)
                db["bbox"].append(box)  # format (x1,y1,x2,y2)

    return db
def prepare_dataset(target_path,
                    retrieval_dataset_train_name, dataset_train,
                    retrieval_dataset_val_name, dataset_val,
                    iou_pos_threshold, iou_neg_threshold,
                    num_queries_image_to_image,
                    logger,
                    retrieval_dataset_test_names=None, datasets_test=None,
                    num_random_crops_per_image=0):
    # prepare data images for train and val
    tgt_image_path_trainval = os.path.join(target_path, "train", retrieval_dataset_train_name, "ims")
    mkdir(tgt_image_path_trainval)
    logger.info(f"Train set {retrieval_dataset_train_name}")
    db_images_train = save_cropped_boxes(dataset_train, tgt_image_path_trainval, extension="", num_random_crops_per_image=num_random_crops_per_image)

    # create val subset: add all boxes from images that have at least one validation box (can add some boxes from train as distractors)
    logger.info(f"Val set {retrieval_dataset_val_name}")
    db_images_val = save_cropped_boxes(dataset_val, tgt_image_path_trainval, extension="", num_random_crops_per_image=num_random_crops_per_image)

    # prepare data images for test
    dbs_images_test = {}
    if datasets_test:
        for dataset_test, dataset_name in zip(datasets_test, retrieval_dataset_test_names):
            tgt_image_path_test = os.path.join(target_path, "test", dataset_name, "jpg")  # the folder name should be always "test" - from cirtorch
            mkdir(tgt_image_path_test)
            logger.info(f"Eval dataset: {dataset_name}")
            dbs_images_test[dataset_name] = save_cropped_boxes(dataset_test, tgt_image_path_test, num_random_crops_per_image=num_random_crops_per_image)

    # save GT images from train
    db_classes_train = save_class_images(dataset_train,
                                         os.path.join(target_path, "train", retrieval_dataset_train_name, "ims"), extension="")

    # save GT images from val
    db_classes_val = save_class_images(dataset_val,
                                       os.path.join(target_path, "train", retrieval_dataset_train_name, "ims"), extension="")

    # save GT images for testing
    dbs_classes_test = {}
    if datasets_test:
        for dataset_test, dataset_name in zip(datasets_test, retrieval_dataset_test_names):
            dbs_classes_test[dataset_name] = save_class_images(dataset_test,
                                                    os.path.join(target_path, "test", dataset_name, "jpg"))

    # merge databases
    logger.info(f"Processing trainval set from {retrieval_dataset_train_name} and {retrieval_dataset_val_name}")
    db_train = create_train_database_queries(db_images_train, db_classes_train,
                                             iou_pos_threshold=iou_pos_threshold,
                                             iou_neg_threshold=iou_neg_threshold,
                                             logger=logger,
                                             num_queries_image_to_image=num_queries_image_to_image)
    db_val = create_train_database_queries(db_images_val, db_classes_val,
                                           iou_pos_threshold=iou_pos_threshold,
                                           iou_neg_threshold=iou_neg_threshold,
                                           logger=logger,
                                           num_queries_image_to_image=num_queries_image_to_image)

    dbs_test = {}
    if datasets_test:
        for dataset_name in retrieval_dataset_test_names:
            logger.info(f"Processing test set {dataset_name}")
            dbs_test[dataset_name] = create_test_database_queries(dbs_images_test[dataset_name], dbs_classes_test[dataset_name],
                                                                  iou_pos_threshold=iou_pos_threshold,
                                                                  iou_neg_threshold=iou_neg_threshold,
                                                                  logger=logger,
                                                                  num_queries_image_to_image=num_queries_image_to_image)

    # save trainval to disk
    db_trainval = {"train":db_train, "val":db_val}
    db_fn = os.path.join(os.path.join(target_path, "train", retrieval_dataset_train_name), f"{retrieval_dataset_train_name}.pkl")
    with open(db_fn, "wb") as f:
        pickle.dump(db_trainval, f)

    # save train separately for whitening
    db_fn = os.path.join(os.path.join(target_path, "train", retrieval_dataset_train_name), f"{retrieval_dataset_train_name}-whiten.pkl")
    with open(db_fn, "wb") as f:
        pickle.dump(db_train, f)

    # save test to disk
    if datasets_test:
        for dataset_name in retrieval_dataset_test_names:
            db_fn = os.path.join(os.path.join(target_path, "test", dataset_name ), f"gnd_{dataset_name}.pkl")
            with open(db_fn, "wb") as f:
                pickle.dump(dbs_test[dataset_name], f)
Ejemplo n.º 9
0
def main():
    args = parse_opts()
    set_random_seed(args.random_seed)
    crop_suffix = f"-rndCropPerImage{args.num_random_crops_per_image}"

    logger_name = "retrieval_data"
    retrieval_dataset_name_suffix = "-retrieval"
    logger = setup_logger(logger_name, None)
    data_path = get_data_path()

    script_path = os.path.expanduser(os.path.dirname(
        os.path.abspath(__file__)))
    target_path = os.path.join(script_path, "cnnimageretrieval-pytorch",
                               "data")
    mkdir(target_path)

    dataset_train = build_dataset_by_name(data_path,
                                          args.dataset_train,
                                          eval_scale=args.dataset_train_scale,
                                          logger_prefix=logger_name)
    retrieval_dataset_train_name = dataset_train.get_name(
    ) + retrieval_dataset_name_suffix

    dataset_val = build_dataset_by_name(data_path,
                                        args.dataset_val,
                                        eval_scale=args.dataset_val_scale,
                                        logger_prefix=logger_name)
    retrieval_dataset_val_name = dataset_val.get_name(
    ) + retrieval_dataset_name_suffix

    if args.datasets_test:
        if len(args.datasets_test_scale) == 1:
            datasets_test_scale = args.datasets_test_scale * len(
                args.datasets_test)
        else:
            datasets_test_scale = args.datasets_test_scale
        assert len(args.datasets_test) == len(
            datasets_test_scale
        ), "Arg datasets-test-scale should ne of len 1 or of len equal to the len of datasets-test"

        datasets_test = []
        retrieval_dataset_test_names = []
        for dataset_name, scale in zip(args.datasets_test,
                                       datasets_test_scale):
            dataset = build_dataset_by_name(data_path,
                                            dataset_name,
                                            eval_scale=scale,
                                            logger_prefix=logger_name)
            retrieval_dataset_test_names.append(dataset.get_name() +
                                                retrieval_dataset_name_suffix)
            datasets_test.append(dataset)

    # prepare data images for train and val
    tgt_image_path_trainval = os.path.join(target_path, "train",
                                           retrieval_dataset_train_name, "ims")
    mkdir(tgt_image_path_trainval)
    logger.info(
        f"Train set {retrieval_dataset_train_name} with no random crops")
    db_images_train = save_cropped_boxes(dataset_train,
                                         tgt_image_path_trainval,
                                         extension="")

    # create val subset: add all boxes from images that have at least one validation box (can add some boxes from train as distractors)
    logger.info(f"Val set {retrieval_dataset_val_name} with no random crops")
    db_images_val = save_cropped_boxes(dataset_val,
                                       tgt_image_path_trainval,
                                       extension="")

    # prepare data images for trainval with crops
    tgt_image_path_trainval_randcrops = os.path.join(
        target_path, "train", retrieval_dataset_train_name + crop_suffix,
        "ims")
    mkdir(tgt_image_path_trainval_randcrops)

    logger.info(
        f"Train set {retrieval_dataset_train_name} with {args.num_random_crops_per_image} crops per image"
    )
    db_images_train_randomCrops = save_cropped_boxes(
        dataset_train,
        tgt_image_path_trainval_randcrops,
        extension="",
        num_random_crops_per_image=args.num_random_crops_per_image)

    # create val subset: add all boxes from images that have at least one validation box (can add some boxes from train as distractors)
    logger.info(
        f"Val set {retrieval_dataset_val_name} with {args.num_random_crops_per_image} crops per image"
    )
    db_images_val_randomCrops = save_cropped_boxes(
        dataset_val,
        tgt_image_path_trainval_randcrops,
        extension="",
        num_random_crops_per_image=args.num_random_crops_per_image)

    # prepare data images for test
    dbs_images_test = {}
    if datasets_test:
        for dataset_test, dataset_name in zip(datasets_test,
                                              retrieval_dataset_test_names):
            tgt_image_path_test = os.path.join(
                target_path, "test", dataset_name, "jpg"
            )  # the folder name should be always "test" - from cirtorch
            mkdir(tgt_image_path_test)
            logger.info(f"Eval dataset: {dataset_name}")
            dbs_images_test[dataset_name] = save_cropped_boxes(
                dataset_test, tgt_image_path_test)

            # prepare data images for test with random crops
            tgt_image_path_test = os.path.join(
                target_path, "test", dataset_name + crop_suffix, "jpg"
            )  # the folder name should be always "test" - from cirtorch
            mkdir(tgt_image_path_test)
            logger.info(f"Eval dataset: {dataset_name + crop_suffix}")
            dbs_images_test[dataset_name + crop_suffix] = save_cropped_boxes(
                dataset_test,
                tgt_image_path_test,
                num_random_crops_per_image=args.num_random_crops_per_image)

    # save GT images from train
    db_classes_train = save_class_images(
        dataset_train,
        os.path.join(target_path, "train", retrieval_dataset_train_name,
                     "ims"),
        extension="")
    db_classes_train_randomCrops = save_class_images(
        dataset_train,
        os.path.join(target_path, "train",
                     retrieval_dataset_train_name + crop_suffix, "ims"),
        extension="")

    # save GT images from val
    db_classes_val = save_class_images(
        dataset_val,
        os.path.join(target_path, "train", retrieval_dataset_train_name,
                     "ims"),
        extension="")
    db_classes_val_randomCrops = save_class_images(
        dataset_val,
        os.path.join(target_path, "train",
                     retrieval_dataset_train_name + crop_suffix, "ims"),
        extension="")

    # save GT images for testing
    dbs_classes_test = {}
    if args.datasets_test:
        for dataset_test, dataset_name in zip(datasets_test,
                                              retrieval_dataset_test_names):
            dbs_classes_test[dataset_name] = save_class_images(
                dataset_test,
                os.path.join(target_path, "test", dataset_name, "jpg"))
            dbs_classes_test[dataset_name + crop_suffix] = save_class_images(
                dataset_test,
                os.path.join(target_path, "test", dataset_name + crop_suffix,
                             "jpg"))

    # merge databases
    logger.info(
        f"Processing trainval set from {retrieval_dataset_train_name} and {retrieval_dataset_val_name}"
    )
    db_train = create_train_database_queries(
        db_images_train,
        db_classes_train,
        iou_pos_threshold=args.iou_pos_threshold,
        iou_neg_threshold=args.iou_neg_threshold,
        logger=logger)
    db_val = create_train_database_queries(
        db_images_val,
        db_classes_val,
        iou_pos_threshold=args.iou_pos_threshold,
        iou_neg_threshold=args.iou_neg_threshold,
        logger=logger)

    logger.info(
        f"Processing trainval set from {retrieval_dataset_train_name} and {retrieval_dataset_val_name} with {args.num_random_crops_per_image} random crops"
    )
    db_train_randomCrops = create_train_database_queries(
        db_images_train_randomCrops,
        db_classes_train_randomCrops,
        iou_pos_threshold=args.iou_pos_threshold,
        iou_neg_threshold=args.iou_neg_threshold,
        logger=logger)
    db_val_randomCrops = create_train_database_queries(
        db_images_val_randomCrops,
        db_classes_val_randomCrops,
        iou_pos_threshold=args.iou_pos_threshold,
        iou_neg_threshold=args.iou_neg_threshold,
        logger=logger)

    dbs_test = {}
    if args.datasets_test:
        for dataset_name in retrieval_dataset_test_names:
            logger.info(
                f"Processing test set {dataset_name} with {args.num_random_crops_per_image} random crops"
            )
            dbs_test[dataset_name] = create_test_database_queries(
                dbs_images_test[dataset_name],
                dbs_classes_test[dataset_name],
                iou_pos_threshold=args.iou_pos_threshold,
                iou_neg_threshold=args.iou_neg_threshold,
                logger=logger)

            logger.info(f"Processing test set {dataset_name + crop_suffix}")
            dbs_test[dataset_name +
                     crop_suffix] = create_test_database_queries(
                         dbs_images_test[dataset_name + crop_suffix],
                         dbs_classes_test[dataset_name + crop_suffix],
                         iou_pos_threshold=args.iou_pos_threshold,
                         iou_neg_threshold=args.iou_neg_threshold,
                         logger=logger)

    # save trainval to disk
    db_trainval = {"train": db_train, "val": db_val}
    db_fn = os.path.join(
        os.path.join(target_path, "train", retrieval_dataset_train_name),
        f"{retrieval_dataset_train_name}.pkl")
    with open(db_fn, "wb") as f:
        pickle.dump(db_trainval, f)

    # save train separately for whitening
    db_fn = os.path.join(
        os.path.join(target_path, "train", retrieval_dataset_train_name),
        f"{retrieval_dataset_train_name}-whiten.pkl")
    with open(db_fn, "wb") as f:
        pickle.dump(db_train, f)

    # save trainval with random crops to disk
    db_trainval_randomCrops = {
        "train": db_train_randomCrops,
        "val": db_val_randomCrops
    }
    db_fn = os.path.join(
        os.path.join(target_path, "train",
                     retrieval_dataset_train_name + crop_suffix),
        f"{retrieval_dataset_train_name}{crop_suffix}.pkl")
    with open(db_fn, "wb") as f:
        pickle.dump(db_trainval_randomCrops, f)

    db_fn = os.path.join(
        os.path.join(target_path, "train",
                     retrieval_dataset_train_name + crop_suffix),
        f"{retrieval_dataset_train_name}{crop_suffix}-whiten.pkl")
    with open(db_fn, "wb") as f:
        pickle.dump(db_train_randomCrops, f)

    # save test to disk
    if args.datasets_test:
        for dataset_name in retrieval_dataset_test_names:
            db_fn = os.path.join(
                os.path.join(target_path, "test", dataset_name),
                f"gnd_{dataset_name}.pkl")
            with open(db_fn, "wb") as f:
                pickle.dump(dbs_test[dataset_name], f)

            # save test with random crops to disk
            db_fn = os.path.join(
                os.path.join(target_path, "test", dataset_name + crop_suffix),
                f"gnd_{dataset_name}{crop_suffix}.pkl")
            with open(db_fn, "wb") as f:
                pickle.dump(dbs_test[dataset_name + crop_suffix], f)