Exemplo n.º 1
0
    def get_example(self, idx: object) -> object:
        """
        Args:
            idx: integer indicating index of dataset

        Returns: example element from dataset

        """
        example = super().get_example(idx)
        output = dict()
        import pdb
        pdb.set_trace()
        sample_idxs = np.random.choice(np.arange(0, self.sequence_length),
                                       2,
                                       replace=False)
        for i, ex_idx in enumerate(sample_idxs):
            if self.config.get("image_type", "") == "mask":
                image = example["masked_frames"][ex_idx]()
            elif self.config.get("image_type", "") == "white":
                image = example["whitened_frames"][ex_idx]()
            else:
                image = example["frames"][ex_idx]()
            try:
                keypoints = self.labels["kps"][idx][ex_idx]
            except:
                keypoints = np.array([[0, 0], [0, 0], [0, 0], [0, 0], [0, 0],
                                      [0, 0], [0, 0], [0, 0], [0, 0], [0, 0],
                                      [0, 0], [0, 0], [0, 0], [0, 0], [0, 0],
                                      [0, 0], [0, 0], [0, 0], [0, 0], [0, 0]])
            try:
                bboxes = self.labels["bboxes"][idx][ex_idx]
                bbox_available = True
            except:
                # self.logger.warning("No bboxes in this dataset!")
                bbox_available = False
            output[f"fid{i}"] = self.labels["fid"][idx][ex_idx]
            # store which keypoints are not present in the dataset
            zero_mask_x = np.where(keypoints[:, 0] <= 0)
            zero_mask_y = np.where(keypoints[:, 1] <= 0)
            # need uint 8 for augmentation methods
            image = adjust_support(image, "0->255")
            if "crop" in self.config.keys():
                if self.config["crop"]:
                    if not bbox_available:
                        pass
                        # self.logger.warning("No resizing possible, no bounding box!")
                    else:
                        image, keypoints = crop(image, keypoints, bboxes)
            if self.augmentation and i == 1:
                # randomly perform some augmentations on the image, keypoints and bboxes
                image, keypoints = self.seq(image=image,
                                            keypoints=keypoints.reshape(
                                                1, -1, 2))
            # (H, W, C) and keypoints need to be reshaped from (N,J,2) -> (J,2)  J==Number of joints / keypoint pairs
            image, keypoints = self.resize(image=image,
                                           keypoints=keypoints.reshape(
                                               1, -1, 2))
            # image, keypoints = self.rescale(image, keypoints.reshape(-1, 2))
            keypoints = keypoints.reshape(-1, 2)
            keypoints[zero_mask_x] = np.array([0, 0])
            keypoints[zero_mask_y] = np.array([0, 0])
            # we always work with "0->1" images and np.float32
            # image = adjust_support(image, "0->1")
            height = image.shape[0]
            width = image.shape[1]
            if "as_grey" in self.config.keys():
                if self.config["as_grey"]:
                    output[f"inp{i}"] = adjust_support(
                        skimage.color.rgb2gray(image).reshape(
                            height, width, 1), "0->1")
                    assert (self.data.data.config["n_channels"] == 1), (
                        "n_channels should be 1, got {}".format(
                            self.config["n_channels"]))
                else:
                    output[f"inp{i}"] = adjust_support(image, "0->1")
            else:
                output[f"inp{i}"] = adjust_support(image, "0->1")

            output[f"targets{i}"] = adjust_support(
                make_heatmaps(output[f"inp{i}"], keypoints, sigma=self.sigma),
                "0->1")
            output[f"framename{i}"] = self.labels["frames_"][idx][ex_idx]
        return output
Exemplo n.º 2
0
    def get_example(self, idx: object) -> object:
        """
        Args:
            idx: integer indicating index of dataset

        Returns: example element from dataset

        """
        example = super().get_example(idx)
        output = dict()
        output["global_video_class0"] = example["labels_"][
            "global_video_class"]
        output["fid0"] = example["labels_"]["fid"]
        appearance_example = self.get_appearance_image(
            output["global_video_class0"], output["fid0"])
        output["global_video_class1"] = appearance_example["labels_"][
            "global_video_class"]
        if output["global_video_class0"] != output["global_video_class1"]:
            import pdb
            pdb.set_trace()
        output["fid1"] = appearance_example["labels_"]["fid"]

        for i, ex in enumerate([example, appearance_example]):
            if self.config.get("image_type", "") == "mask":
                image = ex["masked_frames"]()
                output[f"framename{i}"] = ex["labels_"]["masked_frames_"]
            elif self.config.get("image_type", "") == "white":
                image = ex["whitened_frames"]()
                output[f"framename{i}"] = ex["labels_"]["whitened_frames_"]
            else:
                image = ex["frames"]()
                output[f"framename{i}"] = ex["labels_"]["frames_"]
            try:
                keypoints = ex["kps"]
            except:
                keypoints = np.array([[0, 0], [0, 0], [0, 0], [0, 0], [0, 0],
                                      [0, 0], [0, 0], [0, 0], [0, 0], [0, 0],
                                      [0, 0], [0, 0], [0, 0], [0, 0], [0, 0],
                                      [0, 0], [0, 0], [0, 0], [0, 0], [0, 0]])
            try:
                bboxes = ex["bboxes"]
                bbox_available = True
            except:
                # self.logger.warning("No bboxes in this dataset!")
                bbox_available = False

            # store which keypoints are not present in the dataset
            zero_mask_x = np.where(keypoints[:, 0] <= 0)
            zero_mask_y = np.where(keypoints[:, 1] <= 0)
            # need uint 8 for augmentation methods
            image = adjust_support(image, "0->255")
            if "crop" in self.config.keys():
                if self.config["crop"]:
                    if not bbox_available:
                        pass
                        # self.logger.warning("No resizing possible, no bounding box!")
                    else:
                        image, keypoints = crop(image, keypoints, bboxes)
            if self.augmentation and i == 1:
                # randomly perform some augmentations on the image, keypoints and bboxes
                image, keypoints = self.seq(image=image,
                                            keypoints=keypoints.reshape(
                                                1, -1, 2))
            # (H, W, C) and keypoints need to be reshaped from (N,J,2) -> (J,2)  J==Number of joints / keypoint pairs
            image, keypoints = self.resize(image=image,
                                           keypoints=keypoints.reshape(
                                               1, -1, 2))
            # image, keypoints = self.rescale(image, keypoints.reshape(-1, 2))
            keypoints = keypoints.reshape(-1, 2)
            keypoints[zero_mask_x] = np.array([0, 0])
            keypoints[zero_mask_y] = np.array([0, 0])
            # we always work with "0->1" images and np.float32
            height = image.shape[0]
            width = image.shape[1]

            output[f"inp{i}"] = adjust_support(image, "0->1")
            output[f"kps{i}"] = keypoints
            output[f"targets{i}"] = adjust_support(
                make_heatmaps(output[f"inp{i}"], keypoints, sigma=self.sigma),
                "0->1")
        assert output["global_video_class0"] == output[
            "global_video_class1"], f"Video classes need to be the same! Got {output['global_video_class0']} and {output['global_video_class1']}"
        return output
Exemplo n.º 3
0
    def get_example(self, idx: object) -> object:
        """
        Args:
            idx: integer indicating index of dataset

        Returns: example element from dataset

        """
        example = super().get_example(idx)
        output = {}
        if self.config.get("image_type", "") == "mask":
            image = example["masked_frames"]()
        elif self.config.get("image_type", "") == "white":
            image = example["whitened_frames"]()
        else:
            image = example["frames"]()
        keypoints = self.labels["kps"][idx]
        if "synthetic" in self.data.data.config:
            # add a keypoint
            keypoints = np.append(keypoints, [[0, 0], [0, 0]], axis=0)
        try:
            bboxes = self.labels["bboxes"][idx]
            bbox_available = True
        except:
            # self.logger.warning("No bboxes in this dataset!")
            bbox_available = False
            # estimate bbox from keypoints
            bboxes = bboxes_from_kps(keypoints)
            # check if bbox is out of the image and clip
            # [x, y, width, height]
            width, height, _ = image.shape
            bboxes[0] = bboxes[0].clip(0, width)
            bboxes[1] = bboxes[1].clip(0, height)
            bboxes[2] = bboxes[2].clip(0, width)
            bboxes[3] = bboxes[3].clip(0, height)
            bboxes = bboxes.astype(np.float32)
            bbox_available = True

        # image, keypoints, bboxes = example["frames"](), self.labels["kps"][idx], self.labels["bboxes"][idx]
        # store which keypoints are not present in the dataset
        zero_mask_x = np.where(keypoints[:, 0] <= 0)
        zero_mask_y = np.where(keypoints[:, 1] <= 0)
        # need uint 8 for augmentation methods
        image = adjust_support(image, "0->255")
        if "crop" in self.config.keys():
            if self.data.data.config["crop"]:
                if not bbox_available:
                    pass
                    # self.logger.warning("No resizing possible, no bounding box!")
                else:
                    image, keypoints = crop(image, keypoints, bboxes)
        if self.augmentation:
            # randomly perform some augmentations on the image, keypoints and bboxes
            image, keypoints = self.seq(image=image, keypoints=keypoints.reshape(1, -1, 2))

        # (H, W, C) and keypoints need to be reshaped from (N,J,2) -> (J,2)  J==Number of joints / keypoint pairs
        image, keypoints = self.resize(image=image, keypoints=keypoints.reshape(1, -1, 2))
        keypoints = keypoints.reshape(-1, 2)
        # image, keypoints = self.data.data.rescale(image, keypoints.reshape(-1, 2))
        keypoints[zero_mask_x] = np.array([0, 0])
        keypoints[zero_mask_y] = np.array([0, 0])
        # we always work with "0->1" images and np.float32
        height = image.shape[0]
        width = image.shape[1]
        if "as_grey" in self.data.data.config.keys():
            if self.data.data.config["as_grey"]:
                output["inp0"] = adjust_support(skimage.color.rgb2gray(image).reshape(height, width, 1), "0->1")
                assert (self.data.data.config["n_channels"] == 1), (
                    "n_channels should be 1, got {}".format(self.data.data.config["n_channels"]))
            else:
                output["inp0"] = adjust_support(image, "0->1")
        else:
            output["inp0"] = adjust_support(image, "0->1")

        output["kps"] = keypoints
        output["targets"] = make_heatmaps(output["inp0"], keypoints, sigma=self.sigma)
        output["targets_vis"] = heatmaps_to_image(
            np.expand_dims(make_heatmaps(output["inp0"], keypoints, sigma=self.sigma), 0)).squeeze()
        output["animal_class"] = np.array(animal_class[self.data.data.animal])
        output["stickanmial"] = make_stickanimal(np.expand_dims(output["inp0"], 0), np.expand_dims(keypoints, 0))
        # Workaround for the encoder decoder to just see how vae is doing
        output["inp1"] = output["inp0"]
        output["framename"] = self.labels["frames_"][idx]
        # output["joints"] = self.joints
        return output