示例#1
0
    def post(self):
        '''
        Reset seg model.

        '''
        if all(ctype not in request.headers['Content-Type']
               for ctype in ('multipart/form-data', 'application/json')):
            msg = ('Content-Type should be multipart/form-data or '
                   'application/json')
            logger().error(msg)
            return make_response(jsonify(error=msg), 400)
        if get_state() != TrainState.IDLE.value:
            msg = 'Process is running. Model cannot be reset.'
            logger().error(msg)
            return make_response(jsonify(error=msg), 500)
        try:
            device = get_device()
            if 'multipart/form-data' in request.headers['Content-Type']:
                print(request.form)
                req_json = json.loads(request.form.get('data'))
                file = request.files['file']
                checkpoint = torch.load(file.stream, map_location=device)
                state_dicts = checkpoint if isinstance(checkpoint,
                                                       list) else [checkpoint]
                req_json['url'] = None
            else:
                req_json = request.get_json()
                state_dicts = None
            req_json['device'] = device
            config = ResetConfig(req_json)
            logger().info(config)
            redis_client.set(REDIS_KEY_STATE, TrainState.RUN.value)
            init_seg_models(config.model_path,
                            config.keep_axials,
                            config.device,
                            config.is_3d,
                            config.n_models,
                            config.n_crops,
                            config.zpath_input,
                            config.crop_size,
                            config.scales,
                            url=config.url,
                            state_dicts=state_dicts,
                            is_cpu=config.is_cpu(),
                            cache_maxbytes=config.cache_maxbytes)
        except RuntimeError as e:
            logger().exception('Failed in init_seg_models')
            return make_response(jsonify(error=f'Runtime Error: {e}'), 500)
        except Exception as e:
            logger().exception('Failed in init_seg_models')
            return make_response(jsonify(error=f'Exception: {e}'), 500)
        finally:
            gc.collect()
            torch.cuda.empty_cache()
            redis_client.set(REDIS_KEY_STATE, TrainState.IDLE.value)
        return make_response(jsonify({'completed': True}))
示例#2
0
    def post(self):
        '''
        Update flow label.

        '''
        if request.headers['Content-Type'] != 'application/json':
            msg = 'Content-Type should be application/json'
            logger().error(msg)
            return make_response(jsonify(error=msg), 400)
        state = get_state()
        while (state == TrainState.WAIT.value):
            logger().info(f'waiting @{request.path}')
            time.sleep(1)
            state = get_state()
            if (state == TrainState.IDLE.value):
                return make_response(jsonify({'completed': False}))
        if redis_client.get(REDIS_KEY_UPDATE_ONGOING_FLOW):
            msg = 'Last update is ongoing'
            logger().error(msg)
            return make_response(jsonify(error=msg), 400)
        try:
            redis_client.set(REDIS_KEY_STATE, TrainState.WAIT.value)
            req_json = request.get_json()
            req_json['device'] = get_device()
            config = FlowTrainConfig(req_json)
            logger().info(config)
            if req_json.get('reset'):
                zarr.open_like(zarr.open(config.zpath_flow_label, mode='r'),
                               config.zpath_flow_label,
                               mode='w')
                return make_response(jsonify({'completed': True}))

            spots_dict = collections.defaultdict(list)
            for spot in req_json.get('spots'):
                spots_dict[spot['t']].append(spot)
            if not spots_dict:
                msg = 'nothing to update'
                logger().error(msg)
                return make_response(jsonify(error=msg), 400)
            spots_dict = collections.OrderedDict(sorted(spots_dict.items()))

            response = _update_flow_labels(spots_dict, config.scales,
                                           config.zpath_flow_label,
                                           config.flow_norm_factor)
        except RuntimeError as e:
            logger().exception('Failed in update_flow_labels')
            return make_response(jsonify(error=f'Runtime Error: {e}'), 500)
        except Exception as e:
            logger().exception('Failed in update_flow_labels')
            return make_response(jsonify(error=f'Exception: {e}'), 500)
        finally:
            if get_state() != TrainState.IDLE.value:
                redis_client.set(REDIS_KEY_STATE, state)
            redis_client.delete(REDIS_KEY_UPDATE_ONGOING_FLOW)
        return response
示例#3
0
    def post(self):
        '''
        Download outputs in CTC format.

        '''
        if request.headers['Content-Type'] != 'application/json':
            msg = 'Content-Type should be application/json'
            return make_response(jsonify(error=msg), 400)
        state = get_state()
        while (state == TrainState.WAIT.value):
            logger().info(f'waiting @{request.path}')
            time.sleep(1)
            state = get_state()
            if (state == TrainState.IDLE.value):
                return make_response('', 204)
        try:
            req_json = request.get_json()
            req_json['device'] = get_device()
            config = ExportConfig(req_json)
            za_input = zarr.open(config.zpath_input, mode='r')
            config.shape = za_input.shape[1:]
            logger().info(config)
            spots_dict = collections.defaultdict(list)
            for spot in req_json.get('spots'):
                spots_dict[spot['t']].append(spot)
            spots_dict = collections.OrderedDict(sorted(spots_dict.items()))

            redis_client.set(REDIS_KEY_STATE, TrainState.WAIT.value)
            result = export_ctc_labels(config, spots_dict)
            if isinstance(result, str):
                resp = send_file(result)
                # file_remover.cleanup_once_done(resp, result)
            elif not result:
                resp = make_response('', 204)
            else:
                resp = make_response('', 200)
        except RuntimeError as e:
            logger().exception('Failed in export_ctc_labels')
            return make_response(jsonify(error=f'Runtime Error: {e}'), 500)
        except Exception as e:
            logger().exception('Failed in export_ctc_labels')
            return make_response(jsonify(error=f'Exception: {e}'), 500)
        finally:
            if get_state() != TrainState.IDLE.value:
                redis_client.set(REDIS_KEY_STATE, state)
        return resp
示例#4
0
    def post(self):
        '''
        Update process state.

        '''
        if request.headers['Content-Type'] != 'application/json':
            msg = 'Content-Type should be application/json'
            logger().error(msg)
            return make_response(jsonify(error=msg), 400)
        req_json = request.get_json()
        state = req_json.get(REDIS_KEY_STATE)
        if (not isinstance(state, int)
                or state not in TrainState._value2member_map_):
            msg = f'Invalid state: {state}'
            logger().error(msg)
            return make_response(jsonify(res=msg), 400)
        redis_client.set(REDIS_KEY_STATE, state)
        return make_response(jsonify(success=True, state=get_state()))
示例#5
0
 def post(self):
     if request.headers['Content-Type'] != 'application/json':
         msg = 'Content-Type should be application/json'
         logger().error(msg)
         return make_response(jsonify(error=msg), 400)
     req_json = request.get_json()
     lr = req_json.get('lr')
     if not isinstance(lr, float) or lr < 0:
         msg = f'Invalid learning rate: {lr}'
         logger().error(msg)
         return make_response(jsonify(error=msg), 400)
     n_crops = req_json.get('n_crops')
     if not isinstance(n_crops, int) or n_crops < 0:
         msg = f'Invalid number of crops: {n_crops}'
         logger().error(msg)
         return make_response(jsonify(error=msg), 400)
     redis_client.set(REDIS_KEY_LR, str(lr))
     redis_client.set(REDIS_KEY_NCROPS, str(n_crops))
     logger().info(f'[params updated] lr: {lr}, n_crops: {n_crops}')
     return make_response(
         jsonify(success=True,
                 lr=float(redis_client.get(REDIS_KEY_LR)),
                 n_crops=int(redis_client.get(REDIS_KEY_NCROPS))))
示例#6
0
    def post(self):
        '''
        Train seg model.

        '''
        if request.headers['Content-Type'] != 'application/json':
            msg = 'Content-Type should be application/json'
            logger().error(msg)
            return make_response(jsonify(error=msg), 400)
        try:
            req_json = request.get_json()
            req_json['device'] = get_device()
            config = SegmentationTrainConfig(req_json)
            logger().info(config)
            if config.n_crops < 1:
                msg = 'n_crops should be a positive number'
                logger().error(msg)
                return make_response(jsonify(error=msg), 400)

            spots_dict = collections.defaultdict(list)
            for spot in req_json.get('spots'):
                spots_dict[spot['t']].append(spot)
            if not (spots_dict or config.is_livemode):
                msg = 'nothing to train'
                logger().error(msg)
                return make_response(jsonify(error=msg), 400)
            spots_dict = collections.OrderedDict(sorted(spots_dict.items()))

            if get_state() != TrainState.IDLE.value:
                msg = 'Process is running'
                logger().error(msg)
                return make_response(jsonify(error=msg), 500)
            redis_client.set(REDIS_KEY_STATE, TrainState.RUN.value)
            if config.is_livemode:
                redis_client.delete(REDIS_KEY_TIMEPOINT)
            else:
                try:
                    _update_seg_labels(spots_dict,
                                       config.scales,
                                       config.zpath_input,
                                       config.zpath_seg_label,
                                       config.zpath_seg_label_vis,
                                       config.auto_bg_thresh,
                                       config.c_ratio,
                                       memmap_dir=config.memmap_dir)
                except KeyboardInterrupt:
                    return make_response(jsonify({'completed': False}))
            step_offset = 0
            for path in sorted(Path(config.log_dir).glob('event*')):
                try:
                    *_, last_record = TFRecordDataset(str(path))
                    last = event_pb2.Event.FromString(last_record.numpy()).step
                    step_offset = max(step_offset, last + 1)
                except Exception:
                    pass
            epoch_start = 0
            async_result = train_seg_task.delay(
                list(spots_dict.keys()),
                config.batch_size,
                config.crop_size,
                config.class_weights,
                config.false_weight,
                config.model_path,
                config.n_epochs,
                config.keep_axials,
                config.scales,
                config.lr,
                config.n_crops,
                config.is_3d,
                config.is_livemode,
                config.scale_factor_base,
                config.rotation_angle,
                config.contrast,
                config.zpath_input,
                config.zpath_seg_label,
                config.log_interval,
                config.log_dir,
                step_offset,
                epoch_start,
                config.is_cpu(),
                config.is_mixed_precision,
                config.cache_maxbytes,
                config.memmap_dir,
                config.input_size,
            )
            while not async_result.ready():
                if (redis_client is not None
                        and get_state() == TrainState.IDLE.value):
                    logger().info('training aborted')
                    return make_response(jsonify({'completed': False}))
        except Exception as e:
            logger().exception('Failed in train_seg')
            return make_response(jsonify(error=f'Exception: {e}'), 500)
        finally:
            torch.cuda.empty_cache()
            redis_client.set(REDIS_KEY_STATE, TrainState.IDLE.value)
        return make_response(jsonify({'completed': True}))
示例#7
0
    def post(self):
        '''
        Update seg label.

        '''
        if request.headers['Content-Type'] != 'application/json':
            msg = 'Content-Type should be application/json'
            logger().error(msg)
            return make_response(jsonify(error=msg), 400)
        state = get_state()
        while (state == TrainState.WAIT.value):
            logger().info(f'waiting @{request.path}')
            time.sleep(1)
            state = get_state()
            if (state == TrainState.IDLE.value):
                return make_response(jsonify({'completed': False}))
        if redis_client.get(REDIS_KEY_UPDATE_ONGOING_SEG):
            msg = 'Last update is ongoing'
            logger().error(msg)
            return make_response(jsonify(error=msg), 400)
        try:
            redis_client.set(REDIS_KEY_UPDATE_ONGOING_SEG, 1)
            redis_client.set(REDIS_KEY_STATE, TrainState.WAIT.value)
            req_json = request.get_json()
            req_json['device'] = get_device()
            config = SegmentationTrainConfig(req_json)
            logger().info(config)
            if config.is_livemode and redis_client.get(REDIS_KEY_TIMEPOINT):
                msg = 'Last update/training is ongoing'
                logger().error(msg)
                return make_response(jsonify(error=msg), 400)
            if req_json.get('reset'):
                try:
                    zarr.open_like(zarr.open(config.zpath_seg_label, mode='r'),
                                   config.zpath_seg_label,
                                   mode='w')
                    zarr.open_like(zarr.open(config.zpath_seg_label_vis,
                                             mode='r'),
                                   config.zpath_seg_label_vis,
                                   mode='w')
                except RuntimeError as e:
                    logger().exception('Failed in opening zarr')
                    return make_response(jsonify(error=f'Runtime Error: {e}'),
                                         500)
                except Exception as e:
                    logger().exception('Failed in opening zarr')
                    return make_response(jsonify(error=f'Exception: {e}'), 500)
                return make_response(jsonify({'completed': True}))

            spots_dict = collections.defaultdict(list)
            for spot in req_json.get('spots'):
                spots_dict[spot['t']].append(spot)
            if not spots_dict:
                msg = 'nothing to update'
                logger().error(msg)
                return make_response(jsonify(error=msg), 400)
            if config.is_livemode and len(spots_dict.keys()) != 1:
                msg = 'Livemode should update only a single timepoint'
                logger().error(msg)
                return make_response(jsonify(error=msg), 400)
            spots_dict = collections.OrderedDict(sorted(spots_dict.items()))

            try:
                response = _update_seg_labels(spots_dict,
                                              config.scales,
                                              config.zpath_input,
                                              config.zpath_seg_label,
                                              config.zpath_seg_label_vis,
                                              config.auto_bg_thresh,
                                              config.c_ratio,
                                              config.is_livemode,
                                              memmap_dir=config.memmap_dir)
            except KeyboardInterrupt:
                return make_response(jsonify({'completed': False}))
            except Exception as e:
                logger().exception('Failed in _update_seg_labels')
                return make_response(jsonify(error=f'Exception: {e}'), 500)
        finally:
            if get_state() != TrainState.IDLE.value:
                redis_client.set(REDIS_KEY_STATE, state)
            redis_client.delete(REDIS_KEY_UPDATE_ONGOING_SEG)
        return response
示例#8
0
    def post(self):
        '''
        Predict seg.

        '''
        if request.headers['Content-Type'] != 'application/json':
            msg = 'Content-Type should be application/json'
            return make_response(jsonify(error=msg), 400)
        state = get_state()
        while (state == TrainState.WAIT.value):
            logger().info(f'waiting @{request.path}')
            time.sleep(1)
            state = get_state()
            if (state == TrainState.IDLE.value):
                return make_response(jsonify({'completed': False}))
        try:
            req_json = request.get_json()
            req_json['device'] = get_device()
            config = SegmentationEvalConfig(req_json)
            logger().info(config)
            redis_client.set(REDIS_KEY_STATE, TrainState.WAIT.value)
            async_result = detect_spots_task.delay(
                str(config.device),
                config.model_path,
                config.keep_axials,
                config.is_pad,
                config.is_3d,
                config.crop_size,
                config.scales,
                config.cache_maxbytes,
                config.use_2d,
                config.use_median,
                config.patch_size,
                config.crop_box,
                config.c_ratio,
                config.p_thresh,
                config.r_min,
                config.r_max,
                config.output_prediction,
                config.zpath_input,
                config.zpath_seg_output,
                config.timepoint,
                None,
                config.memmap_dir,
                config.batch_size,
                config.input_size,
            )
            while not async_result.ready():
                if (redis_client is not None
                        and get_state() == TrainState.IDLE.value):
                    logger().info('prediction aborted')
                    return make_response(
                        jsonify({
                            'spots': [],
                            'completed': False
                        }))
            if async_result.failed():
                raise async_result.result
            spots = async_result.result
            if spots is None:
                logger().info('prediction aborted')
                return make_response(jsonify({
                    'spots': [],
                    'completed': False
                }))
            publish_mq('prediction', 'Prediction updated')
        except Exception as e:
            logger().exception('Failed in detect_spots')
            return make_response(jsonify(error=f'Exception: {e}'), 500)
        finally:
            gc.collect()
            torch.cuda.empty_cache()
            if get_state() != TrainState.IDLE.value:
                redis_client.set(REDIS_KEY_STATE, state)
        return make_response(jsonify({'spots': spots, 'completed': True}))
示例#9
0
def _update_seg_labels(spots_dict,
                       scales,
                       zpath_input,
                       zpath_seg_label,
                       zpath_seg_label_vis,
                       auto_bg_thresh=0,
                       c_ratio=0.5,
                       is_livemode=False,
                       memmap_dir=None):
    if is_livemode:
        assert len(spots_dict.keys()) == 1
    za_input = zarr.open(zpath_input, mode='r')
    za_label = zarr.open(zpath_seg_label, mode='a')
    za_label_vis = zarr.open(zpath_seg_label_vis, mode='a')
    keyorder = ['tp', 'fp', 'tn', 'fn', 'tb', 'fb']
    MIN_AREA_ELLIPSOID = 9
    img_shape = za_input.shape[1:]
    n_dims = len(img_shape)
    keybase = Path(za_label.store.path).parent.name
    if scales is None:
        scales = (1., ) * n_dims
    scales = np.array(scales)
    for t, spots in spots_dict.items():
        label_indices = set()
        centroids = []
        label = np.zeros(img_shape, dtype='uint8')
        label_vis = np.zeros(img_shape + (3, ), dtype='uint8')
        if 0 < auto_bg_thresh:
            indices_bg = np.nonzero(
                get_input_at(za_input, t, memmap_dir=memmap_dir) <
                auto_bg_thresh)
            label[indices_bg] = 1
            label_vis[indices_bg] = (255, 0, 0)
            if INDICES_THRESHOLD < len(indices_bg[0]):
                label_indices = None
            else:
                label_indices.update(tuple(map(tuple, np.array(indices_bg).T)))
        cnt = collections.Counter({x: 0 for x in keyorder})
        for spot in spots:
            if get_state() == TrainState.IDLE.value:
                logger().info('update aborted')
                raise KeyboardInterrupt
            cnt[spot['tag']] += 1
            centroid = np.array(spot['pos'][::-1])
            centroid = centroid[-n_dims:]
            centroids.append((centroid / scales).astype(int).tolist())
            covariance = np.array(spot['covariance'][::-1]).reshape(3, 3)
            covariance = covariance[-n_dims:, -n_dims:]
            radii, rotation = np.linalg.eigh(covariance)
            radii = np.sqrt(radii)
            if n_dims == 3:
                draw_func = ellipsoid
                dilate_func = _dilate_3d_indices
            else:
                draw_func = ellipse
                dilate_func = _dilate_2d_indices
            indices_outer = draw_func(centroid, radii, rotation, scales,
                                      img_shape, MIN_AREA_ELLIPSOID)
            if label_indices is not None:
                if INDICES_THRESHOLD < len(indices_outer[0]):
                    label_indices = None
                else:
                    label_indices.update(
                        tuple(map(tuple,
                                  np.array(indices_outer).T)))
                    if INDICES_THRESHOLD < len(label_indices):
                        label_indices = None
            if spot['tag'] in ['tp', 'tb', 'tn']:
                label_offset = 0
                label_vis_value = 255
            else:
                label_offset = 3
                label_vis_value = 127
            cond_outer_1 = np.fmod(label[indices_outer] - 1, 3) <= 1
            if spot['tag'] in ('tp', 'fn'):
                indices_inner = draw_func(centroid, radii * c_ratio, rotation,
                                          scales, img_shape,
                                          MIN_AREA_ELLIPSOID)
                indices_inner_p = dilate_func(*indices_inner, img_shape)
                label[indices_outer] = np.where(cond_outer_1, 2 + label_offset,
                                                label[indices_outer])
                label_vis[indices_outer] = np.where(cond_outer_1[..., None],
                                                    (0, label_vis_value, 0),
                                                    label_vis[indices_outer])
                label[indices_inner_p] = 2 + label_offset
                label_vis[indices_inner_p] = (0, label_vis_value, 0)
                cond_inner = np.fmod(label[indices_inner] - 1, 3) <= 2
                label[indices_inner] = np.where(cond_inner, 3 + label_offset,
                                                label[indices_inner])
                label_vis[indices_inner] = np.where(cond_inner[..., None],
                                                    (0, 0, label_vis_value),
                                                    label_vis[indices_inner])
            elif spot['tag'] in ('tb', 'fb'):
                label[indices_outer] = np.where(cond_outer_1, 2 + label_offset,
                                                label[indices_outer])
                label_vis[indices_outer] = np.where(cond_outer_1[..., None],
                                                    (0, label_vis_value, 0),
                                                    label_vis[indices_outer])
            elif spot['tag'] in ('tn', 'fp'):
                cond_outer_0 = np.fmod(label[indices_outer] - 1, 3) <= 0
                label[indices_outer] = np.where(cond_outer_0, 1 + label_offset,
                                                label[indices_outer])
                label_vis[indices_outer] = np.where(cond_outer_0[..., None],
                                                    (label_vis_value, 0, 0),
                                                    label_vis[indices_outer])
        logger().info('frame:{}, {}'.format(
            t, sorted(cnt.items(), key=lambda i: keyorder.index(i[0]))))
        if label_indices is None:
            for chunk in Path(za_label.store.path).glob(f'{t}.*'):
                chunk.unlink()
            for chunk in Path(za_label_vis.store.path).glob(f'{t}.*'):
                chunk.unlink()
            za_label[t] = label
            za_label_vis[t] = label_vis
        else:
            target = tuple(np.array(list(label_indices)).T)
            target_t = to_fancy_index(t, *target)
            target_vis = (
                tuple(np.tile(target[i], 3) for i in range(n_dims)) +
                (np.array(
                    sum(tuple([(c, ) * len(target[0]) for c in range(3)]),
                        ())), ))
            target_vis_t = to_fancy_index(t, *target_vis)
            for chunk in Path(za_label.store.path).glob(f'{t}.*'):
                chunk.unlink()
            for chunk in Path(za_label_vis.store.path).glob(f'{t}.*'):
                chunk.unlink()
            za_label[target_t] = label[target]
            za_label_vis[target_vis_t] = label_vis[target_vis]
        za_label.attrs[f'label.indices.{t}'] = centroids
        za_label.attrs['updated'] = True
        if memmap_dir:
            for fpath in Path(memmap_dir).glob(
                    f'{keybase}-t{t}*-seglabel.dat'):
                lock = FileLock(str(fpath) + '.lock')
                with lock:
                    if fpath.exists():
                        logger().info(f'remove {fpath}')
                        fpath.unlink()
        if is_livemode:
            if redis_client.get(REDIS_KEY_TIMEPOINT):
                msg = 'Last update/training is ongoing'
                logger().error(msg)
                return make_response(jsonify(error=msg), 400)
            redis_client.set(REDIS_KEY_TIMEPOINT, t)
    return make_response(jsonify({'completed': True}))
示例#10
0
    def __init__(self,
                 zpath_input,
                 zpath_seg_label,
                 indices,
                 img_size,
                 crop_size,
                 n_crops,
                 keep_axials=(True, ) * 4,
                 scales=None,
                 is_livemode=False,
                 scale_factor_base=0.2,
                 is_ae=False,
                 rotation_angle=None,
                 contrast=0.5,
                 is_eval=False,
                 length=None,
                 adaptive_length=False,
                 cache_maxbytes=None,
                 memmap_dir=None):
        """Generate dataset for segmentation.

        Args:
            zpath_input(String): a path to .zarr file for input data.
            zpath_seg_label(String): a path to .zarr file for label data.
            indices(list): a list of timepoints to be used.
            img_size(array-like of length ndim): input image size.
            crop_size(array-like of length ndim): crop size to generate dataset.
            n_crops(int): number of crops per timepoint.
            keep_axials(array-like of length 4): this value is used to calculate
                how many times down/up sampling are performed in z direction.
                Ignored for 2D data.
            scales(array-like of length ndim): a list of pixel/voxel size in
                physical unit (e.g. 0.5 μm/px). This is used to calculate scale
                factors for augmentation.
            is_livemode(boolean): True if training is performed in live mode.
            scale_factor_base(float): a base scale factor for augmentation.
            is_ae(boolean): True if called from a prior training.
            rotation_angle(float): rotation angle for augmentation in degree.
            contrast(float): contrast factor for augmentation.
            is_eval(boolean): True if the dataset is for evaluation, where no
                augmentation is performed.
            length(int): a lenght of this dataset. If None, it is automatically
                calculated by the length of indices and n_crops.
            adaptive_length(boolean): True if the length of the dataset is
                adjusted adaptively. For example, given that the length is 10
                and the len(indices) * self.n_crops is 6, the length becomes 6
                if adaptive_length is true, while it remains 10 if false.
            cache_maxbytes (int): size of the memory capacity for cache in byte.
            memmap_dir (str): path to a directory for storing memmap files.
        """
        if len(img_size) != len(crop_size):
            raise ValueError(
                'img_size: {} and crop_size: {} should have the same length'.
                format(img_size, crop_size))
        if scale_factor_base < 0 or 1 <= scale_factor_base:
            raise ValueError(
                'scale_factor_base should be 0 <= scale_factor_base < 1')
        self.n_dims = len(crop_size)
        if scales is None:
            scales = (1., ) * self.n_dims
        scale_factors = tuple(scale_factor_base * min(scales) / scales[i]
                              for i in range(self.n_dims))
        self.za_input = zarr.open(zpath_input, mode='r')
        crop_size = tuple(
            min(crop_size[i], img_size[i]) for i in range(self.n_dims))
        self.img_size = tuple(img_size)
        self.crop_size = crop_size
        self.scale_factors = scale_factors
        self.rand_crop_ranges = [
            (min(img_size[i], round(crop_size[i] * (1. - scale_factors[i]))),
             min(img_size[i] + 1,
                 int(crop_size[i] * (1. + scale_factors[i])) + 1))
            for i in range(self.n_dims)
        ]
        self.n_crops = n_crops
        self.is_ae = is_ae

        if is_ae:
            self.is_eval = False
            self.is_livemode = False
        else:
            # Label is not used for autoencoder
            self.za_label = zarr.open(zpath_seg_label, mode='r')
            self.zpath_seg_label = zpath_seg_label
            self.indices = indices
            self.is_eval = is_eval
            self.is_livemode = is_livemode and RUN_ON_FLASK
            if self.is_livemode:
                assert redis_client is not None
                redis_client.set(REDIS_KEY_NCROPS, str(n_crops))
            if (adaptive_length and (length is not None)
                    and (len(indices) * self.n_crops <= length)):
                length = None
        self.rotation_angle = rotation_angle
        self.contrast = contrast
        if self.is_livemode:
            self.length = int(redis_client.get(REDIS_KEY_NCROPS))
        else:
            self.length = length
        self.keep_axials = torch.tensor(keep_axials)
        if cache_maxbytes:
            self.use_cache = True
            self.cache_dict_input = LRUCacheDict(cache_maxbytes // 2)
            self.cache_dict_label = LRUCacheDict(cache_maxbytes // 2)
        else:
            self.use_cache = False
            self.cache_dict_input = None
            self.cache_dict_label = None
        self.memmap_dir = memmap_dir