示例#1
0
def test_create_slice_data():
    # test output shape with even division of slice
    fov_len, stack_len, num_crops, num_slices, row_len, col_len, chan_len = 1, 40, 1, 1, 50, 50, 3
    slice_stack_len = 4

    X_data = _blank_data_xr(fov_len=fov_len,
                            stack_len=stack_len,
                            crop_num=num_crops,
                            slice_num=num_slices,
                            row_len=row_len,
                            col_len=col_len,
                            chan_len=chan_len)

    y_data = _blank_data_xr(fov_len=fov_len,
                            stack_len=stack_len,
                            crop_num=num_crops,
                            slice_num=num_slices,
                            row_len=row_len,
                            col_len=col_len,
                            chan_len=chan_len,
                            last_dim_name='compartments')

    X_slice, y_slice, slice_indices = reshape_data.create_slice_data(
        X_data, y_data, slice_stack_len)

    assert X_slice.shape == (fov_len, slice_stack_len, num_crops,
                             int(np.ceil(stack_len / slice_stack_len)),
                             row_len, col_len, chan_len)
示例#2
0
def test_save_npzs_for_caliban():
    fov_len, stack_len, num_crops, num_slices, row_len, col_len, chan_len = 1, 40, 1, 1, 50, 50, 3
    slice_stack_len = 4

    X_data = _blank_data_xr(fov_len=fov_len,
                            stack_len=stack_len,
                            crop_num=num_crops,
                            slice_num=num_slices,
                            row_len=row_len,
                            col_len=col_len,
                            chan_len=chan_len)

    y_data = _blank_data_xr(fov_len=fov_len,
                            stack_len=stack_len,
                            crop_num=num_crops,
                            slice_num=num_slices,
                            row_len=row_len,
                            col_len=col_len,
                            chan_len=1,
                            last_dim_name='compartments')

    sliced_X, sliced_y, log_data = reshape_data.create_slice_data(
        X_data=X_data, y_data=y_data, slice_stack_len=slice_stack_len)

    with tempfile.TemporaryDirectory() as temp_dir:
        io_utils.save_npzs_for_caliban(X_data=sliced_X,
                                       y_data=sliced_y,
                                       original_data=X_data,
                                       log_data=copy.copy(log_data),
                                       save_dir=temp_dir,
                                       blank_labels="include",
                                       save_format="npz",
                                       verbose=False)

        # check that correct size was saved
        test_npz_labels = np.load(
            os.path.join(temp_dir, "fov_fov1_crop_0_slice_0.npz"))

        assert test_npz_labels["y"].shape == (slice_stack_len, row_len,
                                              col_len, 1)

        assert test_npz_labels["y"].shape[:-1] == test_npz_labels[
            "X"].shape[:-1]

        # check that json saved successfully
        with open(os.path.join(temp_dir, "log_data.json")) as json_file:
            saved_log_data = json.load(json_file)

        assert saved_log_data["original_shape"] == list(X_data.shape)

    with tempfile.TemporaryDirectory() as temp_dir:
        # check that combined crop and slice saving works
        crop_size = (10, 10)
        overlap_frac = 0.2
        X_cropped, y_cropped, log_data_crop = \
            reshape_data.crop_multichannel_data(X_data=sliced_X,
                                                y_data=sliced_y,
                                                crop_size=crop_size,
                                                overlap_frac=overlap_frac,
                                                test_parameters=False)

        io_utils.save_npzs_for_caliban(X_data=X_cropped,
                                       y_data=y_cropped,
                                       original_data=X_data,
                                       log_data={
                                           **log_data,
                                           **log_data_crop
                                       },
                                       save_dir=temp_dir,
                                       blank_labels="include",
                                       save_format="npz",
                                       verbose=False)
        expected_crop_num = X_cropped.shape[2] * X_cropped.shape[3]
        files = os.listdir(temp_dir)
        files = [file for file in files if "npz" in file]

        assert len(files) == expected_crop_num

    # check that arguments specifying what to do with blank crops are working

    # set specified crops to not be blank
    sliced_y[0, 0, 0, [1, 4, 7], 0, 0, 0] = 27
    expected_crop_num = sliced_X.shape[2] * sliced_X.shape[3]

    # test that function correctly includes blank crops when saving
    with tempfile.TemporaryDirectory() as temp_dir:
        io_utils.save_npzs_for_caliban(X_data=sliced_X,
                                       y_data=sliced_y,
                                       original_data=X_data,
                                       log_data=copy.copy(log_data),
                                       save_dir=temp_dir,
                                       blank_labels="include",
                                       save_format="npz",
                                       verbose=False)

        # check that there is the expected number of files saved to directory
        files = os.listdir(temp_dir)
        files = [file for file in files if "npz" in file]

        assert len(files) == expected_crop_num

    # test that function correctly skips blank crops when saving
    with tempfile.TemporaryDirectory() as temp_dir:
        io_utils.save_npzs_for_caliban(X_data=sliced_X,
                                       y_data=sliced_y,
                                       original_data=X_data,
                                       log_data=copy.copy(log_data),
                                       save_dir=temp_dir,
                                       save_format="npz",
                                       blank_labels="skip",
                                       verbose=False)

        #  check that expected number of files in directory
        files = os.listdir(temp_dir)
        files = [file for file in files if "npz" in file]
        assert len(files) == 3

    # test that function correctly saves blank crops to separate folder
    with tempfile.TemporaryDirectory() as temp_dir:
        io_utils.save_npzs_for_caliban(X_data=sliced_X,
                                       y_data=sliced_y,
                                       original_data=X_data,
                                       log_data=copy.copy(log_data),
                                       save_dir=temp_dir,
                                       save_format="npz",
                                       blank_labels="separate",
                                       verbose=False)

        # check that expected number of files in each directory
        files = os.listdir(temp_dir)
        files = [file for file in files if "npz" in file]
        assert len(files) == 3

        files = os.listdir(os.path.join(temp_dir, "separate"))
        files = [file for file in files if "npz" in file]
        assert len(files) == expected_crop_num - 3
示例#3
0
def test_load_npzs():
    with tempfile.TemporaryDirectory() as temp_dir:
        # first generate image stack that will be sliced up
        fov_len, stack_len, crop_num, slice_num = 1, 40, 1, 1
        row_len, col_len, chan_len = 50, 50, 3
        slice_stack_len = 4

        X_data = _blank_data_xr(fov_len=fov_len,
                                stack_len=stack_len,
                                crop_num=crop_num,
                                slice_num=slice_num,
                                row_len=row_len,
                                col_len=col_len,
                                chan_len=chan_len)

        y_data = _blank_data_xr(fov_len=fov_len,
                                stack_len=stack_len,
                                crop_num=crop_num,
                                slice_num=slice_num,
                                row_len=row_len,
                                col_len=col_len,
                                chan_len=1,
                                last_dim_name='compartments')

        # slice the data
        X_slice, y_slice, log_data = reshape_data.create_slice_data(
            X_data, y_data, slice_stack_len)

        # crop the data
        crop_size = (10, 10)
        overlap_frac = 0.2
        X_cropped, y_cropped, log_data_crop = \
            reshape_data.crop_multichannel_data(
                X_data=X_slice,
                y_data=y_slice,
                crop_size=crop_size,
                overlap_frac=overlap_frac,
                test_parameters=False)

        # tag the upper left hand corner of the label in each slice
        slice_tags = np.arange(y_cropped.shape[3])
        crop_tags = np.arange(y_cropped.shape[2])
        y_cropped[0, 0, :, 0, 0, 0, 0] = crop_tags
        y_cropped[0, 0, 0, :, 0, 0, 0] = slice_tags

        combined_log_data = {**log_data, **log_data_crop}

        # save the tagged data
        io_utils.save_npzs_for_caliban(X_data=X_cropped,
                                       y_data=y_cropped,
                                       original_data=X_data,
                                       log_data=combined_log_data,
                                       save_dir=temp_dir,
                                       blank_labels="include",
                                       save_format="npz",
                                       verbose=False)

        with open(os.path.join(temp_dir, "log_data.json")) as json_file:
            saved_log_data = json.load(json_file)

        loaded_slices = io_utils.load_npzs(temp_dir,
                                           saved_log_data,
                                           verbose=False)

        # dims other than channels are the same
        assert (np.all(loaded_slices.shape[:-1] == X_cropped.shape[:-1]))

        assert np.all(np.equal(loaded_slices[0, 0, :, 0, 0, 0, 0], crop_tags))
        assert np.all(np.equal(loaded_slices[0, 0, 0, :, 0, 0, 0], slice_tags))

    # test slices with unequal last length
    with tempfile.TemporaryDirectory() as temp_dir:
        # first generate image stack that will be sliced up
        fov_len, stack_len, crop_num, slice_num = 1, 40, 1, 1
        row_len, col_len, chan_len = 50, 50, 3
        slice_stack_len = 7

        X_data = _blank_data_xr(fov_len=fov_len,
                                stack_len=stack_len,
                                crop_num=crop_num,
                                slice_num=slice_num,
                                row_len=row_len,
                                col_len=col_len,
                                chan_len=chan_len)

        y_data = _blank_data_xr(fov_len=fov_len,
                                stack_len=stack_len,
                                crop_num=crop_num,
                                slice_num=slice_num,
                                row_len=row_len,
                                col_len=col_len,
                                chan_len=1,
                                last_dim_name='compartments')

        # slice the data
        X_slice, y_slice, log_data = reshape_data.create_slice_data(
            X_data, y_data, slice_stack_len)

        # crop the data
        crop_size = (10, 10)
        overlap_frac = 0.2
        X_cropped, y_cropped, log_data_crop = \
            reshape_data.crop_multichannel_data(
                X_data=X_slice,
                y_data=y_slice,
                crop_size=crop_size,
                overlap_frac=overlap_frac,
                test_parameters=False)

        # tag the upper left hand corner of the annotations in each slice
        slice_tags = np.arange(y_cropped.shape[3])
        crop_tags = np.arange(X_cropped.shape[2])
        y_cropped[0, 0, :, 0, 0, 0, 0] = crop_tags
        y_cropped[0, 0, 0, :, 0, 0, 0] = slice_tags

        combined_log_data = {**log_data, **log_data_crop}

        # save the tagged data
        io_utils.save_npzs_for_caliban(X_data=X_cropped,
                                       y_data=y_cropped,
                                       original_data=X_data,
                                       log_data=combined_log_data,
                                       save_dir=temp_dir,
                                       blank_labels="include",
                                       save_format="npz",
                                       verbose=False)

        loaded_slices = io_utils.load_npzs(temp_dir, combined_log_data)

        # dims other than channels are the same
        assert (np.all(loaded_slices.shape[:-1] == X_cropped.shape[:-1]))

        assert np.all(np.equal(loaded_slices[0, 0, :, 0, 0, 0, 0], crop_tags))
        assert np.all(np.equal(loaded_slices[0, 0, 0, :, 0, 0, 0], slice_tags))
示例#4
0
def test_slice_helper():
    # test output shape with even division of slice
    fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len = 1, 40, 1, 1, 50, 50, 3
    slice_stack_len = 4

    slice_start_indices, slice_end_indices = slice_utils.compute_slice_indices(
        stack_len, slice_stack_len, 0)

    input_data = _blank_data_xr(fov_len=fov_len,
                                stack_len=stack_len,
                                crop_num=crop_num,
                                slice_num=slice_num,
                                row_len=row_len,
                                col_len=col_len,
                                chan_len=chan_len)

    slice_output = slice_utils.slice_helper(input_data, slice_start_indices,
                                            slice_end_indices)

    assert slice_output.shape == (fov_len, slice_stack_len, crop_num,
                                  int(np.ceil(stack_len / slice_stack_len)),
                                  row_len, col_len, chan_len)

    # test output shape with uneven division of slice
    fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len = 1, 40, 1, 1, 50, 50, 3
    slice_stack_len = 6

    slice_start_indices, slice_end_indices = slice_utils.compute_slice_indices(
        stack_len, slice_stack_len, 0)

    input_data = _blank_data_xr(fov_len=fov_len,
                                stack_len=stack_len,
                                crop_num=crop_num,
                                slice_num=slice_num,
                                row_len=row_len,
                                col_len=col_len,
                                chan_len=chan_len)

    slice_output = slice_utils.slice_helper(input_data, slice_start_indices,
                                            slice_end_indices)

    assert slice_output.shape == (fov_len, slice_stack_len, crop_num,
                                  (np.ceil(stack_len / slice_stack_len)),
                                  row_len, col_len, chan_len)

    # test output shape with slice overlaps
    fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len = 1, 40, 1, 1, 50, 50, 3
    slice_stack_len = 6
    slice_overlap = 1
    slice_start_indices, slice_end_indices = slice_utils.compute_slice_indices(
        stack_len, slice_stack_len, slice_overlap)

    input_data = _blank_data_xr(fov_len=fov_len,
                                stack_len=stack_len,
                                crop_num=crop_num,
                                slice_num=slice_num,
                                row_len=row_len,
                                col_len=col_len,
                                chan_len=chan_len)

    slice_output = slice_utils.slice_helper(input_data, slice_start_indices,
                                            slice_end_indices)

    assert slice_output.shape == (fov_len, slice_stack_len, crop_num, (np.ceil(
        stack_len / (slice_stack_len - slice_overlap))), row_len, col_len,
                                  chan_len)

    # test output values
    fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len = 1, 40, 1, 1, 50, 50, 3
    slice_stack_len = 4
    slice_start_indices, slice_end_indices = slice_utils.compute_slice_indices(
        stack_len, slice_stack_len, 0)

    input_data = _blank_data_xr(fov_len=fov_len,
                                stack_len=stack_len,
                                crop_num=crop_num,
                                slice_num=slice_num,
                                row_len=row_len,
                                col_len=col_len,
                                chan_len=chan_len)

    # tag upper left hand corner of each image
    tags = np.arange(stack_len)
    input_data[0, :, 0, 0, 0, 0, 0] = tags

    slice_output = slice_utils.slice_helper(input_data, slice_start_indices,
                                            slice_end_indices)

    # loop through each slice, make sure values increment as expected
    for i in range(slice_output.shape[1]):
        assert np.all(
            np.equal(slice_output[0, :, 0, i, 0, 0, 0],
                     tags[i * 4:(i + 1) * 4]))
示例#5
0
def test_stitch_slices():
    fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len = 1, 40, 1, 1, 50, 50, 3
    slice_stack_len = 4

    X_data = _blank_data_xr(fov_len=fov_len,
                            stack_len=stack_len,
                            crop_num=crop_num,
                            slice_num=slice_num,
                            row_len=row_len,
                            col_len=col_len,
                            chan_len=chan_len)

    y_data = _blank_data_xr(fov_len=fov_len,
                            stack_len=stack_len,
                            crop_num=crop_num,
                            slice_num=slice_num,
                            row_len=row_len,
                            col_len=col_len,
                            chan_len=1,
                            last_dim_name='compartments')

    # generate ordered data
    linear_seq = np.arange(stack_len * row_len * col_len)
    test_vals = linear_seq.reshape((stack_len, row_len, col_len))
    y_data[0, :, 0, 0, :, :, 0] = test_vals

    X_slice, y_slice, log_data = reshape_data.create_slice_data(
        X_data=X_data, y_data=y_data, slice_stack_len=slice_stack_len)

    log_data["original_shape"] = X_data.shape
    log_data["fov_names"] = X_data.fovs.values
    stitched_slices = slice_utils.stitch_slices(y_slice, {**log_data})

    # dims are the same
    assert np.all(stitched_slices.shape == y_data.shape)

    assert np.all(np.equal(stitched_slices[0, :, 0, 0, :, :, 0], test_vals))

    # test case without even division of crops into imsize

    fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len = 1, 40, 1, 1, 50, 50, 3
    slice_stack_len = 7

    X_data = _blank_data_xr(fov_len=fov_len,
                            stack_len=stack_len,
                            crop_num=crop_num,
                            slice_num=slice_num,
                            row_len=row_len,
                            col_len=col_len,
                            chan_len=chan_len)

    y_data = _blank_data_xr(fov_len=fov_len,
                            stack_len=stack_len,
                            crop_num=crop_num,
                            slice_num=slice_num,
                            row_len=row_len,
                            col_len=col_len,
                            chan_len=1,
                            last_dim_name='compartments')

    # generate ordered data
    linear_seq = np.arange(stack_len * row_len * col_len)
    test_vals = linear_seq.reshape((stack_len, row_len, col_len))
    y_data[0, :, 0, 0, :, :, 0] = test_vals

    X_slice, y_slice, log_data = reshape_data.create_slice_data(
        X_data=X_data, y_data=y_data, slice_stack_len=slice_stack_len)

    # get parameters
    log_data["original_shape"] = y_data.shape
    log_data["fov_names"] = y_data.fovs.values
    stitched_slices = slice_utils.stitch_slices(y_slice, log_data)

    assert np.all(stitched_slices.shape == y_data.shape)

    assert np.all(np.equal(stitched_slices[0, :, 0, 0, :, :, 0], test_vals))
示例#6
0
def test_crop_multichannel_data():
    # img params
    fov_len, stack_len, crop_num, slice_num, row_len = 2, 1, 1, 1, 200
    col_len, channel_len = 200, 1
    crop_size = (50, 50)
    overlap_frac = 0.2

    # test only one crop
    test_X_data = _blank_data_xr(fov_len=fov_len,
                                 stack_len=stack_len,
                                 crop_num=crop_num,
                                 slice_num=slice_num,
                                 row_len=row_len,
                                 col_len=col_len,
                                 chan_len=channel_len)

    test_y_data = _blank_data_xr(fov_len=fov_len,
                                 stack_len=stack_len,
                                 crop_num=crop_num,
                                 slice_num=slice_num,
                                 row_len=row_len,
                                 col_len=col_len,
                                 chan_len=channel_len,
                                 last_dim_name='compartments')

    X_data_cropped, y_data_cropped, log_data = \
        reshape_data.crop_multichannel_data(X_data=test_X_data,
                                            y_data=test_y_data,
                                            crop_size=crop_size,
                                            overlap_frac=overlap_frac,
                                            test_parameters=False)

    expected_crop_num = len(
        crop_utils.compute_crop_indices(img_len=row_len,
                                        crop_size=crop_size[0],
                                        overlap_frac=overlap_frac)[0])**2
    assert (X_data_cropped.shape == (fov_len, stack_len, expected_crop_num,
                                     slice_num, crop_size[0], crop_size[1],
                                     channel_len))

    assert log_data["num_crops"] == expected_crop_num

    # invalid arguments

    # no crop_size or crop_num
    with pytest.raises(ValueError):
        _ = reshape_data.crop_multichannel_data(X_data=test_X_data,
                                                y_data=test_y_data)

    # both crop_size and crop_num
    with pytest.raises(ValueError):
        _ = reshape_data.crop_multichannel_data(X_data=test_X_data,
                                                y_data=test_y_data,
                                                crop_size=(20, 20),
                                                crop_num=(20, 20))
    # bad crop_size dtype
    with pytest.raises(ValueError):
        _ = reshape_data.crop_multichannel_data(X_data=test_X_data,
                                                y_data=test_y_data,
                                                crop_size=5)
    # bad crop_size shape
    with pytest.raises(ValueError):
        _ = reshape_data.crop_multichannel_data(X_data=test_X_data,
                                                y_data=test_y_data,
                                                crop_size=(10, 5, 2))
    # bad crop_size values
    with pytest.raises(ValueError):
        _ = reshape_data.crop_multichannel_data(X_data=test_X_data,
                                                y_data=test_y_data,
                                                crop_size=(0, 5))
    # bad crop_size values
    with pytest.raises(ValueError):
        _ = reshape_data.crop_multichannel_data(X_data=test_X_data,
                                                y_data=test_y_data,
                                                crop_size=(1.5, 5))
    # bad crop_num dtype
    with pytest.raises(ValueError):
        _ = reshape_data.crop_multichannel_data(X_data=test_X_data,
                                                y_data=test_y_data,
                                                crop_num=5)
    # bad crop_num shape
    with pytest.raises(ValueError):
        _ = reshape_data.crop_multichannel_data(X_data=test_X_data,
                                                y_data=test_y_data,
                                                crop_num=(10, 5, 2))
    # bad crop_num values
    with pytest.raises(ValueError):
        _ = reshape_data.crop_multichannel_data(X_data=test_X_data,
                                                y_data=test_y_data,
                                                crop_num=(0, 5))
    # bad crop_num values
    with pytest.raises(ValueError):
        _ = reshape_data.crop_multichannel_data(X_data=test_X_data,
                                                y_data=test_y_data,
                                                crop_num=(1.5, 5))
    # bad overlap_frac value
    with pytest.raises(ValueError):
        _ = reshape_data.crop_multichannel_data(X_data=test_X_data,
                                                y_data=test_y_data,
                                                overlap_frac=1.2)
    # bad X_data dims
    with pytest.raises(ValueError):
        _ = reshape_data.crop_multichannel_data(X_data=test_X_data[0],
                                                y_data=test_y_data,
                                                crop_size=(5, 5))
    # bad y_data dims
    with pytest.raises(ValueError):
        _ = reshape_data.crop_multichannel_data(X_data=test_X_data,
                                                y_data=test_y_data[0],
                                                crop_num=(5, 5))
示例#7
0
def test_reconstruct_image_stack():
    with tempfile.TemporaryDirectory() as temp_dir:
        # generate stack of crops from image with grid pattern
        (fov_len, stack_len, crop_num, slice_num, row_len, col_len,
         chan_len) = 2, 1, 1, 1, 400, 400, 4

        X_data = _blank_data_xr(fov_len=fov_len,
                                stack_len=stack_len,
                                crop_num=crop_num,
                                slice_num=slice_num,
                                row_len=row_len,
                                col_len=col_len,
                                chan_len=chan_len)

        y_data = _blank_data_xr(fov_len=fov_len,
                                stack_len=stack_len,
                                crop_num=crop_num,
                                slice_num=slice_num,
                                row_len=row_len,
                                col_len=col_len,
                                chan_len=1,
                                last_dim_name='compartments')

        # create image with artificial objects to be segmented

        cell_idx = 1
        for i in range(12):
            for j in range(11):
                for fov in range(y_data.shape[0]):
                    y_data[fov, :, :, :, (i * 35):(i * 35 + 10 + fov * 10),
                           (j * 37):(j * 37 + 8 + fov * 10), 0] = cell_idx
                cell_idx += 1

        # Crop the data
        crop_size, overlap_frac = 100, 0.2
        X_cropped, y_cropped, log_data = \
            reshape_data.crop_multichannel_data(X_data=X_data,
                                                y_data=y_data,
                                                crop_size=(crop_size, crop_size),
                                                overlap_frac=overlap_frac)

        io_utils.save_npzs_for_caliban(X_data=X_cropped,
                                       y_data=y_cropped,
                                       original_data=X_data,
                                       log_data=log_data,
                                       save_dir=temp_dir)

        stitched_imgs = reshape_data.reconstruct_image_stack(crop_dir=temp_dir)

        # dims are the same
        assert np.all(stitched_imgs.shape == y_data.shape)

        # all the same pixels are marked
        assert (np.all(
            np.equal(stitched_imgs[:, :, 0] > 0, y_data[:, :, 0] > 0)))

        # there are the same number of cells
        assert (len(np.unique(stitched_imgs)) == len(np.unique(y_data)))

    with tempfile.TemporaryDirectory() as temp_dir:
        # generate data with the corner tagged
        fov_len, stack_len, crop_num, slice_num = 1, 40, 1, 1
        row_len, col_len, chan_len = 50, 50, 3
        slice_stack_len = 4

        X_data = _blank_data_xr(fov_len=fov_len,
                                stack_len=stack_len,
                                crop_num=crop_num,
                                slice_num=slice_num,
                                row_len=row_len,
                                col_len=col_len,
                                chan_len=chan_len)

        y_data = _blank_data_xr(fov_len=fov_len,
                                stack_len=stack_len,
                                crop_num=crop_num,
                                slice_num=slice_num,
                                row_len=row_len,
                                col_len=col_len,
                                chan_len=1,
                                last_dim_name='compartments')

        # tag upper left hand corner of the label in each image
        tags = np.arange(stack_len)
        y_data[0, :, 0, 0, 0, 0, 0] = tags

        X_slice, y_slice, slice_log_data = \
            reshape_data.create_slice_data(X_data=X_data,
                                           y_data=y_data,
                                           slice_stack_len=slice_stack_len)

        io_utils.save_npzs_for_caliban(X_data=X_slice,
                                       y_data=y_slice,
                                       original_data=X_data,
                                       log_data={**slice_log_data},
                                       save_dir=temp_dir,
                                       blank_labels="include",
                                       save_format="npz",
                                       verbose=False)

        stitched_imgs = reshape_data.reconstruct_image_stack(temp_dir)

        assert np.all(stitched_imgs.shape == y_data.shape)
        assert np.all(np.equal(stitched_imgs[0, :, 0, 0, 0, 0, 0], tags))

    with tempfile.TemporaryDirectory() as temp_dir:
        # generate data with both corners tagged and images labeled

        (fov_len, stack_len, crop_num, slice_num, row_len, col_len,
         chan_len) = 1, 8, 1, 1, 400, 400, 4

        X_data = _blank_data_xr(fov_len=fov_len,
                                stack_len=stack_len,
                                crop_num=crop_num,
                                slice_num=slice_num,
                                row_len=row_len,
                                col_len=col_len,
                                chan_len=chan_len)

        y_data = _blank_data_xr(fov_len=fov_len,
                                stack_len=stack_len,
                                crop_num=crop_num,
                                slice_num=slice_num,
                                row_len=row_len,
                                col_len=col_len,
                                chan_len=1,
                                last_dim_name='compartments')

        # create image with artificial objects to be segmented

        cell_idx = 1
        for i in range(1, 12):
            for j in range(1, 11):
                for stack in range(stack_len):
                    y_data[:, stack, :, :, (i * 35):(i * 35 + 10 + stack * 2),
                           (j * 37):(j * 37 + 8 + stack * 2), 0] = cell_idx
                cell_idx += 1

        # tag upper left hand corner of each image with squares of increasing size
        for stack in range(stack_len):
            y_data[0, stack, 0, 0, :stack, :stack, 0] = 1

        # Crop the data
        crop_size, overlap_frac = 100, 0.2
        X_cropped, y_cropped, log_data = \
            reshape_data.crop_multichannel_data(X_data=X_data,
                                                y_data=y_data,
                                                crop_size=(crop_size, crop_size),
                                                overlap_frac=overlap_frac)

        X_slice, y_slice, slice_log_data = \
            reshape_data.create_slice_data(X_data=X_cropped,
                                           y_data=y_cropped,
                                           slice_stack_len=slice_stack_len)

        io_utils.save_npzs_for_caliban(X_data=X_slice,
                                       y_data=y_slice,
                                       original_data=X_data,
                                       log_data={
                                           **slice_log_data,
                                           **log_data
                                       },
                                       save_dir=temp_dir,
                                       blank_labels="include",
                                       save_format="npz",
                                       verbose=False)

        stitched_imgs = reshape_data.reconstruct_image_stack(temp_dir)

        assert np.all(stitched_imgs.shape == y_data.shape)

        # dims are the same
        assert np.all(stitched_imgs.shape == y_data.shape)

        # all the same pixels are marked
        assert (np.all(
            np.equal(stitched_imgs[:, :, 0] > 0, y_data[:, :, 0] > 0)))

        # there are the same number of cells
        assert (len(np.unique(stitched_imgs)) == len(np.unique(y_data)))

        # check mark in upper left hand corner of image
        for stack in range(stack_len):
            original = np.zeros((10, 10))
            original[:stack, :stack] = 1
            new = stitched_imgs[0, stack, 0, 0, :10, :10, 0]
            assert np.array_equal(original > 0, new > 0)