Exemplo n.º 1
0
class PreCache(BatchFilter):
    def __init__(self, request, cache_size=50, num_workers=20):
        '''
            request:

                A BatchRequest used to pre-cache batches.

            cache_size: int

                How many batches to pre-cache.

            num_workers: int

                How many processes to spawn to fill the cache.
        '''
        self.request = copy.deepcopy(request)
        self.batches = multiprocessing.Queue(maxsize=cache_size)
        self.workers = ProducerPool(
            [lambda i=i: self.__run_worker(i) for i in range(num_workers)],
            queue_size=cache_size)

    def setup(self):
        self.workers.start()

    def teardown(self):
        self.workers.stop()

    def provide(self, request):

        timing = Timing(self)
        timing.start()

        logger.debug("getting batch from queue...")
        batch = self.workers.get()

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch

    def __run_worker(self, i):

        request = copy.deepcopy(self.request)
        return self.get_upstream_provider().request_batch(request)
Exemplo n.º 2
0
class GenericTrain(BatchFilter):
    '''Generic train node to perform one training iteration for each batch that
    passes through. This node alone does nothing and should be subclassed for
    concrete implementations.

    Args:

        inputs (dict): Dictionary from names of input layers in the network to
            :class:``ArrayKey`` or batch attribute name as string.

        outputs (dict): Dictionary from the names of output layers in the
            network to :class:``ArrayKey``. New arrays will be generated by
            this node for each entry (if requested downstream).

        gradients (dict): Dictionary from the names of output layers in the
            network to :class:``ArrayKey``. New arrays containing the
            gradient of an output with respect to the loss will be generated by
            this node for each entry (if requested downstream).

        array_specs (dict, optional): An optional dictionary of
            :class:`ArrayKey` to :class:`ArraySpec` to set the array specs
            generated arrays (``outputs`` and ``gradients``). This is useful
            to set the ``voxel_size``, for example, if they differ from the
            voxel size of the input arrays. Only fields that are not ``None``
            in the given :class:`ArraySpec` will be used.

        spawn_subprocess (bool, optional): Whether to run the ``train_step`` in
            a separate process. Default is false.
    '''
    def __init__(self,
                 inputs,
                 outputs,
                 gradients,
                 array_specs=None,
                 spawn_subprocess=False):

        self.initialized = False

        self.inputs = inputs
        self.outputs = outputs
        self.gradients = gradients
        self.array_specs = {} if array_specs is None else array_specs
        self.spawn_subprocess = spawn_subprocess

        self.provided_arrays = list(self.outputs.values()) + list(
            self.gradients.values())

    def setup(self):

        # get common voxel size of inputs, or None if they differ
        common_voxel_size = None
        for key in self.inputs.values():

            if not isinstance(key, ArrayKey):
                continue

            voxel_size = self.spec[key].voxel_size

            if common_voxel_size is None:
                common_voxel_size = voxel_size
            elif common_voxel_size != voxel_size:
                common_voxel_size = None
                break

        # announce provided outputs
        for key in self.provided_arrays:

            if key in self.array_specs:
                spec = self.array_specs[key].copy()
            else:
                spec = ArraySpec()

            if spec.voxel_size is None and not spec.nonspatial:

                assert common_voxel_size is not None, (
                    "There is no common voxel size of the inputs, and no "
                    "ArraySpec has been given for %s that defines "
                    "voxel_size." % key)

                spec.voxel_size = common_voxel_size

            if spec.interpolatable is None:

                # default for predictions
                spec.interpolatable = False

            self.provides(key, spec)

        if self.spawn_subprocess:
            # start training as a producer pool, so that we can gracefully exit if
            # anything goes wrong
            self.worker = ProducerPool([self.__produce_train_batch],
                                       queue_size=1)
            self.batch_in = multiprocessing.Queue(maxsize=1)
            self.worker.start()
        else:
            self.start()
            self.initialized = True

    def prepare(self, request):
        deps = BatchRequest()
        for key in self.inputs.values():
            deps[key] = request[key]
        return deps

    def teardown(self):
        if self.spawn_subprocess:
            # signal "stop"
            self.batch_in.put((None, None))
            try:
                self.worker.get(timeout=2)
            except NoResult:
                pass
            self.worker.stop()
        else:
            self.stop()

    def process(self, batch, request):

        start = time.time()

        if self.spawn_subprocess:

            self.batch_in.put((batch, request))

            try:
                out = self.worker.get()
            except WorkersDied:
                raise TrainProcessDied()

            for array_key in self.provided_arrays:
                if array_key in request:
                    batch.arrays[array_key] = out.arrays[array_key]

            batch.loss = out.loss
            batch.iteration = out.iteration

        else:

            self.train_step(batch, request)

        time_of_iteration = time.time() - start

        logger.info("Train process: iteration=%d loss=%f time=%f",
                    batch.iteration, batch.loss, time_of_iteration)

    def start(self):
        '''To be implemented in subclasses.

        This method will be called before the first call to :fun:`train_step`,
        from the same process that :fun:`train_step` will be called from. Use
        this to initialize you solver and training hardware.
        '''
        pass

    def train_step(self, batch, request):
        '''To be implemented in subclasses.

        In this method, an implementation should perform one training iteration
        on the given batch. ``batch.loss`` and ``batch.iteration`` should be
        set. Output arrays should be created according to the given request
        and added to ``batch``.'''
        raise NotImplementedError("Class %s does not implement 'train_step'" %
                                  self.name())

    def stop(self):
        '''To be implemented in subclasses.

        This method will be called after the last call to :fun:`train_step`,
        from the same process that :fun:`train_step` will be called from. Use
        this to tear down you solver and free training hardware.
        '''
        pass

    def _checkpoint_name(self, basename, iteration):
        return basename + '_checkpoint_' + '%i' % iteration

    def _get_latest_checkpoint(self, basename):
        def atoi(text):
            return int(text) if text.isdigit() else text

        def natural_keys(text):
            return [atoi(c) for c in re.split(r'(\d+)', text)]

        checkpoints = glob.glob(basename + '_checkpoint_*')
        checkpoints.sort(key=natural_keys)

        if len(checkpoints) > 0:

            checkpoint = checkpoints[-1]
            iteration = int(checkpoint.split('_')[-1])
            return checkpoint, iteration

        return None, 0

    def __produce_train_batch(self):
        '''Process one train batch.'''

        if not self.initialized:

            self.start()
            self.initialized = True

        batch, request = self.batch_in.get()

        # stop signal
        if batch is None:
            self.stop()
            return None

        self.train_step(batch, request)

        return batch
Exemplo n.º 3
0
class Train(BatchFilter):
    '''Performs one training iteration for each batch that passes through. 
    Adds the predicted affinities to the batch.
    '''
    def __init__(self, solver_parameters, use_gpu=None):

        # start training as a producer pool, so that we can gracefully exit if
        # anything goes wrong
        self.worker = ProducerPool([lambda gpu=use_gpu: self.__train(gpu)],
                                   queue_size=1)
        self.batch_in = multiprocessing.Queue(maxsize=1)

        self.solver_parameters = solver_parameters
        self.solver_initialized = False

    def setup(self):
        self.worker.start()

    def teardown(self):
        self.worker.stop()

    def prepare(self, request):

        # remove request parts that we provide
        for volume_type in [
                VolumeType.LOSS_GRADIENT, VolumeType.PRED_AFFINITIES
        ]:
            if volume_type in request.volumes:
                del request.volumes[volume_type]

    def process(self, batch, request):

        self.batch_in.put((batch, request))

        try:
            out = self.worker.get()
        except WorkersDied:
            raise TrainProcessDied()

        batch.volumes[VolumeType.PRED_AFFINITIES] = out.volumes[
            VolumeType.PRED_AFFINITIES]
        if VolumeType.LOSS_GRADIENT in request.volumes:
            batch.volumes[VolumeType.LOSS_GRADIENT] = out.volumes[
                VolumeType.LOSS_GRADIENT]
        batch.loss = out.loss
        batch.iteration = out.iteration

    def __train(self, use_gpu):

        start = time.time()

        if not self.solver_initialized:

            logger.info("Initializing solver...")

            if use_gpu is not None:

                logger.debug("Train process: using GPU %d" % use_gpu)
                caffe.enumerate_devices(False)
                caffe.set_devices((use_gpu, ))
                caffe.set_mode_gpu()
                caffe.select_device(use_gpu, False)

            self.solver = caffe.get_solver(self.solver_parameters)
            if self.solver_parameters.resume_from is not None:
                logger.debug("Train process: restoring solver state from " +
                             self.solver_parameters.resume_from)
                self.solver.restore(self.solver_parameters.resume_from)

            self.net_io = NetIoWrapper(self.solver.net)

            self.solver_initialized = True

        batch, request = self.batch_in.get()

        data = {
            'data':
            batch.volumes[VolumeType.RAW].data[np.newaxis, np.newaxis, :],
            'aff_label':
            batch.volumes[VolumeType.GT_AFFINITIES].data[np.newaxis, :],
        }

        if self.solver_parameters.train_state.get_stage(0) == 'euclid':
            logger.debug(
                "Train process: preparing input data for Euclidean training")
            self.__prepare_euclidean(batch, data)
        else:
            logger.debug(
                "Train process: preparing input data for Malis training")
            self.__prepare_malis(batch, data)

        self.net_io.set_inputs(data)

        loss = self.solver.step(1)
        # self.__consistency_check()
        output = self.net_io.get_outputs()
        batch.volumes[VolumeType.PRED_AFFINITIES] = Volume(
            output['aff_pred'][0],
            batch.volumes[VolumeType.GT_AFFINITIES].roi,
            batch.volumes[VolumeType.GT_AFFINITIES].resolution,
            interpolate=True)
        batch.loss = loss
        batch.iteration = self.solver.iter
        if VolumeType.LOSS_GRADIENT in request.volumes:
            diffs = self.net_io.get_output_diffs()
            batch.volumes[VolumeType.LOSS_GRADIENT] = Volume(
                diffs['aff_pred'][0],
                batch.volumes[VolumeType.GT_AFFINITIES].roi,
                batch.volumes[VolumeType.GT_AFFINITIES].resolution,
                interpolate=True)

        time_of_iteration = time.time() - start
        logger.info("Train process: iteration=%d loss=%f time=%f" %
                    (self.solver.iter, batch.loss, time_of_iteration))

        return batch

    def __prepare_euclidean(self, batch, data):

        gt_affinities = batch.volumes[VolumeType.GT_AFFINITIES]

        # initialize error scale with 1s
        error_scale = np.ones(gt_affinities.data.shape, dtype=np.float)

        # set error_scale to 0 in masked-out areas
        if VolumeType.GT_MASK in batch.volumes:
            self.__mask_error_scale(error_scale,
                                    batch.volumes[VolumeType.GT_MASK].data)
        if VolumeType.GT_IGNORE in batch.volumes:
            self.__mask_error_scale(error_scale,
                                    batch.volumes[VolumeType.GT_IGNORE].data)

        # in the masked-in area, compute the fraction of positive samples
        masked_in = error_scale.sum()
        num_pos = (gt_affinities.data * error_scale).sum()
        frac_pos = float(num_pos) / masked_in if masked_in > 0 else 0
        frac_pos = np.clip(frac_pos, 0.05, 0.95)
        frac_neg = 1.0 - frac_pos

        # compute the class weights for positive and negative samples
        w_pos = 1.0 / (2.0 * frac_pos)
        w_neg = 1.0 / (2.0 * frac_neg)

        # scale the masked-in error_scale with the class weights
        error_scale *= (data >= 0.5) * w_pos + (data < 0.5) * w_neg

        data['scale'] = error_scale[np.newaxis, :]

    def __mask_error_scale(self, error_scale, mask):
        for d in range(error_scale.shape[0]):
            error_scale[d] *= mask

    def __prepare_malis(self, batch, data):

        gt_labels = batch.volumes[VolumeType.GT_LABELS]
        next_id = gt_labels.data.max() + 1

        gt_pos_pass = gt_labels.data

        if VolumeType.GT_IGNORE in batch.volumes:

            gt_neg_pass = np.array(gt_labels.data)
            gt_neg_pass[batch.volumes[VolumeType.GT_IGNORE].data ==
                        0] = next_id

        else:

            gt_neg_pass = gt_pos_pass

        data['comp_label'] = np.array([[gt_neg_pass, gt_pos_pass]])
        data['nhood'] = batch.affinity_neighborhood[np.newaxis, np.newaxis, :]

        # Why don't we update gt_affinities in the same way?
        # -> not needed
        #
        # GT affinities are all 0 in the masked area (because masked area is
        # assumed to be set to background in batch.gt).
        #
        # In the negative pass:
        #
        #   We set all affinities inside GT regions to 1. Affinities in masked
        #   area as predicted. Belongs to one forground region (introduced
        #   above). But we only count loss on edges connecting different labels
        #   -> loss in masked-out area only from outside regions.
        #
        # In the positive pass:
        #
        #   We set all affinities outside GT regions to 0 -> no loss in masked
        #   out area.

    def __consistency_check(self):

        diffs = self.net_io.get_outputs()
        for k in diffs:
            assert not np.isnan(
                diffs[k]).any(), "Detected NaN in output diff " + k
Exemplo n.º 4
0
class TrainSyntist(BatchFilter):
    '''Performs one training iteration for each batch that passes through.
    Adds the predicted affinities to the batch.
    '''
    def __init__(self, solver_parameters, use_gpu=None):

        # start training as a producer pool, so that we can gracefully exit if
        # anything goes wrong
        self.worker = ProducerPool([lambda gpu=use_gpu: self.__train(gpu)],
                                   queue_size=1)
        self.batch_in = multiprocessing.Queue(maxsize=1)

        self.solver_parameters = solver_parameters
        self.solver_initialized = False

    def setup(self):
        self.worker.start()

    def teardown(self):
        self.worker.stop()

    def prepare(self, request):

        # remove request parts that we provide
        for volume_type in [
                VolumeTypes.LOSS_GRADIENT, VolumeTypes.PRED_AFFINITIES
        ]:
            if volume_type in request.volumes:
                del request.volumes[volume_type]

    def process(self, batch, request):

        self.batch_in.put((batch, request))

        try:
            out = self.worker.get()
        except WorkersDied:
            raise TrainProcessDied()

        batch.volumes[VolumeTypes.PRED_BM_PRESYN] = out.volumes[
            VolumeTypes.PRED_BM_PRESYN]
        if VolumeTypes.LOSS_GRADIENT in request.volumes:
            batch.volumes[VolumeTypes.LOSS_GRADIENT] = out.volumes[
                VolumeTypes.LOSS_GRADIENT]
        batch.loss = out.loss
        batch.iteration = out.iteration

    def __train(self, use_gpu):

        start = time.time()

        if not self.solver_initialized:

            logger.info("Initializing solver...")

            if use_gpu is not None:

                logger.debug("Train process: using GPU %d" % use_gpu)
                caffe.enumerate_devices(False)
                caffe.set_devices((use_gpu, ))
                caffe.set_mode_gpu()
                caffe.select_device(use_gpu, False)

            self.solver = caffe.get_solver(self.solver_parameters)
            if self.solver_parameters.resume_from is not None:
                logger.debug("Train process: restoring solver state from " +
                             self.solver_parameters.resume_from)
                self.solver.restore(self.solver_parameters.resume_from)

            self.net_io = NetIoWrapper(self.solver.net)

            self.solver_initialized = True

        batch, request = self.batch_in.get()

        data = {
            'data':
            batch.volumes[VolumeTypes.RAW].data[np.newaxis, np.newaxis, :],
            'bm_presyn_label':
            batch.volumes[VolumeTypes.GT_BM_PRESYN].data[np.newaxis, :],
        }

        if self.solver_parameters.train_state.get_stage(0) == 'euclid':
            logger.debug(
                "Train process: preparing input data for Euclidean training")
            self.__prepare_euclidean(batch, data)
        else:
            raise Exception("should not enter malis phase at all")
            logger.debug(
                "Train process: preparing input data for Malis training")
            self.__prepare_malis(batch, data)

        self.net_io.set_inputs(data)

        loss = self.solver.step(1)
        # self.__consistency_check()
        output = self.net_io.get_outputs()
        batch.volumes[VolumeTypes.PRED_BM_PRESYN] = Volume(
            output['bm_presyn_pred'],  # [0]
            batch.volumes[VolumeTypes.GT_BM_PRESYN].roi,
            batch.volumes[VolumeTypes.GT_BM_PRESYN].resolution,
            # interpolate=True
        )
        batch.loss = loss
        batch.iteration = self.solver.iter
        if VolumeTypes.LOSS_GRADIENT in request.volumes:
            diffs = self.net_io.get_output_diffs()
            batch.volumes[VolumeTypes.LOSS_GRADIENT] = Volume(
                diffs['bm_presyn_pred'][0],
                batch.volumes[VolumeTypes.GT_BM_PRESYN].roi,
                batch.volumes[VolumeTypes.GT_BM_PRESYN].resolution,
                # interpolate=True
            )

        time_of_iteration = time.time() - start
        logger.info("Train process: iteration=%d loss=%f time=%f" %
                    (self.solver.iter, batch.loss, time_of_iteration))

        return batch

    def __prepare_euclidean(self, batch, data):

        gt_bm_presyn = batch.volumes[VolumeTypes.GT_BM_PRESYN]
        frac_pos = np.clip(gt_bm_presyn.data.mean(), 0.05, 0.95)
        w_pos = 1.0 / (2.0 * frac_pos)
        w_neg = 1.0 / (2.0 * (1.0 - frac_pos))
        error_scale = self.__scale_errors(gt_bm_presyn.data, w_neg, w_pos)

        # if VolumeTypes.GT_MASK in batch.volumes:
        #     error_scale[batch.volumes[VolumeTypes.GT_MASK_EXCLUSIVEZONE_PRESYN].data==0] = 0
        # if VolumeTypes.GT_IGNORE in batch.volumes:
        #     error_scale[batch.volumes[VolumeTypes.GT_MASK_EXCLUSIVEZONE_PRESYN].data==0] = 0

        # if VolumeTypes.GT_MASK in batch.volumes:
        #     self.__mask_errors(batch, error_scale, batch.volumes[VolumeTypes.GT_MASK].data)
        # if VolumeTypes.GT_IGNORE in batch.volumes:
        #     self.__mask_errors(batch, error_scale, batch.volumes[VolumeTypes.GT_IGNORE].data)

        if VolumeTypes.GT_MASK_EXCLUSIVEZONE_PRESYN in batch.volumes:
            error_scale = np.multiply(
                error_scale,
                batch.volumes[VolumeTypes.GT_MASK_EXCLUSIVEZONE_PRESYN].data)
        if VolumeTypes.GT_MASK_EXCLUSIVEZONE_POSTSYN in batch.volumes:
            error_scale = np.multiply(
                error_scale,
                batch.volumes[VolumeTypes.GT_MASK_EXCLUSIVEZONE_POSTSYN].data)

        data['scale'] = error_scale[np.newaxis, :]

    def __scale_errors(self, data, factor_low, factor_high):
        scaled_data = np.add((data >= 0.5) * factor_high,
                             (data < 0.5) * factor_low)
        return scaled_data

    def __mask_errors(self, batch, error_scale, mask):
        for d in range(len(batch.affinity_neighborhood)):
            error_scale[d] = np.multiply(error_scale[d], mask)

    # def __prepare_malis(self, batch, data):
    #
    #     gt_labels = batch.volumes[VolumeTypes.GT_LABELS]
    #     next_id = gt_labels.data.max() + 1
    #
    #     gt_pos_pass = gt_labels.data
    #
    #     if VolumeTypes.GT_IGNORE in batch.volumes:
    #
    #         gt_neg_pass = np.array(gt_labels.data)
    #         gt_neg_pass[batch.volumes[VolumeTypes.GT_IGNORE].data==0] = next_id
    #
    #     else:
    #
    #         gt_neg_pass = gt_pos_pass
    #
    #     data['comp_label'] = np.array([[gt_neg_pass, gt_pos_pass]])
    #     data['nhood'] = batch.affinity_neighborhood[np.newaxis,np.newaxis,:]
    #
    #     # Why don't we update gt_affinities in the same way?
    #     # -> not needed
    #     #
    #     # GT affinities are all 0 in the masked area (because masked area is
    #     # assumed to be set to background in batch.gt).
    #     #
    #     # In the negative pass:
    #     #
    #     #   We set all affinities inside GT regions to 1. Affinities in masked
    #     #   area as predicted. Belongs to one forground region (introduced
    #     #   above). But we only count loss on edges connecting different labels
    #     #   -> loss in masked-out area only from outside regions.
    #     #
    #     # In the positive pass:
    #     #
    #     #   We set all affinities outside GT regions to 0 -> no loss in masked
    #     #   out area.

    def __consistency_check(self):

        diffs = self.net_io.get_outputs()
        for k in diffs:
            assert not np.isnan(
                diffs[k]).any(), "Detected NaN in output diff " + k
Exemplo n.º 5
0
class Scan(BatchFilter):
    '''Iteratively requests batches of size ``reference`` from upstream
    providers in a scanning fashion, until all requested ROIs are covered. If
    the batch request to this node is empty, it will scan the complete upstream
    ROIs (and return nothing). Otherwise, it scans only the requested ROIs and
    returns a batch assembled of the smaller requests. In either case, the
    upstream requests will be contained in the downstream requested ROI or
    upstream ROIs.

    Args:

        reference(:class:`BatchRequest`): A reference :class:`BatchRequest`.
            This request will be shifted in a scanning fashion over the
            upstream ROIs of the requested arrays or points.

        num_workers (int, optional): If set to >1, upstream requests are made
            in parallel with that number of workers.

        cache_size (int, optional): If multiple workers are used, how many
            batches to hold at most.
    '''
    def __init__(self, reference, num_workers=1, cache_size=50):

        self.reference = reference.copy()
        self.num_workers = num_workers
        self.cache_size = cache_size
        self.workers = None
        if num_workers > 1:
            self.request_queue = multiprocessing.Queue(maxsize=0)
        self.batch = None

    def setup(self):

        if self.num_workers > 1:
            self.workers = ProducerPool(
                [self.__worker_get_chunk for _ in range(self.num_workers)],
                queue_size=self.cache_size)
            self.workers.start()

    def teardown(self):

        if self.num_workers > 1:
            self.workers.stop()

    def provide(self, request):

        empty_request = (len(request) == 0)
        if empty_request:
            scan_spec = self.spec
        else:
            scan_spec = request

        stride = self.__get_stride()
        shift_roi = self.__get_shift_roi(scan_spec)

        shifts = self.__enumerate_shifts(shift_roi, stride)
        num_chunks = len(shifts)

        logger.info("scanning over %d chunks", num_chunks)

        # the batch to return
        self.batch = Batch()

        if self.num_workers > 1:

            for shift in shifts:
                shifted_reference = self.__shift_request(self.reference, shift)
                self.request_queue.put(shifted_reference)

            for i in range(num_chunks):

                chunk = self.workers.get()

                if not empty_request:
                    self.__add_to_batch(request, chunk)

                logger.info("processed chunk %d/%d", i, num_chunks)

        else:

            for i, shift in enumerate(shifts):

                shifted_reference = self.__shift_request(self.reference, shift)
                chunk = self.__get_chunk(shifted_reference)

                if not empty_request:
                    self.__add_to_batch(request, chunk)

                logger.info("processed chunk %d/%d", i, num_chunks)

        batch = self.batch
        self.batch = None

        logger.debug("returning batch %s", batch)

        return batch

    def __get_stride(self):
        '''Get the maximal amount by which ``reference`` can be moved, such
        that it tiles the space.'''

        stride = None

        # get the least common multiple of all voxel sizes, we have to stride
        # at least that far
        lcm_voxel_size = self.spec.get_lcm_voxel_size(
            self.reference.array_specs.keys())

        # that's just the minimal size in each dimension
        for key, reference_spec in self.reference.items():

            shape = reference_spec.roi.get_shape()

            for d in range(len(lcm_voxel_size)):
                assert shape[d] >= lcm_voxel_size[d], (
                    "Shape of reference "
                    "ROI %s for %s is "
                    "smaller than least "
                    "common multiple of "
                    "voxel size "
                    "%s" % (reference_spec.roi, key, lcm_voxel_size))

            if stride is None:
                stride = shape
            else:
                stride = Coordinate((min(a, b) for a, b in zip(stride, shape)))

        return stride

    def __get_shift_roi(self, spec):
        '''Get the minimal and maximal shift (as a ROI) to apply to
        ``self.reference``, such that it is still fully contained in ``spec``.
        '''

        total_shift_roi = None

        # get individual shift ROIs and intersect them
        for key, reference_spec in self.reference.items():

            if key not in spec:
                continue
            if spec[key].roi is None:
                continue

            # we have a reference ROI
            #
            #    [--------) [9]
            #    3        12
            #
            # and a spec ROI
            #
            #                 [---------------) [16]
            #                 16              32
            #
            # min and max shifts of reference are
            #
            #                 [--------) [9]
            #                 16       25
            #                        [--------) [9]
            #                        23       32
            #
            # therefore, all possible ways to shift the reference such that it
            # is contained in the spec are at least 16-3=13 and at most 23-3=20
            # (inclusive)
            #
            #              [-------) [8]
            #              13      21
            #
            # 1. the starting point is beginning of spec - beginning of reference
            # 2. the length is length of spec - length of reference + 1

            # 1. get the starting point of the shift ROI
            shift_begin = (spec[key].roi.get_begin() -
                           reference_spec.roi.get_begin())

            # 2. get the shape of the shift ROI
            shift_shape = (spec[key].roi.get_shape() -
                           reference_spec.roi.get_shape() +
                           (1, ) * reference_spec.roi.dims())

            # create a ROI...
            shift_roi = Roi(shift_begin, shift_shape)

            # ...and intersect it with previous shift ROIs
            if total_shift_roi is None:
                total_shift_roi = shift_roi
            else:
                total_shift_roi = total_shift_roi.intersect(shift_roi)
                if total_shift_roi.empty():
                    raise RuntimeError("There is no location where the ROIs "
                                       "the reference %s are contained in the "
                                       "request/upstream ROIs "
                                       "%s." % (self.reference, spec))

        if total_shift_roi is None:
            raise RuntimeError("None of the upstream ROIs are bounded (all "
                               "ROIs are None). Scan needs at least one "
                               "bounded upstream ROI.")

        return total_shift_roi

    def __enumerate_shifts(self, shift_roi, stride):
        '''Produces a sequence of shift coordinates starting at the beginning
        of ``shift_roi``, progressing with ``stride``. The maximum shift
        coordinate in any dimension will be the last point inside the shift roi
        in this dimension.'''

        min_shift = shift_roi.get_offset()
        max_shift = Coordinate(m - 1 for m in shift_roi.get_end())

        shift = np.array(min_shift)
        shifts = []

        dims = len(min_shift)

        logger.debug("enumerating possible shifts of %s in %s", stride,
                     shift_roi)

        while True:

            logger.debug("adding %s", shift)
            shifts.append(Coordinate(shift))

            if (shift == max_shift).all():
                break

            # count up dimensions
            for d in range(dims):

                if shift[d] >= max_shift[d]:
                    if d == dims - 1:
                        break
                    shift[d] = min_shift[d]
                else:
                    shift[d] += stride[d]
                    # snap to last possible shift, don't overshoot
                    if shift[d] > max_shift[d]:
                        shift[d] = max_shift[d]
                    break

        return shifts

    def __shift_request(self, request, shift):

        shifted = request.copy()
        for _, spec in shifted.items():
            spec.roi = spec.roi.shift(shift)

        return shifted

    def __worker_get_chunk(self):

        request = self.request_queue.get()
        return self.__get_chunk(request)

    def __get_chunk(self, request):

        return self.get_upstream_provider().request_batch(request)

    def __add_to_batch(self, spec, chunk):

        if self.batch.get_total_roi() is None:
            self.batch = self.__setup_batch(spec, chunk)

        for (array_key, array) in chunk.arrays.items():
            if array_key not in spec:
                continue
            self.__fill(self.batch.arrays[array_key].data, array.data,
                        spec.array_specs[array_key].roi, array.spec.roi,
                        self.spec[array_key].voxel_size)

        for (points_key, points) in chunk.points.items():
            if points_key not in spec:
                continue
            self.__fill_points(self.batch.points[points_key].data, points.data,
                               spec.points_specs[points_key].roi, points.roi)

    def __setup_batch(self, batch_spec, chunk):
        '''Allocate a batch matching the sizes of ``batch_spec``, using
        ``chunk`` as template.'''

        batch = Batch()

        for (array_key, spec) in batch_spec.array_specs.items():
            roi = spec.roi
            voxel_size = self.spec[array_key].voxel_size

            # get the 'non-spatial' shape of the chunk-batch
            # and append the shape of the request to it
            array = chunk.arrays[array_key]
            shape = array.data.shape[:-roi.dims()]
            shape += (roi.get_shape() // voxel_size)

            spec = self.spec[array_key].copy()
            spec.roi = roi
            logger.info("allocating array of shape %s for %s", shape,
                        array_key)
            batch.arrays[array_key] = Array(data=np.zeros(shape), spec=spec)

        for (points_key, spec) in batch_spec.points_specs.items():
            roi = spec.roi
            spec = self.spec[points_key].copy()
            spec.roi = roi
            batch.points[points_key] = Points(data={}, spec=spec)

        logger.debug("setup batch to fill %s", batch)

        return batch

    def __fill(self, a, b, roi_a, roi_b, voxel_size):
        logger.debug("filling " + str(roi_b) + " into " + str(roi_a))

        roi_a = roi_a // voxel_size
        roi_b = roi_b // voxel_size

        common_roi = roi_a.intersect(roi_b)
        if common_roi.empty():
            return

        common_in_a_roi = common_roi - roi_a.get_offset()
        common_in_b_roi = common_roi - roi_b.get_offset()

        slices_a = common_in_a_roi.get_bounding_box()
        slices_b = common_in_b_roi.get_bounding_box()

        if len(a.shape) > len(slices_a):
            slices_a = (slice(None), ) * (len(a.shape) -
                                          len(slices_a)) + slices_a
            slices_b = (slice(None), ) * (len(b.shape) -
                                          len(slices_b)) + slices_b

        a[slices_a] = b[slices_b]

    def __fill_points(self, a, b, roi_a, roi_b):
        logger.debug("filling points of " + str(roi_b) + " into points of" +
                     str(roi_a))

        common_roi = roi_a.intersect(roi_b)
        if common_roi is None:
            return

        # find max point_id in a so far
        max_point_id = 0
        for point_id, point in a.items():
            if point_id > max_point_id:
                max_point_id = point_id

        for point_id, point in b.items():
            if roi_a.contains(Coordinate(point.location)):
                a[point_id + max_point_id] = point
Exemplo n.º 6
0
class Predict(BatchFilter):
    '''Augments the batch with the predicted affinities.
    '''
    def __init__(self, prototxt, weights, use_gpu=None):

        for f in [prototxt, weights]:
            if not os.path.isfile(f):
                raise RuntimeError("%s does not exist" % f)

        # start prediction as a producer pool, so that we can gracefully exit if
        # anything goes wrong
        self.worker = ProducerPool([lambda gpu=use_gpu: self.__predict(gpu)],
                                   queue_size=1)
        self.batch_in = multiprocessing.Queue(maxsize=1)

        self.prototxt = prototxt
        self.weights = weights
        self.net_initialized = False

    def setup(self):
        self.worker.start()

    def teardown(self):
        self.worker.stop()

    def prepare(self, request):

        # remove request parts that we provide
        if VolumeTypes.PRED_AFFINITIES in request.volumes:
            del request.volumes[VolumeTypes.PRED_AFFINITIES]

    def process(self, batch, request):

        self.batch_in.put(batch)

        try:
            out = self.worker.get()
        except WorkersDied:
            raise PredictProcessDied()

        affs = out.volumes[VolumeTypes.PRED_AFFINITIES]
        affs.roi = request.volumes[VolumeTypes.PRED_AFFINITIES]
        affs.resolution = batch.volumes[VolumeTypes.RAW].resolution

        batch.volumes[VolumeTypes.PRED_AFFINITIES] = affs

    def __predict(self, use_gpu):

        if not self.net_initialized:

            logger.info("Initializing solver...")

            if use_gpu is not None:

                logger.debug("Predict process: using GPU %d" % use_gpu)
                caffe.enumerate_devices(False)
                caffe.set_devices((use_gpu, ))
                caffe.set_mode_gpu()
                caffe.select_device(use_gpu, False)

            self.net = caffe.Net(self.prototxt, self.weights, caffe.TEST)
            self.net_io = NetIoWrapper(self.net)
            self.net_initialized = True

        start = time.time()

        batch = self.batch_in.get()

        fetch_time = time.time() - start

        self.net_io.set_inputs({
            'data':
            batch.volumes[VolumeTypes.RAW].data[np.newaxis, np.newaxis, :],
        })
        self.net.forward()
        output = self.net_io.get_outputs()

        predict_time = time.time() - start

        logger.info(
            "Predict process: time=%f (including %f waiting for batch)" %
            (predict_time, fetch_time))

        assert len(
            output['aff_pred'].shape
        ) == 5, "Got affinity prediction with unexpected number of dimensions, should be 1 (direction) + 3 (spatial) + 1 (batch, not used), but is %d" % len(
            output['aff_pred'].shape)
        batch.volumes[VolumeTypes.PRED_AFFINITIES] = Volume(
            output['aff_pred'][0], Roi(), (1, 1, 1))

        return batch
Exemplo n.º 7
0
class PreCache(BatchFilter):
    '''Pre-cache repeated equal batch requests. For the first of a series of
    equal batch request, a set of workers is spawned to pre-cache the batches
    in parallel processes. This way, subsequent requests can be served quickly.

    This node only makes sense if:

    1. Incoming batch requests are repeatedly the same.
    2. There is a source of randomness in upstream nodes.

    Args:

        cache_size (int): How many batches to hold at most in the cache.

        num_workers (int): How many processes to spawn to fill the cache.
    '''
    def __init__(self, cache_size=50, num_workers=20):

        self.current_request = None
        self.workers = None
        self.cache_size = cache_size
        self.num_workers = num_workers

    def teardown(self):

        if self.workers is not None:
            self.workers.stop()

    def provide(self, request):

        timing = Timing(self)
        timing.start()

        if request != self.current_request:

            if self.workers is not None:
                logger.info(
                    "new request received, stopping current workers...")
                self.workers.stop()

            self.current_request = copy.deepcopy(request)

            logger.info("starting new set of workers...")
            self.workers = ProducerPool([
                lambda i=i: self.__run_worker(i)
                for i in range(self.num_workers)
            ],
                                        queue_size=self.cache_size)
            self.workers.start()

        logger.debug("getting batch from queue...")
        batch = self.workers.get()

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch

    def __run_worker(self, i):

        return self.get_upstream_provider().request_batch(self.current_request)
Exemplo n.º 8
0
class GenericPredict(BatchFilter):
    '''Generic predict node to add predictions of a trained network to each each
    batch that passes through. This node alone does nothing and should be
    subclassed for concrete implementations.

    Args:

        inputs (dict): Dictionary from names of input layers in the network to
            :class:``ArrayKey`` or batch attribute name as string.

        outputs (dict): Dictionary from the names of output layers in the
            network to :class:``ArrayKey``. New arrays will be generated by
            this node for each entry (if requested downstream).

        array_specs (dict, optional): An optional dictionary of
            :class:`ArrayKey` to :class:`ArraySpec` to set the array specs
            generated arrays (``outputs`` and ``gradients``). This is useful
            to set the ``voxel_size``, for example, if they differ from the
            voxel size of the input arrays. Only fields that are not ``None``
            in the given :class:`ArraySpec` will be used.

        spawn_subprocess (bool, optional): Whether to run ``predict`` in a
            separate process. Default is false.
    '''
    def __init__(self,
                 inputs,
                 outputs,
                 array_specs=None,
                 spawn_subprocess=False):

        self.initialized = False
        self.inputs = inputs
        self.outputs = outputs
        self.array_specs = {} if array_specs is None else array_specs
        self.spawn_subprocess = spawn_subprocess
        self.timer_start = None

    def setup(self):

        # get common voxel size of inputs, or None if they differ
        common_voxel_size = None
        for key in self.inputs.values():

            if not isinstance(key, ArrayKey):
                continue

            voxel_size = self.spec[key].voxel_size

            if common_voxel_size is None:
                common_voxel_size = voxel_size
            elif common_voxel_size != voxel_size:
                common_voxel_size = None
                break

        # announce provided outputs
        for key in self.outputs.values():

            if key in self.array_specs:
                spec = self.array_specs[key].copy()
            else:
                spec = ArraySpec()

            if spec.voxel_size is None and not spec.nonspatial:

                assert common_voxel_size is not None, (
                    "There is no common voxel size of the inputs, and no "
                    "ArraySpec has been given for %s that defines "
                    "voxel_size." % key)

                spec.voxel_size = common_voxel_size

            if spec.interpolatable is None:

                # default for predictions
                spec.interpolatable = False

            self.provides(key, spec)

        if self.spawn_subprocess:
            # start prediction as a producer pool, so that we can gracefully
            # exit if anything goes wrong
            self.worker = ProducerPool([self.__produce_predict_batch],
                                       queue_size=1)
            self.batch_in = multiprocessing.Queue(maxsize=1)
            self.batch_in_lock = multiprocessing.Lock()
            self.batch_out_lock = multiprocessing.Lock()
            self.worker.start()

    def teardown(self):
        if self.spawn_subprocess:
            # signal "stop"
            self.batch_in.put((None, None))
            try:
                self.worker.get(timeout=2)
            except NoResult:
                pass
            self.worker.stop()
        else:
            self.stop()

    def prepare(self, request):

        if not self.initialized and not self.spawn_subprocess:
            self.start()
            self.initialized = True

        deps = BatchRequest()
        for key in self.inputs.values():
            deps[key] = request[key]
        return deps

    def process(self, batch, request):

        if self.spawn_subprocess:

            start = time.time()
            self.batch_in_lock.acquire()
            logger.debug("waited for batch in lock for %.3fs",
                         time.time() - start)
            start = time.time()
            self.batch_in.put((batch, request))
            logger.debug("queued batch for %.3fs", time.time() - start)

            start = time.time()
            with self.batch_out_lock:
                logger.debug("waited for batch out lock for %.3fs",
                             time.time() - start)

                start = time.time()
                self.batch_in_lock.release()
                logger.debug("released batch in lock for %.3fs",
                             time.time() - start)
                try:
                    start = time.time()
                    out = self.worker.get()
                    logger.debug("retreived batch for %.3fs",
                                 time.time() - start)
                except WorkersDied:
                    raise PredictProcessDied()

            for array_key in self.outputs.values():
                if array_key in request:
                    batch.arrays[array_key] = out.arrays[array_key]

        else:

            self.predict(batch, request)

    def start(self):
        '''To be implemented in subclasses.

        This method will be called before the first call to :fun:`predict`,
        from the same process that :fun:`predict` will be called from. Use
        this to initialize your model and hardware.
        '''
        pass

    def predict(self, batch, request):
        '''To be implemented in subclasses.

        In this method, an implementation should predict arrays on the given
        batch. Output arrays should be created according to the given request
        and added to ``batch``.'''
        raise NotImplementedError("Class %s does not implement 'predict'" %
                                  self.name())

    def stop(self):
        '''To be implemented in subclasses.

        This method will be called after the last call to :fun:`predict`,
        from the same process that :fun:`predict` will be called from. Use
        this to tear down your model and free training hardware.
        '''
        pass

    def __produce_predict_batch(self):
        '''Process one batch.'''

        if not self.initialized:

            self.start()
            self.initialized = True

        if self.timer_start is not None:
            self.time_out = time.time() - self.timer_start
            logger.info("batch in: %.3fs, predict: %.3fs, batch out: %.3fs",
                        self.time_in, self.time_predict, self.time_out)

        self.timer_start = time.time()
        batch, request = self.batch_in.get()
        self.time_in = time.time() - self.timer_start

        # stop signal
        if batch is None:
            self.stop()
            return None

        self.timer_start = time.time()
        self.predict(batch, request)
        self.time_predict = time.time() - self.timer_start
        self.timer_start = time.time()

        return batch
Exemplo n.º 9
0
class PreCache(BatchFilter):
    '''Pre-cache repeated equal batch requests. For the first of a series of
    equal batch request, a set of workers is spawned to pre-cache the batches
    in parallel processes. This way, subsequent requests can be served quickly.

    A note on changing the requests sent to `PreCache`.
    Given requests A and B, if requests are sent in the sequence:
    A, ..., A, B, A, ..., A, B, A, ...
    Precache will build a Queue of batches that satisfy A, and handle requests
    B on demand. This prevents `PreCache` from discarding the queue on every
    SnapshotRequest.
    However if B request replace A as the most common request, i.e.:
    A, A, A, ..., A, B, B, B, ...,
    `PreCache` will discard the A queue and build a B queue after it has seen
    more B requests than A requests out of the last 5 requests.

    This node only makes sense if:

    1. Incoming batch requests are repeatedly the same.
    2. There is a source of randomness in upstream nodes.

    Args:

        cache_size (``int``):

            How many batches to hold at most in the cache.

        num_workers (``int``):

            How many processes to spawn to fill the cache.
    '''
    def __init__(self, cache_size=50, num_workers=20):

        self.current_request = None
        self.workers = None
        self.cache_size = cache_size
        self.num_workers = num_workers

        # keep track of recent requests
        self.last_5 = deque([
            None,
        ] * 5, maxlen=5)

    def teardown(self):

        if self.workers is not None:
            self.workers.stop()

    def provide(self, request):

        timing = Timing(self)
        timing.start()

        # update recent requests
        self.last_5.popleft()
        self.last_5.append(request)

        if request != self.current_request:

            current_count = sum([
                recent_request == self.current_request
                for recent_request in self.last_5
            ])
            new_count = sum(
                [recent_request == request for recent_request in self.last_5])
            if new_count > current_count or self.current_request is None:

                if self.workers is not None:
                    logger.info(
                        "new request received, stopping current workers...")
                    self.workers.stop()

                self.current_request = copy.deepcopy(request)

                logger.info("starting new set of workers...")
                self.workers = ProducerPool(
                    [
                        lambda i=i: self.__run_worker(i)
                        for i in range(self.num_workers)
                    ],
                    queue_size=self.cache_size,
                )
                self.workers.start()

                logger.debug("getting batch from queue...")
                batch = self.workers.get()

                timing.stop()
                batch.profiling_stats.add(timing)

            else:
                logger.debug("Resolving new request sequentially")
                batch = self.get_upstream_provider().request_batch(request)

                timing.stop()
                batch.profiling_stats.add(timing)

        else:
            logger.debug("getting batch from queue...")
            batch = self.workers.get()

            timing.stop()
            batch.profiling_stats.add(timing)

        return batch

    def __run_worker(self, i):

        return self.get_upstream_provider().request_batch(self.current_request)