Exemplo n.º 1
0
 def test_open(self) -> None:
     with self._patch_download():
         with PathManager.open(self._remote_uri, "rb") as f:
             self.assertTrue(os.path.exists(f.name))
             self.assertTrue(os.path.isfile(f.name))
             self.assertTrue(f.read() != "")
Exemplo n.º 2
0
def load_sem_seg(gt_root, image_root, gt_ext="png", image_ext="jpg"):
    """
    Load semantic segmentation datasets. All files under "gt_root" with "gt_ext" extension are
    treated as ground truth annotations and all files under "image_root" with "image_ext" extension
    as input images. Ground truth and input images are matched using file paths relative to
    "gt_root" and "image_root" respectively without taking into account file extensions.

    Args:
        gt_root (str): full path to ground truth semantic segmentation files. Semantic segmentation
            annotations are stored as images with integer values in pixels that represent
            corresponding semantic labels.
        image_root (str): the directory where the input images are.
        gt_ext (str): file extension for ground truth annotations.
        image_ext (str): file extension for input images.

    Returns:
        list[dict]:
            a list of dicts in detectron2 standard format without instance-level
            annotation.

    Notes:
        1. This function does not read the image and ground truth files.
           The results do not have the "image" and "sem_seg" fields.
    """

    # We match input images with ground truth based on their relative filepaths (without file
    # extensions) starting from 'image_root' and 'gt_root' respectively. SOBA API works with integer
    # IDs, hence, we try to convert these paths to int if possible.
    def file2id(folder_path, file_path):
        # TODO id is not used.
        # extract relative path starting from `folder_path`
        image_id = os.path.normpath(
            os.path.relpath(file_path, start=folder_path))
        # remove file extension
        image_id = os.path.splitext(image_id)[0]
        try:
            image_id = int(image_id)
        except ValueError:
            pass
        return image_id

    input_files = sorted(
        (os.path.join(image_root, f)
         for f in PathManager.ls(image_root) if f.endswith(image_ext)),
        key=lambda file_path: file2id(image_root, file_path),
    )
    gt_files = sorted(
        (os.path.join(gt_root, f)
         for f in PathManager.ls(gt_root) if f.endswith(gt_ext)),
        key=lambda file_path: file2id(gt_root, file_path),
    )

    assert len(gt_files) > 0, "No annotations found in {}.".format(gt_root)

    # Use the intersection, so that val2017_100 annotations can run smoothly with val2017 images
    if len(input_files) != len(gt_files):
        logger.warn(
            "Directory {} and {} has {} and {} files, respectively.".format(
                image_root, gt_root, len(input_files), len(gt_files)))
        input_basenames = [
            os.path.basename(f)[:-len(image_ext)] for f in input_files
        ]
        gt_basenames = [os.path.basename(f)[:-len(gt_ext)] for f in gt_files]
        intersect = list(set(input_basenames) & set(gt_basenames))
        # sort, otherwise each worker may obtain a list[dict] in different order
        intersect = sorted(intersect)
        logger.warn("Will use their intersection of {} files.".format(
            len(intersect)))
        input_files = [
            os.path.join(image_root, f + image_ext) for f in intersect
        ]
        gt_files = [os.path.join(gt_root, f + gt_ext) for f in intersect]

    logger.info("Loaded {} images with semantic segmentation from {}".format(
        len(input_files), image_root))

    dataset_dicts = []
    for (img_path, gt_path) in zip(input_files, gt_files):
        record = {}
        record["file_name"] = img_path
        record["sem_seg_file_name"] = gt_path
        record["image_id"] = file2id(image_root, img_path)
        assert record["image_id"] == file2id(
            gt_root,
            gt_path), "there is no ground truth for {}".format(img_path)
        with PathManager.open(gt_path, "rb") as f:
            img = Image.open(f)
            w, h = img.size
        record["height"] = h
        record["width"] = w
        dataset_dicts.append(record)

    return dataset_dicts
Exemplo n.º 3
0
    def _construct_loader(self):
        """
        Construct the video loader.
        """
        # Loading label names.
        with PathManager.open(
            os.path.join(
                self.cfg.DATA.PATH_TO_DATA_DIR,
                "something-something-v2-labels.json",
            ),
            "r",
        ) as f:
            label_dict = json.load(f)

        # Loading labels.
        label_file = os.path.join(
            self.cfg.DATA.PATH_TO_DATA_DIR,
            "something-something-v2-{}.json".format(
                "train" if self.mode == "train" else "validation"
            ),
        )
        with PathManager.open(label_file, "r") as f:
            label_json = json.load(f)

        self._video_names = []
        self._labels = []
        for video in label_json:
            video_name = video["id"]
            template = video["template"]
            template = template.replace("[", "")
            template = template.replace("]", "")
            label = int(label_dict[template])
            self._video_names.append(video_name)
            self._labels.append(label)

        path_to_file = os.path.join(
            self.cfg.DATA.PATH_TO_DATA_DIR,
            "{}.csv".format("train" if self.mode == "train" else "val"),
        )
        assert PathManager.exists(path_to_file), "{} dir not found".format(
            path_to_file
        )

        self._path_to_videos, _ = utils.load_image_lists(
            path_to_file, self.cfg.DATA.PATH_PREFIX
        )

        assert len(self._path_to_videos) == len(self._video_names), (
            len(self._path_to_videos),
            len(self._video_names),
        )

        # From dict to list.
        new_paths, new_labels = [], []
        for index in range(len(self._video_names)):
            if self._video_names[index] in self._path_to_videos:
                new_paths.append(
                    self._path_to_videos[self._video_names[index]])
                new_labels.append(self._labels[index])

        self._labels = new_labels
        self._path_to_videos = new_paths

        # Extend self when self._num_clips > 1 (during testing).
        self._path_to_videos = list(
            chain.from_iterable(
                [[x] * self._num_clips for x in self._path_to_videos]
            )
        )
        self._labels = list(
            chain.from_iterable([[x] * self._num_clips for x in self._labels])
        )
        self._spatial_temporal_idx = list(
            chain.from_iterable(
                [
                    range(self._num_clips)
                    for _ in range(len(self._path_to_videos))
                ]
            )
        )
        logger.info(
            "Something-Something V2 dataloader constructed "
            " (size: {}) from {}".format(
                len(self._path_to_videos), path_to_file
            )
        )
Exemplo n.º 4
0
    def __call__(self, dataset_dict):
        """
        Args:
            dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.

        Returns:
            dict: a format that builtin models in detectron2 accept
        """
        dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below
        # USER: Write your own image loading if it's not from a file
        image = utils.read_image(dataset_dict["file_name"], format=self.img_format)

        # TODO YONK add coordinates augmentation & NVDI augmentation here ...

        utils.check_image_size(dataset_dict, image)

        if "annotations" not in dataset_dict:
            image, transforms = T.apply_transform_gens(
                ([self.crop_gen] if self.crop_gen else []) + self.tfm_gens, image
            )
        else:
            # Crop around an instance if there are instances in the image.
            # USER: Remove if you don't use cropping
            if self.crop_gen:
                crop_tfm = utils.gen_crop_transform_with_instance(
                    self.crop_gen.get_crop_size(image.shape[:2]),
                    image.shape[:2],
                    np.random.choice(dataset_dict["annotations"]),
                )
                image = crop_tfm.apply_image(image)
            image, transforms = T.apply_transform_gens(self.tfm_gens, image)
            if self.crop_gen:
                transforms = crop_tfm + transforms

        image_shape = image.shape[:2]  # h, w

        # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
        # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
        # Therefore it's important to use torch.Tensor.
        dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32"))
        # Can use uint8 if it turns out to be slow some day

        # USER: Remove if you don't use pre-computed proposals.
        if self.load_proposals:
            utils.transform_proposals(
                dataset_dict, image_shape, transforms, self.min_box_side_len, self.proposal_topk
            )

        if not self.is_train:
            dataset_dict.pop("annotations", None)
            dataset_dict.pop("sem_seg_file_name", None)
            return dataset_dict

        if "annotations" in dataset_dict:
            # USER: Modify this if you want to keep them for some reason.
            for anno in dataset_dict["annotations"]:
                if not self.mask_on:
                    anno.pop("segmentation", None)
                if not self.keypoint_on:
                    anno.pop("keypoints", None)

            # USER: Implement additional transformations if you have other types of data
            annos = [
                utils.transform_instance_annotations(
                    obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices
                )
                for obj in dataset_dict.pop("annotations")
                if obj.get("iscrowd", 0) == 0
            ]
            instances = utils.annotations_to_instances(
                annos, image_shape, mask_format=self.mask_format
            )
            # Create a tight bounding box from masks, useful when image is cropped
            if self.crop_gen and instances.has("gt_masks"):
                instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
            dataset_dict["instances"] = utils.filter_empty_instances(instances)

        # USER: Remove if you don't do semantic/panoptic segmentation.
        if "sem_seg_file_name" in dataset_dict:
            with PathManager.open(dataset_dict.pop("sem_seg_file_name"), "rb") as f:
                sem_seg_gt = Image.open(f)
                sem_seg_gt = np.asarray(sem_seg_gt, dtype="uint8")
            sem_seg_gt = transforms.apply_segmentation(sem_seg_gt)
            sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))
            dataset_dict["sem_seg"] = sem_seg_gt
        return dataset_dict
Exemplo n.º 5
0
def load_checkpoint(
    path_to_checkpoint,
    model,
    data_parallel=True,
    optimizer=None,
    inflation=False,
    convert_from_caffe2=False,
):
    """
    Load the checkpoint from the given file. If inflation is True, inflate the
    2D Conv weights from the checkpoint to 3D Conv.
    Args:
        path_to_checkpoint (string): path to the checkpoint to load.
        model (model): model to load the weights from the checkpoint.
        data_parallel (bool): if true, model is wrapped by
        torch.nn.parallel.DistributedDataParallel.
        optimizer (optim): optimizer to load the historical state.
        inflation (bool): if True, inflate the weights from the checkpoint.
        convert_from_caffe2 (bool): if True, load the model from caffe2 and
            convert it to pytorch.
    Returns:
        (int): the number of training epoch of the checkpoint.
    """
    assert PathManager.exists(
        path_to_checkpoint), "Checkpoint '{}' not found".format(
            path_to_checkpoint)
    # Account for the DDP wrapper in the multi-gpu setting.
    ms = model.module if data_parallel else model
    if convert_from_caffe2:
        with PathManager.open(path_to_checkpoint, "rb") as f:
            caffe2_checkpoint = pickle.load(f, encoding="latin1")
        state_dict = OrderedDict()
        name_convert_func = get_name_convert_func()
        for key in caffe2_checkpoint["blobs"].keys():
            converted_key = name_convert_func(key)
            converted_key = c2_normal_to_sub_bn(converted_key, ms.state_dict())
            if converted_key in ms.state_dict():
                c2_blob_shape = caffe2_checkpoint["blobs"][key].shape
                model_blob_shape = ms.state_dict()[converted_key].shape
                # Load BN stats to Sub-BN.
                if (len(model_blob_shape) == 1 and len(c2_blob_shape) == 1
                        and model_blob_shape[0] > c2_blob_shape[0]
                        and model_blob_shape[0] % c2_blob_shape[0] == 0):
                    caffe2_checkpoint["blobs"][key] = np.concatenate(
                        [caffe2_checkpoint["blobs"][key]] *
                        (model_blob_shape[0] // c2_blob_shape[0]))
                    c2_blob_shape = caffe2_checkpoint["blobs"][key].shape

                if c2_blob_shape == tuple(model_blob_shape):
                    state_dict[converted_key] = torch.tensor(
                        caffe2_checkpoint["blobs"][key]).clone()
                    logger.info("{}: {} => {}: {}".format(
                        key,
                        c2_blob_shape,
                        converted_key,
                        tuple(model_blob_shape),
                    ))
                else:
                    logger.warn("!! {}: {} does not match {}: {}".format(
                        key,
                        c2_blob_shape,
                        converted_key,
                        tuple(model_blob_shape),
                    ))
            else:
                if not any(prefix in key
                           for prefix in ["momentum", "lr", "model_iter"]):
                    logger.warn("!! {}: can not be converted, got {}".format(
                        key, converted_key))
        ms.load_state_dict(state_dict, strict=False)
        epoch = -1
    else:
        # Load the checkpoint on CPU to avoid GPU mem spike.
        with PathManager.open(path_to_checkpoint, "rb") as f:
            checkpoint = torch.load(f, map_location="cpu")
        model_state_dict_3d = (model.module.state_dict()
                               if data_parallel else model.state_dict())
        checkpoint["model_state"] = normal_to_sub_bn(checkpoint["model_state"],
                                                     model_state_dict_3d)
        if inflation:
            # Try to inflate the model.
            inflated_model_dict = inflate_weight(checkpoint["model_state"],
                                                 model_state_dict_3d)
            ms.load_state_dict(inflated_model_dict, strict=False)
        else:
            ms.load_state_dict(checkpoint["model_state"])
            # Load the optimizer state (commonly not done when fine-tuning)
            if optimizer:
                optimizer.load_state_dict(checkpoint["optimizer_state"])
        if "epoch" in checkpoint.keys():
            epoch = checkpoint["epoch"]
        else:
            epoch = -1
    return epoch
Exemplo n.º 6
0
 def _open(self, path, mode="r", **kwargs):
     return PathManager.open(self._get_local_path(path), mode, **kwargs)
def _read_image(file_name,
                format=None,
                vision_type='trichromat',
                contrast=None,
                opponent_space='lab',
                mosaic_pattern=None,
                apply_net=None):
    """
    Read an image into the given format.
    Will apply rotation and flipping if the image has such exif information.

    Args:
        file_name (str): image file path
        format (str): one of the supported image modes in PIL, or "BGR"

    Returns:
        image (np.ndarray): an HWC image in the given format.
    """
    if apply_net is not None:
        file_name = apply_net(file_name)
    with PathManager.open(file_name, "rb") as f:
        image = Image.open(f).convert('RGB')

        if contrast is not None:
            image = np.asarray(image).copy()
            # FIXME: nicer solution
            if type(contrast) is list:
                amount = random.uniform(contrast, 1)
            else:
                amount = contrast
            image = imutils.adjust_contrast(image, amount)
            image = Image.fromarray(np.uint8(image))

        if vision_type != 'trichromat':
            image = np.asarray(image).copy()
            if vision_type == 'monochromat':
                image = imutils.reduce_chromaticity(image, 0, opponent_space)
            elif vision_type == 'dichromat_yb':
                image = imutils.reduce_yellow_blue(image, 0, opponent_space)
            elif vision_type == 'dichromat_rg':
                image = imutils.reduce_red_green(image, 0, opponent_space)
            else:
                sys.exit('Not supported vision type %s' % vision_type)
            image = Image.fromarray(image)

        if mosaic_pattern != "" and mosaic_pattern is not None:
            image = np.asarray(image).copy()
            image = imutils.im2mosaic(image, mosaic_pattern)
            image = Image.fromarray(np.uint8(image))

        # capture and ignore this bug:
        # https://github.com/python-pillow/Pillow/issues/3973
        try:
            image = ImageOps.exif_transpose(image)
        except Exception:
            pass

        if format is not None:
            # PIL only supports RGB, so convert to RGB and flip channels over
            # below
            conversion_format = format
            if format == "BGR":
                conversion_format = "RGB"
            image = image.convert(conversion_format)
        image = np.asarray(image)
        if format == "BGR":
            # flip channels if needed
            image = image[:, :, ::-1]
        # PIL squeezes out the channel dimension for "L", so make it HWC
        if format == "L":
            image = np.expand_dims(image, -1)
        return image
Exemplo n.º 8
0
def inference_on_dataset(model,
                         data_loader,
                         distributed=True,
                         output_dir=None):
    num_devices = get_world_size()
    logger = logging.getLogger("detectron2")
    logger.info("Start inference on {} images".format(len(data_loader)))

    total = len(data_loader)  # inference data loader must have a fixed length

    num_warmup = min(5, total - 1)
    start_time = time.perf_counter()
    total_compute_time = 0
    predictions = []
    with inference_context(model), torch.no_grad():
        for idx, inputs in enumerate(data_loader):
            if idx == num_warmup:
                start_time = time.perf_counter()
                total_compute_time = 0

            start_compute_time = time.perf_counter()
            outputs = forward_warpper(model, inputs)

            if torch.cuda.is_available():
                torch.cuda.synchronize()
            total_compute_time += time.perf_counter() - start_compute_time
            predictions.extend(process(inputs, outputs))

            iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
            seconds_per_img = total_compute_time / iters_after_start
            if idx >= num_warmup * 2 or seconds_per_img > 5:
                total_seconds_per_img = (time.perf_counter() -
                                         start_time) / iters_after_start
                eta = datetime.timedelta(seconds=int(total_seconds_per_img *
                                                     (total - idx - 1)))
                log_every_n_seconds(
                    logging.INFO,
                    "Inference done {}/{}. {:.4f} s / img. ETA={}".format(
                        idx + 1, total, seconds_per_img, str(eta)),
                    n=5,
                    name="detectron2",
                )
            # Measure the time only for this worker (before the synchronization barrier)
        total_time = time.perf_counter() - start_time
        total_time_str = str(datetime.timedelta(seconds=total_time))
        # NOTE this format is parsed by grep
        logger.info(
            "Total inference time: {} ({:.6f} s / img per device, on {} devices)"
            .format(total_time_str, total_time / (total - num_warmup),
                    num_devices))
        total_compute_time_str = str(
            datetime.timedelta(seconds=int(total_compute_time)))
        logger.info(
            "Total inference pure compute time: {} ({:.6f} s / img per device, on {} devices)"
            .format(total_compute_time_str,
                    total_compute_time / (total - num_warmup), num_devices))

    if distributed:
        comm.synchronize()
        predictions = comm.gather(predictions, dst=0)
        predictions = list(itertools.chain(*predictions))

        if not comm.is_main_process():
            return {}

    if output_dir:
        PathManager.mkdirs(output_dir)
        file_path = os.path.join(output_dir, "instances_predictions.pth")
        logger.info("Saving results to {}".format(file_path))
        with PathManager.open(file_path, "wb") as f:
            torch.save(predictions, f)

    coco_results = list(itertools.chain(*[x["instances"]
                                          for x in predictions]))
    logger.info(
        "Start converting obj365 results to coco type annotation json file...")
    coco_dict = convert_obj365_res_to_coco_json(coco_results)

    return coco_dict
Exemplo n.º 9
0
def load_voc_instances(dirname: str, split: str,
                       class_names: Union[List[str], Tuple[str, ...]]):
    """
    Load Pascal VOC detection annotations to Detectron2 format.

    Args:
        dirname: Contain "Annotations", "ImageSets", "JPEGImages"
        split (str): one of "train", "test", "val", "trainval"
        class_names: list or tuple of class names
    """
    with PathManager.open(
            os.path.join(dirname, "ImageSets", "Main", split + ".txt")) as f:
        fileids = np.loadtxt(f, dtype=np.str)

    # Needs to read many small annotation files. Makes sense at local
    annotation_dirname = PathManager.get_local_path(
        os.path.join(dirname, "Annotations/"))
    dicts = []
    for fileid in fileids:
        anno_file = os.path.join(annotation_dirname, fileid + ".xml")
        jpeg_file = os.path.join(dirname, "JPEGImages", fileid + ".jpg")

        try:
            with PathManager.open(anno_file) as f:
                tree = ET.parse(f)
        except:
            logger = logging.getLogger(__name__)
            logger.info('Not able to load: ' + anno_file +
                        '. Continuing without aboarting...')
            continue

        r = {
            "file_name": jpeg_file,
            "image_id": fileid,
            "height": int(tree.findall("./size/height")[0].text),
            "width": int(tree.findall("./size/width")[0].text),
        }
        instances = []

        for obj in tree.findall("object"):
            cls = obj.find("name").text
            if cls in VOC_CLASS_NAMES_COCOFIED:
                cls = BASE_VOC_CLASS_NAMES[VOC_CLASS_NAMES_COCOFIED.index(cls)]
            # We include "difficult" samples in training.
            # Based on limited experiments, they don't hurt accuracy.
            # difficult = int(obj.find("difficult").text)
            # if difficult == 1:
            # continue
            bbox = obj.find("bndbox")
            bbox = [
                float(bbox.find(x).text)
                for x in ["xmin", "ymin", "xmax", "ymax"]
            ]
            # Original annotations are integers in the range [1, W or H]
            # Assuming they mean 1-based pixel indices (inclusive),
            # a box with annotation (xmin=1, xmax=W) covers the whole image.
            # In coordinate space this is represented by (xmin=0, xmax=W)
            bbox[0] -= 1.0
            bbox[1] -= 1.0
            instances.append({
                "category_id": class_names.index(cls),
                "bbox": bbox,
                "bbox_mode": BoxMode.XYXY_ABS
            })
        r["annotations"] = instances
        dicts.append(r)
    return dicts
Exemplo n.º 10
0
def load_checkpoint(
        path_to_checkpoint,
        model,
        data_parallel=True,
        optimizer=None,
        epoch_reset=False,
        clear_name_pattern=(),
):
    """
    Load the checkpoint from the given file.
    Args:
        path_to_checkpoint (string): path to the checkpoint to load.
        model (model): model to load the weights from the checkpoint.
        data_parallel (bool): if true, model is wrapped by
        torch.nn.parallel.DistributedDataParallel.
        optimizer (optim): optimizer to load the historical state.
        epoch_reset (bool): if True, reset #train iterations from the checkpoint.
        clear_name_pattern (string): if given, this (sub)string will be cleared
            from a layer name if it can be matched.
    Returns:
        (int): the number of training epoch of the checkpoint.
    """
    assert PathManager.exists(
        path_to_checkpoint), "Checkpoint '{}' not found".format(
            path_to_checkpoint)
    logger.info("Loading network weights from {}.".format(path_to_checkpoint))

    # Account for the DDP wrapper in the multi-gpu setting.
    ms = model.module if data_parallel else model
    # Load the checkpoint on CPU to avoid GPU mem spike.
    with PathManager.open(path_to_checkpoint, "rb") as f:
        checkpoint = torch.load(f, map_location="cpu")
    model_state_dict = (model.module.state_dict()
                        if data_parallel else model.state_dict())
    checkpoint["model_state"] = normal_to_sub_bn(checkpoint["model_state"],
                                                 model_state_dict)
    if clear_name_pattern:
        for item in clear_name_pattern:
            model_state_dict_new = OrderedDict()
            for k in checkpoint["model_state"]:
                if item in k:
                    k_re = k.replace(item, "")
                    model_state_dict_new[k_re] = checkpoint["model_state"][k]
                    logger.info("renaming: {} -> {}".format(k, k_re))
                else:
                    model_state_dict_new[k] = checkpoint["model_state"][k]
            checkpoint["model_state"] = model_state_dict_new

    pre_train_dict = checkpoint["model_state"]
    model_dict = ms.state_dict()
    # Match pre-trained weights that have same shape as current model.
    pre_train_dict_match = {
        k: v
        for k, v in pre_train_dict.items()
        if k in model_dict and v.size() == model_dict[k].size()
    }
    # Weights that do not have match from the pre-trained model.
    not_load_layers = [
        k for k in model_dict.keys() if k not in pre_train_dict_match.keys()
    ]
    # Log weights that are not loaded with the pre-trained weights.
    if not_load_layers:
        for k in not_load_layers:
            logger.info("Network weights {} not loaded.".format(k))
    # Load pre-trained weights.
    ms.load_state_dict(pre_train_dict_match, strict=False)

    # Load the optimizer state (commonly not done when fine-tuning)
    if "epoch" in checkpoint.keys() and not epoch_reset:
        epoch = checkpoint["epoch"]
        if optimizer:
            optimizer.load_state_dict(checkpoint["optimizer_state"])
    else:
        epoch = -1
    return epoch
Exemplo n.º 11
0
def get_class_names(path, parent_path=None, subset_path=None):
    """
    Read json file with entries {classname: index} and return
    an array of class names in order.
    If parent_path is provided, load and map all children to their ids.
    Args:
        path (str): path to class ids json file.
            File must be in the format {"class1": id1, "class2": id2, ...}
        parent_path (Optional[str]): path to parent-child json file.
            File must be in the format {"parent1": ["child1", "child2", ...], ...}
        subset_path (Optional[str]): path to text file containing a subset
            of class names, separated by newline characters.
    Returns:
        class_names (list of strs): list of class names.
        class_parents (dict): a dictionary where key is the name of the parent class
            and value is a list of ids of the children classes.
        subset_ids (list of ints): list of ids of the classes provided in the
            subset file.
    """
    try:
        with PathManager.open(path, "r") as f:
            class2idx = json.load(f)
    except Exception as err:
        print("Fail to load file from {} with error {}".format(path, err))
        return

    class_names = [None] * len(class2idx)

    for k, i in class2idx.items():
        class_names[i] = k

    class_parent = None
    if parent_path is not None and parent_path != "":
        try:
            with PathManager.open(parent_path, "r") as f:
                d_parent = json.load(f)
        except EnvironmentError as err:
            print(
                "Fail to load file from {} with error {}".format(
                    parent_path, err
                )
            )
            return
        class_parent = {}
        for parent, children in d_parent.items():
            indices = [
                class2idx[c] for c in children if class2idx.get(c) is not None
            ]
            class_parent[parent] = indices

    subset_ids = None
    if subset_path is not None and subset_path != "":
        try:
            with PathManager.open(subset_path, "r") as f:
                subset = f.read().split("\n")
                subset_ids = [
                    class2idx[name]
                    for name in subset
                    if class2idx.get(name) is not None
                ]
        except EnvironmentError as err:
            print(
                "Fail to load file from {} with error {}".format(
                    subset_path, err
                )
            )
            return

    return class_names, class_parent, subset_ids
Exemplo n.º 12
0
        cfg.freeze()
        default_setup(
            cfg, args
        )  # if you don't like any of the default setup, write your own setup code
        global CONFIG
        CONFIG = cfg

    # if 'build_model(cfg)':
    #     meta_arch = cfg.MODEL.META_ARCHITECTURE
    #     model = META_ARCH_REGISTRY.get(meta_arch)(cfg)
    #     model.to(torch.device(cfg.MODEL.DEVICE))
    #
    # if 'load_pretrained_weights':
    #     DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
    #         cfg.MODEL.WEIGHTS, resume=False
    #     )

    if 'extract_weights_without_roi_heads':
        with open(cfg.MODEL.WEIGHTS_PATH, 'rb') as f:
            ckpt = pickle.load(f, encoding='latin1')

        weight_names = deepcopy(list(ckpt["model"].keys()))
        for k in weight_names:
            if k.startswith('roi_heads'):
                ckpt["model"].pop(k)

        save_path = cfg.MODEL.WEIGHTS_PATH.replace('.pkl',
                                                   '_without_roi_heads.pkl')
        with PathManager.open(save_path, "wb") as f:
            pickle.dump(ckpt, f)
Exemplo n.º 13
0
 def test_open(self) -> None:
     # pyre-ignore
     with PathManager.open(self._tmpfile, "r") as f:
         self.assertEqual(f.read(), self._tmpfile_contents)
Exemplo n.º 14
0
 def test_open_writes(self) -> None:
     # HTTPURLHandler does not support writing, only reading.
     with self.assertRaises(AssertionError):
         with PathManager.open(self._remote_uri, "w") as f:
             f.write("foobar")  # pyre-ignore
Exemplo n.º 15
0
    def train_cls(self, features, targets, cls_num):
        """
        Train SVM on the input features and targets for a given class.
        The SVMs are trained for all costs values for the given class. We
        also save the cross-validation AP at each cost value for the given
        class.
        """
        logging.info(f"Training cls: {cls_num}")
        for cost_idx in range(len(self.costs_list)):
            cost = self.costs_list[cost_idx]
            out_file, ap_out_file = self._get_svm_model_filename(cls_num, cost)
            if (PathManager.exists(out_file)
                    and PathManager.exists(ap_out_file)
                    and not self.config.force_retrain):
                logging.info(f"SVM model exists: {out_file}")
                logging.info(f"AP file exists: {ap_out_file}")
                continue

            logging.info(
                f"Training model with the cost: {cost} cls: {cls_num}")
            clf = LinearSVC(
                C=cost,
                class_weight={
                    1: 2,
                    -1: 1
                },
                intercept_scaling=1.0,
                verbose=1,
                penalty=self.config["penalty"],
                loss=self.config["loss"],
                tol=0.0001,
                dual=self.config["dual"],
                max_iter=self.config["max_iter"],
            )
            cls_labels = targets[:, cls_num].astype(dtype=np.int32, copy=True)
            # meaning of labels in VOC/COCO original loaded target files:
            # label 0 = not present, set it to -1 as svm train target
            # label 1 = present. Make the svm train target labels as -1, 1.
            cls_labels[np.where(cls_labels == 0)] = -1
            num_positives = len(np.where(cls_labels == 1)[0])
            num_negatives = len(cls_labels) - num_positives
            logging.info(
                f"cls: {cls_num} has +ve: {num_positives} -ve: {num_negatives} "
                f"ratio: {float(num_positives) / num_negatives} "
                f"features: {features.shape} cls_labels: {cls_labels.shape}")
            ap_scores = cross_val_score(
                clf,
                features,
                cls_labels,
                cv=self.config["cross_val_folds"],
                scoring="average_precision",
            )
            self.train_ap_matrix[cls_num][cost_idx] = ap_scores.mean()
            clf.fit(features, cls_labels)
            logging.info(f"cls: {cls_num} cost: {cost} AP: {ap_scores} "
                         f"mean:{ap_scores.mean()}")
            logging.info(f"Saving cls cost AP to: {ap_out_file}")
            save_file(np.array([ap_scores.mean()]), ap_out_file)
            logging.info(f"Saving SVM model to: {out_file}")
            with PathManager.open(out_file, "wb") as fwrite:
                pickle.dump(clf, fwrite)
Exemplo n.º 16
0
                        help="JSON file produced by the model")
    parser.add_argument("--output", required=True, help="output directory")
    parser.add_argument("--dataset",
                        help="name of the dataset",
                        default="indiscapes_val")
    parser.add_argument("--conf-threshold",
                        default=0.5,
                        type=float,
                        help="confidence threshold")
    args = parser.parse_args()

    logger = setup_logger()

    predictions = list()
    for input in args.inputs:
        with PathManager.open(input, "r") as f:
            predictions.append(json.load(f))

    pred_by_input = list(defaultdict(list))
    for prediction in predictions:
        pred_by_image = defaultdict(list)
        for p in prediction:
            pred_by_image[p["image_id"]].append(p)
        pred_by_input.append(pred_by_image)

    dicts = list(DatasetCatalog.get(args.dataset))
    metadata = MetadataCatalog.get(args.dataset)

    if hasattr(metadata, "thing_dataset_id_to_contiguous_id"):

        def dataset_id_map(ds_id):
Exemplo n.º 17
0
    parser.add_argument("--input",
                        required=True,
                        help="JSON file produced by the model")
    parser.add_argument("--output", required=True, help="output directory")
    parser.add_argument("--dataset",
                        help="name of the dataset",
                        default="coco_2017_val")
    parser.add_argument("--conf-threshold",
                        default=0.5,
                        type=float,
                        help="confidence threshold")
    args = parser.parse_args()

    logger = setup_logger()

    with PathManager.open(args.input, "r") as f:
        predictions = json.load(f)

    pred_by_image = defaultdict(list)
    for p in predictions:
        pred_by_image[p["image_id"]].append(p)

    dicts = list(DatasetCatalog.get(args.dataset))
    metadata = MetadataCatalog.get(args.dataset)
    if hasattr(metadata, "thing_dataset_id_to_contiguous_id"):

        def dataset_id_map(ds_id):
            return metadata.thing_dataset_id_to_contiguous_id[ds_id]

    elif "lvis" in args.dataset:
        # LVIS results are in the same format as COCO results, but have a different
Exemplo n.º 18
0
    def draw_dataset_dict(self, dic):
        """
        Draw annotations/segmentaions in Detectron2 Dataset format.

        Args:
            dic (dict): annotation/segmentation data of one image, in Detectron2 Dataset format.

        Returns:
            output (VisImage): image object with visualizations.
        """
        annos = dic.get("annotations", None)
        if annos:
            if "segmentation" in annos[0]:
                masks = [x["segmentation"] for x in annos]
            else:
                masks = None
            if "keypoints" in annos[0]:
                keypts = [x["keypoints"] for x in annos]
                keypts = np.array(keypts).reshape(len(annos), -1, 3)
            else:
                keypts = None

            boxes = [
                BoxMode.convert(x["bbox"], x["bbox_mode"], BoxMode.XYXY_ABS)
                if len(x["bbox"]) == 4 else x["bbox"] for x in annos
            ]

            colors = None
            category_ids = [x["category_id"] for x in annos]
            if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get(
                    "thing_colors"):
                colors = [
                    self._jitter(
                        [x / 255 for x in self.metadata.thing_colors[c]])
                    for c in category_ids
                ]
            names = self.metadata.get("thing_classes", None)
            labels = _create_text_labels(
                category_ids,
                scores=None,
                class_names=[
                    "Hv", "Hp", "CLS", "BL", "PD", "PB", "CC", "LM", "D/P"
                ],
                is_crowd=[x.get("iscrowd", 0) for x in annos],
            )
            boxes = None
            alpha = 0
            self.overlay_instances(labels=labels,
                                   boxes=boxes,
                                   masks=masks,
                                   keypoints=keypts,
                                   assigned_colors=colors,
                                   alpha=alpha)

        sem_seg = dic.get("sem_seg", None)
        if sem_seg is None and "sem_seg_file_name" in dic:
            with PathManager.open(dic["sem_seg_file_name"], "rb") as f:
                sem_seg = Image.open(f)
                sem_seg = np.asarray(sem_seg, dtype="uint8")
        if sem_seg is not None:
            self.draw_sem_seg(sem_seg, area_threshold=0, alpha=0.5)

        pan_seg = dic.get("pan_seg", None)
        if pan_seg is None and "pan_seg_file_name" in dic:
            with PathManager.open(dic["pan_seg_file_name"], "rb") as f:
                pan_seg = Image.open(f)
                pan_seg = np.asarray(pan_seg)
                from panopticapi.utils import rgb2id

                pan_seg = rgb2id(pan_seg)
        if pan_seg is not None:
            segments_info = dic["segments_info"]
            pan_seg = torch.Tensor(pan_seg)
            self.draw_panoptic_seg(pan_seg,
                                   segments_info,
                                   area_threshold=0,
                                   alpha=0.5)
        return self.output
Exemplo n.º 19
0
def perform_test(test_loader, model, test_meter, cfg, writer=None):
    """
    For classification:
    Perform mutli-view testing that uniformly samples N clips from a video along
    its temporal axis. For each clip, it takes 3 crops to cover the spatial
    dimension, followed by averaging the softmax scores across all Nx3 views to
    form a video-level prediction. All video predictions are compared to
    ground-truth labels and the final testing performance is logged.
    For detection:
    Perform fully-convolutional testing on the full frames without crop.
    Args:
        test_loader (loader): video testing loader.
        model (model): the pretrained video model to test.
        test_meter (TestMeter): testing meters to log and ensemble the testing
            results.
        cfg (CfgNode): configs. Details can be found in
            slowfast/config/defaults.py
        writer (TensorboardWriter object, optional): TensorboardWriter object
            to writer Tensorboard log.
    """
    # Enable eval mode.
    model.eval()
    test_meter.iter_tic()

    for cur_iter, (inputs, labels, video_idx, meta) in enumerate(test_loader):
        if cfg.NUM_GPUS:
            # Transfer the data to the current GPU device.
            if isinstance(inputs, (list,)):
                for i in range(len(inputs)):
                    inputs[i] = inputs[i].cuda(non_blocking=True)
            else:
                inputs = inputs.cuda(non_blocking=True)

            # Transfer the data to the current GPU device.
            labels = labels.cuda()
            video_idx = video_idx.cuda()
            for key, val in meta.items():
                if isinstance(val, (list,)):
                    for i in range(len(val)):
                        val[i] = val[i].cuda(non_blocking=True)
                else:
                    meta[key] = val.cuda(non_blocking=True)
        test_meter.data_toc()

        if cfg.DETECTION.ENABLE:
            # Compute the predictions.
            preds = model(inputs, meta["boxes"])
            ori_boxes = meta["ori_boxes"]
            metadata = meta["metadata"]

            preds = preds.detach().cpu() if cfg.NUM_GPUS else preds.detach()
            ori_boxes = (
                ori_boxes.detach().cpu() if cfg.NUM_GPUS else ori_boxes.detach()
            )
            metadata = (
                metadata.detach().cpu() if cfg.NUM_GPUS else metadata.detach()
            )

            if cfg.NUM_GPUS > 1:
                preds = torch.cat(du.all_gather_unaligned(preds), dim=0)
                ori_boxes = torch.cat(du.all_gather_unaligned(ori_boxes), dim=0)
                metadata = torch.cat(du.all_gather_unaligned(metadata), dim=0)

            test_meter.iter_toc()
            # Update and log stats.
            test_meter.update_stats(preds, ori_boxes, metadata)
            test_meter.log_iter_stats(None, cur_iter)
        else:
            # Perform the forward pass.
            preds = model(inputs)

            # Gather all the predictions across all the devices to perform ensemble.
            if cfg.NUM_GPUS > 1:
                preds, labels, video_idx = du.all_gather(
                    [preds, labels, video_idx]
                )
            if cfg.NUM_GPUS:
                preds = preds.cpu()
                labels = labels.cpu()
                video_idx = video_idx.cpu()

            test_meter.iter_toc()
            # Update and log stats.
            test_meter.update_stats(
                preds.detach(), labels.detach(), video_idx.detach()
            )
            test_meter.log_iter_stats(cur_iter)

        test_meter.iter_tic()

    # Log epoch stats and print the final testing results.
    if not cfg.DETECTION.ENABLE:
        all_preds = test_meter.video_preds.clone().detach()
        all_labels = test_meter.video_labels
        if cfg.NUM_GPUS:
            all_preds = all_preds.cpu()
            all_labels = all_labels.cpu()
        if writer is not None:
            writer.plot_eval(preds=all_preds, labels=all_labels)

        if cfg.TEST.SAVE_RESULTS_PATH != "":
            save_path = os.path.join(cfg.OUTPUT_DIR, cfg.TEST.SAVE_RESULTS_PATH)

            with PathManager.open(save_path, "wb") as f:
                pickle.dump([all_labels, all_labels], f)

            logger.info(
                "Successfully saved prediction results to {}".format(save_path)
            )

    test_meter.finalize_metrics()
    return test_meter
Exemplo n.º 20
0
    parser.add_argument("--ins_input",
                        required=True,
                        help="JSON file produced by the model")
    parser.add_argument("--ass_input", required=True)
    parser.add_argument("--output", required=True, help="output directory")
    parser.add_argument("--dataset",
                        help="name of the dataset",
                        default="coco_2017_val")
    parser.add_argument("--conf-threshold",
                        default=0.5,
                        type=float,
                        help="confidence threshold")
    args = parser.parse_args()

    logger = setup_logger()
    with PathManager.open(args.ins_input, "r") as f:
        ins_predictions = json.load(f)
    with PathManager.open(args.ass_input, 'r') as f:
        ass_predictions = json.load(f)

    ins_pred_by_image = defaultdict(list)
    ass_pred_by_image = defaultdict(list)

    for p in ins_predictions:
        ins_pred_by_image[p["image_id"]].append(p)
    for p in ass_predictions:
        ass_pred_by_image[p["image_id"]].append(p)

    dicts = list(DatasetCatalog.get(args.dataset))
    metadata = MetadataCatalog.get(args.dataset)
    if hasattr(metadata, "thing_dataset_id_to_contiguous_id"):
Exemplo n.º 21
0
def load_checkpoint(path_to_checkpoint,
                    model,
                    data_parallel=True,
                    optimizer=None,
                    inflation=False,
                    convert_from_caffe2=False,
                    epoch_reset=False,
                    clear_name_pattern=(),
                    load_projection=True):
    """
    Load the checkpoint from the given file. If inflation is True, inflate the
    2D Conv weights from the checkpoint to 3D Conv.
    Args:
        path_to_checkpoint (string): path to the checkpoint to load.
        model (model): model to load the weights from the checkpoint.
        data_parallel (bool): if true, model is wrapped by
        torch.nn.parallel.DistributedDataParallel.
        optimizer (optim): optimizer to load the historical state.
        inflation (bool): if True, inflate the weights from the checkpoint.
        convert_from_caffe2 (bool): if True, load the model from caffe2 and
            convert it to pytorch.
        epoch_reset (bool): if True, reset #train iterations from the checkpoint.
        clear_name_pattern (string): if given, this (sub)string will be cleared
            from a layer name if it can be matched.
    Returns:
        (int): the number of training epoch of the checkpoint.
    """
    assert PathManager.exists(
        path_to_checkpoint), "Checkpoint '{}' not found".format(
            path_to_checkpoint)
    logger.info("Loading network weights from {}.".format(path_to_checkpoint))

    # Account for the DDP wrapper in the multi-gpu setting.
    ms = model.module if data_parallel else model
    if convert_from_caffe2:
        with PathManager.open(path_to_checkpoint, "rb") as f:
            caffe2_checkpoint = pickle.load(f, encoding="latin1")
        state_dict = OrderedDict()
        name_convert_func = get_name_convert_func()
        for key in caffe2_checkpoint["blobs"].keys():
            converted_key = name_convert_func(key)
            converted_key = c2_normal_to_sub_bn(converted_key, ms.state_dict())
            if converted_key in ms.state_dict():
                c2_blob_shape = caffe2_checkpoint["blobs"][key].shape
                model_blob_shape = ms.state_dict()[converted_key].shape

                # expand shape dims if they differ (eg for converting linear to conv params)
                if len(c2_blob_shape) < len(model_blob_shape):
                    c2_blob_shape += (1, ) * (len(model_blob_shape) -
                                              len(c2_blob_shape))
                    caffe2_checkpoint["blobs"][key] = np.reshape(
                        caffe2_checkpoint["blobs"][key], c2_blob_shape)
                # Load BN stats to Sub-BN.
                if (len(model_blob_shape) == 1 and len(c2_blob_shape) == 1
                        and model_blob_shape[0] > c2_blob_shape[0]
                        and model_blob_shape[0] % c2_blob_shape[0] == 0):
                    caffe2_checkpoint["blobs"][key] = np.concatenate(
                        [caffe2_checkpoint["blobs"][key]] *
                        (model_blob_shape[0] // c2_blob_shape[0]))
                    c2_blob_shape = caffe2_checkpoint["blobs"][key].shape

                if c2_blob_shape == tuple(model_blob_shape):
                    state_dict[converted_key] = torch.tensor(
                        caffe2_checkpoint["blobs"][key]).clone()
                    logger.info("{}: {} => {}: {}".format(
                        key,
                        c2_blob_shape,
                        converted_key,
                        tuple(model_blob_shape),
                    ))
                else:
                    logger.warn("!! {}: {} does not match {}: {}".format(
                        key,
                        c2_blob_shape,
                        converted_key,
                        tuple(model_blob_shape),
                    ))
            else:
                if not any(prefix in key
                           for prefix in ["momentum", "lr", "model_iter"]):
                    logger.warn("!! {}: can not be converted, got {}".format(
                        key, converted_key))
        diff = set(ms.state_dict()) - set(state_dict)
        diff = {d for d in diff if "num_batches_tracked" not in d}
        if len(diff) > 0:
            logger.warn("Not loaded {}".format(diff))
        ms.load_state_dict(state_dict, strict=False)
        epoch = -1
    else:
        # Load the checkpoint on CPU to avoid GPU mem spike.
        with PathManager.open(path_to_checkpoint, "rb") as f:
            checkpoint = torch.load(f, map_location="cpu")
        model_state_dict_3d = (model.module.state_dict()
                               if data_parallel else model.state_dict())
        key_name = 'model_state' if 'model_state' in checkpoint else 'state_dict'
        checkpoint[key_name] = normal_to_sub_bn(checkpoint[key_name],
                                                model_state_dict_3d)
        if inflation:
            # Try to inflate the model.
            inflated_model_dict = inflate_weight(checkpoint[key_name],
                                                 model_state_dict_3d)
            ms.load_state_dict(inflated_model_dict, strict=False)
        else:
            if clear_name_pattern:
                for item in clear_name_pattern:
                    model_state_dict_new = OrderedDict()
                    for k in checkpoint["model_state"]:
                        if item in k:
                            k_re = k.replace(item, "")
                            model_state_dict_new[k_re] = checkpoint[
                                "model_state"][k]
                            logger.info("renaming: {} -> {}".format(k, k_re))
                        else:
                            model_state_dict_new[k] = checkpoint[
                                "model_state"][k]
                    checkpoint["model_state"] = model_state_dict_new

            pre_train_dict = checkpoint["model_state"]
            model_dict = ms.state_dict()
            # Match pre-trained weights that have same shape as current model.
            pre_train_dict_match_buffer = {
                k: v
                for k, v in pre_train_dict.items()
                if k in model_dict and v.size() == model_dict[k].size()
            }
            pre_train_dict_match = {}
            for k in pre_train_dict_match_buffer:
                if 'projection' in k and not load_projection:
                    continue
                else:
                    pre_train_dict_match[k] = pre_train_dict_match_buffer[k]

            # Weights that do not have match from the pre-trained model.
            not_load_layers = [
                k for k in model_dict.keys()
                if k not in pre_train_dict_match.keys()
            ]
            # Log weights that are not loaded with the pre-trained weights.
            if not_load_layers:
                for k in not_load_layers:
                    logger.info("Network weights {} not loaded.".format(k))
            # Load pre-trained weights.
            ms.load_state_dict(pre_train_dict_match, strict=False)
            epoch = -1

            # Load the optimizer state (commonly not done when fine-tuning)
        if "epoch" in checkpoint.keys() and not epoch_reset:
            epoch = checkpoint["epoch"]
            if optimizer:
                optimizer.load_state_dict(checkpoint["optimizer_state"])
        else:
            epoch = -1
    return epoch
Exemplo n.º 22
0
def voc_eval(detpath,
             annopath,
             imagesetfile,
             classname,
             ovthresh=0.5,
             use_07_metric=False):
    """rec, prec, ap = voc_eval(detpath,
                                annopath,
                                imagesetfile,
                                classname,
                                [ovthresh],
                                [use_07_metric])

    Top level function that does the PASCAL VOC evaluation.

    detpath: Path to detections
        detpath.format(classname) should produce the detection results file.
    annopath: Path to annotations
        annopath.format(imagename) should be the xml annotations file.
    imagesetfile: Text file containing the list of images, one image per line.
    classname: Category name (duh)
    [ovthresh]: Overlap threshold (default = 0.5)
    [use_07_metric]: Whether to use VOC07's 11 point AP computation
        (default False)
    """
    # assumes detections are in detpath.format(classname)
    # assumes annotations are in annopath.format(imagename)
    # assumes imagesetfile is a text file with each line an image name

    # first load gt
    # read list of images
    with PathManager.open(imagesetfile, "r") as f:
        lines = f.readlines()
    imagenames = [x.strip() for x in lines]

    # load annots
    recs = {}
    for imagename in imagenames:
        recs[imagename] = parse_rec(annopath.format(imagename))['objects']

    # extract gt objects for this class
    class_recs = {}
    npos = 0
    for imagename in imagenames:
        R = [obj for obj in recs[imagename] if obj["name"] == classname]
        bbox = np.array([x["bbox"] for x in R])
        difficult = np.array([x["difficult"] for x in R]).astype(np.bool)
        # difficult = np.array([False for x in R]).astype(np.bool)  # treat all "difficult" as GT
        det = [False] * len(R)
        npos = npos + sum(~difficult)
        class_recs[imagename] = {
            "bbox": bbox,
            "difficult": difficult,
            "det": det
        }

    # read dets
    detfile = detpath.format(classname)
    # print(f'-----{detfile}-----------')  # /tmp/pascal_voc_eval_pzeoxwon/face-head.txt
    with open(detfile, "r") as f:
        lines = f.readlines()
    # print(f'-----{lines}-----------')
    splitlines = [x.strip().split(" ") for x in lines]
    image_ids = [x[0] for x in splitlines]
    confidence = np.array([float(x[1]) for x in splitlines])
    BB = np.array([[float(z) for z in x[2:]]
                   for x in splitlines]).reshape(-1, 4)

    # sort by confidence
    sorted_ind = np.argsort(-confidence)
    # print(f'----------{sorted_ind}-------------')   #[2 0 1 3]
    BB = BB[sorted_ind, :]
    image_ids = [image_ids[x] for x in sorted_ind]
    # print(f'----------{image_ids}-------------') #['9', '1', '6', '9']

    # go down dets and mark TPs and FPs
    nd = len(image_ids)
    tp = np.zeros(nd)
    fp = np.zeros(nd)
    for d in range(nd):
        R = class_recs[image_ids[d]]
        bb = BB[d, :].astype(float)
        ovmax = -np.inf
        BBGT = R["bbox"].astype(float)

        if BBGT.size > 0:
            # compute overlaps
            # intersection
            ixmin = np.maximum(BBGT[:, 0], bb[0])
            iymin = np.maximum(BBGT[:, 1], bb[1])
            ixmax = np.minimum(BBGT[:, 2], bb[2])
            iymax = np.minimum(BBGT[:, 3], bb[3])
            iw = np.maximum(ixmax - ixmin + 1.0, 0.0)
            ih = np.maximum(iymax - iymin + 1.0, 0.0)
            inters = iw * ih

            # union
            uni = ((bb[2] - bb[0] + 1.0) * (bb[3] - bb[1] + 1.0) +
                   (BBGT[:, 2] - BBGT[:, 0] + 1.0) *
                   (BBGT[:, 3] - BBGT[:, 1] + 1.0) - inters)

            overlaps = inters / uni
            ovmax = np.max(overlaps)
            jmax = np.argmax(overlaps)

        if ovmax > ovthresh:
            if not R["difficult"][jmax]:
                if not R["det"][jmax]:
                    tp[d] = 1.0
                    R["det"][jmax] = 1
                else:
                    fp[d] = 1.0
        else:
            fp[d] = 1.0

    # compute precision recall
    fp = np.cumsum(fp)
    tp = np.cumsum(tp)
    rec = tp / float(npos)
    # avoid divide by zero in case the first detection matches a difficult
    # ground truth
    prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
    ap = voc_ap(rec, prec, use_07_metric)

    return rec, prec, ap
Exemplo n.º 23
0
    def evaluate(self):
        if self._distributed:
            comm.synchronize()
            predictions = comm.gather(self._predictions, dst=0)
            predictions = list(itertools.chain(*predictions))

            aug_gts = comm.gather(self._TTA_gts, dst=0)
            aug_gts = list(itertools.chain(*aug_gts))

            if not comm.is_main_process():
                return {}
        else:
            predictions = self._predictions
            aug_gts = self._TTA_gts

        if len(predictions) == 0:
            self._logger.warning(
                "[COCOEvaluator] Did not receive valid predictions.")
            return {}

        if self._output_dir:
            PathManager.mkdirs(self._output_dir)
            file_path = os.path.join(self._output_dir,
                                     "instances_predictions.pth")
            with PathManager.open(file_path, "wb") as f:
                torch.save(predictions, f)

        # aug_gts > 0 means use the aug label
        if len(aug_gts) > 0:
            tta_json_file = os.path.join(self._output_dir, "tta_dataset.json")
            aug_gt_convert_to_coco_json(aug_gts, output_file=tta_json_file)

            # update dataset
            with contextlib.redirect_stdout(io.StringIO()):
                self._coco_api = COCO(tta_json_file)

            self._id2anno.clear()
            for anno in self._coco_api.dataset["annotations"]:
                ann = copy.deepcopy(anno)
                ann['bbox_mode'] = BoxMode.XYWH_ABS
                ann['category_id'] -= 1
                self._id2anno[ann['image_id']].append(ann)

        # run evaluation
        self._results = OrderedDict()

        if "instances" in predictions[0]:
            # run coco evaluation
            self._eval_predictions(set(self._tasks), predictions)

            # self._results['bbox'] = {}
            # run wheat evaluation
            oof_score = calculate_final_score(
                predictions, self._id2anno, score_threshold=self._score_thresh)
            self._results['bbox'].update({"OOFScore": oof_score})
            self._logger.info(
                f"OOF score at threshold {self._score_thresh} is {oof_score}")

            # calculate the best threshold
            # metric = calculate_best_threshold(predictions, self._id2anno)
            # self._logger.info(f"best score is {best_score}, best threshold {best_threshold}")
            # if 'bbox' not in self._results:
            #     self._results['bbox'] = {}
            # self._results['bbox'].update(metric)

        # Copy so the caller can do whatever with results
        return copy.deepcopy(self._results)
Exemplo n.º 24
0
    logger = setup_logger(name=__name__)

    dirname = "cityscapes-data-vis"
    os.makedirs(dirname, exist_ok=True)

    if args.type == "instance":
        dicts = load_cityscapes_instances(
            args.image_dir, args.gt_dir, from_json=True, to_polygons=True
        )
        logger.info("Done loading {} samples.".format(len(dicts)))

        thing_classes = [k.name for k in labels if k.hasInstances and not k.ignoreInEval]
        meta = Metadata().set(thing_classes=thing_classes)

    else:
        dicts = load_cityscapes_semantic(args.image_dir, args.gt_dir)
        logger.info("Done loading {} samples.".format(len(dicts)))

        stuff_names = [k.name for k in labels if k.trainId != 255]
        stuff_colors = [k.color for k in labels if k.trainId != 255]
        meta = Metadata().set(stuff_names=stuff_names, stuff_colors=stuff_colors)

    for d in dicts:
        img = np.array(Image.open(PathManager.open(d["file_name"], "rb")))
        visualizer = Visualizer(img, metadata=meta)
        vis = visualizer.draw_dataset_dict(d)
        # cv2.imshow("a", vis.get_image()[:, :, ::-1])
        # cv2.waitKey()
        fpath = os.path.join(dirname, os.path.basename(d["file_name"]))
        vis.save(fpath)
Exemplo n.º 25
0
def _cached_log_stream(filename):
    return PathManager.open(filename, "a")
Exemplo n.º 26
0
def cityscapes_files_to_dict(files, from_json, to_polygons):
    """
    Parse cityscapes annotation files to a instance segmentation dataset dict.

    Args:
        files (tuple): consists of (image_file, instance_id_file, label_id_file, json_file)
        from_json (bool): whether to read annotations from the raw json file or the png files.
        to_polygons (bool): whether to represent the segmentation as polygons
            (COCO's format) instead of masks (cityscapes's format).

    Returns:
        A dict in Detectron2 Dataset format.
    """
    from cityscapesscripts.helpers.labels import id2label, name2label

    image_file, instance_id_file, _, json_file = files

    annos = []

    if from_json:
        from shapely.geometry import MultiPolygon, Polygon

        with PathManager.open(json_file, "r") as f:
            jsonobj = json.load(f)
        ret = {
            "file_name": image_file,
            "image_id": os.path.basename(image_file),
            "height": jsonobj["imgHeight"],
            "width": jsonobj["imgWidth"],
        }

        # `polygons_union` contains the union of all valid polygons.
        polygons_union = Polygon()

        # CityscapesScripts draw the polygons in sequential order
        # and each polygon *overwrites* existing ones. See
        # (https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/preparation/json2instanceImg.py) # noqa
        # We use reverse order, and each polygon *avoids* early ones.
        # This will resolve the ploygon overlaps in the same way as CityscapesScripts.
        for obj in jsonobj["objects"][::-1]:
            if "deleted" in obj:  # cityscapes data format specific
                continue
            label_name = obj["label"]

            try:
                label = name2label[label_name]
            except KeyError:
                if label_name.endswith("group"):  # crowd area
                    label = name2label[label_name[: -len("group")]]
                else:
                    raise
            if label.id < 0:  # cityscapes data format
                continue

            # Cityscapes's raw annotations uses integer coordinates
            # Therefore +0.5 here
            poly_coord = np.asarray(obj["polygon"], dtype="f4") + 0.5
            # CityscapesScript uses PIL.ImageDraw.polygon to rasterize
            # polygons for evaluation. This function operates in integer space
            # and draws each pixel whose center falls into the polygon.
            # Therefore it draws a polygon which is 0.5 "fatter" in expectation.
            # We therefore dilate the input polygon by 0.5 as our input.
            poly = Polygon(poly_coord).buffer(0.5, resolution=4)

            if not label.hasInstances or label.ignoreInEval:
                # even if we won't store the polygon it still contributes to overlaps resolution
                polygons_union = polygons_union.union(poly)
                continue

            # Take non-overlapping part of the polygon
            poly_wo_overlaps = poly.difference(polygons_union)
            if poly_wo_overlaps.is_empty:
                continue
            polygons_union = polygons_union.union(poly)

            anno = {}
            anno["iscrowd"] = label_name.endswith("group")
            anno["category_id"] = label.id

            if isinstance(poly_wo_overlaps, Polygon):
                poly_list = [poly_wo_overlaps]
            elif isinstance(poly_wo_overlaps, MultiPolygon):
                poly_list = poly_wo_overlaps.geoms
            else:
                raise NotImplementedError("Unknown geometric structure {}".format(poly_wo_overlaps))

            poly_coord = []
            for poly_el in poly_list:
                # COCO API can work only with exterior boundaries now, hence we store only them.
                # TODO: store both exterior and interior boundaries once other parts of the
                # codebase support holes in polygons.
                poly_coord.append(list(chain(*poly_el.exterior.coords)))
            anno["segmentation"] = poly_coord
            (xmin, ymin, xmax, ymax) = poly_wo_overlaps.bounds

            anno["bbox"] = (xmin, ymin, xmax, ymax)
            anno["bbox_mode"] = BoxMode.XYXY_ABS

            annos.append(anno)
    else:
        # See also the official annotation parsing scripts at
        # https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/evaluation/instances2dict.py  # noqa
        with PathManager.open(instance_id_file, "rb") as f:
            inst_image = np.asarray(Image.open(f), order="F")
        # ids < 24 are stuff labels (filtering them first is about 5% faster)
        flattened_ids = np.unique(inst_image[inst_image >= 24])

        ret = {
            "file_name": image_file,
            "image_id": os.path.basename(image_file),
            "height": inst_image.shape[0],
            "width": inst_image.shape[1],
        }

        for instance_id in flattened_ids:
            # For non-crowd annotations, instance_id // 1000 is the label_id
            # Crowd annotations have <1000 instance ids
            label_id = instance_id // 1000 if instance_id >= 1000 else instance_id
            label = id2label[label_id]
            if not label.hasInstances or label.ignoreInEval:
                continue

            anno = {}
            anno["iscrowd"] = instance_id < 1000
            anno["category_id"] = label.id

            mask = np.asarray(inst_image == instance_id, dtype=np.uint8, order="F")

            inds = np.nonzero(mask)
            ymin, ymax = inds[0].min(), inds[0].max()
            xmin, xmax = inds[1].min(), inds[1].max()
            anno["bbox"] = (xmin, ymin, xmax, ymax)
            if xmax <= xmin or ymax <= ymin:
                continue
            anno["bbox_mode"] = BoxMode.XYXY_ABS
            if to_polygons:
                # This conversion comes from D4809743 and D5171122,
                # when Mask-RCNN was first developed.
                contours = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)[
                    -2
                ]
                polygons = [c.reshape(-1).tolist() for c in contours if len(c) >= 3]
                # opencv's can produce invalid polygons
                if len(polygons) == 0:
                    continue
                anno["segmentation"] = polygons
            else:
                anno["segmentation"] = mask_util.encode(mask[:, :, None])[0]
            annos.append(anno)
    ret["annotations"] = annos
    return ret
Exemplo n.º 27
0
def _cached_log_stream(filename):
    io = PathManager.open(filename, "a", buffering=1024)
    atexit.register(io.close)
    return io
Exemplo n.º 28
0
    def forward(self, scores: torch.Tensor, head_id: int):
        assert scores.shape[0] % self.num_crops == 0
        bs = scores.shape[0] // self.num_crops

        total_loss = 0
        n_term_loss = 0

        # 2 big crops are normally used for the assignment
        for i, crop_id in enumerate(self.crops_for_assign):

            # Compute the target assignments, taking crop_id as the features
            # used to compute the codes to which other crops will be mapped
            with torch.no_grad():
                scores_this_crop = scores[bs * crop_id:bs * (crop_id + 1)]

                # Add representations of the queue (this option is useful when
                # the batch size is small, to increase the number of samples
                # in sinkhornknopp to make equal repartition possible)
                if self.use_queue:
                    queue = getattr(self,
                                    "local_queue" + str(head_id))[i].clone()
                    scores_this_crop = torch.cat((scores_this_crop, queue))

                # Divide by epsilon (which can be seen as a temperature which
                # helps to sharpen the distribution of the assignments)
                if self.use_double_prec:
                    assignments = torch.exp(scores_this_crop.double() /
                                            np.float64(self.epsilon)).t()
                    assignments = assignments.double()
                else:
                    assignments = scores_this_crop / self.epsilon
                    # use the log-sum-exp trick for numerical stability.
                    M = torch.max(assignments)
                    all_reduce_max(M)
                    assignments -= M
                    assignments = torch.exp(assignments).t()

                # Apply sinkhornknopp algorithm to divide equally the
                # assignment to each of the prototypes
                assignments = distributed_sinkhornknopp(
                    Q=assignments,
                    hard_assignment=self.num_iteration <
                    self.temp_hard_assignment_iters,
                    world_size=self.world_size,
                    num_iter=self.nmb_sinkhornknopp_iters,
                    use_gpu=self.use_gpu,
                    use_double_prec=self.use_double_prec,
                )
                assignments = assignments[:bs]

            # For each crop other than the one used as target assignment
            # compute the cross entropy between the target assigment and
            # the soft-max of the dot product of each crop to the prototypes
            loss = 0
            idx_crop_pred = np.delete(np.arange(self.num_crops), crop_id)
            for p in idx_crop_pred:
                if self.use_double_prec:
                    loss -= torch.mean(
                        torch.sum(
                            assignments *
                            self.log_softmax(scores[bs * p:bs *
                                                    (p + 1)].double() /
                                             np.float64(self.temperature)),
                            dim=1,
                            dtype=assignments.dtype,
                        ))
                else:
                    loss -= torch.mean(
                        torch.sum(
                            assignments * self.log_softmax(
                                scores[bs * p:bs *
                                       (p + 1)] / self.temperature),
                            dim=1,
                            dtype=assignments.dtype,
                        ))

            # Average of the contribution of each crop (we don't want and
            # increase in the number of crop to impact the loss magnitude
            # and force us to update the LR)
            loss /= len(idx_crop_pred)

            # Average the contribution of each swapped assignment (the
            # division by 'n_term_loss' is done at the end of the loop)
            # for the same reason as above
            total_loss += loss
            n_term_loss += 1

            # Stop training if NaN appears and log the output to help debugging
            # TODO (prigoyal): extract the logic to be common for all losses
            # debug_state() method that all losses can override
            if torch.isnan(loss):
                logging.info(
                    f"Infinite Loss or NaN. Loss value: {loss}, rank: {self.dist_rank}"
                )
                scores_output_file = os.path.join(
                    self.output_dir,
                    "rank" + str(self.dist_rank) + "_scores" + str(i) + ".pth",
                )
                assignments_out_file = os.path.join(
                    self.output_dir,
                    "rank" + str(self.dist_rank) + "_assignments" + str(i) +
                    ".pth",
                )
                with PathManager.open(scores_output_file, "wb") as fwrite:
                    torch.save(scores, fwrite)
                with PathManager.open(assignments_out_file, "wb") as fwrite:
                    torch.save(assignments, fwrite)
                logging.info(
                    f"Saved the scores matrix to: {scores_output_file}")
                logging.info(
                    f"Saved the assignment matrix to: {assignments_out_file}")

        total_loss /= n_term_loss
        return total_loss
def generate_seeds(args):
    data = []
    data_per_cat = {c: [] for c in VOC_CLASSES}
    for year in [2007, 2012]:
        # for year in [2007]:
        data_file = 'datasets/VOC{}/ImageSets/Main/trainval.txt'.format(year)
        with PathManager.open(data_file) as f:
            fileids = np.loadtxt(f, dtype=np.str).tolist()
        data.extend(fileids)
    # print(data)
    # import sys
    # sys.exit(0)
    for fileid in data:
        year = "2012" if "_" in fileid else "2007"
        # year="2007"
        dirname = os.path.join("datasets", "VOC{}".format(year))
        anno_file = os.path.join(dirname, "Annotations", fileid + ".xml")
        tree = ET.parse(anno_file)
        clses = []
        for obj in tree.findall("object"):
            cls = obj.find("name").text
            clses.append(cls)
        for cls in set(clses):
            data_per_cat[cls].append(anno_file)
    # print('data_per_cat:',data_per_cat,'clses',clses)

    result = {cls: {} for cls in data_per_cat.keys()}
    shots = [1, 2, 3, 5, 10]
    # print ('result',result)
    for i in range(args.seeds[0], args.seeds[1]):
        random.seed(i)
        for c in data_per_cat.keys():
            c_data = []
            for j, shot in enumerate(shots):
                diff_shot = shots[j] - shots[j - 1] if j != 0 else 1
                shots_c = random.sample(data_per_cat[c], diff_shot)
                num_objs = 0
                for s in shots_c:
                    if s not in c_data:
                        tree = ET.parse(s)
                        file = tree.find("filename").text
                        year = tree.find("folder").text
                        # year = "2007"
                        name = 'datasets/VOC{}/JPEGImages/{}'.format(
                            year, file)
                        c_data.append(name)
                        for obj in tree.findall("object"):
                            if obj.find("name").text == c:
                                num_objs += 1
                        if num_objs >= diff_shot:
                            break
                result[c][shot] = copy.deepcopy(c_data)
        save_path = 'datasets/vocsplit/seed{}'.format(i)
        os.makedirs(save_path, exist_ok=True)
        for c in result.keys():
            # print('result_keys:',c)
            for shot in result[c].keys():
                # print('result[c].keys():',shot)
                filename = 'box_{}shot_{}_train.txt'.format(shot, c)
                with open(os.path.join(save_path, filename), 'w') as fp:
                    print('Writing File at ',
                          os.path.join(save_path, filename))
                    fp.write('\n'.join(result[c][shot]) + '\n')
Exemplo n.º 30
0
def read_image(file_name, format=None):
    """
    Read an image into the given format.
    Will apply rotation and flipping if the image has such exif information.

    Args:
        file_name (str): image file path
        format (str): one of the supported image modes in PIL, or "BGR"

    Returns:
        image (np.ndarray): an HWC image
    """
    with PathManager.open(file_name, "rb") as f:
        if format == "BGRT":
            """
            # KAIST
            folder = file_name.split('visible')[0]
            img_name = file_name.split('visible/')[1]
            path_rgb = file_name
            path_thermal = folder + 'lwir/' + img_name
            img_rgb = cv2.imread(path_rgb)
            img_thermal = cv2.imread(path_thermal)
                        
            image = np.zeros((img_thermal.shape[0], img_thermal.shape[1], 4))
            image [:,:,0:3] = img_rgb
            image [:,:,3] = img_thermal[:,:,0]
            
            """
            # FLIR
            folder = file_name.split('thermal_8_bit/')[0]
            img_name = file_name.split('thermal_8_bit/')[1]
            img_name = img_name.split('.')[0] + '.jpg'
            rgb_path = folder + 'RGB/' + img_name
            #print(rgb_path)
            rgb_img = cv2.imread(rgb_path)
            thermal_img = cv2.imread(file_name)
            #import pdb; pdb.set_trace()
            rgb_img = cv2.resize(rgb_img,
                                 (thermal_img.shape[1], thermal_img.shape[0]))
            image = np.zeros((thermal_img.shape[0], thermal_img.shape[1], 4))
            image[:, :, 0:3] = rgb_img
            image[:, :, 3] = thermal_img[:, :, 0]
            #"""
        elif format == 'BGR_only':
            folder = file_name.split('thermal_8_bit/')[0]
            img_name = file_name.split('thermal_8_bit/')[1]
            img_name = img_name.split('.')[0] + '.jpg'
            rgb_path = folder + 'resized_RGB/' + img_name
            image = cv2.imread(rgb_path)
        elif format == 'BGRTTT':  # middle fusion
            """
            # KAIST
            folder = file_name.split('visible')[0]
            img_name = file_name.split('visible/')[1]
            path_rgb = file_name
            path_thermal = folder + 'lwir/' + img_name
            img_rgb = cv2.imread(path_rgb)
            img_thermal = cv2.imread(path_thermal)                        
            image = np.zeros((img_thermal.shape[0], img_thermal.shape[1], 6))
            image [:,:,0:3] = img_rgb
            image [:,:,3:] = img_thermal
            """
            # FLIR
            folder = file_name.split('thermal_8_bit/')[0]
            img_name = file_name.split('thermal_8_bit/')[-1]

            img_name = img_name.split('.')[0] + '.jpg'
            rgb_path = folder + 'RGB/' + img_name
            rgb_img = cv2.imread(rgb_path)
            thermal_img = cv2.imread(file_name)

            rgb_img = cv2.resize(rgb_img,
                                 (thermal_img.shape[1], thermal_img.shape[0]))
            image = np.zeros((thermal_img.shape[0], thermal_img.shape[1], 6))
            image[:, :, 0:3] = rgb_img
            image[:, :, 3:6] = thermal_img
            #"""
        elif format == 'BGRTTT_perturb':

            folder = file_name.split('thermal_8_bit/')[0]
            img_name = file_name.split('thermal_8_bit/')[1]
            img_name = img_name.split('.')[0] + '.jpg'
            rgb_path = folder + 'RGB/' + img_name
            rgb_img = cv2.imread(rgb_path)

            import os
            number = int(file_name.split('video_')[-1].split('.')[0])
            #number = int(file_name.split('FLIR')[-1].split('_')[1].split('.')[0])
            number -= 1
            number_str = '{:05d}'.format(number)
            new_file_name = file_name.split('thermal')[
                0] + 'thermal_8_bit/FLIR_video_' + number_str + '.jpeg'
            if os.path.exists(new_file_name):
                thermal_img = cv2.imread(new_file_name)
                print(new_file_name, '  RGB: ', rgb_path)
            else:
                thermal_img = cv2.imread(file_name)
                print(file_name, '  RGB: ', rgb_path)
            rgb_img = cv2.resize(rgb_img,
                                 (thermal_img.shape[1], thermal_img.shape[0]))
            """
            # Random resize
            import random
            ratio = random.randrange(100,121) / 100
            width_new = int(640*ratio+0.5)
            height_new = int(512*ratio+0.5)            
            rgb_img = cv2.resize(rgb_img, (width_new, height_new))
            
            # Random crop
            [height, width, _] = thermal_img.shape
            diff_w = width_new - width
            diff_h = height_new - height
            if diff_w > 0: shift_x = random.randrange(0, diff_w)
            else: shift_x = 0
            if diff_h > 0: shift_y = random.randrange(0, diff_h)
            else: shift_y = 0            
            
            rgb_img = rgb_img[shift_y:shift_y+height, shift_x:shift_x+width, :]
            """
            #import pdb; pdb.set_trace()
            image = np.zeros((thermal_img.shape[0], thermal_img.shape[1], 6))
            image[:, :, 0:3] = rgb_img
            image[:, :, 3:6] = thermal_img
        elif format == "mid_RGB_out":
            thermal_img = cv2.imread(file_name)
            image = np.zeros((thermal_img.shape[0], thermal_img.shape[1], 6))
            image[:, :, 3:6] = thermal_img
        elif format == 'T_TCONV':
            #import pdb;pdb.set_trace()
            folder = file_name.split('thermal_8_bit/')[0]
            img_name = file_name.split('thermal_8_bit/')[1]
            img_name = img_name.split('.')[0] + '.jpeg'
            t_conv_path = folder + 'thermal_convert/' + img_name
            t_conv_img = cv2.imread(t_conv_path)
            thermal_img = cv2.imread(file_name)
            image = np.zeros((thermal_img.shape[0], thermal_img.shape[1], 2))
            image[:, :, 0] = t_conv_img[:, :, 0]
            image[:, :, 1] = thermal_img[:, :, 0]
        elif format == 'T_TCONV_MASK':
            folder = file_name.split('thermal_convert/')[0]
            img_name = file_name.split('thermal_convert/')[1]
            #img_name = img_name.split('.')[0] + '.jpeg'
            t_conv_path = folder + 'thermal_convert/' + img_name
            t_conv_img = cv2.imread(t_conv_path)
            t_mask_path = folder + 'thermal_analysis/' + file_name.split(
                'thermal_convert/')[1].split(".")[0] + '_mask.jpg'
            mask_img = cv2.imread(t_mask_path)
            thermal_img = cv2.imread(file_name)
            image = np.zeros((thermal_img.shape[0], thermal_img.shape[1], 3))
            image[:, :, 0] = t_conv_img[:, :, 0]
            image[:, :, 1] = thermal_img[:, :, 0]
            image[:, :, 2] = mask_img[:, :, 0]
        elif format == 'UVV':  # UV in first two channel, 0 in third channel
            if 'train' in file_name:
                folder = '../../../Datasets/KAIST/train/KAIST_flow_train_sanitized/'
            else:
                folder = '../../../Datasets/KAIST/test/KAIST_flow_test_sanitized/'

            fname = file_name.split('/')[-1].split('.')[0] + '.flo'
            fpath = folder + fname
            flow = readFlow(fpath)
            image = np.zeros((flow.shape[0], flow.shape[1], 3))
            image[:, :, 0] = flow[:, :, 0]
            image[:, :, 1] = flow[:, :, 1]
            image[:, :, 2] = flow[:, :, 1]
            image *= 4.0
            #image += 128.0
            image[image > 255] = 255.0
            #pdb.set_trace()
            """
            image = np.abs(image) / 40.0 * 255.0
            image[image>255] = 255.0
            """
        elif format == 'UVM':  # UV + magnitude(uv)
            if 'train' in file_name:
                folder = '../../../Datasets/KAIST/train/KAIST_flow_train_sanitized/'
            else:
                folder = '../../../Datasets/KAIST/test/KAIST_flow_test_sanitized/'

            fname = file_name.split('/')[-1].split('.')[0] + '.flo'
            fpath = folder + fname
            flow = readFlow(fpath)
            flow_s = flow * flow
            magnitude = np.sqrt(flow_s[:, :, 0] + flow_s[:, :, 1])

            image = np.zeros((flow.shape[0], flow.shape[1], 3))
            image[:, :, 0] = flow[:, :, 0]
            image[:, :, 1] = flow[:, :, 1]
            image[:, :, 2] = magnitude
            image *= 4.0
            #image += 128.0
            image[image > 255] = 255.0
            """
            image = np.abs(image) / 40.0 * 255.0
            image[image>255] = 255.0
            """
        elif format == 'BGRTUV':
            if 'train' in file_name:
                flow_folder = '../../../Datasets/KAIST/train/KAIST_flow_train_sanitized/'
                img_folder = '../../../Datasets/KAIST/train/'
            else:
                flow_folder = '../../../Datasets/KAIST/test/KAIST_flow_test_sanitized/'
                img_folder = '../../../Datasets/KAIST/test/'

            fname = file_name.split('/')[-1].split('.')[0] + '.flo'
            fpath = flow_folder + fname
            flow = readFlow(fpath)

            image = np.zeros((flow.shape[0], flow.shape[1], 6))
            image[:, :, 4] = flow[:, :, 0]
            image[:, :, 5] = flow[:, :, 1]
            image *= 3
            image += 128.0
            image[image > 255] = 255.0

            set_name = file_name.split('/')[-1].split('_')[0]
            V_name = file_name.split('/')[-1].split('_')[1]
            img_name = file_name.split('/')[-1].split('_')[2]

            fname_bgr = img_folder + set_name + '/' + V_name + '/visible/' + img_name
            fname_thr = img_folder + set_name + '/' + V_name + '/lwir/' + img_name
            bgr = cv2.imread(fname_bgr)
            thr = cv2.imread(fname_thr)

            image[:, :, 0:3] = bgr
            image[:, :, 3] = thr[:, :, 0]

        else:
            #import pdb; pdb.set_trace()
            image = Image.open(f)

            # capture and ignore this bug: https://github.com/python-pillow/Pillow/issues/3973
            try:
                image = ImageOps.exif_transpose(image)
            except Exception:
                pass

            if format is not None:
                # PIL only supports RGB, so convert to RGB and flip channels over below
                conversion_format = format
                if format == "BGR":
                    conversion_format = "RGB"
                image = image.convert(conversion_format)
            image = np.asarray(image)
            if format == "BGR":
                # flip channels if needed
                image = image[:, :, ::-1]
            # PIL squeezes out the channel dimension for "L", so make it HWC
            if format == "L":
                image = np.expand_dims(image, -1)
        return image