def save_class_images(dataset, tgt_image_path, extension=".jpg"): db = {"cids":[], "cluster":[], "gtbboxid":[], "classid":[], "imageid":[], "difficult":[], "type":[], "size":[], "bbox":[]} for lbl, label_image in tqdm(dataset.gt_images_per_classid.items()): # create the file name to be used with cirtorch.datasets.datahelpers.cid2filename and their dataloader cid = "lbl{label:05d}{box_type}".format(label = lbl, box_type="CL") file_name = cid2filename(cid, prefix=tgt_image_path) # save the image image_path, _ = os.path.split(file_name) mkdir(image_path) if extension: label_image.save("{}{}".format(file_name, extension)) else: # cirtorch uses files with empty extension for training for some reason, need to support that label_image.save("{}".format(file_name), format="jpeg") width, height = label_image.size box = [0, 0, width, height] # format (x1,y1,x2,y2) # add to the db structure db["cids"].append(cid) db["cluster"].append(lbl) # use labels as clusters not to sample negatives from the same object db["classid"].append(lbl) db["gtbboxid"].append(None) db["imageid"].append(None) db["difficult"].append(None) db["type"].append("classimage") db["size"].append(label_image.size) db["bbox"].append(box) # format (x1,y1,x2,y2) return db
def main(): args = parse_opts() set_random_seed(args.random_seed) logger_name = "retrieval_data" retrieval_dataset_name_suffix = "-retrieval" logger = setup_logger(logger_name, None) data_path = get_data_path() script_path = os.path.expanduser(os.path.dirname(os.path.abspath(__file__))) target_path = os.path.join(script_path, "cnnimageretrieval-pytorch", "data") mkdir(target_path) dataset_train = build_dataset_by_name(data_path, args.dataset_train, eval_scale=args.dataset_train_scale, logger_prefix=logger_name) retrieval_dataset_train_name = dataset_train.get_name() + retrieval_dataset_name_suffix dataset_val = build_dataset_by_name(data_path, args.dataset_val, eval_scale=args.dataset_val_scale, logger_prefix=logger_name) retrieval_dataset_val_name = dataset_val.get_name() + retrieval_dataset_name_suffix datasets_test = [] retrieval_dataset_test_names = [] if args.datasets_test: if len(args.datasets_test_scale) == 1: datasets_test_scale = args.datasets_test_scale * len(args.datasets_test) else: datasets_test_scale = args.datasets_test_scale assert len(args.datasets_test) == len(datasets_test_scale), "Arg datasets-test-scale should be of len 1 or of len equal to the len of datasets-test" for dataset_name, scale in zip(args.datasets_test, datasets_test_scale): dataset = build_dataset_by_name(data_path, dataset_name, eval_scale=scale, logger_prefix=logger_name) retrieval_dataset_test_names.append(dataset.get_name() + retrieval_dataset_name_suffix) datasets_test.append(dataset) # create dataset if args.num_random_crops_per_image > 0: crop_suffix = f"-rndCropPerImage{args.num_random_crops_per_image}" retrieval_dataset_train_name = retrieval_dataset_train_name + crop_suffix retrieval_dataset_val_name = retrieval_dataset_val_name + crop_suffix retrieval_dataset_test_names = [name + crop_suffix for name in retrieval_dataset_test_names] prepare_dataset(target_path, retrieval_dataset_train_name, dataset_train, retrieval_dataset_val_name, dataset_val, args.iou_pos_threshold, args.iou_neg_threshold, args.num_queries_image_to_image, logger, retrieval_dataset_test_names=retrieval_dataset_test_names, datasets_test=datasets_test, num_random_crops_per_image=args.num_random_crops_per_image)
def init_logger(args, logger_prefix="detector-retrieval"): if args.output_dir: mkdir(args.output_dir) logger = setup_logger(logger_prefix, args.output_dir if args.output_dir else None) if args.config_file: with open(args.config_file, "r") as cf: config_str = "\n" + cf.read() logger.info(config_str) logger.info("Running with the default eval section:\n{}".format(cfg.eval)) else: logger.info("Launched with no OS2D config file") logger.info("Running args:\n{}".format(args)) return logger
def init_logger(cfg, config_file): output_dir = cfg.output.path if output_dir: mkdir(output_dir) logger = setup_logger("OS2D", output_dir if cfg.output.save_log_to_file else None) if config_file: logger.info("Loaded configuration file {}".format(config_file)) with open(config_file, "r") as cf: config_str = "\n" + cf.read() logger.info(config_str) else: logger.info("Config file was not provided") logger.info("Running with config:\n{}".format(cfg)) # save config file only when training (to run multiple evaluations in the same folder) if output_dir and cfg.train.do_training: output_config_path = os.path.join(output_dir, "config.yml") logger.info("Saving config into: {}".format(output_config_path)) # save overloaded model config in the output directory save_config(cfg, output_config_path)
def build_instre_dataset(data_path, name, eval_scale=None, cache_images=False, no_image_reading=False, logger_prefix="OS2D"): logger = logging.getLogger(f"{logger_prefix}.dataset") logger.info( "Preparing the INSTRE dataset: version {0}, eval scale {1}, image caching {2}" .format(name, eval_scale, cache_images)) # INSTRE dataset was downloaded from here: ftp://ftp.irisa.fr/local/texmex/corpus/instre/instre.tar.gz # Splits by Iscen et al. (2016) were downloaded from here: ftp://ftp.irisa.fr/local/texmex/corpus/instre/gnd_instre.mat image_size = 1000 import scipy.io as sio dataset_path = os.path.join(data_path, "instre") annotation_file = os.path.join(dataset_path, "gnd_instre.mat") annotation_data = sio.loadmat(annotation_file) # annotation_data["qimlist"][0] - 1250 queries - each in annotation_data["qimlist"][0][i][0] file, root - os.path.join(data_path, "instre") # annotation_data["imlist"][0] - 27293 database images - each in annotation_data["imlist"][0][i][0] file, root - os.path.join(data_path, "instre") # annotation_data["gnd"][0] - 1250 annotations for all queries: # annotation_data["gnd"][0][i][0] - indices of positives in annotation_data["imlist"][0] (WARNING - 1-based) # annotation_data["gnd"][0][i][1] - bbox of the query object, one of the boxes from ent of *.txt # images in subsets INSTRE-S1 and INSTRE-S2 contain exactly one object # images in the subset INSTRE-M contain two objects each image_path = dataset_path gt_path = os.path.join(dataset_path, "classes") gt_image_path = os.path.join(gt_path, "images") mkdir(gt_image_path) classdatafile = os.path.join(gt_path, "instre.csv") if not os.path.isfile(classdatafile): logger.info( f"Did not find data file {classdatafile}, creating it from INSTRE source data" ) # create the annotation file from the raw dataset annotation_data["qimlist"] = annotation_data["qimlist"].flatten() annotation_data["imlist"] = annotation_data["imlist"].flatten() annotation_data["gnd"] = annotation_data["gnd"].flatten() num_classes = len(annotation_data["qimlist"]) gtboxframe = [] # will be creating dataframe from a list of dicts for i_class in range(num_classes): query_image_path_original = str( annotation_data["qimlist"][i_class][0]) if query_image_path_original.split("/")[0].lower() == "instre-m": # Query boxes from subset "INSTRE-M" contain both objects, so it is not clear how to use them logger.info( f"Skipping query {i_class}: {query_image_path_original}") continue logger.info(f"Adding query {i_class}: {query_image_path_original}") query_bbox = annotation_data["gnd"][i_class][1].flatten() query_positives = annotation_data["gnd"][i_class][0].flatten( ) - 1 # "-1" because of the original MATLAB indexing classid = i_class classfilename = f"{i_class:05d}_{'_'.join(query_image_path_original.split('/'))}" if not os.path.isfile(classfilename): query_img = read_image( os.path.join(dataset_path, query_image_path_original)) query_img_cropped_box = query_img.crop(query_bbox) query_img_cropped_box.save( os.path.join(gt_image_path, classfilename)) def convert_the_box_from_xywh(box, imsize): lx = float(box[0]) / imsize.w ty = float(box[1]) / imsize.h rx = lx + float(box[2]) / imsize.w by = ty + float(box[3]) / imsize.h return lx, ty, rx, by def read_boxes_from(file_with_boxes): with open(file_with_boxes, "r") as fo: lines = fo.readlines() boxes = [[int(s) for s in line.split(" ")] for line in lines if line] return boxes def get_box_file_for_image_file(image_filename): return image_filename.split(".")[0] + ".txt" def get_the_boxes(image_filename): file_with_boxes = os.path.join( image_path, get_box_file_for_image_file(image_filename)) # get image size - recompute boxes boxes = read_boxes_from(file_with_boxes) img = read_image(os.path.join(image_path, image_filename)) imsize = FeatureMapSize(img=img) # choose the correct box if have two of them # From INSTRE documentation: # Specially, for each tuple-class in INSTRE-M, there are two corresponding object classes in INSTRE-S1. # In each annotation file for a INSTRE-M image, the first line records the object labeled as [a] in INSTRE-S1 # and the second line records the object labeled as [b] in INSTRE-S1. # # CAUTION! the matlab file has boxes in x1, y1, x2, y2, but the .txt files in x, y, w, h query_path_split = query_image_path_original.split("/") image_filename_split = image_filename.split("/") if query_path_split[0].lower( ) == "instre-s1" and image_filename_split[0].lower( ) == "instre-m": assert len( boxes ) == 2, f"INSTRE-M images should have exactly two boxes, but have {boxes}" assert query_path_split[1][2] in ["a", "b"] i_box = 0 if query_path_split[1][2] == "a" else 1 boxes = [convert_the_box_from_xywh(boxes[i_box], imsize)] elif query_path_split[0].lower() == "instre-s1" and image_filename_split[0].lower() == "instre-s1" or \ query_path_split[0].lower() == "instre-s2" and image_filename_split[0].lower() == "instre-s2": boxes = [ convert_the_box_from_xywh(box, imsize) for box in boxes ] else: raise RuntimeError( f"Should not be happening, query {query_image_path_original}, image {image_filename}, boxes {boxes}" ) return boxes for image_id in query_positives: # add one bbox to the annotation # required_columns = ["imageid", "imagefilename", "classid", "classfilename", "gtbboxid", "difficult", "lx", "ty", "rx", "by"] image_file_name = str(annotation_data["imlist"][image_id][0]) boxes = get_the_boxes(image_file_name) for box in boxes: item = OrderedDict() item["gtbboxid"] = len(gtboxframe) item["classid"] = classid item["classfilename"] = classfilename item["imageid"] = image_id assert annotation_data["imlist"][image_id].size == 1 item["imagefilename"] = image_file_name item["difficult"] = 0 item["lx"], item["ty"], item["rx"], item["by"] = box gtboxframe.append(item) gtboxframe = pd.DataFrame(gtboxframe) gtboxframe.to_csv(classdatafile) gtboxframe = read_annotation_file(classdatafile) # get these automatically from gtboxframe image_ids = None image_file_names = None # define a subset split (using closure) subset_name = name.lower() assert subset_name.startswith("instre"), "" subset_name = subset_name[len("instre"):] subsets = [ "all", "s1-train", "s1-val", "s1-test", "s2-train", "s2-val", "s2-test" ] found_subset = False for subset in subsets: if subset_name == "-" + subset: found_subset = subset break assert found_subset, "Could not identify subset {}".format(subset_name) if subset == "all": pass elif subset in ["s1-train", "s1-val", "s1-test"]: gtboxframe = gtboxframe[gtboxframe.classfilename.str.contains( "INSTRE-S1")] classes = gtboxframe.classfilename.drop_duplicates() if subset == "s1-train": classes = classes[:len(classes) * 75 // 100] # first 75% elif subset == "s1-test": classes = classes[len(classes) * 8 // 10:] # last 20% else: # "s1-val" classes = classes[len(classes) * 75 // 100:len(classes) * 8 // 10] # 5% gtboxframe = gtboxframe[gtboxframe.classfilename.isin(classes)] elif subset in ["s2-train", "s2-val", "s2-test"]: gtboxframe = gtboxframe[gtboxframe.classfilename.str.contains( "INSTRE-S2")] classes = gtboxframe.classfilename.drop_duplicates() if subset == "s2-train": classes = classes[:len(classes) * 75 // 100] # first 75% elif subset == "s2-test": classes = classes[len(classes) * 8 // 10:] # last 20% else: # "s2-val" classes = classes[len(classes) * 75 // 100:len(classes) * 8 // 10] # 5% gtboxframe = gtboxframe[gtboxframe.classfilename.isin(classes)] else: raise (RuntimeError("Unknown subset {0}".format(subset))) dataset = DatasetOneShotDetection(gtboxframe, gt_image_path, image_path, name, image_size, eval_scale, image_ids=image_ids, image_file_names=image_file_names, cache_images=cache_images, no_image_reading=no_image_reading, logger_prefix=logger_prefix) return dataset
def build_imagenet_test_episodes(subset_name, data_path, logger): episode_id = int(subset_name.split('-')[-1]) epi_data_name = "epi_inloc_in_domain_1_5_10_500" image_size = 1000 dataset_path = os.path.join(data_path, "ImageNet-RepMet") roidb_path = os.path.join(dataset_path, "RepMet_CVPR2019_data", "data", "Imagenet_LOC", "voc_inloc_roidb.pkl") with open(roidb_path, 'rb') as fid: roidb = pickle.load(fid, encoding='latin1') episodes_path = os.path.join(dataset_path, "RepMet_CVPR2019_data", "data", "Imagenet_LOC", "episodes", f"{epi_data_name}.pkl") with open(episodes_path, 'rb') as fid: episode_data = pickle.load(fid, encoding='latin1') logger.info(f"Extracting episode {episode_id} out of {len(episode_data)}") episode = episode_data[episode_id] dataset_image_path = os.path.join(data_path, "ImageNet-RepMet", "ILSVRC") SWAP_IMG_PATH_SRC = "/dccstor/leonidka1/data/imagenet/ILSVRC/" def _get_image_path(image_path): image_path = image_path.replace(SWAP_IMG_PATH_SRC, "") return image_path # episode["epi_cats"] - list of class ids # episode["query_images"] - list of path to the episode images # episode["epi_cats_names"] - list of names of the episode classes # episode["train_boxes"] - list of box data about class boxes num_classes = len(episode["epi_cats"]) gt_path = os.path.join(dataset_path, epi_data_name) gt_path = os.path.join(gt_path, f"classes_episode_{episode_id}") gt_image_path = os.path.join(gt_path, "images") mkdir(gt_image_path) classdatafile = os.path.join( gt_path, f"classes_{epi_data_name}_episode_{episode_id}.csv") if not os.path.isfile(classdatafile): logger.info( f"Did not find data file {classdatafile}, creating it from the RepMet source data" ) # create the annotation file from the raw dataset gtboxframe = [] # will be creating dataframe from a list of dicts gt_filename_by_id = {} for i_class in range(len(episode["train_boxes"])): train_boxes_data = episode["train_boxes"][i_class] class_id = train_boxes_data[0] assert class_id in episode[ "epi_cats"], f"class_id={class_id} should be listed in episode['epi_cats']={episode['epi_cats']}" query_image_path_original = _get_image_path(train_boxes_data[2]) query_bbox = train_boxes_data[3] query_bbox = query_bbox.flatten() classfilename = f"{class_id:05d}_{'_'.join(query_image_path_original.split('/'))}" if class_id not in gt_filename_by_id: logger.info( f"Adding query #{len(gt_filename_by_id)} - {class_id}: {query_image_path_original}" ) if not os.path.isfile(classfilename) or True: query_img = read_image( os.path.join(dataset_image_path, query_image_path_original)) query_img_cropped_box = query_img.crop(query_bbox) query_img_cropped_box.save( os.path.join(gt_image_path, classfilename)) gt_filename_by_id[class_id] = classfilename else: logger.info( f"WARNING: class {class_id} has multiple entries in GT image {query_image_path_original}, using the first box as GT" ) for class_id in episode["epi_cats"]: if class_id not in gt_filename_by_id: logger.info( f"WARNING: ground truth for class {class_id} not found in episode {episode_id}" ) def convert_the_box_to_relative(box, imsize): lx = float(box[0]) / imsize.w ty = float(box[1]) / imsize.h rx = float(box[2]) / imsize.w by = float(box[3]) / imsize.h return lx, ty, rx, by def find_image_path_in_roidb(image_file_name, roidb): for i_image, im_data in enumerate(roidb["roidb"]): if im_data["flipped"]: raise RuntimeError( f"Image {i_image} data {im_data} has flipped flag on") if im_data["image"] == image_file_name: return i_image return None for image_file_name in episode["query_images"]: # add one bbox to the annotation # required_columns = ["imageid", "imagefilename", "classid", "classfilename", "gtbboxid", "difficult", "lx", "ty", "rx", "by"] image_id = find_image_path_in_roidb(image_file_name, roidb) im_data = roidb["roidb"][image_id] image_file_name = _get_image_path(image_file_name) imsize = FeatureMapSize(w=int(im_data["width"]), h=int(im_data["height"])) boxes_xyxy = im_data["boxes"] classes = im_data["gt_classes"] for box, class_id in zip(boxes_xyxy, classes): if class_id in gt_filename_by_id: item = OrderedDict() item["imageid"] = int(image_id) item["imagefilename"] = image_file_name item["classid"] = int(class_id) item["classfilename"] = gt_filename_by_id[class_id] item["gtbboxid"] = len(gtboxframe) item["difficult"] = 0 item["lx"], item["ty"], item["rx"], item[ "by"] = convert_the_box_to_relative(box, imsize) gtboxframe.append(item) gtboxframe = pd.DataFrame(gtboxframe) gtboxframe.to_csv(classdatafile) gtboxframe = pd.read_csv(classdatafile) return gtboxframe, gt_image_path, dataset_image_path, image_size
def save_cropped_boxes(dataset, tgt_image_path, extension=".jpg", num_random_crops_per_image=0): # crop all the boxes db = {"cids":[], "cluster":[], "gtbboxid":[], "classid":[], "imageid":[], "difficult":[], "type":[], "size":[], "bbox":[]} for image_id in tqdm(dataset.image_ids): img = dataset._get_dataset_image_by_id(image_id) boxes = dataset.get_image_annotation_for_imageid(image_id) assert boxes.has_field("labels"), "GT boxes need a field 'labels'" # remove all fields except "labels" and "difficult" for f in boxes.fields(): if f not in ["labels", "difficult"]: boxes.remove_field(f) if not boxes.has_field("difficult"): boxes.add_field("difficult", torch.zeros(len(boxes), dtype=torch.bool)) num_gt_boxes = len(boxes) im_size = FeatureMapSize(img=img) assert im_size == boxes.image_size eval_scale = dataset.get_eval_scale() # sample random boxes if needed if num_random_crops_per_image > 0: boxes_random = torch.rand(num_random_crops_per_image, 4) x1 = torch.min(boxes_random[:, 0], boxes_random[:, 2]) * im_size.w x2 = torch.max(boxes_random[:, 0], boxes_random[:, 2]) * im_size.w y1 = torch.min(boxes_random[:, 1], boxes_random[:, 3]) * im_size.h y2 = torch.max(boxes_random[:, 1], boxes_random[:, 3]) * im_size.h boxes_random = torch.stack([x1, y1, x2, y2], 1).floor() # crop boxes that are too small min_size = 10.0 / eval_scale * max(im_size.w, im_size.h) mask_bad_boxes = (boxes_random[:,0] + min_size > boxes_random[:,2]) | (boxes_random[:,1] + min_size > boxes_random[:,3]) good_boxes = torch.nonzero(~mask_bad_boxes).view(-1) boxes_random = boxes_random[good_boxes] boxes_random = BoxList(boxes_random, im_size, mode="xyxy") boxes_random.add_field("labels", torch.full([len(boxes_random)], -1, dtype=torch.long)) boxes_random.add_field("difficult", torch.zeros(len(boxes_random), dtype=torch.bool)) boxes = cat_boxlist([boxes, boxes_random]) if boxes is not None: for i_box in range(len(boxes)): # box format: left, top, right, bottom box = boxes[i_box].bbox_xyxy.view(-1) box = [b.item() for b in box] cropped_img = img.crop(box) if i_box < num_gt_boxes: lbl = boxes[i_box].get_field("labels").item() dif_flag = boxes[i_box].get_field("difficult").item() box_id = i_box box_type = "GT" else: lbl = -1 dif_flag = 0 box_id = i_box box_type = "RN" # create the file name to be used with cirtorch.datasets.datahelpers.cid2filename and their dataloader cid = "box{box_id:05d}_lbl{label:05d}_dif{dif:01d}_im{image_id:05d}{box_type}".format(box_id=box_id, image_id = image_id, label = lbl, dif = dif_flag, box_type=box_type) file_name = cid2filename(cid, prefix=tgt_image_path) # save the image image_path, _ = os.path.split(file_name) mkdir(image_path) if extension: cropped_img.save("{}{}".format(file_name, extension)) else: # cirtorch uses files with empty extension for training for some reason, need to support that cropped_img.save("{}".format(file_name), format="jpeg") # add to the db structure db["cids"].append(cid) db["cluster"].append(lbl) # use labels as clusters not to sample negatives from the same object db["classid"].append(lbl) db["gtbboxid"].append(box_id) db["imageid"].append(image_id) db["difficult"].append(dif_flag) if i_box < num_gt_boxes: db["type"].append("gtproposal") else: db["type"].append("randomcrop") db["size"].append(cropped_img.size) db["bbox"].append(box) # format (x1,y1,x2,y2) return db
def prepare_dataset(target_path, retrieval_dataset_train_name, dataset_train, retrieval_dataset_val_name, dataset_val, iou_pos_threshold, iou_neg_threshold, num_queries_image_to_image, logger, retrieval_dataset_test_names=None, datasets_test=None, num_random_crops_per_image=0): # prepare data images for train and val tgt_image_path_trainval = os.path.join(target_path, "train", retrieval_dataset_train_name, "ims") mkdir(tgt_image_path_trainval) logger.info(f"Train set {retrieval_dataset_train_name}") db_images_train = save_cropped_boxes(dataset_train, tgt_image_path_trainval, extension="", num_random_crops_per_image=num_random_crops_per_image) # create val subset: add all boxes from images that have at least one validation box (can add some boxes from train as distractors) logger.info(f"Val set {retrieval_dataset_val_name}") db_images_val = save_cropped_boxes(dataset_val, tgt_image_path_trainval, extension="", num_random_crops_per_image=num_random_crops_per_image) # prepare data images for test dbs_images_test = {} if datasets_test: for dataset_test, dataset_name in zip(datasets_test, retrieval_dataset_test_names): tgt_image_path_test = os.path.join(target_path, "test", dataset_name, "jpg") # the folder name should be always "test" - from cirtorch mkdir(tgt_image_path_test) logger.info(f"Eval dataset: {dataset_name}") dbs_images_test[dataset_name] = save_cropped_boxes(dataset_test, tgt_image_path_test, num_random_crops_per_image=num_random_crops_per_image) # save GT images from train db_classes_train = save_class_images(dataset_train, os.path.join(target_path, "train", retrieval_dataset_train_name, "ims"), extension="") # save GT images from val db_classes_val = save_class_images(dataset_val, os.path.join(target_path, "train", retrieval_dataset_train_name, "ims"), extension="") # save GT images for testing dbs_classes_test = {} if datasets_test: for dataset_test, dataset_name in zip(datasets_test, retrieval_dataset_test_names): dbs_classes_test[dataset_name] = save_class_images(dataset_test, os.path.join(target_path, "test", dataset_name, "jpg")) # merge databases logger.info(f"Processing trainval set from {retrieval_dataset_train_name} and {retrieval_dataset_val_name}") db_train = create_train_database_queries(db_images_train, db_classes_train, iou_pos_threshold=iou_pos_threshold, iou_neg_threshold=iou_neg_threshold, logger=logger, num_queries_image_to_image=num_queries_image_to_image) db_val = create_train_database_queries(db_images_val, db_classes_val, iou_pos_threshold=iou_pos_threshold, iou_neg_threshold=iou_neg_threshold, logger=logger, num_queries_image_to_image=num_queries_image_to_image) dbs_test = {} if datasets_test: for dataset_name in retrieval_dataset_test_names: logger.info(f"Processing test set {dataset_name}") dbs_test[dataset_name] = create_test_database_queries(dbs_images_test[dataset_name], dbs_classes_test[dataset_name], iou_pos_threshold=iou_pos_threshold, iou_neg_threshold=iou_neg_threshold, logger=logger, num_queries_image_to_image=num_queries_image_to_image) # save trainval to disk db_trainval = {"train":db_train, "val":db_val} db_fn = os.path.join(os.path.join(target_path, "train", retrieval_dataset_train_name), f"{retrieval_dataset_train_name}.pkl") with open(db_fn, "wb") as f: pickle.dump(db_trainval, f) # save train separately for whitening db_fn = os.path.join(os.path.join(target_path, "train", retrieval_dataset_train_name), f"{retrieval_dataset_train_name}-whiten.pkl") with open(db_fn, "wb") as f: pickle.dump(db_train, f) # save test to disk if datasets_test: for dataset_name in retrieval_dataset_test_names: db_fn = os.path.join(os.path.join(target_path, "test", dataset_name ), f"gnd_{dataset_name}.pkl") with open(db_fn, "wb") as f: pickle.dump(dbs_test[dataset_name], f)
def main(): args = parse_opts() set_random_seed(args.random_seed) crop_suffix = f"-rndCropPerImage{args.num_random_crops_per_image}" logger_name = "retrieval_data" retrieval_dataset_name_suffix = "-retrieval" logger = setup_logger(logger_name, None) data_path = get_data_path() script_path = os.path.expanduser(os.path.dirname( os.path.abspath(__file__))) target_path = os.path.join(script_path, "cnnimageretrieval-pytorch", "data") mkdir(target_path) dataset_train = build_dataset_by_name(data_path, args.dataset_train, eval_scale=args.dataset_train_scale, logger_prefix=logger_name) retrieval_dataset_train_name = dataset_train.get_name( ) + retrieval_dataset_name_suffix dataset_val = build_dataset_by_name(data_path, args.dataset_val, eval_scale=args.dataset_val_scale, logger_prefix=logger_name) retrieval_dataset_val_name = dataset_val.get_name( ) + retrieval_dataset_name_suffix if args.datasets_test: if len(args.datasets_test_scale) == 1: datasets_test_scale = args.datasets_test_scale * len( args.datasets_test) else: datasets_test_scale = args.datasets_test_scale assert len(args.datasets_test) == len( datasets_test_scale ), "Arg datasets-test-scale should ne of len 1 or of len equal to the len of datasets-test" datasets_test = [] retrieval_dataset_test_names = [] for dataset_name, scale in zip(args.datasets_test, datasets_test_scale): dataset = build_dataset_by_name(data_path, dataset_name, eval_scale=scale, logger_prefix=logger_name) retrieval_dataset_test_names.append(dataset.get_name() + retrieval_dataset_name_suffix) datasets_test.append(dataset) # prepare data images for train and val tgt_image_path_trainval = os.path.join(target_path, "train", retrieval_dataset_train_name, "ims") mkdir(tgt_image_path_trainval) logger.info( f"Train set {retrieval_dataset_train_name} with no random crops") db_images_train = save_cropped_boxes(dataset_train, tgt_image_path_trainval, extension="") # create val subset: add all boxes from images that have at least one validation box (can add some boxes from train as distractors) logger.info(f"Val set {retrieval_dataset_val_name} with no random crops") db_images_val = save_cropped_boxes(dataset_val, tgt_image_path_trainval, extension="") # prepare data images for trainval with crops tgt_image_path_trainval_randcrops = os.path.join( target_path, "train", retrieval_dataset_train_name + crop_suffix, "ims") mkdir(tgt_image_path_trainval_randcrops) logger.info( f"Train set {retrieval_dataset_train_name} with {args.num_random_crops_per_image} crops per image" ) db_images_train_randomCrops = save_cropped_boxes( dataset_train, tgt_image_path_trainval_randcrops, extension="", num_random_crops_per_image=args.num_random_crops_per_image) # create val subset: add all boxes from images that have at least one validation box (can add some boxes from train as distractors) logger.info( f"Val set {retrieval_dataset_val_name} with {args.num_random_crops_per_image} crops per image" ) db_images_val_randomCrops = save_cropped_boxes( dataset_val, tgt_image_path_trainval_randcrops, extension="", num_random_crops_per_image=args.num_random_crops_per_image) # prepare data images for test dbs_images_test = {} if datasets_test: for dataset_test, dataset_name in zip(datasets_test, retrieval_dataset_test_names): tgt_image_path_test = os.path.join( target_path, "test", dataset_name, "jpg" ) # the folder name should be always "test" - from cirtorch mkdir(tgt_image_path_test) logger.info(f"Eval dataset: {dataset_name}") dbs_images_test[dataset_name] = save_cropped_boxes( dataset_test, tgt_image_path_test) # prepare data images for test with random crops tgt_image_path_test = os.path.join( target_path, "test", dataset_name + crop_suffix, "jpg" ) # the folder name should be always "test" - from cirtorch mkdir(tgt_image_path_test) logger.info(f"Eval dataset: {dataset_name + crop_suffix}") dbs_images_test[dataset_name + crop_suffix] = save_cropped_boxes( dataset_test, tgt_image_path_test, num_random_crops_per_image=args.num_random_crops_per_image) # save GT images from train db_classes_train = save_class_images( dataset_train, os.path.join(target_path, "train", retrieval_dataset_train_name, "ims"), extension="") db_classes_train_randomCrops = save_class_images( dataset_train, os.path.join(target_path, "train", retrieval_dataset_train_name + crop_suffix, "ims"), extension="") # save GT images from val db_classes_val = save_class_images( dataset_val, os.path.join(target_path, "train", retrieval_dataset_train_name, "ims"), extension="") db_classes_val_randomCrops = save_class_images( dataset_val, os.path.join(target_path, "train", retrieval_dataset_train_name + crop_suffix, "ims"), extension="") # save GT images for testing dbs_classes_test = {} if args.datasets_test: for dataset_test, dataset_name in zip(datasets_test, retrieval_dataset_test_names): dbs_classes_test[dataset_name] = save_class_images( dataset_test, os.path.join(target_path, "test", dataset_name, "jpg")) dbs_classes_test[dataset_name + crop_suffix] = save_class_images( dataset_test, os.path.join(target_path, "test", dataset_name + crop_suffix, "jpg")) # merge databases logger.info( f"Processing trainval set from {retrieval_dataset_train_name} and {retrieval_dataset_val_name}" ) db_train = create_train_database_queries( db_images_train, db_classes_train, iou_pos_threshold=args.iou_pos_threshold, iou_neg_threshold=args.iou_neg_threshold, logger=logger) db_val = create_train_database_queries( db_images_val, db_classes_val, iou_pos_threshold=args.iou_pos_threshold, iou_neg_threshold=args.iou_neg_threshold, logger=logger) logger.info( f"Processing trainval set from {retrieval_dataset_train_name} and {retrieval_dataset_val_name} with {args.num_random_crops_per_image} random crops" ) db_train_randomCrops = create_train_database_queries( db_images_train_randomCrops, db_classes_train_randomCrops, iou_pos_threshold=args.iou_pos_threshold, iou_neg_threshold=args.iou_neg_threshold, logger=logger) db_val_randomCrops = create_train_database_queries( db_images_val_randomCrops, db_classes_val_randomCrops, iou_pos_threshold=args.iou_pos_threshold, iou_neg_threshold=args.iou_neg_threshold, logger=logger) dbs_test = {} if args.datasets_test: for dataset_name in retrieval_dataset_test_names: logger.info( f"Processing test set {dataset_name} with {args.num_random_crops_per_image} random crops" ) dbs_test[dataset_name] = create_test_database_queries( dbs_images_test[dataset_name], dbs_classes_test[dataset_name], iou_pos_threshold=args.iou_pos_threshold, iou_neg_threshold=args.iou_neg_threshold, logger=logger) logger.info(f"Processing test set {dataset_name + crop_suffix}") dbs_test[dataset_name + crop_suffix] = create_test_database_queries( dbs_images_test[dataset_name + crop_suffix], dbs_classes_test[dataset_name + crop_suffix], iou_pos_threshold=args.iou_pos_threshold, iou_neg_threshold=args.iou_neg_threshold, logger=logger) # save trainval to disk db_trainval = {"train": db_train, "val": db_val} db_fn = os.path.join( os.path.join(target_path, "train", retrieval_dataset_train_name), f"{retrieval_dataset_train_name}.pkl") with open(db_fn, "wb") as f: pickle.dump(db_trainval, f) # save train separately for whitening db_fn = os.path.join( os.path.join(target_path, "train", retrieval_dataset_train_name), f"{retrieval_dataset_train_name}-whiten.pkl") with open(db_fn, "wb") as f: pickle.dump(db_train, f) # save trainval with random crops to disk db_trainval_randomCrops = { "train": db_train_randomCrops, "val": db_val_randomCrops } db_fn = os.path.join( os.path.join(target_path, "train", retrieval_dataset_train_name + crop_suffix), f"{retrieval_dataset_train_name}{crop_suffix}.pkl") with open(db_fn, "wb") as f: pickle.dump(db_trainval_randomCrops, f) db_fn = os.path.join( os.path.join(target_path, "train", retrieval_dataset_train_name + crop_suffix), f"{retrieval_dataset_train_name}{crop_suffix}-whiten.pkl") with open(db_fn, "wb") as f: pickle.dump(db_train_randomCrops, f) # save test to disk if args.datasets_test: for dataset_name in retrieval_dataset_test_names: db_fn = os.path.join( os.path.join(target_path, "test", dataset_name), f"gnd_{dataset_name}.pkl") with open(db_fn, "wb") as f: pickle.dump(dbs_test[dataset_name], f) # save test with random crops to disk db_fn = os.path.join( os.path.join(target_path, "test", dataset_name + crop_suffix), f"gnd_{dataset_name}{crop_suffix}.pkl") with open(db_fn, "wb") as f: pickle.dump(dbs_test[dataset_name + crop_suffix], f)