def get_source(self, array, overwrite_spec=None): if overwrite_spec: return gp.ZarrSource(self.filename, {array: self.ds_name}, array_specs={array: overwrite_spec}) else: return gp.ZarrSource(self.filename, {array: self.ds_name})
def build_source(self): data = daisy.open_ds(filename, key) if self.time_window is None: source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape()) else: offs = list(data.roi.get_offset()) offs[1] += self.time_window[0] sh = list(data.roi.get_shape()) offs[1] = self.time_window[1] - self.time_window[0] source_roi = gp.Roi(tuple(offs), tuple(sh)) voxel_size = gp.Coordinate(data.voxel_size) return gp.ZarrSource(filename, { self.raw_0: key, self.raw_1: key }, array_specs={ self.raw_0: gp.ArraySpec( roi=source_roi, voxel_size=voxel_size, interpolatable=True), self.raw_1: gp.ArraySpec( roi=source_roi, voxel_size=voxel_size, interpolatable=True) })
def predict(iteration,path_to_dataGP): input_size = (8, 96, 96) output_size = (4, 64, 64) amount_size = gp.Coordinate((2, 16, 16)) model = SpineUNet(crop_output='output_size') raw = gp.ArrayKey('RAW') affs_predicted = gp.ArrayKey('AFFS_PREDICTED') reference_request = gp.BatchRequest() reference_request.add(raw, input_size) reference_request.add(affs_predicted, output_size) source = gp.ZarrSource( path_to_dataGP, { raw: 'validate/sample1/raw' } ) with gp.build(source): source_roi = source.spec[raw].roi request = gp.BatchRequest() request[raw] = gp.ArraySpec(roi=source_roi) request[affs_predicted] = gp.ArraySpec(roi=source_roi) pipeline = ( source + gp.Pad(raw,amount_size) + gp.Normalize(raw) + # raw: (d, h, w) gp.Stack(1) + # raw: (1, d, h, w) AddChannelDim(raw) + # raw: (1, 1, d, h, w) gp_torch.Predict( model, inputs={'x': raw}, outputs={0: affs_predicted}, checkpoint=f'C:/Users/filip/spine_yodl/model_checkpoint_{iteration}') + RemoveChannelDim(raw) + RemoveChannelDim(raw) + RemoveChannelDim(affs_predicted) + # raw: (d, h, w) # affs_predicted: (3, d, h, w) gp.Scan(reference_request) ) with gp.build(pipeline): prediction = pipeline.request_batch(request) return prediction[raw].data, prediction[affs_predicted].data
def __init__(self, filename, key, density=None, channels=0, shape=(16, 256, 256), time_window=None, add_sparse_mosaic_channel=True, random_rot=False): self.filename = filename self.key = key self.shape = shape self.density = density self.raw = gp.ArrayKey('RAW_0') self.add_sparse_mosaic_channel = add_sparse_mosaic_channel self.random_rot = random_rot self.channels = channels data = daisy.open_ds(filename, key) if time_window is None: source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape()) else: offs = list(data.roi.get_offset()) offs[1] += time_window[0] sh = list(data.roi.get_shape()) offs[1] = time_window[1] - time_window[0] source_roi = gp.Roi(tuple(offs), tuple(sh)) voxel_size = gp.Coordinate(data.voxel_size) self.pipeline = gp.ZarrSource( filename, { self.raw: key }, array_specs={ self.raw: gp.ArraySpec( roi=source_roi, voxel_size=voxel_size, interpolatable=True) }) + gp.RandomLocation() + IntensityDiffFilter(self.raw, 0, min_distance=0.1, channels=Slice(None)) # add augmentations self.pipeline = self.pipeline + gp.ElasticAugment([40, 40], [2, 2], [0, math.pi / 2.0], prob_slip=-1, spatial_dims=2) self.pipeline.setup() np.random.seed(os.getpid() + int(time.time()))
def build_pipeline(data_dir, model, checkpoint_file, input_size, output_size, raw, labels, affs_predicted, dataset_shape, num_samples, sample_size): checkpoint = torch.load(checkpoint_file) model.load_state_dict(checkpoint['model_state_dict']) scan_request = gp.BatchRequest() scan_request.add(raw, input_size) scan_request.add(affs_predicted, output_size) scan_request.add(labels, output_size) pipeline = ( gp.ZarrSource(str(data_dir), { raw: 'validate/raw', labels: 'validate/gt' }) + gp.Pad(raw, size=None) + gp.Normalize(raw) + # raw: (s, h, w) # labels: (s, h, w) train.AddChannelDim(raw) + # raw: (c=1, s, h, w) # labels: (s, h, w) train.TransposeDims(raw, (1, 0, 2, 3)) + # raw: (s, c=1, h, w) # labels: (s, h, w) Predict(model=model, inputs={'x': raw}, outputs={0: affs_predicted}) + # raw: (s, c=1, h, w) # affs_predicted: (s, c=2, h, w) # labels: (s, h, w) train.TransposeDims(raw, (1, 0, 2, 3)) + train.RemoveChannelDim(raw) + # raw: (s, h, w) # affs_predicted: (s, c=2, h, w) # labels: (s, h, w) gp.PrintProfilingStats(every=100) + gp.Scan(scan_request)) return pipeline
def predict_volume(model, dataset, out_dir, out_filename, out_ds_names, input_key='0/raw', normalize_factor=None, model_output=0, in_shape=None, out_shape=None, spawn_subprocess=True, num_workers=0): raw = gp.ArrayKey('RAW') prediction = gp.ArrayKey('PREDICTION') data = daisy.open_ds(dataset.filename, dataset.ds_names[0]) source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape()) voxel_size = gp.Coordinate(data.voxel_size) data_dims = len(data.shape) # Get in and out shape if in_shape is None: in_shape = model.in_shape if out_shape is None: out_shape = model.out_shape in_shape = gp.Coordinate(in_shape) out_shape = gp.Coordinate(out_shape) spatial_dims = in_shape.dims() if apply_voxel_size: in_shape = in_shape * voxel_size out_shape = out_shape * voxel_size logger.info(f"source roi: {source_roi}") logger.info(f"in_shape: {in_shape}") logger.info(f"out_shape: {out_shape}") logger.info(f"voxel_size: {voxel_size}") request = gp.BatchRequest() request.add(raw, in_shape) request.add(prediction, out_shape) context = (in_shape - out_shape) / 2 print("context", context, in_shape, out_shape) source = (gp.ZarrSource( dataset.filename, { raw: dataset.ds_names[0], }, array_specs={raw: gp.ArraySpec(roi=source_roi, interpolatable=True)})) num_additional_channels = (2 + spatial_dims) - data_dims for _ in range(num_additional_channels): source += AddChannelDim(raw) # prediction requires samples first, channels second source += TransposeDims(raw, (1, 0) + tuple(range(2, 2 + spatial_dims))) with gp.build(source): raw_roi = source.spec[raw].roi logger.info(f"raw_roi: {raw_roi}") pipeline = source if normalize_factor != "skip": pipeline = pipeline + gp.Normalize(raw, factor=normalize_factor) pipeline = pipeline + (gp.Pad(raw, context) + gp.torch.Predict( model, inputs={input_name: raw}, outputs={model_output: prediction}, array_specs={prediction: gp.ArraySpec(roi=raw_roi)}, checkpoint=checkpoint, spawn_subprocess=spawn_subprocess)) # # remove sample dimension for 3D data # pipeline += RemoveChannelDim(raw) # pipeline += RemoveChannelDim(prediction) pipeline += (gp.ZarrWrite({ prediction: out_ds_names[0], }, output_dir=out_dir, output_filename=out_filename, compression_type='gzip') + gp.Scan(request, num_workers=num_workers)) with gp.build(pipeline): pipeline.request_batch(gp.BatchRequest())
def predict(iteration, raw_file, raw_dataset, out_file, db_host, db_name, worker_config, network_config, out_properties={}, **kwargs): setup_dir = os.path.dirname(os.path.realpath(__file__)) with open( os.path.join(setup_dir, '{}_net_config.json'.format(network_config)), 'r') as f: net_config = json.load(f) # voxels input_shape = gp.Coordinate(net_config['input_shape']) output_shape = gp.Coordinate(net_config['output_shape']) # nm voxel_size = gp.Coordinate((40, 4, 4)) input_size = input_shape * voxel_size output_size = output_shape * voxel_size parameterfile = os.path.join(setup_dir, 'parameter.json') if os.path.exists(parameterfile): with open(parameterfile, 'r') as f: parameters = json.load(f) else: parameters = {} raw = gp.ArrayKey('RAW') pred_postpre_vectors = gp.ArrayKey('PRED_POSTPRE_VECTORS') pred_post_indicator = gp.ArrayKey('PRED_POST_INDICATOR') chunk_request = gp.BatchRequest() chunk_request.add(raw, input_size) chunk_request.add(pred_postpre_vectors, output_size) chunk_request.add(pred_post_indicator, output_size) d_property = out_properties[ 'pred_partner_vectors'] if 'pred_partner_vectors' in out_properties else None m_property = out_properties[ 'pred_syn_indicator_out'] if 'pred_syn_indicator_out' in out_properties else None # Hdf5Source if raw_file.endswith('.hdf'): pipeline = gp.Hdf5Source(raw_file, datasets={raw: raw_dataset}, array_specs={ raw: gp.ArraySpec(interpolatable=True), }) elif raw_file.endswith('.zarr') or raw_file.endswith('.n5'): pipeline = gp.ZarrSource(raw_file, datasets={raw: raw_dataset}, array_specs={ raw: gp.ArraySpec(interpolatable=True), }) else: raise RuntimeError('unknwon input data format {}'.format(raw_file)) pipeline += gp.Pad(raw, size=None) pipeline += gp.Normalize(raw) pipeline += gp.IntensityScaleShift(raw, 2, -1) pipeline += gp.tensorflow.Predict( os.path.join(setup_dir, 'train_net_checkpoint_%d' % iteration), inputs={net_config['raw']: raw}, outputs={ net_config['pred_syn_indicator_out']: pred_post_indicator, net_config['pred_partner_vectors']: pred_postpre_vectors }, graph=os.path.join(setup_dir, '{}_net.meta'.format(network_config))) d_scale = parameters['d_scale'] if 'd_scale' in parameters else None if d_scale != 1 and d_scale is not None: pipeline += gp.IntensityScaleShift(pred_postpre_vectors, 1. / d_scale, 0) # Map back to nm world. if m_property is not None and 'scale' in m_property: if m_property['scale'] != 1: pipeline += gp.IntensityScaleShift(pred_post_indicator, m_property['scale'], 0) if d_property is not None and 'scale' in d_property: pipeline += gp.IntensityScaleShift(pred_postpre_vectors, d_property['scale'], 0) if d_property is not None and 'dtype' in d_property: assert d_property['dtype'] == 'int8' or d_property[ 'dtype'] == 'float32', 'predict not adapted to dtype {}'.format( d_property['dtype']) if d_property['dtype'] == 'int8': pipeline += IntensityScaleShiftClip(pred_postpre_vectors, 1, 0, clip=(-128, 127)) pipeline += gp.ZarrWrite(dataset_names={ pred_post_indicator: 'volumes/pred_syn_indicator', pred_postpre_vectors: 'volumes/pred_partner_vectors', }, output_filename=out_file) pipeline += gp.PrintProfilingStats(every=10) pipeline += gp.DaisyRequestBlocks( chunk_request, roi_map={ raw: 'read_roi', pred_postpre_vectors: 'write_roi', pred_post_indicator: 'write_roi' }, num_workers=worker_config['num_cache_workers'], block_done_callback=lambda b, s, d: block_done_callback( db_host, db_name, worker_config, b, s, d)) print("Starting prediction...") with gp.build(pipeline): pipeline.request_batch(gp.BatchRequest()) print("Prediction finished")
def create_train_pipeline(self, model): print( f"Creating training pipeline with batch size {self.params['batch_size']}" ) filename = self.params['data_file'] raw_dataset = self.params['dataset']['train']['raw'] gt_dataset = self.params['dataset']['train']['gt'] optimizer = self.params['optimizer'](model.parameters(), **self.params['optimizer_kwargs']) raw = gp.ArrayKey('RAW') gt_labels = gp.ArrayKey('LABELS') gt_aff = gp.ArrayKey('AFFINITIES') predictions = gp.ArrayKey('PREDICTIONS') emb = gp.ArrayKey('EMBEDDING') raw_data = daisy.open_ds(filename, raw_dataset) source_roi = gp.Roi(raw_data.roi.get_offset(), raw_data.roi.get_shape()) source_voxel_size = gp.Coordinate(raw_data.voxel_size) out_voxel_size = gp.Coordinate(raw_data.voxel_size) # Get in and out shape in_shape = gp.Coordinate(model.in_shape) out_shape = gp.Coordinate(model.out_shape[2:]) is_2d = in_shape.dims() == 2 in_shape = in_shape * out_voxel_size out_shape = out_shape * out_voxel_size context = (in_shape - out_shape) / 2 gt_labels_out_shape = out_shape # Add fake 3rd dim if is_2d: source_voxel_size = gp.Coordinate((1, *source_voxel_size)) source_roi = gp.Roi((0, *source_roi.get_offset()), (raw_data.shape[0], *source_roi.get_shape())) context = gp.Coordinate((0, *context)) aff_neighborhood = [[0, -1, 0], [0, 0, -1]] gt_labels_out_shape = (1, *gt_labels_out_shape) else: aff_neighborhood = [[-1, 0, 0], [0, -1, 0], [0, 0, -1]] logger.info(f"source roi: {source_roi}") logger.info(f"in_shape: {in_shape}") logger.info(f"out_shape: {out_shape}") logger.info(f"voxel_size: {out_voxel_size}") logger.info(f"context: {context}") request = gp.BatchRequest() request.add(raw, in_shape) request.add(gt_aff, out_shape) request.add(predictions, out_shape) snapshot_request = gp.BatchRequest() snapshot_request[emb] = gp.ArraySpec( roi=gp.Roi((0, ) * in_shape.dims(), gp.Coordinate((*model.base_encoder.out_shape[2:], )) * out_voxel_size)) snapshot_request[gt_labels] = gp.ArraySpec( roi=gp.Roi(context, gt_labels_out_shape)) source = ( gp.ZarrSource(filename, { raw: raw_dataset, gt_labels: gt_dataset }, array_specs={ raw: gp.ArraySpec(roi=source_roi, voxel_size=source_voxel_size, interpolatable=True), gt_labels: gp.ArraySpec(roi=source_roi, voxel_size=source_voxel_size) }) + gp.Normalize(raw, self.params['norm_factor']) + gp.Pad(raw, context) + gp.Pad(gt_labels, context) + gp.RandomLocation() # raw : (l=1, h, w) # gt_labels: (l=1, h, w) ) source = self._augmentation_pipeline(raw, source) pipeline = ( source + # raw : (l=1, h, w) # gt_labels: (l=1, h, w) gp.AddAffinities(aff_neighborhood, gt_labels, gt_aff) + SetDtype(gt_aff, np.float32) + # raw : (l=1, h, w) # gt_aff : (c=2, l=1, h, w) AddChannelDim(raw) # raw : (c=1, l=1, h, w) # gt_aff : (c=2, l=1, h, w) ) if is_2d: pipeline = ( pipeline + RemoveSpatialDim(raw) + RemoveSpatialDim(gt_aff) # raw : (c=1, h, w) # gt_aff : (c=2, h, w) ) pipeline = ( pipeline + gp.Stack(self.params['batch_size']) + gp.PreCache() + # raw : (b, c=1, h, w) # gt_aff : (b, c=2, h, w) # (which is what train requires) gp.torch.Train( model, self.loss, optimizer, inputs={'raw': raw}, loss_inputs={ 0: predictions, 1: gt_aff }, outputs={ 0: predictions, 1: emb }, array_specs={ predictions: gp.ArraySpec(voxel_size=out_voxel_size), }, checkpoint_basename=self.logdir + '/checkpoints/model', save_every=self.params['save_every'], log_dir=self.logdir, log_every=self.log_every) + # everything is 2D at this point, plus extra dimensions for # channels and batch # raw : (b, c=1, h, w) # gt_aff : (b, c=2, h, w) # predictions: (b, c=2, h, w) # Crop GT to look at labels gp.Crop(gt_labels, gp.Roi(context, gt_labels_out_shape)) + gp.Snapshot(output_dir=self.logdir + '/snapshots', output_filename='it{iteration}.hdf', dataset_names={ raw: 'raw', gt_labels: 'gt_labels', predictions: 'predictions', gt_aff: 'gt_aff', emb: 'emb' }, additional_request=snapshot_request, every=self.params['save_every']) + gp.PrintProfilingStats(every=500)) return pipeline, request
def train(until): model = SpineUNet() loss = torch.nn.BCELoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.001) input_size = (8, 96, 96) raw = gp.ArrayKey('RAW') labels = gp.ArrayKey('LABELS') affs = gp.ArrayKey('AFFS') affs_predicted = gp.ArrayKey('AFFS_PREDICTED') pipeline = ( ( gp.ZarrSource( 'data/20200201.zarr', { raw: 'train/sample1/raw', labels: 'train/sample1/labels' }), gp.ZarrSource( 'data/20200201.zarr', { raw: 'train/sample2/raw', labels: 'train/sample2/labels' }), gp.ZarrSource( 'data/20200201.zarr', { raw: 'train/sample3/raw', labels: 'train/sample3/labels' }) ) + gp.RandomProvider() + gp.Normalize(raw) + gp.RandomLocation() + gp.SimpleAugment(transpose_only=(1, 2)) + gp.ElasticAugment((2, 10, 10), (0.0, 0.5, 0.5), [0, math.pi]) + gp.AddAffinities( [(1, 0, 0), (0, 1, 0), (0, 0, 1)], labels, affs) + gp.Normalize(affs, factor=1.0) + #gp.PreCache(num_workers=1) + # raw: (d, h, w) # affs: (3, d, h, w) gp.Stack(1) + # raw: (1, d, h, w) # affs: (1, 3, d, h, w) AddChannelDim(raw) + # raw: (1, 1, d, h, w) # affs: (1, 3, d, h, w) gp_torch.Train( model, loss, optimizer, inputs={'x': raw}, outputs={0: affs_predicted}, loss_inputs={0: affs_predicted, 1: affs}, save_every=10000) + RemoveChannelDim(raw) + RemoveChannelDim(raw) + RemoveChannelDim(affs) + RemoveChannelDim(affs_predicted) + # raw: (d, h, w) # affs: (3, d, h, w) # affs_predicted: (3, d, h, w) gp.Snapshot( { raw: 'raw', labels: 'labels', affs: 'affs', affs_predicted: 'affs_predicted' }, every=500, output_filename='iteration_{iteration}.hdf') ) request = gp.BatchRequest() request.add(raw, input_size) request.add(labels, input_size) request.add(affs, input_size) request.add(affs_predicted, input_size) with gp.build(pipeline): for i in range(until): pipeline.request_batch(request)
def predict_volume(model, dataset, out_dir, out_filename, out_ds_names, checkpoint, input_name='raw_0', normalize_factor=None, model_output=0, in_shape=None, out_shape=None, spawn_subprocess=True, num_workers=0, apply_voxel_size=True): raw = gp.ArrayKey('RAW') prediction = gp.ArrayKey('PREDICTION') data = daisy.open_ds(dataset.filename, dataset.ds_names[0]) source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape()) voxel_size = gp.Coordinate(data.voxel_size) data_dims = len(data.shape) # Get in and out shape if in_shape is None: in_shape = model.in_shape if out_shape is None: out_shape = model.out_shape in_shape = gp.Coordinate(in_shape) out_shape = gp.Coordinate(out_shape) spatial_dims = in_shape.dims() is_2d = spatial_dims == 2 in_shape = in_shape * voxel_size out_shape = out_shape * voxel_size logger.info(f"source roi: {source_roi}") logger.info(f"in_shape: {in_shape}") logger.info(f"out_shape: {out_shape}") logger.info(f"voxel_size: {voxel_size}") request = gp.BatchRequest() request.add(raw, in_shape) request.add(prediction, out_shape) context = (in_shape - out_shape) / 2 source = (gp.ZarrSource( dataset.filename, { raw: dataset.ds_names[0], }, array_specs={raw: gp.ArraySpec(roi=source_roi, interpolatable=True)})) # ensure raw has sample and channel dims # # n = number of samples # c = number of channels # 2D raw is either (n, y, x) or (c, n, y, x) # 3D raw is either (z, y, x) or (c, z, y, x) for _ in range((2 + spatial_dims) - data_dims): source += AddChannelDim(raw) # 2D raw: (c, n, y, x) # 3D raw: (c, n=1, z, y, x) # prediction requires samples first, channels second source += TransposeDims(raw, (1, 0) + tuple(range(2, 2 + spatial_dims))) # 2D raw: (n, c, y, x) # 3D raw: (n=1, c, z, y, x) with gp.build(source): raw_roi = source.spec[raw].roi logger.info(f"raw_roi: {raw_roi}") pipeline = source if normalize_factor != "skip": pipeline = pipeline + gp.Normalize(raw, factor=normalize_factor) pipeline = pipeline + (gp.Pad(raw, context) + gp.torch.Predict( model, inputs={input_name: raw}, outputs={model_output: prediction}, array_specs={prediction: gp.ArraySpec(roi=raw_roi)}, checkpoint=checkpoint, spawn_subprocess=spawn_subprocess)) # 2D raw : (n, c, y, x) # 2D prediction: (n, c, y, x) # 3D raw : (n=1, c, z, y, x) # 3D prediction: (n=1, c, z, y, x) if is_2d: # restore channels first for 2D data pipeline += TransposeDims(raw, (1, 0) + tuple(range(2, 2 + spatial_dims))) pipeline += TransposeDims(prediction, (1, 0) + tuple(range(2, 2 + spatial_dims))) else: # remove sample dimension for 3D data pipeline += RemoveChannelDim(raw) pipeline += RemoveChannelDim(prediction) # 2D raw : (c, n, y, x) # 2D prediction: (c, n, y, x) # 3D raw : (c, z, y, x) # 3D prediction: (c, z, y, x) pipeline += (gp.ZarrWrite({ prediction: out_ds_names[0], }, output_dir=out_dir, output_filename=out_filename, compression_type='gzip') + gp.Scan(request, num_workers=num_workers)) logger.info("Writing prediction to %s/%s[%s]", out_dir, out_filename, out_ds_names[0]) with gp.build(pipeline): pipeline.request_batch(gp.BatchRequest())
def train_until(max_iteration): in_channels = 1 num_fmaps = 12 fmap_inc_factors = 6 downsample_factors = [(1, 3, 3), (1, 3, 3), (3, 3, 3)] unet = UNet(in_channels, num_fmaps, fmap_inc_factors, downsample_factors, constant_upsample=True) model = Convolve(unet, 12, 1) loss = torch.nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=1e-6) # start of gunpowder part: raw = gp.ArrayKey('RAW') points = gp.GraphKey('POINTS') groundtruth = gp.ArrayKey('RASTER') prediction = gp.ArrayKey('PRED_POINT') grad = gp.ArrayKey('GRADIENT') voxel_size = gp.Coordinate((40, 4, 4)) input_shape = (96, 430, 430) output_shape = (60, 162, 162) input_size = gp.Coordinate(input_shape) * voxel_size output_size = gp.Coordinate(output_shape) * voxel_size request = gp.BatchRequest() request.add(raw, input_size) request.add(points, output_size) request.add(groundtruth, output_size) request.add(prediction, output_size) request.add(grad, output_size) pos_sources = tuple( gp.ZarrSource(filename, {raw: 'volumes/raw'}, {raw: gp.ArraySpec(interpolatable=True)}) + AddCenterPoint(points, raw) + gp.Pad(raw, None) + gp.RandomLocation(ensure_nonempty=points) for filename in pos_samples) + gp.RandomProvider() neg_sources = tuple( gp.ZarrSource(filename, {raw: 'volumes/raw'}, {raw: gp.ArraySpec(interpolatable=True)}) + AddNoPoint(points, raw) + gp.RandomLocation() for filename in neg_samples) + gp.RandomProvider() data_sources = (pos_sources, neg_sources) data_sources += gp.RandomProvider(probabilities=[0.9, 0.1]) data_sources += gp.Normalize(raw) train_pipeline = data_sources train_pipeline += gp.ElasticAugment(control_point_spacing=[4, 40, 40], jitter_sigma=[0, 2, 2], rotation_interval=[0, math.pi / 2.0], prob_slip=0.05, prob_shift=0.05, max_misalign=10, subsample=8) train_pipeline += gp.SimpleAugment(transpose_only=[1, 2]) train_pipeline += gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1, \ z_section_wise=True) train_pipeline += gp.RasterizePoints( points, groundtruth, array_spec=gp.ArraySpec(voxel_size=voxel_size), settings=gp.RasterizationSettings(radius=(100, 100, 100), mode='peak')) train_pipeline += gp.PreCache(cache_size=40, num_workers=10) train_pipeline += Reshape(raw, (1, 1) + input_shape) train_pipeline += Reshape(groundtruth, (1, 1) + output_shape) train_pipeline += gp_torch.Train(model=model, loss=loss, optimizer=optimizer, inputs={'x': raw}, outputs={0: prediction}, loss_inputs={ 0: prediction, 1: groundtruth }, gradients={0: grad}, save_every=1000, log_dir='log') train_pipeline += Reshape(raw, input_shape) train_pipeline += Reshape(groundtruth, output_shape) train_pipeline += Reshape(prediction, output_shape) train_pipeline += Reshape(grad, output_shape) train_pipeline += gp.Snapshot( { raw: 'volumes/raw', groundtruth: 'volumes/groundtruth', prediction: 'volumes/prediction', grad: 'volumes/gradient' }, every=500, output_filename='test_{iteration}.hdf') train_pipeline += gp.PrintProfilingStats(every=10) with gp.build(train_pipeline): for i in range(max_iteration): train_pipeline.request_batch(request)
def make_pipeline(self): raw = gp.ArrayKey('RAW') embs = gp.ArrayKey('EMBS') source_shape = zarr.open(self.data_file)[self.dataset].shape raw_roi = gp.Roi(np.zeros(len(source_shape[1:])), source_shape[1:]) data = daisy.open_ds(self.data_file, self.dataset) source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape()) voxel_size = gp.Coordinate(data.voxel_size) # Get in and out shape in_shape = gp.Coordinate(self.model.in_shape) out_shape = gp.Coordinate(self.model.out_shape[2:]) is_2d = in_shape.dims() == 2 logger.info(f"source roi: {source_roi}") logger.info(f"in_shape: {in_shape}") logger.info(f"out_shape: {out_shape}") logger.info(f"voxel_size: {voxel_size}") in_shape = in_shape * voxel_size out_shape = out_shape * voxel_size logger.info(f"source roi: {source_roi}") logger.info(f"in_shape: {in_shape}") logger.info(f"out_shape: {out_shape}") logger.info(f"voxel_size: {voxel_size}") request = gp.BatchRequest() request.add(raw, in_shape) request.add(embs, out_shape) context = (in_shape - out_shape) / 2 source = (gp.ZarrSource(self.data_file, { raw: self.dataset, }, array_specs={ raw: gp.ArraySpec(roi=source_roi, interpolatable=False) })) if is_2d: source = (source + AddChannelDim(raw, axis=1)) else: source = (source + AddChannelDim(raw, axis=0) + AddChannelDim(raw)) source = ( source # raw : (c=1, roi) ) with gp.build(source): raw_roi = source.spec[raw].roi logger.info(f"raw_roi: {raw_roi}") pipeline = ( source + gp.Normalize(raw, factor=self.params['norm_factor']) + gp.Pad(raw, context) + gp.PreCache() + gp.torch.Predict(self.model, inputs={'raw': raw}, outputs={0: embs}, array_specs={embs: gp.ArraySpec(roi=raw_roi)})) pipeline = (pipeline + gp.ZarrWrite({ embs: 'embs', }, output_dir=self.curr_log_dir, output_filename=self.dataset + '_embs.zarr', compression_type='gzip') + gp.Scan(request)) return pipeline, request, embs
def predict_frame(in_shape, out_shape, model_output, model_configfile, model_checkpoint, input_dataset_file, inference_frame, out_dir, out_filename, out_key_or_index=1, intermediate_layer=None, dataset_raw_key="train/raw", dataset_prediction_key="train/prediction", dataset_intermediate_key="train/prediction_interm", model_input_tensor_name="patches", model_architecture="PatchedResnet", num_workers=5): # initialize model if model_architecture == "PatchedResnet": model = PatchedResnet(1, 2, resnet_size=18) elif model_architecture == "unet": model = lisl.models.create(model_configfile) else: raise NotImplementedError(f"{model_architecture} not implemented") model.add_spatial_dim = True model.eval() # gp variables in_shape = gp.Coordinate(in_shape) out_shape = gp.Coordinate(out_shape) raw = gp.ArrayKey(f'RAW_{inference_frame}') prediction = gp.ArrayKey(f'PREDICTION_{inference_frame}') intermediate_prediction = gp.ArrayKey(f'ITERM_{inference_frame}') ds_key = f'{dataset_raw_key}/{inference_frame}' out_key = f'{dataset_prediction_key}/{inference_frame}' interm_key = f'{dataset_intermediate_key}/{inference_frame}' # build pipeline zsource = gp.ZarrSource( input_dataset_file, {raw: ds_key}, {raw: gp.ArraySpec(interpolatable=True, voxel_size=(1, 1))}) pipeline = zsource with gp.build(zsource): raw_roi = zsource.spec[raw].roi logger.info(f"raw_roi: {raw_roi}") pipeline += AddChannelDim(raw) pipeline += AddChannelDim(raw) pipeline += gp.Pad(raw, None) # setup prediction node pred_dict = {out_key_or_index: prediction} pred_spec = {prediction: gp.ArraySpec(roi=raw_roi)} if intermediate_layer is not None: pred_dict[intermediate_layer] = intermediate_prediction pred_spec[intermediate_prediction] = gp.ArraySpec(roi=raw_roi) pipeline += gp.torch.Predict(model, inputs={model_input_tensor_name: raw}, outputs=pred_dict, array_specs=pred_spec, checkpoint=model_checkpoint, spawn_subprocess=True) request = gp.BatchRequest() request.add(raw, in_shape) request.add(prediction, out_shape) zarr_dict = {prediction: out_key} if intermediate_layer is not None: zarr_dict[intermediate_prediction] = interm_key request.add(intermediate_prediction, out_shape) pipeline += gp.Scan(request, num_workers=num_workers) pipeline += gp.ZarrWrite(zarr_dict, output_dir=out_dir, output_filename=out_filename, compression_type='gzip') total_request = gp.BatchRequest() total_request[prediction] = gp.ArraySpec(roi=raw_roi) if intermediate_layer is not None: total_request[intermediate_prediction] = gp.ArraySpec(roi=raw_roi) with gp.build(pipeline): pipeline.request_batch(total_request)
def get_mouselight_data_sources(setup_config: Dict[str, Any], source_samples: List[str], locations=False): # Source Paths and accessibility raw_n5 = setup_config["RAW_N5"] mongo_url = setup_config["MONGO_URL"] samples_path = Path(setup_config["SAMPLES_PATH"]) # specified_locations = setup_config.get("SPECIFIED_LOCATIONS") # Graph matching parameters point_balance_radius = setup_config["POINT_BALANCE_RADIUS"] matching_failures_dir = setup_config["MATCHING_FAILURES_DIR"] matching_failures_dir = (matching_failures_dir if matching_failures_dir is None else Path(matching_failures_dir)) # Data Properties voxel_size = gp.Coordinate(setup_config["VOXEL_SIZE"]) output_shape = gp.Coordinate(setup_config["OUTPUT_SHAPE"]) output_size = output_shape * voxel_size micron_scale = voxel_size[0] distance_attr = setup_config["DISTANCE_ATTRIBUTE"] target_distance = float(setup_config["MIN_DIST_TO_FALLBACK"]) max_nonempty_points = int(setup_config["MAX_RANDOM_LOCATION_POINTS"]) mongo_db_template = setup_config["MONGO_DB_TEMPLATE"] matched_source = setup_config.get("MATCHED_SOURCE", "matched") # New array keys # Note: These are intended to be requested with size input_size raw = ArrayKey("RAW") matched = gp.PointsKey("MATCHED") nonempty_placeholder = gp.PointsKey("NONEMPTY") labels = ArrayKey("LABELS") ensure_nonempty = nonempty_placeholder node_offset = { sample.name: (daisy.persistence.MongoDbGraphProvider( mongo_db_template.format(sample=sample.name, source="skeletonization"), mongo_url, ).num_nodes(None) + 1) for sample in samples_path.iterdir() if sample.name in source_samples } # if specified_locations is not None: # centers = pickle.load(open(specified_locations, "rb")) # random = gp.SpecifiedLocation # kwargs = {"locations": centers, "choose_randomly": True} # logger.info(f"Using specified locations from {specified_locations}") # elif locations: # random = RandomLocations # kwargs = { # "ensure_nonempty": ensure_nonempty, # "ensure_centered": True, # "point_balance_radius": point_balance_radius * micron_scale, # "loc": gp.ArrayKey("RANDOM_LOCATION"), # } # else: random = RandomLocation kwargs = { "ensure_nonempty": ensure_nonempty, "ensure_centered": True, "point_balance_radius": point_balance_radius * micron_scale, } data_sources = (tuple( ( gp.ZarrSource( filename=str((sample / raw_n5).absolute()), datasets={raw: "volume-rechunked"}, array_specs={ raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size, dtype=np.uint16) }, ), DaisyGraphProvider( mongo_db_template.format(sample=sample.name, source=matched_source), mongo_url, points=[matched], directed=True, node_attrs=[], edge_attrs=[], ), FilteredDaisyGraphProvider( mongo_db_template.format(sample=sample.name, source=matched_source), mongo_url, points=[nonempty_placeholder], directed=True, node_attrs=["distance_to_fallback"], edge_attrs=[], num_nodes=max_nonempty_points, dist_attribute=distance_attr, min_dist=target_distance, ), ) + gp.MergeProvider() + random(**kwargs) + gp.Normalize(raw) + FilterComponents( matched, node_offset[sample.name], centroid_size=output_size) + RasterizeSkeleton( points=matched, array=labels, array_spec=gp.ArraySpec( interpolatable=False, voxel_size=voxel_size, dtype=np.int64), ) for sample in samples_path.iterdir() if sample.name in source_samples) + gp.RandomProvider()) return (data_sources, raw, labels, nonempty_placeholder, matched)
def validation_pipeline(config): """ Per block { Raw -> predict -> scan gt -> rasterize -> merge -> candidates -> trees } -> merge -> comatch + evaluate """ blocks = config["BLOCKS"] benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"]) sample = config["VALIDATION_SAMPLES"][0] sample_dir = Path(config["SAMPLES_PATH"]) raw_n5 = config["RAW_N5"] transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt" neuron_width = int(config["NEURON_RADIUS"]) voxel_size = gp.Coordinate(config["VOXEL_SIZE"]) micron_scale = max(voxel_size) input_shape = gp.Coordinate(config["INPUT_SHAPE"]) output_shape = gp.Coordinate(config["OUTPUT_SHAPE"]) input_size = voxel_size * input_shape output_size = voxel_size * output_shape distance_attr = config["DISTANCE_ATTR"] candidate_threshold = config["NMS_THRESHOLD"] candidate_spacing = min(config["NMS_WINDOW_SIZE"]) * micron_scale coordinate_scale = config["COORDINATE_SCALE"] * np.array( voxel_size) / micron_scale emb_model = get_emb_model(config) fg_model = get_fg_model(config) validation_pipelines = [] specs = {} for block in blocks: validation_dir = get_validation_dir(benchmark_datasets_path, block) trees = [] cube = None for gt_file in validation_dir.iterdir(): if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc": trees.append(gt_file) if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc": cube = gt_file assert cube.exists() cube_roi = get_roi_from_swc( cube, Path(transform_template.format(sample=sample)), np.array([300, 300, 1000]), ) raw = gp.ArrayKey(f"RAW_{block}") ground_truth = gp.GraphKey(f"GROUND_TRUTH_{block}") labels = gp.ArrayKey(f"LABELS_{block}") candidates = gp.ArrayKey(f"CANDIDATES_{block}") mst = gp.GraphKey(f"MST_{block}") raw_source = (gp.ZarrSource( filename=str(Path(sample_dir, sample, raw_n5).absolute()), datasets={raw: "volume-rechunked"}, array_specs={ raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size) }, ) + gp.Normalize(raw, dtype=np.float32) + mCLAHE([raw], [20, 64, 64])) emb_source, emb = add_emb_pred(config, raw_source, raw, block, emb_model) pred_source, fg = add_fg_pred(config, emb_source, raw, block, fg_model) pred_source = add_scan(pred_source, { raw: input_size, emb: output_size, fg: output_size }) swc_source = nl.gunpowder.nodes.MouselightSwcFileSource( validation_dir, [ground_truth], transform_file=transform_template.format(sample=sample), ignore_human_nodes=False, scale=voxel_size, transpose=[2, 1, 0], points_spec=[ gp.PointsSpec(roi=gp.Roi( gp.Coordinate([None, None, None]), gp.Coordinate([None, None, None]), )) ], ) additional_request = BatchRequest() input_roi = cube_roi.grow((input_size - output_size) // 2, (input_size - output_size) // 2) block_spec = specs.setdefault(block, {}) block_spec["raw"] = (raw, gp.ArraySpec(input_roi)) additional_request[raw] = gp.ArraySpec(roi=input_roi) block_spec["ground_truth"] = (ground_truth, gp.GraphSpec(cube_roi)) additional_request[ground_truth] = gp.GraphSpec(roi=cube_roi) block_spec["labels"] = (labels, gp.ArraySpec(cube_roi)) additional_request[labels] = gp.ArraySpec(roi=cube_roi) block_spec["fg_pred"] = (fg, gp.ArraySpec(cube_roi)) additional_request[fg] = gp.ArraySpec(roi=cube_roi) block_spec["emb_pred"] = (emb, gp.ArraySpec(cube_roi)) additional_request[emb] = gp.ArraySpec(roi=cube_roi) block_spec["candidates"] = (candidates, gp.ArraySpec(cube_roi)) additional_request[candidates] = gp.ArraySpec(roi=cube_roi) block_spec["mst_pred"] = (mst, gp.GraphSpec(cube_roi)) additional_request[mst] = gp.GraphSpec(roi=cube_roi) pipeline = ((swc_source, pred_source) + gp.nodes.MergeProvider() + nl.gunpowder.RasterizeSkeleton( ground_truth, labels, connected_component_labeling=True, array_spec=gp.ArraySpec( voxel_size=voxel_size, dtype=np.int64, roi=gp.Roi( gp.Coordinate([None, None, None]), gp.Coordinate([None, None, None]), ), ), ) + nl.gunpowder.GrowLabels( labels, radii=[neuron_width * micron_scale]) + Skeletonize(fg, candidates, candidate_spacing, candidate_threshold) + EMST( emb, candidates, mst, distance_attr=distance_attr, coordinate_scale=coordinate_scale, ) + gp.Snapshot( { raw: f"volumes/{raw}", ground_truth: f"points/{ground_truth}", labels: f"volumes/{labels}", fg: f"volumes/{fg}", emb: f"volumes/{emb}", candidates: f"volumes/{candidates}", mst: f"points/{mst}", }, additional_request=additional_request, output_dir="snapshots", output_filename="{id}.hdf", edge_attrs={mst: [distance_attr]}, )) validation_pipelines.append(pipeline) full_gt = gp.GraphKey("FULL_GT") full_mst = gp.GraphKey("FULL_MST") score = gp.ArrayKey("SCORE") validation_pipeline = ( tuple(pipeline for pipeline in validation_pipelines) + gp.MergeProvider() + MergeGraphs(specs, full_gt, full_mst) + Evaluate(full_gt, full_mst, score, edge_threshold_attr=distance_attr) + gp.PrintProfilingStats()) return validation_pipeline, score
def validation_data_sources_recomputed(config, blocks): benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"]) sample = config["VALIDATION_SAMPLES"][0] sample_dir = Path(config["SAMPLES_PATH"]) raw_n5 = config["RAW_N5"] transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt" neuron_width = int(config["NEURON_RADIUS"]) voxel_size = gp.Coordinate(config["VOXEL_SIZE"]) input_shape = gp.Coordinate(config["INPUT_SHAPE"]) output_shape = gp.Coordinate(config["OUTPUT_SHAPE"]) input_size = voxel_size * input_shape output_size = voxel_size * output_shape validation_dirs = {} for group in benchmark_datasets_path.iterdir(): if "validation" in group.name and group.is_dir(): for validation_dir in group.iterdir(): validation_num = int(validation_dir.name.split("_")[-1]) if validation_num in blocks: validation_dirs[validation_num] = validation_dir validation_dirs = [validation_dirs[block] for block in blocks] raw = gp.ArrayKey("RAW") ground_truth = gp.GraphKey("GROUND_TRUTH") labels = gp.ArrayKey("LABELS") validation_pipelines = [] for validation_dir in validation_dirs: trees = [] cube = None for gt_file in validation_dir.iterdir(): if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc": trees.append(gt_file) if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc": cube = gt_file assert cube.exists() cube_roi = get_roi_from_swc( cube, Path(transform_template.format(sample=sample)), np.array([300, 300, 1000]), ) pipeline = (( gp.ZarrSource( filename=str(Path(sample_dir, sample, raw_n5).absolute()), datasets={raw: "volume-rechunked"}, array_specs={ raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size) }, ), nl.gunpowder.nodes.MouselightSwcFileSource( validation_dir, [ground_truth], transform_file=transform_template.format(sample=sample), ignore_human_nodes=False, scale=voxel_size, transpose=[2, 1, 0], points_spec=[ gp.PointsSpec(roi=gp.Roi( gp.Coordinate([None, None, None]), gp.Coordinate([None, None, None]), )) ], ), ) + gp.nodes.MergeProvider() + gp.Normalize( raw, dtype=np.float32) + nl.gunpowder.RasterizeSkeleton( ground_truth, labels, connected_component_labeling=True, array_spec=gp.ArraySpec( voxel_size=voxel_size, dtype=np.int64, roi=gp.Roi( gp.Coordinate([None, None, None]), gp.Coordinate([None, None, None]), ), ), ) + nl.gunpowder.GrowLabels(labels, radii=[neuron_width * 1000])) request = gp.BatchRequest() input_roi = cube_roi.grow((input_size - output_size) // 2, (input_size - output_size) // 2) print(f"input_roi has shape: {input_roi.get_shape()}") print(f"cube_roi has shape: {cube_roi.get_shape()}") request[raw] = gp.ArraySpec(input_roi) request[ground_truth] = gp.GraphSpec(cube_roi) request[labels] = gp.ArraySpec(cube_roi) validation_pipelines.append((pipeline, request)) return validation_pipelines, (raw, labels, ground_truth)
chunks=(50, 50, 50), overwrite=True) f["sphere"].attrs["offset"] = (0, 0, 0) f["sphere"].attrs["resolution"] = (1, 1, 1) f["sphere"][:] = sphere # declare arrays to use labels = gp.ArrayKey("LABELS") stardists = gp.ArrayKey("STARDIST") # prepare requests scan_request = gp.BatchRequest() scan_request[stardists] = gp.Roi((0, 0, 0), (50, 50, 50)) request = gp.BatchRequest() source = gp.ZarrSource(os.path.join(directory, "sphere.n5"), datasets={labels: "sphere"}) # prepare node for 3D stardist generation with a maximum distance stardist_gen = gpstardist.AddStarDist3D(labels, stardists, rays=96, anisotropy=(1, 1, 1), grid=(1, 1, 1), max_dist=max_dist, unlabeled_id=-1, invalid_value=-3) # write result to a new dataset writer = gp.ZarrWrite( output_dir=directory, output_filename="sphere.n5",
def random_point_pairs_pipeline(model, loss, optimizer, dataset, augmentation_parameters, point_density, out_dir, normalize_factor=None, checkpoint_interval=5000, snapshot_interval=5000): raw_0 = gp.ArrayKey('RAW_0') points_0 = gp.GraphKey('POINTS_0') locations_0 = gp.ArrayKey('LOCATIONS_0') emb_0 = gp.ArrayKey('EMBEDDING_0') raw_1 = gp.ArrayKey('RAW_1') points_1 = gp.GraphKey('POINTS_1') locations_1 = gp.ArrayKey('LOCATIONS_1') emb_1 = gp.ArrayKey('EMBEDDING_1') # TODO parse this key from somewhere key = 'train/raw/0' data = daisy.open_ds(dataset.filename, key) source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape()) voxel_size = gp.Coordinate(data.voxel_size) emb_voxel_size = voxel_size # Get in and out shape in_shape = gp.Coordinate(model.in_shape) out_shape = gp.Coordinate(model.out_shape) logger.info(f"source roi: {source_roi}") logger.info(f"in_shape: {in_shape}") logger.info(f"out_shape: {out_shape}") logger.info(f"voxel_size: {voxel_size}") request = gp.BatchRequest() request.add(raw_0, in_shape) request.add(raw_1, in_shape) request.add(points_0, out_shape) request.add(points_1, out_shape) request[locations_0] = gp.ArraySpec(nonspatial=True) request[locations_1] = gp.ArraySpec(nonspatial=True) snapshot_request = gp.BatchRequest() snapshot_request[emb_0] = gp.ArraySpec(roi=request[points_0].roi) snapshot_request[emb_1] = gp.ArraySpec(roi=request[points_1].roi) # Let's hardcode this for now # TODO read actual number from zarr file keys n_samples = 447 batch_size = 1 dim = 2 padding = (100, 100) sources = [] for i in range(n_samples): ds_key = f'train/raw/{i}' image_sources = tuple( gp.ZarrSource( dataset.filename, {raw: ds_key}, {raw: gp.ArraySpec(interpolatable=True, voxel_size=(1, 1))}) + gp.Pad(raw, None) for raw in [raw_0, raw_1]) random_point_generator = RandomPointGenerator(density=point_density, repetitions=2) point_sources = tuple( (RandomPointSource(points_0, dim, random_point_generator=random_point_generator), RandomPointSource(points_1, dim, random_point_generator=random_point_generator))) # TODO: get augmentation parameters from some config file! points_and_image_sources = tuple( (img_source, point_source) + gp.MergeProvider() + \ gp.SimpleAugment() + \ gp.ElasticAugment( spatial_dims=2, control_point_spacing=(10, 10), jitter_sigma=(0.0, 0.0), rotation_interval=(0, math.pi/2)) + \ gp.IntensityAugment(r, scale_min=0.8, scale_max=1.2, shift_min=-0.2, shift_max=0.2, clip=False) + \ gp.NoiseAugment(r, var=0.01, clip=False) for r, img_source, point_source in zip([raw_0, raw_1], image_sources, point_sources)) sample_source = points_and_image_sources + gp.MergeProvider() data = daisy.open_ds(dataset.filename, ds_key) source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape()) sample_source += gp.Crop(raw_0, source_roi) sample_source += gp.Crop(raw_1, source_roi) sample_source += gp.Pad(raw_0, padding) sample_source += gp.Pad(raw_1, padding) sample_source += gp.RandomLocation() sources.append(sample_source) sources = tuple(sources) pipeline = sources + gp.RandomProvider() pipeline += gp.Unsqueeze([raw_0, raw_1]) pipeline += PrepareBatch(raw_0, raw_1, points_0, points_1, locations_0, locations_1) # How does prepare batch relate to Stack????? pipeline += RejectArray(ensure_nonempty=locations_1) pipeline += RejectArray(ensure_nonempty=locations_0) # batch content # raw_0: (1, h, w) # raw_1: (1, h, w) # locations_0: (n, 2) # locations_1: (n, 2) pipeline += gp.Stack(batch_size) # batch content # raw_0: (b, 1, h, w) # raw_1: (b, 1, h, w) # locations_0: (b, n, 2) # locations_1: (b, n, 2) pipeline += gp.PreCache(num_workers=10) pipeline += gp.torch.Train( model, loss, optimizer, inputs={ 'raw_0': raw_0, 'raw_1': raw_1 }, loss_inputs={ 'emb_0': emb_0, 'emb_1': emb_1, 'locations_0': locations_0, 'locations_1': locations_1 }, outputs={ 2: emb_0, 3: emb_1 }, array_specs={ emb_0: gp.ArraySpec(voxel_size=emb_voxel_size), emb_1: gp.ArraySpec(voxel_size=emb_voxel_size) }, checkpoint_basename=os.path.join(out_dir, 'model'), save_every=checkpoint_interval) pipeline += gp.Snapshot( { raw_0: 'raw_0', raw_1: 'raw_1', emb_0: 'emb_0', emb_1: 'emb_1', # locations_0 : 'locations_0', # locations_1 : 'locations_1', }, every=snapshot_interval, additional_request=snapshot_request) return pipeline, request
def batch__aug_data_generator(input_path, batch_size=12, voxel_shape=[1, 1, 1], input_shape=[240, 240, 4], output_shape=[240, 240, 4], without_background=False, mix_output=False, validate=False, seq=None): raw = gp.ArrayKey('raw') gt = gp.ArrayKey('ground_truth') files = os.listdir(input_path) files = [os.path.join(input_path, f) for f in files] pipeline = ( tuple( gp.ZarrSource( files[t], # the zarr container { raw: 'raw', gt: 'ground_truth' }, # which dataset to associate to the array key { raw: gp.ArraySpec(interpolatable=True, dtype=np.dtype('float32'), voxel_size=voxel_shape), gt: gp.ArraySpec(interpolatable=True, dtype=np.dtype('float32'), voxel_size=voxel_shape) } # meta-information ) + gp.RandomLocation() for t in range(len(files))) + gp.RandomProvider() # +gp.Stack(batch_size) ) input_size = gp.Coordinate(input_shape) output_size = gp.Coordinate(output_shape) request = gp.BatchRequest() request.add(raw, input_size) request.add(gt, input_size) diff = input_shape[1] - output_shape[1] diff = int(diff / 2) max_p = input_shape[1] - diff different_shape = diff > 0 if different_shape: print('Difference padding: {}'.format(diff)) with gp.build(pipeline): while 1: b = 0 imgs = [] masks = [] while b < batch_size: valid = False batch = pipeline.request_batch(request) if validate: valid = validate_mask(batch[gt].data) else: valid = True while (valid == False): batch = pipeline.request_batch(request) valid = validate_mask(batch[gt].data) im = batch[raw].data out = batch[gt].data # if different_shape: # out = out[diff:max_p,diff:max_p,:] if without_background: out = out[:, :, 1:4] if mix_output: out = out.argmax(axis=3).astype(float) imgs.append(im) masks.append(out) b = b + 1 imgs = np.asarray(imgs) masks = np.asarray(masks) if seq is not None: imgs, masks = augmentation(imgs, masks, seq) if different_shape: out = [] for m in masks: out.append(m[diff:max_p, diff:max_p, :]) masks = np.asarray(out) yield imgs, masks
def batch_data_generator(input_path,batch_size=12,voxel_shape = [1,1,1], input_shape= [240, 240,4], output_shape = [240, 240,4], without_background = False, mix_output = False, validate = False: raw = gp.ArrayKey('raw') gt = gp.ArrayKey('ground_truth') files = os.listdir(input_path) files = [os.path.join(input_path,f) for f in files ] pipeline =( tuple ( gp.ZarrSource( files[t], # the zarr container {raw: 'raw', gt : 'ground_truth'}, # which dataset to associate to the array key {raw: gp.ArraySpec(interpolatable=True,dtype=np.dtype('float32'),voxel_size=voxel_shape), gt: gp.ArraySpec(interpolatable=True,dtype=np.dtype('float32'),voxel_size=voxel_shape)} # meta-information ) + gp.RandomLocation() for t in range(len(files)) ) + gp.RandomProvider() # +gp.Stack(batch_size) ) input_size = gp.Coordinate(input_shape) output_size = gp.Coordinate(output_shape) request = gp.BatchRequest() request.add(raw,input_size) request.add(gt,input_size) diff = input_shape[1] - output_shape[1]
def create_train_pipeline(self, model): print(f"Creating training pipeline with batch size \ {self.params['batch_size']}") filename = self.params['data_file'] raw_dataset = self.params['dataset']['train']['raw'] gt_dataset = self.params['dataset']['train']['gt'] optimizer = self.params['optimizer'](model.parameters(), **self.params['optimizer_kwargs']) raw = gp.ArrayKey('RAW') gt_labels = gp.ArrayKey('LABELS') points = gp.GraphKey("POINTS") locations = gp.ArrayKey("LOCATIONS") predictions = gp.ArrayKey('PREDICTIONS') emb = gp.ArrayKey('EMBEDDING') raw_data = daisy.open_ds(filename, raw_dataset) source_roi = gp.Roi(raw_data.roi.get_offset(), raw_data.roi.get_shape()) source_voxel_size = gp.Coordinate(raw_data.voxel_size) out_voxel_size = gp.Coordinate(raw_data.voxel_size) # Get in and out shape in_shape = gp.Coordinate(model.in_shape) out_roi = gp.Coordinate(model.base_encoder.out_shape[2:]) is_2d = in_shape.dims() == 2 in_shape = in_shape * out_voxel_size out_roi = out_roi * out_voxel_size out_shape = gp.Coordinate( (self.params["num_points"], *model.out_shape[2:])) context = (in_shape - out_roi) / 2 gt_labels_out_shape = out_roi # Add fake 3rd dim if is_2d: source_voxel_size = gp.Coordinate((1, *source_voxel_size)) source_roi = gp.Roi((0, *source_roi.get_offset()), (raw_data.shape[0], *source_roi.get_shape())) context = gp.Coordinate((0, *context)) gt_labels_out_shape = (1, *gt_labels_out_shape) points_roi = out_voxel_size * tuple((*self.params["point_roi"], )) points_pad = (0, *points_roi) context = gp.Coordinate((0, None, None)) else: points_roi = source_voxel_size * tuple(self.params["point_roi"]) points_pad = points_roi context = gp.Coordinate((None, None, None)) logger.info(f"source roi: {source_roi}") logger.info(f"in_shape: {in_shape}") logger.info(f"out_shape: {out_shape}") logger.info(f"voxel_size: {out_voxel_size}") logger.info(f"context: {context}") logger.info(f"out_voxel_size: {out_voxel_size}") request = gp.BatchRequest() request.add(raw, in_shape) request.add(points, points_roi) request.add(gt_labels, out_roi) request[locations] = gp.ArraySpec(nonspatial=True) request[predictions] = gp.ArraySpec(nonspatial=True) snapshot_request = gp.BatchRequest() snapshot_request[emb] = gp.ArraySpec( roi=gp.Roi((0, ) * in_shape.dims(), gp.Coordinate((*model.base_encoder.out_shape[2:], )) * out_voxel_size)) source = ( (gp.ZarrSource(filename, { raw: raw_dataset, gt_labels: gt_dataset }, array_specs={ raw: gp.ArraySpec(roi=source_roi, voxel_size=source_voxel_size, interpolatable=True), gt_labels: gp.ArraySpec(roi=source_roi, voxel_size=source_voxel_size) }), PointsLabelsSource(points, self.data, scale=source_voxel_size)) + gp.MergeProvider() + gp.Pad(raw, context) + gp.Pad(gt_labels, context) + gp.Pad(points, points_pad) + gp.RandomLocation(ensure_nonempty=points) + gp.Normalize(raw, self.params['norm_factor']) # raw : (source_roi) # gt_labels: (source_roi) # points : (c=1, source_locations_shape) # If 2d then source_roi = (1, input_shape) in order to select a RL ) source = self._augmentation_pipeline(raw, source) pipeline = ( source + # Batches seem to be rejected because points are chosen near the # edge of the points ROI and the augmentations remove them. # TODO: Figure out if this is an actual issue, and if anything can # be done. gp.Reject(ensure_nonempty=points) + SetDtype(gt_labels, np.int64) + # raw : (source_roi) # gt_labels: (source_roi) # points : (c=1, source_locations_shape) AddChannelDim(raw) + AddChannelDim(gt_labels) # raw : (c=1, source_roi) # gt_labels: (c=2, source_roi) # points : (c=1, source_locations_shape) ) if is_2d: pipeline = ( # Remove extra dim the 2d roi had pipeline + RemoveSpatialDim(raw) + RemoveSpatialDim(gt_labels) + RemoveSpatialDim(points) # raw : (c=1, roi) # gt_labels: (c=1, roi) # points : (c=1, locations_shape) ) pipeline = ( pipeline + FillLocations(raw, points, locations, is_2d=False, max_points=1) + gp.Stack(self.params['batch_size']) + gp.PreCache() + # raw : (b, c=1, roi) # gt_labels: (b, c=1, roi) # locations: (b, c=1, locations_shape) # (which is what train requires) gp.torch.Train( model, self.loss, optimizer, inputs={ 'raw': raw, 'points': locations }, loss_inputs={ 0: predictions, 1: gt_labels, 2: locations }, outputs={ 0: predictions, 1: emb }, array_specs={ predictions: gp.ArraySpec(nonspatial=True), emb: gp.ArraySpec(voxel_size=out_voxel_size) }, checkpoint_basename=self.logdir + '/checkpoints/model', save_every=self.params['save_every'], log_dir=self.logdir, log_every=self.log_every) + # everything is 2D at this point, plus extra dimensions for # channels and batch # raw : (b, c=1, roi) # gt_labels : (b, c=1, roi) # predictions: (b, num_points) gp.Snapshot(output_dir=self.logdir + '/snapshots', output_filename='it{iteration}.hdf', dataset_names={ raw: 'raw', gt_labels: 'gt_labels', predictions: 'predictions', emb: 'emb' }, additional_request=snapshot_request, every=self.params['save_every']) + InspectBatch('END') + gp.PrintProfilingStats(every=500)) return pipeline, request
def build_pipeline( data_dir, model, save_every, batch_size, input_size, output_size, raw, labels, affs, affs_predicted, lr=1e-5): dataset_shape = zarr.open(str(data_dir))['train/raw'].shape num_samples = dataset_shape[0] sample_size = dataset_shape[1:] loss = torch.nn.MSELoss() optimizer = RAdam(model.parameters(), lr=lr) pipeline = ( gp.ZarrSource( data_dir, { raw: 'train/raw', labels: 'train/gt' }, array_specs={ raw: gp.ArraySpec( roi=gp.Roi((0, 0, 0), (num_samples,) + sample_size), voxel_size=(1, 1, 1)), labels: gp.ArraySpec( roi=gp.Roi((0, 0, 0), (num_samples,) + sample_size), voxel_size=(1, 1, 1)) }) + # raw: (d=1, h, w) # labels: (d=1, fmap_inc_factors=5h, w) gp.RandomLocation() + # raw: (d=1, h, w) # labels: (d=1, h, w) gp.AddAffinities( affinity_neighborhood=[(0, 1, 0), (0, 0, 1)], labels=labels, affinities=affs) + gp.Normalize(affs, factor=1.0) + # raw: (d=1, h, w) # affs: (c=2, d=1, h, w) Squash(dim=-3) + # get rid of z dim # raw: (h, w) # affs: (c=2, h, w) AddChannelDim(raw) + # raw: (c=1, h, w) # affs: (c=2, h, w) gp.PreCache() + gp.Stack(batch_size) + # raw: (b=10, c=1, h, w) # affs: (b=10, c=2, h, w) Train( model=model, loss=loss, optimizer=optimizer, inputs={'x': raw}, target=affs, output=affs_predicted, save_every=save_every, log_dir='log') + # raw: (b=10, c=1, h, w) # affs: (b=10, c=2, h, w) # affs_predicted: (b=10, c=2, h, w) TransposeDims(raw,(1, 0, 2, 3)) + TransposeDims(affs,(1, 0, 2, 3)) + TransposeDims(affs_predicted,(1, 0, 2, 3)) + # raw: (c=1, b=10, h, w) # affs: (c=2, b=10, h, w) # affs_predicted: (c=2, b=10, h, w) RemoveChannelDim(raw) + # raw: (b=10, h, w) # affs: (c=2, b=10, h, w) # affs_predicted: (c=2, b=10, h, w) gp.Snapshot( dataset_names={ raw: 'raw', labels: 'labels', affs: 'affs', affs_predicted: 'affs_predicted' }, every=100) + gp.PrintProfilingStats(every=100) ) return pipeline
def make_pipeline(self): raw = gp.ArrayKey('RAW') pred_affs = gp.ArrayKey('PREDICTIONS') source_shape = zarr.open(self.data_file)[self.dataset].shape raw_roi = gp.Roi(np.zeros(len(source_shape[1:])), source_shape[1:]) data = daisy.open_ds(self.data_file, self.dataset) source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape()) voxel_size = gp.Coordinate(data.voxel_size) # Get in and out shape in_shape = gp.Coordinate(self.model.in_shape) out_shape = gp.Coordinate(self.model.out_shape[2:]) is_2d = in_shape.dims() == 2 in_shape = in_shape * voxel_size out_shape = out_shape * voxel_size logger.info(f"source roi: {source_roi}") logger.info(f"in_shape: {in_shape}") logger.info(f"out_shape: {out_shape}") logger.info(f"voxel_size: {voxel_size}") request = gp.BatchRequest() request.add(raw, in_shape) request.add(pred_affs, out_shape) context = (in_shape - out_shape) / 2 source = (gp.ZarrSource(self.data_file, { raw: self.dataset, }, array_specs={ raw: gp.ArraySpec(roi=source_roi, interpolatable=False) })) in_dims = len(self.model.in_shape) if is_2d: # 2D: [samples, y, x] or [samples, channels, y, x] needs_channel_fix = (len(data.shape) - in_dims == 1) if needs_channel_fix: source = (source + AddChannelDim(raw, axis=1)) # raw [samples, channels, y, x] else: # 3D: [z, y, x] or [channel, z, y, x] or [sample, channel, z, y, x] needs_channel_fix = (len(data.shape) - in_dims == 0) needs_batch_fix = (len(data.shape) - in_dims <= 1) if needs_channel_fix: source = (source + AddChannelDim(raw, axis=0)) # Batch fix if needs_batch_fix: source = (source + AddChannelDim(raw)) # raw: [sample, channels, z, y, x] with gp.build(source): raw_roi = source.spec[raw].roi logger.info(f"raw_roi: {raw_roi}") pipeline = (source + gp.Normalize(raw, factor=self.params['norm_factor']) + gp.Pad(raw, context) + gp.PreCache() + gp.torch.Predict( self.model, inputs={'raw': raw}, outputs={0: pred_affs}, array_specs={pred_affs: gp.ArraySpec(roi=raw_roi)})) pipeline = (pipeline + gp.ZarrWrite({ pred_affs: 'predictions', }, output_dir=self.curr_log_dir, output_filename='predictions.zarr', compression_type='gzip') + gp.Scan(request)) return pipeline, request, pred_affs
def validation_pipeline(config): """ Per block { Raw -> predict -> scan gt -> rasterize -> merge -> candidates -> trees } -> merge -> comatch + evaluate """ blocks = config["BLOCKS"] benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"]) sample = config["VALIDATION_SAMPLES"][0] sample_dir = Path(config["SAMPLES_PATH"]) raw_n5 = config["RAW_N5"] transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt" neuron_width = int(config["NEURON_RADIUS"]) voxel_size = gp.Coordinate(config["VOXEL_SIZE"]) micron_scale = max(voxel_size) input_shape = gp.Coordinate(config["INPUT_SHAPE"]) output_shape = gp.Coordinate(config["OUTPUT_SHAPE"]) input_size = voxel_size * input_shape output_size = voxel_size * output_shape distance_attr = config["DISTANCE_ATTR"] validation_pipelines = [] specs = {} for block in blocks: validation_dir = get_validation_dir(benchmark_datasets_path, block) trees = [] cube = None for gt_file in validation_dir.iterdir(): if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc": trees.append(gt_file) if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc": cube = gt_file assert cube.exists() cube_roi = get_roi_from_swc( cube, Path(transform_template.format(sample=sample)), np.array([300, 300, 1000]), ) raw = gp.ArrayKey(f"RAW_{block}") raw_clahed = gp.ArrayKey(f"RAW_CLAHED_{block}") ground_truth = gp.GraphKey(f"GROUND_TRUTH_{block}") labels = gp.ArrayKey(f"LABELS_{block}") raw_source = (gp.ZarrSource( filename=str(Path(sample_dir, sample, raw_n5).absolute()), datasets={ raw: "volume-rechunked", raw_clahed: "volume-rechunked" }, array_specs={ raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size), raw_clahed: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size), }, ) + gp.Normalize(raw, dtype=np.float32) + gp.Normalize(raw_clahed, dtype=np.float32) + scipyCLAHE([raw_clahed], [20, 64, 64])) swc_source = nl.gunpowder.nodes.MouselightSwcFileSource( validation_dir, [ground_truth], transform_file=transform_template.format(sample=sample), ignore_human_nodes=False, scale=voxel_size, transpose=[2, 1, 0], points_spec=[ gp.PointsSpec(roi=gp.Roi( gp.Coordinate([None, None, None]), gp.Coordinate([None, None, None]), )) ], ) additional_request = BatchRequest() input_roi = cube_roi.grow((input_size - output_size) // 2, (input_size - output_size) // 2) cube_roi_shifted = gp.Roi((0, ) * len(cube_roi.get_shape()), cube_roi.get_shape()) input_roi = cube_roi_shifted.grow((input_size - output_size) // 2, (input_size - output_size) // 2) block_spec = specs.setdefault(block, {}) block_spec[raw] = gp.ArraySpec(input_roi) additional_request[raw] = gp.ArraySpec(roi=input_roi) block_spec[raw_clahed] = gp.ArraySpec(input_roi) additional_request[raw_clahed] = gp.ArraySpec(roi=input_roi) block_spec[ground_truth] = gp.GraphSpec(cube_roi_shifted) additional_request[ground_truth] = gp.GraphSpec(roi=cube_roi_shifted) block_spec[labels] = gp.ArraySpec(cube_roi_shifted) additional_request[labels] = gp.ArraySpec(roi=cube_roi_shifted) pipeline = ((swc_source, raw_source) + gp.nodes.MergeProvider() + gp.SpecifiedLocation(locations=[cube_roi.get_center()]) + gp.Crop(raw, roi=input_roi) + gp.Crop(raw_clahed, roi=input_roi) + gp.Crop(ground_truth, roi=cube_roi_shifted) + nl.gunpowder.RasterizeSkeleton( ground_truth, labels, connected_component_labeling=True, array_spec=gp.ArraySpec( voxel_size=voxel_size, dtype=np.int64, roi=gp.Roi( gp.Coordinate([None, None, None]), gp.Coordinate([None, None, None]), ), ), ) + nl.gunpowder.GrowLabels( labels, radii=[neuron_width * micron_scale]) + gp.Crop(labels, roi=cube_roi_shifted) + gp.Snapshot( { raw: f"volumes/{block}/raw", raw_clahed: f"volumes/{block}/raw_clahe", ground_truth: f"points/{block}/ground_truth", labels: f"volumes/{block}/labels", }, additional_request=additional_request, output_dir="validations", output_filename="validations.hdf", )) validation_pipelines.append(pipeline) validation_pipeline = (tuple(pipeline for pipeline in validation_pipelines) + gp.MergeProvider() + gp.PrintProfilingStats()) return validation_pipeline, specs
out_density_map = gp.ArrayKey('OUT_DENSITY_MAP') prediction = gp.ArrayKey('PREDICTION') class PrepareTrainingData(gp.BatchFilter): def process(self, batch, request): batch[out_cage_map].data = batch[out_cage_map].data.astype(np.float32) batch[out_cage_map].spec.dtype = np.float32 # assemble pipeline sourceA = gp.ZarrSource('../data/cropped_sample_A.zarr', { raw: 'raw', seg: 'segmentation' }, { raw: gp.ArraySpec(interpolatable=True), seg: gp.ArraySpec(interpolatable=False) }) sourceB = gp.ZarrSource('../data/cropped_sample_B.zarr', { raw: 'raw', seg: 'segmentation' }, { raw: gp.ArraySpec(interpolatable=True), seg: gp.ArraySpec(interpolatable=False) }) sourceC = gp.ZarrSource('../data/cropped_sample_C.zarr', { raw: 'raw', seg: 'segmentation' }, { raw: gp.ArraySpec(interpolatable=True),
def create_train_pipeline(self, model): optimizer = self.params['optimizer'](model.parameters(), **self.params['optimizer_kwargs']) filename = self.params['data_file'] datasets = self.params['dataset'] raw_0 = gp.ArrayKey('RAW_0') points_0 = gp.GraphKey('POINTS_0') locations_0 = gp.ArrayKey('LOCATIONS_0') emb_0 = gp.ArrayKey('EMBEDDING_0') raw_1 = gp.ArrayKey('RAW_1') points_1 = gp.GraphKey('POINTS_1') locations_1 = gp.ArrayKey('LOCATIONS_1') emb_1 = gp.ArrayKey('EMBEDDING_1') data = daisy.open_ds(filename, datasets[0]) source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape()) voxel_size = gp.Coordinate(data.voxel_size) # Get in and out shape in_shape = gp.Coordinate(model.in_shape) out_shape = gp.Coordinate(model.out_shape[2:]) is_2d = in_shape.dims() == 2 emb_voxel_size = voxel_size cv_loss = ContrastiveVolumeLoss(self.params['temperature'], self.params['point_density'], out_shape * voxel_size) # Add fake 3rd dim if is_2d: in_shape = gp.Coordinate((1, *in_shape)) out_shape = gp.Coordinate((1, *out_shape)) voxel_size = gp.Coordinate((1, *voxel_size)) source_roi = gp.Roi((0, *source_roi.get_offset()), (data.shape[0], *source_roi.get_shape())) in_shape = in_shape * voxel_size out_shape = out_shape * voxel_size logger.info(f"source roi: {source_roi}") logger.info(f"in_shape: {in_shape}") logger.info(f"out_shape: {out_shape}") logger.info(f"voxel_size: {voxel_size}") request = gp.BatchRequest() request.add(raw_0, in_shape) request.add(raw_1, in_shape) request.add(points_0, out_shape) request.add(points_1, out_shape) request[locations_0] = gp.ArraySpec(nonspatial=True) request[locations_1] = gp.ArraySpec(nonspatial=True) snapshot_request = gp.BatchRequest() snapshot_request[emb_0] = gp.ArraySpec(roi=request[points_0].roi) snapshot_request[emb_1] = gp.ArraySpec(roi=request[points_1].roi) random_point_generator = RandomPointGenerator( density=self.params['point_density'], repetitions=2) # Use volume to calculate probabilities, RandomSourceGenerator will # normalize volumes to probablilties probabilities = np.array([ np.product(daisy.open_ds(filename, dataset).shape) for dataset in datasets ]) random_source_generator = RandomSourceGenerator( num_sources=len(datasets), probabilities=probabilities, repetitions=2) array_sources = tuple( tuple( gp.ZarrSource( filename, {raw: dataset}, # fake 3D data array_specs={ raw: gp.ArraySpec(roi=source_roi, voxel_size=voxel_size, interpolatable=True) }) for dataset in datasets) for raw in [raw_0, raw_1]) # Choose a random dataset to pull from array_sources = \ tuple(arrays + RandomMultiBranchSource(random_source_generator) + gp.Normalize(raw, self.params['norm_factor']) + gp.Pad(raw, None) for raw, arrays in zip([raw_0, raw_1], array_sources)) point_sources = tuple( (RandomPointSource(points_0, random_point_generator=random_point_generator), RandomPointSource(points_1, random_point_generator=random_point_generator))) # Merge the point and array sources together. # There is one array and point source per branch. sources = tuple((array_source, point_source) + gp.MergeProvider() for array_source, point_source in zip( array_sources, point_sources)) sources = tuple( self._make_train_augmentation_pipeline(raw, source) for raw, source in zip([raw_0, raw_1], sources)) pipeline = (sources + gp.MergeProvider() + gp.Crop(raw_0, source_roi) + gp.Crop(raw_1, source_roi) + gp.RandomLocation() + PrepareBatch(raw_0, raw_1, points_0, points_1, locations_0, locations_1, is_2d) + RejectArray(ensure_nonempty=locations_0) + RejectArray(ensure_nonempty=locations_1)) if not is_2d: pipeline = (pipeline + AddChannelDim(raw_0) + AddChannelDim(raw_1)) pipeline = (pipeline + gp.PreCache() + gp.torch.Train( model, cv_loss, optimizer, inputs={ 'raw_0': raw_0, 'raw_1': raw_1 }, loss_inputs={ 'emb_0': emb_0, 'emb_1': emb_1, 'locations_0': locations_0, 'locations_1': locations_1 }, outputs={ 2: emb_0, 3: emb_1 }, array_specs={ emb_0: gp.ArraySpec(voxel_size=emb_voxel_size), emb_1: gp.ArraySpec(voxel_size=emb_voxel_size) }, checkpoint_basename=self.logdir + '/contrastive/checkpoints/model', save_every=self.params['save_every'], log_dir=self.logdir + "/contrastive", log_every=self.log_every)) if is_2d: pipeline = ( pipeline + # everything is 3D, except emb_0 and emb_1 AddSpatialDim(emb_0) + AddSpatialDim(emb_1)) pipeline = ( pipeline + # now everything is 3D RemoveChannelDim(raw_0) + RemoveChannelDim(raw_1) + RemoveChannelDim(emb_0) + RemoveChannelDim(emb_1) + gp.Snapshot(output_dir=self.logdir + '/contrastive/snapshots', output_filename='it{iteration}.hdf', dataset_names={ raw_0: 'raw_0', raw_1: 'raw_1', locations_0: 'locations_0', locations_1: 'locations_1', emb_0: 'emb_0', emb_1: 'emb_1' }, additional_request=snapshot_request, every=self.params['save_every']) + gp.PrintProfilingStats(every=500)) return pipeline, request