Exemple #1
0
    def __init__(
        self,
        root_dir: str,
        section: str,
        transform: Callable[..., Any] = LoadPNGd("image"),
        download: bool = False,
        seed: int = 0,
        val_frac: float = 0.1,
        test_frac: float = 0.1,
        cache_num: int = sys.maxsize,
        cache_rate: float = 1.0,
        num_workers: int = 0,
    ):
        if not os.path.isdir(root_dir):
            raise ValueError("root_dir must be a directory.")
        self.section = section
        self.val_frac = val_frac
        self.test_frac = test_frac
        self.set_random_state(seed=seed)
        tarfile_name = os.path.join(root_dir, self.compressed_file_name)
        dataset_dir = os.path.join(root_dir, self.dataset_folder_name)
        if download:
            download_and_extract(self.resource, tarfile_name, root_dir,
                                 self.md5)

        if not os.path.exists(dataset_dir):
            raise RuntimeError(
                f"can not find dataset directory: {dataset_dir}, please use download=True to download it."
            )
        data = self._generate_data_list(dataset_dir)
        super().__init__(data,
                         transform,
                         cache_num=cache_num,
                         cache_rate=cache_rate,
                         num_workers=num_workers)
Exemple #2
0
 def test_shape(self, input_param, expected_shape):
     test_image = np.random.randint(0, 256, size=[128, 128, 3])
     with tempfile.TemporaryDirectory() as tempdir:
         test_data = dict()
         for key in KEYS:
             Image.fromarray(test_image.astype("uint8")).save(
                 os.path.join(tempdir, key + ".png"))
             test_data.update({key: os.path.join(tempdir, key + ".png")})
         result = LoadPNGd(**input_param)(test_data)
     for key in KEYS:
         self.assertTupleEqual(result[key].shape, expected_shape)
    def test_values(self):
        testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                                   "testing_data")
        transform = Compose([
            LoadPNGd(keys="image"),
            AddChanneld(keys="image"),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
        ])

        def _test_dataset(dataset):
            self.assertEqual(len(dataset), 5986)
            self.assertTrue("image" in dataset[0])
            self.assertTrue("label" in dataset[0])
            self.assertTrue("image_meta_dict" in dataset[0])
            self.assertTupleEqual(dataset[0]["image"].shape, (1, 64, 64))

        try:  # will start downloading if testing_dir doesn't have the MedNIST files
            data = MedNISTDataset(root_dir=testing_dir,
                                  transform=transform,
                                  section="test",
                                  download=True)
        except (ContentTooShortError, HTTPError, RuntimeError) as e:
            print(str(e))
            if isinstance(e, RuntimeError):
                # FIXME: skip MD5 check as current downloading method may fail
                self.assertTrue(str(e).startswith("md5 check"))
            return  # skipping this test due the network connection errors

        _test_dataset(data)

        # testing from
        data = MedNISTDataset(root_dir=testing_dir,
                              transform=transform,
                              section="test",
                              download=False)
        _test_dataset(data)
        data = MedNISTDataset(root_dir=testing_dir,
                              section="test",
                              download=False)
        self.assertTupleEqual(data[0]["image"].shape, (64, 64))
        shutil.rmtree(os.path.join(testing_dir, "MedNIST"))
        try:
            data = MedNISTDataset(root_dir=testing_dir,
                                  transform=transform,
                                  section="test",
                                  download=False)
        except RuntimeError as e:
            print(str(e))
            self.assertTrue(str(e).startswith("Cannot find dataset directory"))
Exemple #4
0
    def test_values(self):
        tempdir = tempfile.mkdtemp()
        transform = Compose([
            LoadPNGd(keys="image"),
            AddChanneld(keys="image"),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
        ])

        def _test_dataset(dataset):
            self.assertEqual(len(dataset), 5986)
            self.assertTrue("image" in dataset[0])
            self.assertTrue("label" in dataset[0])
            self.assertTrue("image_meta_dict" in dataset[0])
            self.assertTupleEqual(dataset[0]["image"].shape, (1, 64, 64))

        data = MedNISTDataset(root_dir=tempdir,
                              transform=transform,
                              section="test",
                              download=True)
        _test_dataset(data)
        data = MedNISTDataset(root_dir=tempdir,
                              transform=transform,
                              section="test",
                              download=False)
        _test_dataset(data)
        data = MedNISTDataset(root_dir=tempdir, section="test", download=False)
        self.assertTupleEqual(data[0]["image"].shape, (64, 64))
        shutil.rmtree(os.path.join(tempdir, "MedNIST"))
        try:
            data = MedNISTDataset(root_dir=tempdir,
                                  transform=transform,
                                  section="test",
                                  download=False)
        except RuntimeError as e:
            print(str(e))
            self.assertTrue(
                str(e).startswith("can not find dataset directory"))

        shutil.rmtree(tempdir)
Exemple #5
0
def run(file, inputMount, outputMount):
    inputPath = os.path.join(inputMount, file.split('/')[-1])
    outputPath = os.path.join(outputMount, file.split('/')[-1])
    print("attempting to infer cell clusters")
    try:

        # open the image only to find its size
        with Image.open(inputPath) as input_img:
            input_bbox = input_img.getbbox()
            input_width = input_bbox[2] - input_bbox[0]
            input_height = input_bbox[3] - input_bbox[1]
            #print('input image size:',input_width,input_height)

        # instantiate the model
        # standard PyTorch program style: create UNet, DiceLoss and Adam optimizer
        #print('checking for cuda device')
        device = torch.device('cuda:0')
        model = monai.networks.nets.UNet(dimensions=2,
                                         in_channels=3,
                                         out_channels=2,
                                         channels=(16, 32, 64, 128, 256),
                                         strides=(2, 2, 2, 2),
                                         num_res_units=2,
                                         norm=Norm.BATCH).to(device)

        print('attempting load of pretrained network from /tmp directory')
        # read in the pretrained model
        model.load_state_dict(
            torch.load('/tmp/unet_368x368_segment_model.pth'))
        #print('loading complete')
        model.eval()
        NETWORK_IMAGE_SIZE = 368

        # define xforms in MONAI to prepare imagery for PyTorch inferencing
        infer_transforms = Compose([
            LoadPNGd(keys=['image']),
            AsChannelFirstd(keys=['image']),
            Resized(keys=['image'],
                    spatial_size=(NETWORK_IMAGE_SIZE, NETWORK_IMAGE_SIZE),
                    mode='bilinear',
                    align_corners=False),
            CastToTyped(keys=['image'], dtype='float32'),
            ScaleIntensityd(keys=['image'], minv=0.0, maxv=1.0),
            ToTensord(keys=['image'])
        ])

        # create and load a Monai dataset composed of the single image, because the transforms
        # are performed automatically in MONAI. A spec for the image is needed by the Dataset definition
        infer_files = [{'image': inputPath}]
        infer_ds = monai.data.Dataset(infer_files, transform=infer_transforms)
        infer_loader = monai.data.DataLoader(infer_ds, batch_size=1)

        # get the transformed single image out of the dataset
        input_data = monai.utils.misc.first(infer_loader)

        # move the tensor to the GPU
        input_tensor = input_data['image'].to(device)

        # run the forward prediction
        predict_tensor = model(input_tensor)
        #print('inference complete. preparing output')

        # get the result back from the GPU and drop the first dimension
        infer_array = predict_tensor.detach().cpu().squeeze()

        # rearrange to num channels last and make it a single channel binary image
        pred_array = torch.argmax(np.transpose(infer_array, (1, 2, 0)), dim=2)

        # convert type back from torch to numpy
        prediction = pred_array.numpy()

        # write the file out in a viewable way, the output image is resized to match the
        # input image size for convenience, even though inferencing is always done at the
        # size of the pretrained network
        outimg = Image.fromarray(prediction.astype('uint8') * 255)
        resized = outimg.resize((input_width, input_height))
        print('saving segmentation to:', outputPath)
        resized.save(outputPath)

    except OSError:
        print("cannot create inference image for", file)