Ejemplo n.º 1
0
def human_baseline(gold_dataset_dir: str, annotation_path: str):
    """
    Shows the images from the dataset and ask the human to label them.
    :param str gold_dataset_dir: The directory of the gold dataset
    :param str annotation_path: The path to the annotation file, if it exists we will resume the labeling session
    """
    files = glob.glob(os.path.join(gold_dataset_dir, "*.jpeg"))
    io_handler = IOHandler()
    input_dictionary = restore_dictionary(annotation_path=annotation_path)

    try:
        pbar_desc = (
            "-1"
            if input_dictionary["total"] == 0
            else f"Current human accuracy: {round((input_dictionary['correct']/input_dictionary['total'])*100,2)}%"
        )
        with tqdm(total=len(files) - input_dictionary["total"], desc=pbar_desc) as pbar:
            for image_name in files:
                metadata = os.path.basename(image_name)[:-5]
                header, values = metadata.split("%")
                image_no = int(header[1:])
                if image_no not in input_dictionary["human_predictions"]:
                    gold_key = io_handler.imagename_input_conversion(
                        image_name=image_name, output_type="keyboard"
                    )

                    # image = io.imread(image_name)
                    os.system(f"xv {image_name} &")
                    # cv2.imshow("window1", img_as_ubyte(image))
                    # cv2.waitKey(1)
                    user_key = keys_to_id(input("Push the keys: "))

                    input_dictionary["human_predictions"][image_no] = user_key
                    input_dictionary["total"] += 1
                    if user_key == gold_key:
                        input_dictionary["correct"] += 1

                    pbar.update(1)

                    pbar.set_description(
                        f"Current human accuracy: {round((input_dictionary['correct']/input_dictionary['total'])*100,2)}%"
                    )

                    if input_dictionary["total"] % 20 == 0:
                        with open(
                            annotation_path, "w+", encoding="utf8"
                        ) as annotation_file:
                            json.dump(input_dictionary, annotation_file)

    except KeyboardInterrupt:
        with open(annotation_path, "w+", encoding="utf8") as annotation_file:
            json.dump(input_dictionary, annotation_file)

    with open(annotation_path, "w+", encoding="utf8") as annotation_file:
        json.dump(input_dictionary, annotation_file)
Ejemplo n.º 2
0
class Tedd1104Dataset(Dataset):
    """TEDD1104 dataset."""
    def __init__(
        self,
        dataset_dir: str,
        hide_map_prob: float,
        dropout_images_prob: List[float],
        control_mode: str = "keyboard",
        train: bool = False,
    ):
        """
        INIT

        :param str dataset_dir: The directory of the dataset.
        :param bool hide_map_prob: Probability of hiding the minimap (0<=hide_map_prob<=1)
        :param List[float] dropout_images_prob: Probability of dropping an image (0<=dropout_images_prob<=1)
        :param str control_mode: Type of the user input: "keyboard" or "controller"
        :param bool train: If True, the dataset is used for training.
        """

        self.dataset_dir = dataset_dir
        self.hide_map_prob = hide_map_prob
        self.dropout_images_prob = dropout_images_prob
        self.control_mode = control_mode.lower()

        assert self.control_mode in [
            "keyboard",
            "controller",
        ], f"{self.control_mode} control mode not supported. Supported dataset types: [keyboard, controller].  "

        assert 0 <= hide_map_prob <= 1.0, (
            f"hide_map_prob not in 0 <= hide_map_prob <= 1.0 range. "
            f"hide_map_prob: {hide_map_prob}")

        assert len(dropout_images_prob) == 5, (
            f"dropout_images_prob must have 5 probabilities, one for each image in the sequence. "
            f"dropout_images_prob len: {len(dropout_images_prob)}")

        for dropout_image_prob in dropout_images_prob:
            assert 0 <= dropout_image_prob <= 1.0, (
                f"All probabilities in dropout_image_prob must be in the range 0 <= dropout_image_prob <= 1.0. "
                f"dropout_images_prob: {dropout_images_prob}")

        if train:
            self.transform = transforms.Compose([
                RemoveMinimap(hide_map_prob=hide_map_prob),
                RemoveImage(dropout_images_prob=dropout_images_prob),
                SplitImages(),
                ToTensor(),
                SequenceColorJitter(),
                Normalize(),
                MergeImages(),
            ])
        else:
            self.transform = transforms.Compose([
                # RemoveMinimap(hide_map_prob=hide_map_prob),
                # RemoveImage(dropout_images_prob=dropout_images_prob),
                SplitImages(),
                ToTensor(),
                # SequenceColorJitter(),
                Normalize(),
                MergeImages(),
            ])

        self.dataset_files = glob.glob(os.path.join(dataset_dir, "*.jpeg"))

        self.IOHandler = IOHandler()

    def __len__(self):
        """
        Returns the length of the dataset.

        :return: int - Length of the dataset.
        """
        return len(self.dataset_files)

    def __getitem__(self, idx):
        """
        Returns a sample from the dataset.

        :param int idx: Index of the sample.
        :return: Dict[str, torch.tensor]- Transformed sequence of images
        """
        if torch.is_tensor(idx):
            idx = int(idx)

        img_name = self.dataset_files[idx]
        image = None
        while image is None:
            try:
                image = io.imread(img_name)
            except (ValueError, FileNotFoundError) as err:
                error_message = str(err).split("\n")[-1]
                print(
                    f"Error reading image: {img_name} probably a corrupted file.\n"
                    f"Exception: {error_message}\n"
                    f"We will load a random image instead.")
                img_name = self.dataset_files[int(
                    len(self.dataset_files) * torch.rand(1))]

        y = self.IOHandler.imagename_input_conversion(
            image_name=img_name,
            output_type=self.control_mode,
        )

        sample = {"image": image, "y": y}

        return self.transform(sample)