示例#1
0
    def load_dataset(self,
                     csv_file,
                     root_dir=None,
                     augment=False,
                     shuffle=True,
                     batch_size=1):
        """Create a tree dataset for inference
        Csv file format is .csv file with the columns "image_path", "xmin","ymin","xmax","ymax" for the image name and bounding box position.
        Image_path is the relative filename, not absolute path, which is in the root_dir directory. One bounding box per line.

        Args:
            csv_file: path to csv file
            root_dir: directory of images. If none, uses "image_dir" in config
            augment: Whether to create a training dataset, this activates data augmentations
        Returns:
            ds: a pytorch dataset
        """

        ds = dataset.TreeDataset(
            csv_file=csv_file,
            root_dir=root_dir,
            transforms=dataset.get_transform(augment=augment),
            label_dict=self.label_dict)

        data_loader = torch.utils.data.DataLoader(
            ds,
            batch_size=batch_size,
            shuffle=shuffle,
            collate_fn=utilities.collate_fn,
            num_workers=self.config["workers"],
        )

        return data_loader
示例#2
0
def run():
    csv_file = get_data("OSBS_029.csv")
    root_dir = os.path.dirname(csv_file)

    ds = dataset.TreeDataset(csv_file=csv_file,
                             root_dir=root_dir,
                             transforms=dataset.get_transform(augment=True))

    for x in range(1000):
        next(iter(ds))
示例#3
0
def test_collate():
    """Due to data augmentations the dataset class may yield empty bounding box annotations"""
    csv_file = get_data("example.csv")
    root_dir = os.path.dirname(csv_file)
    ds = dataset.TreeDataset(csv_file=csv_file,
                             root_dir=root_dir,
                             transforms=dataset.get_transform(augment=False))

    for i in range(len(ds)):
        #Between 0 and 1
        batch = ds[i]
        collated_batch = utilities.collate_fn(batch)
        assert len(collated_batch) == 2
def test_TreeDataset_transform(augment):
    csv_file = get_data("example.csv")
    root_dir = os.path.dirname(csv_file)
    ds = dataset.TreeDataset(csv_file=csv_file,
                             root_dir=root_dir,
                             transforms=dataset.get_transform(augment=augment))

    for i in range(len(ds)):
        #Between 0 and 1
        path, image, targets = ds[i]
        assert image.max() <= 1
        assert image.min() >= 0
        assert targets["boxes"].shape == (79, 4)
        assert targets["labels"].shape == (79, )
示例#5
0
    def log_images(self, pl_module):

        ds = dataset.TreeDataset(
            csv_file=self.csv_file,
            root_dir=self.root_dir,
            transforms=dataset.get_transform(augment=False),
            label_dict=pl_module.label_dict)

        if self.n > len(ds):
            self.n = len(ds)

        ds = torch.utils.data.Subset(ds, np.arange(0, self.n, 1))

        data_loader = torch.utils.data.DataLoader(
            ds, batch_size=1, shuffle=False, collate_fn=utilities.collate_fn)

        pl_module.model.eval()

        for batch in data_loader:
            paths, images, targets = batch

            if not pl_module.device.type == "cpu":
                images = [x.to(pl_module.device) for x in images]

            predictions = pl_module.model(images)

            for path, image, prediction, target in zip(paths, images,
                                                       predictions, targets):
                image = image.permute(1, 2, 0)
                image = image.cpu()
                visualize.plot_prediction_and_targets(image=image,
                                                      predictions=prediction,
                                                      targets=target,
                                                      image_name=path,
                                                      savedir=self.savedir)
                plt.close()
        try:
            saved_plots = glob.glob("{}/*.png".format(self.savedir))
            for x in saved_plots:
                pl_module.logger.experiment.log_image(x)
        except Exception as e:
            print(
                "Could not find logger in ligthning module, skipping upload, images were saved to {}, error was rasied {}"
                .format(self.savedir, e))
示例#6
0
def test_multi_image_warning():
    tmpdir = tempfile.gettempdir()
    csv_file1 = get_data("example.csv")
    csv_file2 = get_data("OSBS_029.csv")
    df1 = pd.read_csv(csv_file1)
    df2 = pd.read_csv(csv_file2)
    df = pd.concat([df1, df2])
    csv_file = "{}/multiple.csv".format(tmpdir)
    df.to_csv(csv_file)

    root_dir = os.path.dirname(csv_file1)
    ds = dataset.TreeDataset(csv_file=csv_file,
                             root_dir=root_dir,
                             transforms=dataset.get_transform(augment=False))

    for i in range(len(ds)):
        #Between 0 and 1
        batch = ds[i]
        collated_batch = utilities.collate_fn([None, batch, batch])
        len(collated_batch[0]) == 2