Exemplo n.º 1
0
def fetch(in_vol, voxel_size, roi_offset, roi_shape, out_file, out_ds,
          num_workers):

    total_roi = daisy.Roi((roi_offset), (roi_shape))

    read_roi = daisy.Roi((0, ) * 3, (4800, 1280, 1280))
    write_roi = read_roi

    logging.info('Creating out dataset...')

    raw_out = daisy.prepare_ds(out_file,
                               out_ds,
                               total_roi,
                               voxel_size,
                               dtype=np.uint8,
                               write_roi=write_roi)

    logging.info('Writing to dataset...')

    daisy.run_blockwise(total_roi,
                        read_roi,
                        write_roi,
                        process_function=lambda b: fetch_in_block(
                            b, voxel_size, in_vol, raw_out),
                        fit='shrink',
                        num_workers=num_workers)
    def test_relabel_connected_components(self):

        roi = daisy.Roi((0, 0, 0), (100, 100, 100))

        block_size = (25, 25, 25)

        with tempfile.TemporaryDirectory() as tmpdir:

            array_in = daisy.prepare_ds(os.path.join(tmpdir, 'array_in.zarr'),
                                        'volumes/in',
                                        roi,
                                        voxel_size=(1, 1, 1),
                                        dtype=np.uint64)

            in_data = np.zeros((100, 100, 100), dtype=np.uint64)
            in_data[20] = 1
            in_data[40] = 1
            in_data[60] = 2

            array_in[roi] = in_data

            array_out = daisy.prepare_ds(os.path.join(tmpdir,
                                                      'array_out.zarr'),
                                         'volumes/out',
                                         roi,
                                         voxel_size=(1, 1, 1),
                                         write_size=block_size,
                                         dtype=np.uint64)

            segment.arrays.relabel_connected_components(array_in,
                                                        array_out,
                                                        block_size=block_size,
                                                        num_workers=10)

            out_data = array_out.to_ndarray(roi)

        np.testing.assert_array_equal(out_data[20] == out_data[40], False)
        np.testing.assert_array_equal(out_data[40] == out_data[60], False)
        self.assertEqual(out_data[0:20].sum(), 0)
        self.assertEqual(out_data[21:40].sum(), 0)
        self.assertEqual(out_data[41:60].sum(), 0)
        self.assertEqual(out_data[61:].sum(), 0)
        self.assertEqual(len(np.unique(out_data[20])), 1)
        self.assertEqual(len(np.unique(out_data[40])), 1)
        self.assertEqual(len(np.unique(out_data[60])), 1)
Exemplo n.º 3
0
def create_test_array():

    ds = daisy.prepare_ds('test_array.zarr',
                          'test',
                          total_roi=daisy.Roi((0, 0, 0), (20, 40, 80)),
                          voxel_size=(1, 1, 1),
                          write_size=(2, 4, 4),
                          dtype=np.float32)
    ds.data[:, :, 0:40] = 0.5
    ds.data[:, :, 40:60] = 0.75
    ds.data[:, :, 60:80] = 1.0
    ds.data[:, 0:20, :] *= 0.5

    return ds
Exemplo n.º 4
0
    def _task_init(self):

        # open dataset
        dataset = daisy.open_ds(self.in_file, self.in_ds_name)

        # define total region of interest (roi)
        total_roi = dataset.roi
        ndims = len(total_roi.get_offset())

        # define block read and write rois
        assert len(self.block_read_size) == ndims,\
            "Read size must have same dimensions as in_file"
        assert len(self.block_write_size) == ndims,\
            "Write size must have same dimensions as in_file"
        block_read_size = daisy.Coordinate(self.block_read_size)
        block_write_size = daisy.Coordinate(self.block_write_size)
        block_read_size *= dataset.voxel_size
        block_write_size *= dataset.voxel_size
        context = (block_read_size - block_write_size) / 2
        block_read_roi = daisy.Roi((0,)*ndims, block_read_size)
        block_write_roi = daisy.Roi(context, block_write_size)

        # prepare output dataset
        output_roi = total_roi.grow(-context, -context)
        if self.out_file is None:
            self.out_file = self.in_file
        if self.out_ds_name is None:
            self.out_ds_name = self.in_ds_name + '_smoothed'

        logger.info(f'Processing data to {self.out_file}/{self.out_ds_name}')

        output_dataset = daisy.prepare_ds(
                self.out_file,
                self.out_ds_name,
                total_roi=output_roi,
                voxel_size=dataset.voxel_size,
                dtype=dataset.dtype,
                write_size=block_write_roi.get_shape())

        # save variables for other functions
        self.total_roi = total_roi
        self.block_read_roi = block_read_roi
        self.block_write_roi = block_write_roi
        self.dataset = dataset
        self.output_dataset = output_dataset
    def test_minimal(self):

        labels = np.array([[[1, 1, 1, 2, 2, 3, 2, 2, 1, 140, 140, 0]]],
                          dtype=np.uint64)

        roi = daisy.Roi((0, 0, 0), labels.shape)
        voxel_size = (1, 1, 1)

        block_size = (1, 1, 2)

        with tempfile.TemporaryDirectory() as tmpdir:

            a = daisy.Array(labels, roi=roi, voxel_size=voxel_size)
            b = daisy.prepare_ds(os.path.join(tmpdir, 'array_out.zarr'),
                                 '/volumes/b',
                                 total_roi=roi,
                                 voxel_size=voxel_size,
                                 write_size=block_size,
                                 dtype=np.uint64)

            b.data[:] = 0

            segment.arrays.relabel_connected_components(a, b, block_size, 1)

            b = b.data[:].flatten()

            self.assertTrue(b[0] == b[1] == b[2])
            self.assertTrue(b[3] == b[4])
            self.assertTrue(b[6] == b[7])
            self.assertTrue(b[9] == b[10])
            self.assertTrue(b[2] != b[3])
            self.assertTrue(b[4] != b[5])
            self.assertTrue(b[5] != b[6])
            self.assertTrue(b[7] != b[8])
            self.assertTrue(b[8] != b[9])
            self.assertTrue(b[10] != b[11])
Exemplo n.º 6
0
def predict_blockwise(base_dir, experiment, train_number, predict_number,
                      iteration, in_container_spec, in_container, in_dataset,
                      in_offset, in_size, out_container, db_name, db_host,
                      singularity_container, num_cpus, num_cache_workers,
                      num_block_workers, queue, mount_dirs, **kwargs):
    '''Run prediction in parallel blocks. Within blocks, predict in chunks.

    Args:

        experiment (``string``):

            Name of the experiment (cremi, fib19, fib25, ...).

        setup (``string``):

            Name of the setup to predict.

        iteration (``int``):

            Training iteration to predict from.

        raw_file (``string``):
        raw_dataset (``string``):
        auto_file (``string``):
        auto_dataset (``string``):

            Paths to the input autocontext datasets (affs or lsds). Can be None if not needed.

        out_file (``string``):

            Path to directory where zarr should be stored

        **Note:

            out_dataset no longer needed as input, build out_dataset from config
            outputs dictionary generated in mknet.py

        file_name (``string``):

            Name of output file

        block_size_in_chunks (``tuple`` of ``int``):

            The size of one block in chunks (not voxels!). A chunk corresponds
            to the output size of the network.

        num_workers (``int``):

            How many blocks to run in parallel.

        queue (``string``):

            Name of queue to run inference on (i.e slowpoke, gpu_rtx, gpu_any,
            gpu_tesla, gpu_tesla_large)
    '''

    predict_setup_dir = os.path.join(
        os.path.join(base_dir, experiment),
        "02_predict/setup_t{}_p{}".format(train_number, predict_number))
    train_setup_dir = os.path.join(os.path.join(base_dir, experiment),
                                   "01_train/setup_t{}".format(train_number))

    # from here on, all values are in world units (unless explicitly mentioned)
    # get ROI of source
    source = daisy.open_ds(in_container_spec, in_dataset)
    logger.info('Source dataset has shape %s, ROI %s, voxel size %s' %
                (source.shape, source.roi, source.voxel_size))

    # Read network config
    predict_net_config = os.path.join(predict_setup_dir, 'predict_net.json')
    with open(predict_net_config) as f:
        logger.info('Reading setup config from {}'.format(predict_net_config))
        net_config = json.load(f)
    outputs = net_config['outputs']

    # get chunk size and context
    net_input_size = daisy.Coordinate(
        net_config['input_shape']) * source.voxel_size
    net_output_size = daisy.Coordinate(
        net_config['output_shape']) * source.voxel_size
    context = (net_input_size - net_output_size) / 2
    logger.info('Network context: {}'.format(context))

    # get total input and output ROIs
    input_roi = source.roi.grow(context, context)
    output_roi = source.roi

    # create read and write ROI
    block_read_roi = daisy.Roi((0, 0, 0), net_input_size) - context
    block_write_roi = daisy.Roi((0, 0, 0), net_output_size)

    logger.info('Preparing output dataset...')

    for output_name, val in outputs.items():
        out_dims = val['out_dims']
        out_dtype = val['out_dtype']
        out_dataset = 'volumes/%s' % output_name

        ds = daisy.prepare_ds(out_container,
                              out_dataset,
                              output_roi,
                              source.voxel_size,
                              out_dtype,
                              write_roi=block_write_roi,
                              num_channels=out_dims,
                              compressor={
                                  'id': 'gzip',
                                  'level': 5
                              })

    logger.info('Starting block-wise processing...')

    client = pymongo.MongoClient(db_host)
    db = client[db_name]
    if 'blocks_predicted' not in db.list_collection_names():
        blocks_predicted = db['blocks_predicted']
        blocks_predicted.create_index([('block_id', pymongo.ASCENDING)],
                                      name='block_id')
    else:
        blocks_predicted = db['blocks_predicted']

    # process block-wise
    succeeded = daisy.run_blockwise(
        input_roi,
        block_read_roi,
        block_write_roi,
        process_function=lambda: predict_worker(
            train_setup_dir, predict_setup_dir, predict_number, train_number,
            experiment, iteration, in_container, in_dataset, out_container,
            db_host, db_name, queue, singularity_container, num_cpus,
            num_cache_workers, mount_dirs),
        check_function=lambda b: check_block(blocks_predicted, b),
        num_workers=num_block_workers,
        read_write_conflict=False,
        fit='overhang')

    if not succeeded:
        raise RuntimeError("Prediction failed for (at least) one block")
Exemplo n.º 7
0
    def create_from_array_identifier(
        cls,
        array_identifier,
        axes,
        roi,
        num_channels,
        voxel_size,
        dtype,
        write_size=None,
        name=None,
    ):
        """
        Create a new ZarrArray given an array identifier. It is assumed that
        this array_identifier points to a dataset that does not yet exist
        """
        if write_size is None:
            # total storage per block is approx c*x*y*z*dtype_size
            # appropriate block size about 5MB.
            axis_length = (
                (1024**2 * 5 /
                 (num_channels if num_channels is not None else 1) /
                 np.dtype(dtype).itemsize)**(1 / voxel_size.dims)) // 1
            write_size = Coordinate(
                (axis_length, ) * voxel_size.dims) * voxel_size
        write_size = Coordinate(
            (min(a, b) for a, b in zip(write_size, roi.shape)))
        zarr_container = zarr.open(array_identifier.container, "a")
        try:
            daisy.prepare_ds(
                f"{array_identifier.container}",
                array_identifier.dataset,
                roi,
                voxel_size,
                dtype,
                num_channels=num_channels,
                write_size=write_size,
            )
            zarr_dataset = zarr_container[array_identifier.dataset]
            zarr_dataset.attrs["offset"] = roi.offset
            zarr_dataset.attrs["resolution"] = voxel_size
            zarr_dataset.attrs["axes"] = axes
        except zarr.errors.ContainsArrayError:
            zarr_dataset = zarr_container[array_identifier.dataset]
            assert (tuple(zarr_dataset.attrs["offset"]) == roi.offset
                    ), f"{zarr_dataset.attrs['offset']}, {roi.offset}"
            assert (tuple(zarr_dataset.attrs["resolution"]) == voxel_size
                    ), f"{zarr_dataset.attrs['resolution']}, {voxel_size}"
            assert tuple(zarr_dataset.attrs["axes"]) == tuple(
                axes), f"{zarr_dataset.attrs['axes']}, {axes}"
            assert (
                zarr_dataset.shape == (
                    (num_channels, ) if num_channels is not None else
                    ()) + roi.shape / voxel_size
            ), f"{zarr_dataset.shape}, {((num_channels,) if num_channels is not None else ()) + roi.shape / voxel_size}"
            zarr_dataset[:] = np.zeros(zarr_dataset.shape, dtype)

        zarr_array = cls.__new__(cls)
        zarr_array.file_name = array_identifier.container
        zarr_array.dataset = array_identifier.dataset
        zarr_array._axes = None
        zarr_array._attributes = zarr_array.data.attrs
        zarr_array.snap_to_grid = None
        return zarr_array
Exemplo n.º 8
0
def predict_blockwise(experiment,
                      setup,
                      iteration,
                      raw_file,
                      raw_dataset,
                      out_directory,
                      out_filename,
                      num_workers,
                      db_host,
                      db_name,
                      worker_config=None,
                      out_properties={},
                      overwrite=False,
                      configname='test'):
    '''Run prediction in parallel blocks. Within blocks, predict in chunks.

    Args:

        experiment (``string``):

            Name of the experiment (cremi, fib19, fib25, ...).

        setup (``string``):

            Name of the setup to predict.

        iteration (``int``):

            Training iteration to predict from.

        raw_file (``string``):

            Input raw file for network to predict from.

        raw_dataset (``string``):

            Datasetname of raw data.

        out_directory (``string``):

            Output base directory.

        out_filename (``string``):

            File is written to <out_directory>/<setup>/<iteration>/<out_filename>

        num_workers (``int``):

            How many blocks to run in parallel.

        db_host (``string``):

            MongoDB host. This is used to monitor block completeness.

        db_name (``db_name``):

            MongoDB name. A collection is created with `blocks_predicted` for
            monitoring blocks.

        out_properties (``dic``, optional):

            Use this to set properties for the output data. The dictionary
            maps from tensorflow name to properties (all optional):
            - dsname: sets the datasetname, if not provided, tensorflow name is used
            - dtype: sets the dtype (make sure that you scale and clip
            accordingly). If not provided, the dtype in config file is used.
            - scale: scales the data and sets dataset attribute 'scale'

        configname (``string``, optional):

            Name of the configfile: Networksetups (such as input_size and
            output_size) are loaded from file: <configname>_net_config.json.
            This should be the same file that the predict script loads.
            Train usually indicates a smaller network than test (test network
            is written out for larger datasets/production).

        overwrite (``bool``, optional):

            If set to True, inference is started form scratch and log info
            from `blocks_predicted` is ignored, database collection
            `blocks_predicted` is overwritten.

    '''

    experiment_dir = '../'
    train_dir = os.path.join(experiment_dir, 'train', experiment)
    if not os.path.exists(train_dir):
        train_dir = os.path.join(experiment_dir, 'train')

    db_name = db_name + '_{}_{}'.format(setup, iteration)
    if experiment != 'cremi':
        db_name += f'_{experiment}'

    network_dir = os.path.join(experiment, setup, str(iteration))
    if experiment != 'cremi':  # backwards compatability
        out_directory = os.path.join(out_directory, experiment)

    raw_file = os.path.abspath(raw_file)
    out_file = os.path.abspath(
        os.path.join(out_directory, setup, str(iteration), out_filename))

    setup = os.path.abspath(os.path.join(train_dir, setup))

    print('Input file path: ', raw_file)
    print('Output file path: ', out_file)
    # from here on, all values are in world units (unless explicitly mentioned)

    # get ROI of source
    try:
        source = daisy.open_ds(raw_file, raw_dataset)
    except:
        raw_dataset = raw_dataset + '/s0'
        source = daisy.open_ds(raw_file, raw_dataset)
    print("Source dataset has shape %s, ROI %s, voxel size %s" %
          (source.shape, source.roi, source.voxel_size))

    # load config
    with open(os.path.join(setup,
                           '{}_net_config.json'.format(configname))) as f:
        print("Reading setup config from %s" %
              os.path.join(setup, '{}_net_config.json'.format(configname)))
        net_config = json.load(f)
    outputs = net_config['outputs']

    # get chunk size and context
    net_input_size = daisy.Coordinate(
        net_config['input_shape']) * source.voxel_size
    net_output_size = daisy.Coordinate(
        net_config['output_shape']) * source.voxel_size
    context = (net_input_size - net_output_size) / 2

    # get total input and output ROIs
    input_roi = source.roi.grow(context, context)
    output_roi = source.roi

    print("Following sizes in world units:")
    print("net input size  = %s" % (net_input_size, ))
    print("net output size = %s" % (net_output_size, ))
    print("context         = %s" % (context, ))

    # create read and write ROI
    block_read_roi = daisy.Roi((0, 0, 0), net_input_size) - context
    block_write_roi = daisy.Roi((0, 0, 0), net_output_size)

    print("Following ROIs in world units:")
    print("Block read  ROI  = %s" % block_read_roi)
    print("Block write ROI  = %s" % block_write_roi)
    print("Total input  ROI  = %s" % input_roi)
    print("Total output ROI  = %s" % output_roi)

    logging.info('Preparing output dataset')
    print("Preparing output dataset...")
    for outputname, val in outputs.items():
        out_dims = val['out_dims']
        out_dtype = val['out_dtype']
        scale = None
        print(outputname)
        if outputname in out_properties:
            out_property = out_properties[outputname]
            out_dtype = out_property[
                'dtype'] if 'dtype' in out_property else out_dtype
            scale = out_property['scale'] if 'scale' in out_property else None
            outputname = out_property[
                'dsname'] if 'dsname' in out_property else outputname
        print('setting dtype to {}'.format(out_dtype))
        out_dataset = 'volumes/%s' % outputname
        print('Creatining dataset: {}'.format(out_dataset))
        print('Number of dimensions is %i' % out_dims)
        ds = daisy.prepare_ds(
            out_file,
            out_dataset,
            output_roi,
            source.voxel_size,
            out_dtype,
            write_roi=block_write_roi,
            num_channels=out_dims,
            # temporary fix until
            # https://github.com/zarr-developers/numcodecs/pull/87 gets approved
            # (we want gzip to be the default)
            compressor={
                'id': 'gzip',
                'level': 5
            })
        if scale is not None:
            ds.data.attrs['scale'] = scale

    print("Starting block-wise processing...")

    client = pymongo.MongoClient(db_host)
    db = client[db_name]

    if overwrite:
        db.drop_collection('blocks_predicted')

    if 'blocks_predicted' not in db.list_collection_names():
        blocks_predicted = db['blocks_predicted']
        blocks_predicted.create_index([('block_id', pymongo.ASCENDING)],
                                      name='block_id')
    else:
        blocks_predicted = db['blocks_predicted']

    # process block-wise
    succeeded = daisy.run_blockwise(
        input_roi,
        block_read_roi,
        block_write_roi,
        process_function=lambda: predict_worker(
            setup, network_dir, iteration, raw_file, raw_dataset, out_file,
            out_properties, db_host, db_name, configname, worker_config),
        check_function=lambda b: check_block(blocks_predicted, b),
        num_workers=num_workers,
        read_write_conflict=False,
        fit='overhang')

    if not succeeded:
        raise RuntimeError("Prediction failed for (at least) one block")
Exemplo n.º 9
0
def save_samples(pred_affs, pred_affs_ds, segmentation, labels, labels_dataset,
                 fragments, boundarys, distances, seeds, history, threshold,
                 curr_log_dir, checkpoint, is_2d):
    voxel_size = labels_dataset.voxel_size
    roi = labels_dataset.roi
    if is_2d:
        voxel_size = (1, *voxel_size)
        roi = daisy.Roi((0, *roi.get_offset()), (1, *roi.get_shape()))
        labels = labels[np.newaxis]
        segmentation = segmentation[np.newaxis]

    zarr_file = os.path.join(
        curr_log_dir,
        f'samples/sample_{checkpoint}_thresh' + f'_{threshold}.zarr')

    seg = daisy.prepare_ds(zarr_file,
                           ds_name='segmentation',
                           total_roi=roi,
                           voxel_size=voxel_size,
                           dtype=np.uint64,
                           num_channels=1)

    gt = daisy.prepare_ds(zarr_file,
                          ds_name='gt',
                          total_roi=roi,
                          voxel_size=voxel_size,
                          dtype=np.uint64,
                          num_channels=1)

    pred = daisy.prepare_ds(zarr_file,
                            ds_name='affs',
                            total_roi=pred_affs_ds.roi,
                            voxel_size=pred_affs_ds.voxel_size,
                            dtype=pred_affs_ds.dtype,
                            num_channels=pred_affs.shape[0])

    utils.save_zarr(fragments,
                    zarr_file,
                    ds='fragment',
                    roi=fragments.shape,
                    voxel_size=voxel_size,
                    fit_voxel=True)

    utils.save_zarr(boundarys,
                    zarr_file,
                    ds='boundary',
                    roi=boundarys.shape,
                    voxel_size=voxel_size,
                    fit_voxel=True)

    utils.save_zarr(distances,
                    zarr_file,
                    ds='dist_trfm',
                    roi=distances.shape,
                    voxel_size=voxel_size,
                    fit_voxel=True)

    utils.save_zarr(seeds,
                    zarr_file,
                    ds='seeds',
                    roi=seeds.shape,
                    voxel_size=voxel_size,
                    fit_voxel=True)

    seg.data[:] = segmentation
    gt.data[:] = labels
    pred.data[:] = pred_affs

    history = [merge for thresh in history for merge in thresh]
    history = pd.DataFrame(history)
    history.to_csv(
        os.path.join(curr_log_dir, f'samples/sample_{checkpoint}_hist' +
                     f'_{threshold}.csv'))
Exemplo n.º 10
0
    def _task_init(self):

        logger.info(f"Accessing {self.in_ds_name} in {self.in_file}")
        try:
            self.in_ds = daisy.open_ds(self.in_file, self.in_ds_name)
        except Exception as e:
            logger.info(f"EXCEPTION: {e}")
            exit(1)

        voxel_size = self.in_ds.voxel_size

        if self.in_ds.n_channel_dims == 0:
            num_channels = 1
        elif self.in_ds.n_channel_dims == 1:
            num_channels = self.in_ds.shape[0]
        else:
            raise RuntimeError(
                "more than one channel not yet implemented, sorry...")

        self.ds_roi = self.in_ds.roi

        sub_roi = None
        if self.roi_offset is not None or self.roi_shape is not None:
            assert self.roi_offset is not None and self.roi_shape is not None
            self.schedule_roi = daisy.Roi(
                tuple(self.roi_offset), tuple(self.roi_shape))
            sub_roi = self.schedule_roi
        else:
            self.schedule_roi = self.in_ds.roi

        if self.chunk_shape_voxel is None:
            self.chunk_shape_voxel = calculateNearIsotropicDimensions(
                voxel_size, self.max_voxel_count)
            logger.info(voxel_size)
            logger.info(self.chunk_shape_voxel)
        self.chunk_shape_voxel = Coordinate(self.chunk_shape_voxel)

        self.schedule_roi = self.schedule_roi.snap_to_grid(
            voxel_size,
            mode='grow')
        out_ds_roi = self.ds_roi.snap_to_grid(
            voxel_size,
            mode='grow')

        self.write_size = self.chunk_shape_voxel*voxel_size

        scheduling_block_size = self.write_size
        self.write_roi = daisy.Roi((0, 0, 0), scheduling_block_size)

        if sub_roi is not None:
            # with sub_roi, the coordinates are absolute
            # so we'd need to align total_roi to the write size too
            self.schedule_roi = self.schedule_roi.snap_to_grid(
                self.write_size, mode='grow')
            out_ds_roi = out_ds_roi.snap_to_grid(
                self.write_size, mode='grow')

        logger.info(f"out_ds_roi: {out_ds_roi}")
        logger.info(f"schedule_roi: {self.schedule_roi}")
        logger.info(f"write_size: {self.write_size}")
        logger.info(f"voxel_size: {voxel_size}")

        if self.out_file is None:
            self.out_file = '.'.join(self.in_file.split('.')[0:-1])+'.zarr'
        if self.out_ds_name is None:
            self.out_ds_name = self.in_ds_name

        delete = self.overwrite == 2

        self.out_ds = daisy.prepare_ds(
            self.out_file,
            self.out_ds_name,
            total_roi=out_ds_roi,
            voxel_size=voxel_size,
            write_size=self.write_size,
            dtype=self.in_ds.dtype,
            num_channels=num_channels,
            force_exact_write_size=True,
            compressor={'id': 'blosc', 'clevel': 3},
            delete=delete,
            )
def extract_segmentation(fragments_file,
                         fragments_dataset,
                         edges_collection,
                         threshold,
                         block_size,
                         out_file,
                         out_dataset,
                         num_workers,
                         roi_offset=None,
                         roi_shape=None,
                         run_type=None,
                         **kwargs):
    '''

    Args:

        fragments_file (``string``):

            Path to file (zarr/n5) containing fragments (supervoxels).

        fragments_dataset (``string``):

            Name of fragments dataset (e.g `volumes/fragments`)

        edges_collection (``string``):

            The name of the MongoDB database edges collection to use.

        threshold (``float``):

            The threshold to use for generating a segmentation.

        block_size (``tuple`` of ``int``):

            The size of one block in world units (must be multiple of voxel
            size).

        out_file (``string``):

            Path to file (zarr/n5) to write segmentation to.

        out_dataset (``string``):

            Name of segmentation dataset (e.g `volumes/segmentation`).

        num_workers (``int``):

            How many workers to use when reading the region adjacency graph
            blockwise.

        roi_offset (array-like of ``int``, optional):

            The starting point (inclusive) of the ROI. Entries can be ``None``
            to indicate unboundedness.

        roi_shape (array-like of ``int``, optional):

            The shape of the ROI. Entries can be ``None`` to indicate
            unboundedness.

        run_type (``string``, optional):

            Can be used to direct luts into directory (e.g testing, validation,
            etc).

    '''

    # open fragments
    fragments = daisy.open_ds(fragments_file, fragments_dataset)

    total_roi = fragments.roi
    if roi_offset is not None:
        assert roi_shape is not None, "If roi_offset is set, roi_shape " \
                                      "also needs to be provided"
        total_roi = daisy.Roi(offset=roi_offset, shape=roi_shape)

    read_roi = daisy.Roi((0, ) * 3, daisy.Coordinate(block_size))
    write_roi = read_roi

    logging.info("Preparing segmentation dataset...")
    segmentation = daisy.prepare_ds(out_file,
                                    out_dataset,
                                    total_roi,
                                    voxel_size=fragments.voxel_size,
                                    dtype=np.uint64,
                                    write_roi=write_roi)

    lut_filename = f'seg_{edges_collection}_{int(threshold*100)}'

    lut_dir = os.path.join(fragments_file, 'luts', 'fragment_segment')

    if run_type:
        lut_dir = os.path.join(lut_dir, run_type)
        logging.info(f"Run type set, using luts from {run_type} data")

    lut = os.path.join(lut_dir, lut_filename + '.npz')

    assert os.path.exists(lut), f"{lut} does not exist"

    logging.info("Reading fragment-segment LUT...")

    lut = np.load(lut)['fragment_segment_lut']

    logging.info(f"Found {len(lut[0])} fragments in LUT")

    num_segments = len(np.unique(lut[1]))
    logging.info(f"Relabelling fragments to {num_segments} segments")

    daisy.run_blockwise(total_roi,
                        read_roi,
                        write_roi,
                        lambda b: segment_in_block(
                            b, fragments_file, segmentation, fragments, lut),
                        fit='shrink',
                        num_workers=num_workers)
Exemplo n.º 12
0
    block_read_roi = daisy.Roi((0,)*ndims, block_read_size)
    block_write_roi = daisy.Roi(context, block_write_size)

    # prepare output dataset
    output_roi = total_roi.grow(-context, -context)
    if config.out_file is None:
        config.out_file = config.in_file
    if config.out_ds_name is None:
        config.out_ds_name = config.in_ds_name + '_smoothed'

    logger.info(f'Processing data to {config.out_file}/{config.out_ds_name}')

    output_dataset = daisy.prepare_ds(
            config.out_file,
            config.out_ds_name,
            total_roi=output_roi,
            voxel_size=dataset.voxel_size,
            dtype=dataset.dtype,
            write_size=block_write_roi.get_shape())

    # make task
    task = daisy.Task(
            'GaussianSmoothingTask',
            total_roi,
            block_read_roi,
            block_write_roi,
            process_function=lambda b: smooth(
                b, dataset, output_dataset, sigma=config.sigma),
            read_write_conflict=False,
            num_workers=config.num_workers,
            fit='shrink'
Exemplo n.º 13
0
def extract_fragments(experiment,
                      setup,
                      iteration,
                      affs_file,
                      affs_dataset,
                      fragments_file,
                      fragments_dataset,
                      block_size,
                      context,
                      db_host,
                      db_name,
                      num_workers,
                      fragments_in_xy,
                      queue,
                      epsilon_agglomerate=0,
                      mask_file=None,
                      mask_dataset=None,
                      filter_fragments=0,
                      replace_sections=None,
                      **kwargs):
    '''

    Extract fragments in parallel blocks. Requires that affinities have been
    predicted before.

    When running parallel inference, the worker files are located in the setup
    directory of each experiment since that is where the training was done and
    checkpoints are located. When running watershed (and agglomeration) in
    parallel, we call a worker file which can be located anywhere. By default,
    we assume there is a workers directory inside the current directory that
    contains worker scripts (e.g `workers/extract_fragments_worker.py`).

    Args:

        * following three params just used to build out file directory *

        experiment (``string``):

            Name of the experiment (fib25, hemi, zfinch, ...).

        setup (``string``):

            Name of the setup to predict (setup01, setup02, ...).

        iteration (``int``):

            Training iteration.

        affs_file (``string``):

            Path to file (zarr/n5) where predictions are stored.

        affs_dataset (``string``):

            Predictions dataset to use (e.g 'volumes/affs'). If using a scale pyramid,
            will try scale zero assuming stored in directory `s0` (e.g
            'volumes/affs/s0').

        fragments_file (``string``):

            Path to file (zarr/n5) to store fragments (supervoxels) - generally
            a good idea to store in the same place as affs.

        fragments_dataset (``string``):

            Name of dataset to write fragments (supervoxels) to (e.g
            'volumes/fragments').

        block_size (``tuple`` of ``int``):

            The size of one block in world units (must be multiple of voxel
            size).

        context (``tuple`` of ``int``):

            The context to consider for fragment extraction in world units.

        db_host (``string``):

            Name of MongoDB client.

        db_name (``string``):

            Name of MongoDB database to use (for logging successful blocks in
            check function and writing nodes to the region adjacency graph).

        num_workers (``int``):

            How many blocks to run in parallel.

        fragments_in_xy (``bool``):

            Whether to extract fragments for each xy-section separately.

        queue (``string``):

            Name of cpu queue to use (e.g local)

        epsilon_agglomerate (``float``, optional):

            Perform an initial waterz agglomeration on the extracted fragments
            to this threshold. Skip if 0 (default).

        mask_file (``string``, optional):

            Path to file (zarr/n5) containing mask.

        mask_dataset (``string``, optional):

            Name of mask dataset. Data should be uint8 where 1 == masked in, 0
            == masked out.

        filter_fragments (``float``, optional):

            Filter fragments that have an average affinity lower than this
            value.

        replace_sections (``list`` of ``int``, optional):

            Replace fragments data with zero in given sections (useful if large
            artifacts are causing issues). List of section numbers (in voxels).

    '''

    logging.info(f"Reading affs from {affs_file}")

    try:
        affs = daisy.open_ds(affs_file, affs_dataset)
    except:
        affs_dataset = affs_dataset + '/s0'
        source = daisy.open_ds(affs_file, affs_dataset)

    network_dir = os.path.join(experiment, setup, str(iteration))

    client = pymongo.MongoClient(db_host)
    db = client[db_name]

    if 'blocks_extracted' not in db.list_collection_names():
        blocks_extracted = db['blocks_extracted']
        blocks_extracted.create_index([('block_id', pymongo.ASCENDING)],
                                      name='block_id')
    else:
        blocks_extracted = db['blocks_extracted']

    # prepare fragments dataset. By default use same roi as affinities, change
    # roi if extracting fragments in cropped region
    fragments = daisy.prepare_ds(fragments_file,
                                 fragments_dataset,
                                 affs.roi,
                                 affs.voxel_size,
                                 np.uint64,
                                 daisy.Roi((0, 0, 0), block_size),
                                 compressor={
                                     'id': 'zlib',
                                     'level': 5
                                 })

    context = daisy.Coordinate(context)
    total_roi = affs.roi.grow(context, context)

    read_roi = daisy.Roi((0, ) * affs.roi.dims(),
                         block_size).grow(context, context)
    write_roi = daisy.Roi((0, ) * affs.roi.dims(), block_size)

    #get number of voxels in block
    num_voxels_in_block = (write_roi / affs.voxel_size).size()

    #blockwise watershed
    daisy.run_blockwise(
        total_roi=total_roi,
        read_roi=read_roi,
        write_roi=write_roi,
        process_function=lambda: start_worker(
            affs_file, affs_dataset, fragments_file, fragments_dataset,
            db_host, db_name, context, fragments_in_xy, queue, network_dir,
            epsilon_agglomerate, mask_file, mask_dataset, filter_fragments,
            replace_sections, num_voxels_in_block),
        check_function=lambda b: check_block(blocks_extracted, b),
        num_workers=num_workers,
        read_write_conflict=False,
        fit='shrink')
Exemplo n.º 14
0
def predict_blockwise(config_file, iteration):
    config = {
        "solve_context": daisy.Coordinate((2, 100, 100, 100)),
        "num_workers": 16,
        "data_dir": '../01_data',
        "setups_dir": '../02_setups',
    }
    master_config = load_config(config_file)
    config.update(master_config['general'])
    config.update(master_config['predict'])
    sample = config['sample']
    data_dir = config['data_dir']
    setup = config['setup']
    # solve_context = daisy.Coordinate(master_config['solve']['context'])
    setup_dir = os.path.abspath(os.path.join(config['setups_dir'], setup))
    voxel_size, source_roi = get_source_roi(data_dir, sample)
    predict_roi = source_roi

    # limit to specific frames, if given
    if 'limit_to_roi_offset' in config or 'frames' in config:
        if 'frames' in config:
            frames = config['frames']
            logger.info("Limiting prediction to frames %s" % str(frames))
            begin, end = frames
            frames_roi = daisy.Roi((begin, None, None, None),
                                   (end - begin, None, None, None))
            predict_roi = predict_roi.intersect(frames_roi)
        if 'limit_to_roi_offset' in config:
            assert 'limit_to_roi_shape' in config,\
                    "Must specify shape and offset in config file"
            limit_to_roi = daisy.Roi(
                daisy.Coordinate(config['limit_to_roi_offset']),
                daisy.Coordinate(config['limit_to_roi_shape']))
            predict_roi = predict_roi.intersect(limit_to_roi)

        # Given frames and rois are the prediction region,
        # not the solution region
        # predict_roi = target_roi.grow(solve_context, solve_context)
        # predict_roi = predict_roi.intersect(source_roi)

    # get context and total input and output ROI
    with open(os.path.join(setup_dir, 'test_net_config.json'), 'r') as f:
        net_config = json.load(f)
    net_input_size = net_config['input_shape']
    net_output_size = net_config['output_shape_2']
    net_input_size = daisy.Coordinate(net_input_size) * voxel_size
    net_output_size = daisy.Coordinate(net_output_size) * voxel_size
    context = (net_input_size - net_output_size) / 2

    # expand predict roi to multiple of block write_roi
    predict_roi = predict_roi.snap_to_grid(net_output_size, mode='grow')

    input_roi = predict_roi.grow(context, context)
    output_roi = predict_roi

    # prepare output zarr, if necessary
    if 'output_zarr' in config:
        output_zarr = config['output_zarr']
        parent_vectors_ds = 'volumes/parent_vectors'
        cell_indicator_ds = 'volumes/cell_indicator'
        output_path = os.path.join(setup_dir, output_zarr)
        logger.debug("Preparing zarr at %s" % output_path)
        daisy.prepare_ds(output_path,
                         parent_vectors_ds,
                         output_roi,
                         voxel_size,
                         dtype=np.float32,
                         write_size=net_output_size,
                         num_channels=3)
        daisy.prepare_ds(output_path,
                         cell_indicator_ds,
                         output_roi,
                         voxel_size,
                         dtype=np.float32,
                         write_size=net_output_size,
                         num_channels=1)

    # create read and write ROI
    block_write_roi = daisy.Roi((0, 0, 0, 0), net_output_size)
    block_read_roi = block_write_roi.grow(context, context)

    logger.info("Following ROIs in world units:")
    logger.info("Input ROI       = %s" % input_roi)
    logger.info("Block read  ROI = %s" % block_read_roi)
    logger.info("Block write ROI = %s" % block_write_roi)
    logger.info("Output ROI      = %s" % output_roi)

    logger.info("Starting block-wise processing...")

    # process block-wise
    if 'db_name' in config:
        daisy.run_blockwise(
            input_roi,
            block_read_roi,
            block_write_roi,
            process_function=lambda: predict_worker(config_file, iteration),
            check_function=lambda b: check_function(b, 'predict', config[
                'db_name'], config['db_host']),
            num_workers=config['num_workers'],
            read_write_conflict=False,
            max_retries=0,
            fit='valid')
    else:
        daisy.run_blockwise(
            input_roi,
            block_read_roi,
            block_write_roi,
            process_function=lambda: predict_worker(config_file, iteration),
            num_workers=config['num_workers'],
            read_write_conflict=False,
            max_retries=0,
            fit='valid')
Exemplo n.º 15
0
def save_zarr(data,
              zarr_file,
              ds,
              roi,
              voxel_size=(1, 1, 1),
              num_channels=1,
              dtype=None,
              fit_voxel=False):
    """
        Helper function to save_zarr files using daisy.

        Args:
            
            data (`numpy array`):

                The data that you want to save.

            zarr_file (`string`):

                The zarr file you want to save to.

            ds (`string`):

                The dataset that the data should be saved as.

            roi (`daisy.Roi` or `list-like`):

                The roi to save the datset as.

            voxel_size (`tuple`, default=(1, 1, 1)):

                The voxel size to save the dataset as.

            num_channels (`int`, default=1):

                How many channels the data has. 
                Note: Daisy only supports saving zarrs with a single
                channel dim, so (num_channels, roi) is the only possible
                shape of the data.

            dtype (`numpy dtype`, optional):
                
                The datatype to save the data as

            fit_voxel (`bool`):
                
                If true then the roi will be multiplied by the voxel_size.
                This is useful if the ROI is in unit voxels and you want it
                to be in world units.


    """
    if not isinstance(roi, daisy.Roi):
        roi = daisy.Roi([0 for d in range(len(roi))], roi)

    if fit_voxel:
        roi = roi * voxel_size

    if dtype is None:
        dtype = data.dtype

    dataset = daisy.prepare_ds(zarr_file,
                               ds_name=ds,
                               total_roi=roi,
                               voxel_size=voxel_size,
                               dtype=data.dtype,
                               num_channels=num_channels)

    if roi.dims() > len(data.shape) and num_channels == 1:
        data = np.squeeze(data, 0)
    dataset.data[:] = data
Exemplo n.º 16
0
def create_scale_pyramid(in_file, in_ds_name, scales, chunk_shape):

    ds = zarr.open(in_file)

    # make sure in_ds_name points to a dataset
    try:
        daisy.open_ds(in_file, in_ds_name)
    except Exception:
        raise RuntimeError("%s does not seem to be a dataset" % in_ds_name)

    if not in_ds_name.endswith('/s0'):

        ds_name = in_ds_name + '/s0'

        print("Moving %s to %s" % (in_ds_name, ds_name))
        ds.store.rename(in_ds_name, in_ds_name + '__tmp')
        ds.store.rename(in_ds_name + '__tmp', ds_name)

    else:

        ds_name = in_ds_name
        in_ds_name = in_ds_name[:-3]

    print("Scaling %s by a factor of %s" % (in_file, scales))

    prev_array = daisy.open_ds(in_file, ds_name)

    if chunk_shape is not None:
        chunk_shape = daisy.Coordinate(chunk_shape)
    else:
        chunk_shape = daisy.Coordinate(prev_array.data.chunks)
        print("Reusing chunk shape of %s for new datasets" % (chunk_shape, ))

    if prev_array.n_channel_dims == 0:
        num_channels = 1
    elif prev_array.n_channel_dims == 1:
        num_channels = prev_array.shape[0]
    else:
        raise RuntimeError(
            "more than one channel not yet implemented, sorry...")

    for scale_num, scale in enumerate(scales):

        try:
            scale = daisy.Coordinate(scale)
        except Exception:
            scale = daisy.Coordinate((scale, ) * chunk_shape.dims())

        next_voxel_size = prev_array.voxel_size * scale
        next_total_roi = prev_array.roi.snap_to_grid(next_voxel_size,
                                                     mode='grow')
        next_write_size = chunk_shape * next_voxel_size

        print("Next voxel size: %s" % (next_voxel_size, ))
        print("Next total ROI: %s" % next_total_roi)
        print("Next chunk size: %s" % (next_write_size, ))

        next_ds_name = in_ds_name + '/s' + str(scale_num + 1)
        print("Preparing %s" % (next_ds_name, ))

        next_array = daisy.prepare_ds(in_file,
                                      next_ds_name,
                                      total_roi=next_total_roi,
                                      voxel_size=next_voxel_size,
                                      write_size=next_write_size,
                                      dtype=prev_array.dtype,
                                      num_channels=num_channels)

        downscale(prev_array, next_array, scale, next_write_size)

        prev_array = next_array
def extract_segmentation(fragments_file,
                         fragments_dataset,
                         edges_collection,
                         threshold,
                         out_file,
                         out_dataset,
                         num_workers,
                         lut_fragment_segment,
                         roi_offset=None,
                         roi_shape=None,
                         run_type=None,
                         **kwargs):

    # open fragments
    fragments = daisy.open_ds(fragments_file, fragments_dataset)

    total_roi = fragments.roi
    if roi_offset is not None:
        assert roi_shape is not None, "If roi_offset is set, roi_shape " \
                                      "also needs to be provided"
        total_roi = daisy.Roi(offset=roi_offset, shape=roi_shape)

    read_roi = daisy.Roi((0, 0, 0), (5000, 5000, 5000))
    write_roi = daisy.Roi((0, 0, 0), (5000, 5000, 5000))

    logging.info("Preparing segmentation dataset...")
    segmentation = daisy.prepare_ds(out_file,
                                    out_dataset,
                                    total_roi,
                                    voxel_size=fragments.voxel_size,
                                    dtype=np.uint64,
                                    write_roi=write_roi)

    lut_filename = 'seg_%s_%d' % (edges_collection, int(threshold * 100))

    lut_dir = os.path.join(fragments_file, lut_fragment_segment)

    if run_type:
        lut_dir = os.path.join(lut_dir, run_type)
        logging.info("Run type set, using luts from %s data" % run_type)

    lut = os.path.join(lut_dir, lut_filename + '.npz')

    assert os.path.exists(lut), "%s does not exist" % lut

    start = time.time()
    logging.info("Reading fragment-segment LUT...")
    lut = np.load(lut)['fragment_segment_lut']
    logging.info("%.3fs" % (time.time() - start))

    logging.info("Found %d fragments in LUT" % len(lut[0]))

    daisy.run_blockwise(total_roi,
                        read_roi,
                        write_roi,
                        lambda b: segment_in_block(
                            b, fragments_file, segmentation, fragments, lut),
                        fit='shrink',
                        num_workers=num_workers,
                        processes=True,
                        read_write_conflict=False)
Exemplo n.º 18
0
def predict_blockwise(base_dir,
                      experiment,
                      setup,
                      iteration,
                      raw_file,
                      raw_dataset,
                      out_base,
                      file_name,
                      num_workers,
                      db_host,
                      db_name,
                      queue,
                      auto_file=None,
                      auto_dataset=None,
                      singularity_image=None):
    '''

    Run prediction in parallel blocks. Within blocks, predict in chunks.


    Assumes a general directory structure:


    base
    ├── fib25 (experiment dir)
    │   │
    │   ├── 01_data (data dir)
    │   │   └── training_data (i.e zarr/n5, etc)
    │   │
    │   └── 02_train (train/predict dir)
    │       │
    │       ├── setup01 (setup dir - e.g baseline affinities)
    │       │   │
    │       │   ├── config.json (specifies meta data for inference)
    │       │   │
    │       │   ├── mknet.py (creates network, jsons to be used)
    │       │   │
    │       │   ├── model_checkpoint (saved network checkpoint for inference)
    │       │   │
    │       │   ├── predict.py (worker inference file - logic to be distributed)
    │       │   │
    │       │   ├── train_net.json (specifies meta data for training)
    │       │   │
    │       │   └── train.py (network training script)
    │       │
    │       ├──    .
    │       ├──    .
    │       ├──    .
    │       └── setup{n}
    │
    ├── hemi-brain
    ├── zebrafinch
    ├──     .
    ├──     .
    ├──     .
    └── experiment{n}

    Args:

        base_dir (``string``):

            Path to base directory containing experiment sub directories.

        experiment (``string``):

            Name of the experiment (fib25, hemi, zfinch, ...).

        setup (``string``):

            Name of the setup to predict (setup01, setup02, ...).

        iteration (``int``):

            Training iteration to predict from.

        raw_file (``string``):

            Path to raw file (zarr/n5) - can also be a json container
            specifying a crop, where offset and size are in world units:

                {
                    "container": "path/to/raw",
                    "offset": [z, y, x],
                    "size": [z, y, x]
                }

        raw_dataset (``string``):

            Raw dataset to use (e.g 'volumes/raw'). If using a scale pyramid,
            will try scale zero assuming stored in directory `s0` (e.g
            'volumes/raw/s0')

        out_base (``string``):

            Path to base directory where zarr/n5 should be stored. The out_file
            will be built from this directory, setup, iteration, file name

            **Note:

                out_dataset no longer needed as input, build out_dataset from config
                outputs dictionary generated in mknet.py (config.json for
                example)

        file_name (``string``):

            Name of output zarr/n5

        num_workers (``int``):

            How many blocks to run in parallel.

        db_host (``string``):

            Name of MongoDB client.

        db_name (``string``):

            Name of MongoDB database to use (for logging successful blocks in
            check function and DaisyRequestBlocks node inside worker predict
            script).

        queue (``string``):

            Name of gpu queue to run inference on (i.e gpu_rtx, gpu_tesla, etc)

        auto_file (``string``, optional):

            Path to zarr/n5 containing first pass predictions to use as input to
            autocontext network (i.e aclsd / acrlsd). None if not needed

        auto_dataset (``string``, optional):

            Input dataset to use if running autocontext (e.g 'volumes/lsds').
            None if not needed

        singularity_image (``string``, optional):

            Path to singularity image. None if not needed

    '''

    #get relevant dirs + files

    experiment_dir = os.path.join(base_dir, experiment)
    train_dir = os.path.join(experiment_dir, '02_train')
    network_dir = os.path.join(experiment, setup, str(iteration))

    raw_file = os.path.abspath(raw_file)
    out_file = os.path.abspath(
        os.path.join(out_base, setup, str(iteration), file_name))

    setup = os.path.abspath(os.path.join(train_dir, setup))

    # from here on, all values are in world units (unless explicitly mentioned)

    # get ROI of source
    try:
        source = daisy.open_ds(raw_file, raw_dataset)
    except:
        raw_dataset = raw_dataset + '/s0'
        source = daisy.open_ds(raw_file, raw_dataset)

    logging.info(f'Source shape: {source.shape}')
    logging.info(f'Source roi: {source.roi}')
    logging.info(f'Source voxel size: {source.voxel_size}')

    # load config
    with open(os.path.join(setup, 'config.json')) as f:
        net_config = json.load(f)

    outputs = net_config['outputs']

    # get chunk size and context for network (since unet has smaller output size
    # than input size
    net_input_size = daisy.Coordinate(
        net_config['input_shape']) * source.voxel_size
    net_output_size = daisy.Coordinate(
        net_config['output_shape']) * source.voxel_size

    context = (net_input_size - net_output_size) / 2

    # get total input and output ROIs
    input_roi = source.roi.grow(context, context)
    output_roi = source.roi

    # create read and write ROI
    block_read_roi = daisy.Roi((0, 0, 0), net_input_size) - context
    block_write_roi = daisy.Roi((0, 0, 0), net_output_size)

    logging.info('Preparing output dataset...')

    # get output file(s) meta data from config.json, prepare dataset(s)
    for output_name, val in outputs.items():
        out_dims = val['out_dims']
        out_dtype = val['out_dtype']
        out_dataset = 'volumes/%s' % output_name

        ds = daisy.prepare_ds(out_file,
                              out_dataset,
                              output_roi,
                              source.voxel_size,
                              out_dtype,
                              write_roi=block_write_roi,
                              num_channels=out_dims,
                              compressor={
                                  'id': 'gzip',
                                  'level': 5
                              })

    logging.info('Starting block-wise processing...')

    # for logging successful blocks (see check_block function). if anything
    # fails, blocks which completed will be skipped when re-running

    client = pymongo.MongoClient(db_host)
    db = client[db_name]
    if 'blocks_predicted' not in db.list_collection_names():
        blocks_predicted = db['blocks_predicted']
        blocks_predicted.create_index([('block_id', pymongo.ASCENDING)],
                                      name='block_id')
    else:
        blocks_predicted = db['blocks_predicted']

    # process block-wise
    succeeded = daisy.run_blockwise(
        input_roi,
        block_read_roi,
        block_write_roi,
        process_function=lambda: predict_worker(
            experiment, setup, network_dir, iteration, raw_file, raw_dataset,
            auto_file, auto_dataset, out_file, out_dataset, db_host, db_name,
            queue, singularity_image),
        check_function=lambda b: check_block(blocks_predicted, b),
        num_workers=num_workers,
        read_write_conflict=False,
        fit='overhang')

    if not succeeded:
        raise RuntimeError("Prediction failed for (at least) one block")