コード例 #1
0
ファイル: shift_augment.py プロジェクト: yajivunev/gunpowder
    def test_shift_and_crop_static(self):
        shift_node = ShiftAugment(sigma=1, shift_axis=0)
        shift_node.ndim = 2
        upstream_arr = np.arange(16).reshape(4, 4)
        sub_shift_array = np.zeros(8, dtype=int).reshape(4, 2)
        roi_shape = (4, 4)
        voxel_size = Coordinate((1, 1))

        downstream_arr = np.arange(16).reshape(4, 4)

        result = shift_node.shift_and_crop(upstream_arr, roi_shape,
                                           sub_shift_array, voxel_size)
        self.assertTrue(np.array_equal(result, downstream_arr))
コード例 #2
0
ファイル: shift_augment.py プロジェクト: yajivunev/gunpowder
    def test_pipeline3(self):
        array_key = ArrayKey("TEST_ARRAY")
        points_key = GraphKey("TEST_POINTS")
        voxel_size = Coordinate((1, 1))
        spec = ArraySpec(voxel_size=voxel_size, interpolatable=True)

        hdf5_source = Hdf5Source(self.fake_data_file, {array_key: "testdata"},
                                 array_specs={array_key: spec})
        csv_source = CsvPointsSource(
            self.fake_points_file,
            points_key,
            GraphSpec(roi=Roi(shape=Coordinate((100, 100)), offset=(0, 0))),
        )

        request = BatchRequest()
        shape = Coordinate((60, 60))
        request.add(array_key, shape, voxel_size=Coordinate((1, 1)))
        request.add(points_key, shape)

        shift_node = ShiftAugment(prob_slip=0.2,
                                  prob_shift=0.2,
                                  sigma=5,
                                  shift_axis=0)
        pipeline = ((hdf5_source, csv_source) + MergeProvider() +
                    RandomLocation(ensure_nonempty=points_key) + shift_node)
        with build(pipeline) as b:
            request = b.request_batch(request)
            # print(request[points_key])

        target_vals = [
            self.fake_data[point[0]][point[1]] for point in self.fake_points
        ]
        result_data = request[array_key].data
        result_points = list(request[points_key].nodes)
        result_vals = [
            result_data[int(point.location[0])][int(point.location[1])]
            for point in result_points
        ]

        for result_val in result_vals:
            self.assertTrue(
                result_val in target_vals,
                msg=
                "result value {} at points {} not in target values {} at points {}"
                .format(
                    result_val,
                    list(result_points),
                    target_vals,
                    self.fake_points,
                ),
            )
コード例 #3
0
ファイル: shift_augment.py プロジェクト: yajivunev/gunpowder
    def test_shift_and_crop3(self):
        shift_node = ShiftAugment(sigma=1, shift_axis=1)
        shift_node.ndim = 2
        upstream_arr = np.arange(16).reshape(4, 4)
        sub_shift_array = np.zeros(8, dtype=int).reshape(4, 2)
        sub_shift_array[:, 0] = np.array([0, 1, 0, 2], dtype=int)
        roi_shape = (2, 4)
        voxel_size = Coordinate((1, 1))

        downstream_arr = np.array([[8, 5, 10, 3], [12, 9, 14, 7]], dtype=int)

        result = shift_node.shift_and_crop(upstream_arr, roi_shape,
                                           sub_shift_array, voxel_size)
        # print(result)
        self.assertTrue(np.array_equal(result, downstream_arr))
コード例 #4
0
ファイル: shift_augment.py プロジェクト: yajivunev/gunpowder
    def test_shift_and_crop2(self):
        shift_node = ShiftAugment(sigma=1, shift_axis=0)
        shift_node.ndim = 2
        upstream_arr = np.arange(16).reshape(4, 4)
        sub_shift_array = np.zeros(8, dtype=int).reshape(4, 2)
        sub_shift_array[:, 1] = np.array([0, -1, -2, 0], dtype=int)
        roi_shape = (4, 2)
        voxel_size = Coordinate((1, 1))

        downstream_arr = np.array([[0, 1], [5, 6], [10, 11], [12, 13]],
                                  dtype=int)

        result = shift_node.shift_and_crop(upstream_arr, roi_shape,
                                           sub_shift_array, voxel_size)
        self.assertTrue(np.array_equal(result, downstream_arr))
コード例 #5
0
ファイル: shift_augment.py プロジェクト: yajivunev/gunpowder
    def test_prepare1(self):

        key = ArrayKey("TEST_ARRAY")
        spec = ArraySpec(voxel_size=Coordinate((1, 1)), interpolatable=True)

        hdf5_source = Hdf5Source(self.fake_data_file, {key: "testdata"},
                                 array_specs={key: spec})

        request = BatchRequest()
        shape = Coordinate((3, 3))
        request.add(key, shape, voxel_size=Coordinate((1, 1)))

        shift_node = ShiftAugment(sigma=1, shift_axis=0)
        with build((hdf5_source + shift_node)):
            shift_node.prepare(request)
            self.assertTrue(shift_node.ndim == 2)
            self.assertTrue(shift_node.shift_sigmas == tuple([0.0, 1.0]))
コード例 #6
0
ファイル: shift_augment.py プロジェクト: yajivunev/gunpowder
    def test_pipeline2(self):

        key = ArrayKey("TEST_ARRAY")
        spec = ArraySpec(voxel_size=Coordinate((3, 1)), interpolatable=True)

        hdf5_source = Hdf5Source(self.fake_data_file, {key: "testdata"},
                                 array_specs={key: spec})

        request = BatchRequest()
        shape = Coordinate((3, 3))
        request.add(key, shape, voxel_size=Coordinate((3, 1)))

        shift_node = ShiftAugment(prob_slip=0.2,
                                  prob_shift=0.2,
                                  sigma=1,
                                  shift_axis=0)
        with build((hdf5_source + shift_node)) as b:
            b.request_batch(request)