示例#1
0
    def test_tiff_dataset(self):
        # Test that TiffVolume opens a TIFF stack
        testDataset = TiffVolume(os.path.join(IMAGE_PATH, "inputs.tif"),
                                 BoundingBox(Vector(0, 0, 0),
                                             Vector(1024, 512, 50)),
                                 iteration_size=BoundingBox(
                                     Vector(0, 0, 0), Vector(128, 128, 20)),
                                 stride=Vector(128, 128, 20))
        testDataset.__enter__()

        # Test that TiffVolume has the correct length
        self.assertEqual(64, len(testDataset),
                         "TIFF dataset size does not match correct size")

        # Test that TiffVolume outputs the correct samples
        self.assertTrue((tif.imread(os.path.join(
            IMAGE_PATH, "test_sample.tif")) == testDataset[10].getArray()).all,
                        "TIFF dataset value does not match correct value")

        # Test that TiffVolume can read and write consistent samples
        tif.imsave(os.path.join(IMAGE_PATH, "test_write.tif"),
                   testDataset[10].getArray())
        self.assertTrue((tif.imread(os.path.join(
            IMAGE_PATH, "test_write.tif")) == testDataset[10].getArray()).all,
                        "TIFF dataset output does not match written output")
示例#2
0
    def test_checkpoint(self):
        if not os.path.isdir('./tests/checkpoints'):
            os.mkdir('tests/checkpoints')

        net = RSUNet()
        inputs_dataset = TiffVolume(
            os.path.join(IMAGE_PATH, "inputs.tif"),
            BoundingBox(Vector(0, 0, 0), Vector(1024, 512, 50)))
        labels_dataset = TiffVolume(
            os.path.join(IMAGE_PATH, "labels.tif"),
            BoundingBox(Vector(0, 0, 0), Vector(1024, 512, 50)))
        inputs_dataset.__enter__()
        labels_dataset.__enter__()
        trainer = Trainer(net,
                          inputs_dataset,
                          labels_dataset,
                          max_epochs=10,
                          gpu_device=1)
        trainer = CheckpointWriter(trainer,
                                   checkpoint_dir='./tests/checkpoints',
                                   checkpoint_period=5)
        trainer.run_training()
        trainer = Trainer(net,
                          inputs_dataset,
                          labels_dataset,
                          max_epochs=10,
                          checkpoint='./tests/checkpoints/iteration_5.ckpt',
                          gpu_device=1)
        trainer.run_training()
示例#3
0
    def test_pooled_volume(self):
        pooled_volume = PooledVolume(stack_size=5)
        pooled_volume.add(
            TiffVolume(os.path.join(IMAGE_PATH, "inputs.tif"),
                       BoundingBox(Vector(0, 0, 0), Vector(1024, 512, 50))))
        pooled_volume.add(
            TiffVolume(os.path.join(IMAGE_PATH, "inputs.tif"),
                       BoundingBox(Vector(0, 0, 50), Vector(1024, 512, 100))))
        output = pooled_volume.get(
            BoundingBox(Vector(0, 0, 40), Vector(128, 128, 60)))

        self.assertTrue((tif.imread(
            os.path.join(IMAGE_PATH,
                         "test_pooled_volume.tif")) == output.getArray()).all,
                        "PooledVolume output does not match test case")
示例#4
0
    def test_stitcher(self):
        # Stitch a test TIFF dataset
        inputDataset = TiffVolume(
            os.path.join(IMAGE_PATH, "inputs.tif"),
            BoundingBox(Vector(0, 0, 0), Vector(1024, 512, 50)))
        outputDataset = Array(
            np.zeros(inputDataset.getBoundingBox().getNumpyDim()))
        inputDataset.__enter__()
        for data in inputDataset:
            outputDataset.blend(data)

        self.assertTrue(
            (inputDataset[20].getArray() == outputDataset[20].getArray()).all,
            "Blending output does not match input")

        tif.imsave(os.path.join(IMAGE_PATH, "test_stitch.tif"),
                   outputDataset[100].getArray().astype(np.uint16))
示例#5
0
 def test_cpu_training(self):
     net = RSUNet()
     inputs_dataset = TiffVolume(
         os.path.join(IMAGE_PATH, "inputs.tif"),
         BoundingBox(Vector(0, 0, 0), Vector(1024, 512, 50)))
     labels_dataset = TiffVolume(
         os.path.join(IMAGE_PATH, "labels.tif"),
         BoundingBox(Vector(0, 0, 0), Vector(1024, 512, 50)))
     inputs_dataset.__enter__()
     labels_dataset.__enter__()
     trainer = Trainer(net, inputs_dataset, labels_dataset, max_epochs=1)
     trainer.run_training()
示例#6
0
    def test_prediction(self):
        if not os.path.isdir('./tests/checkpoints'):
            os.mkdir('tests/checkpoints')

        net = RSUNet()

        checkpoint = './tests/checkpoints/iteration_10.ckpt'
        inputs_dataset = TiffVolume(
            os.path.join(IMAGE_PATH, "inputs.tif"),
            BoundingBox(Vector(0, 0, 0), Vector(1024, 512, 50)))
        inputs_dataset.__enter__()
        predictor = Predictor(net, checkpoint, gpu_device=1)

        output_volume = Array(
            np.zeros(inputs_dataset.getBoundingBox().getNumpyDim()))

        predictor.run(inputs_dataset, output_volume, batch_size=5)

        tif.imsave(os.path.join(IMAGE_PATH, "test_prediction.tif"),
                   output_volume.getArray().astype(np.float32))
示例#7
0
 def test_loss(self):
     net = RSUNet()
     inputs_dataset = TiffVolume(
         os.path.join(IMAGE_PATH, "inputs.tif"),
         BoundingBox(Vector(0, 0, 0), Vector(1024, 512, 50)))
     labels_dataset = TiffVolume(
         os.path.join(IMAGE_PATH, "labels.tif"),
         BoundingBox(Vector(0, 0, 0), Vector(1024, 512, 50)))
     inputs_dataset.__enter__()
     labels_dataset.__enter__()
     trainer = Trainer(net,
                       inputs_dataset,
                       labels_dataset,
                       max_epochs=10,
                       gpu_device=1,
                       criterion=SimplePointBCEWithLogitsLoss())
     trainer.run_training()
示例#8
0
    def test_torch_dataset(self):
        input_dataset = TiffVolume(
            os.path.join(IMAGE_PATH, "inputs.tif"),
            BoundingBox(Vector(0, 0, 0), Vector(1024, 512, 50)))
        label_dataset = TiffVolume(
            os.path.join(IMAGE_PATH, "labels.tif"),
            BoundingBox(Vector(0, 0, 0), Vector(1024, 512, 50)))
        input_dataset.__enter__()
        label_dataset.__enter__()
        training_dataset = AlignedVolume(
            (input_dataset, label_dataset),
            iteration_size=BoundingBox(Vector(0, 0, 0), Vector(128, 128, 20)),
            stride=Vector(128, 128, 20))

        tif.imsave(os.path.join(IMAGE_PATH, "test_input.tif"),
                   training_dataset[10][0].getArray())
        tif.imsave(os.path.join(IMAGE_PATH, "test_label.tif"),
                   training_dataset[10][1].getArray() * 255)
示例#9
0
    def test_loss_writer(self):
        if not os.path.isdir('./tests/test_experiment'):
            os.mkdir('tests/test_experiment')
        shutil.rmtree('./tests/test_experiment')

        net = RSUNet()
        inputs_dataset = TiffVolume(
            os.path.join(IMAGE_PATH, "inputs.tif"),
            BoundingBox(Vector(0, 0, 0), Vector(1024, 512, 50)))
        labels_dataset = TiffVolume(
            os.path.join(IMAGE_PATH, "labels.tif"),
            BoundingBox(Vector(0, 0, 0), Vector(1024, 512, 50)))
        inputs_dataset.__enter__()
        labels_dataset.__enter__()
        trainer = Trainer(net,
                          inputs_dataset,
                          labels_dataset,
                          max_epochs=1,
                          gpu_device=1)
        trainer = LossWriter(trainer, './tests/', "test_experiment")
        trainer.run_training()
示例#10
0
    def test_memory_free(self):
        process = Process(getpid())
        initial_memory = process.memory_info().rss

        start = time.perf_counter()
        with TiffVolume(os.path.join(IMAGE_PATH, "inputs.tif"),
                        BoundingBox(Vector(0, 0, 0), Vector(1024, 512,
                                                            50))) as v:
            volume_memory = process.memory_info().rss
        end = time.perf_counter()
        print("Load time: {} secs".format(end - start))

        final_memory = process.memory_info().rss

        self.assertAlmostEqual(initial_memory,
                               final_memory,
                               delta=initial_memory * 0.2,
                               msg=("memory leakage: final memory usage is " +
                                    "larger than the initial memory usage"))
        self.assertLess(initial_memory,
                        volume_memory,
                        msg=("volume loading error: volume memory usage is " +
                             "not less than the initial memory usage"))
示例#11
0
    def openVolume(self, volume_spec):
        """
        Opens a volume from a volume specification

        :param volume_spec: A dictionary specifying the volume's parameters

        :return: The volume corresponding to the volume dataset
        """
        try:
            filename = os.path.abspath(volume_spec["filename"])

            if filename.endswith(".tif"):
                edges = volume_spec["bounding_box"]
                bounding_box = BoundingBox(Vector(*edges[0]),
                                           Vector(*edges[1]))
                volume = TiffVolume(filename, bounding_box)

                return volume

            elif filename.endswith(".hdf5"):
                pooled_volume = PooledVolume()
                for dataset in volume_spec["datasets"]:
                    edges = dataset["bounding_box"]
                    bounding_box = BoundingBox(Vector(*edges[0]),
                                               Vector(*edges[1]))
                    volume = Hdf5Volume(filename, dataset, bounding_box)
                    pooled_volume.add(volume)

                return pooled_volume

            else:
                error_string = "{} is an unsupported filetype".format(volume_type)
                raise ValueError(error_string)

        except KeyError:
            error_string = "given volume_spec is corrupt"
            raise ValueError(error_string)