def create_train_pipeline(self, model): optimizer = self.params['optimizer'](model.parameters(), **self.params['optimizer_kwargs']) points = gp.ArrayKey('POINTS') predictions = gp.ArrayKey("PREDICTIONS") gt_labels = gp.ArrayKey('LABELS') request = gp.BatchRequest() # Because of PointsLabelsSource we can keep everything as nonspatial request[points] = gp.ArraySpec(nonspatial=True) request[predictions] = gp.ArraySpec(nonspatial=True) request[gt_labels] = gp.ArraySpec(nonspatial=True) pipeline = ( PointsLabelsSource(points, self.data, gt_labels, self.labels, 1) + gp.Stack(self.params['batch_size']) + gp.torch.Train( model, self.loss, optimizer, inputs={'points': points}, loss_inputs={ 0: predictions, 1: gt_labels }, outputs={0: predictions}, checkpoint_basename=self.logdir + '/checkpoints/model', save_every=self.params['save_every'], log_dir=self.logdir, log_every=self.log_every)) return pipeline, request
def predict(iteration,path_to_dataGP): input_size = (8, 96, 96) output_size = (4, 64, 64) amount_size = gp.Coordinate((2, 16, 16)) model = SpineUNet(crop_output='output_size') raw = gp.ArrayKey('RAW') affs_predicted = gp.ArrayKey('AFFS_PREDICTED') reference_request = gp.BatchRequest() reference_request.add(raw, input_size) reference_request.add(affs_predicted, output_size) source = gp.ZarrSource( path_to_dataGP, { raw: 'validate/sample1/raw' } ) with gp.build(source): source_roi = source.spec[raw].roi request = gp.BatchRequest() request[raw] = gp.ArraySpec(roi=source_roi) request[affs_predicted] = gp.ArraySpec(roi=source_roi) pipeline = ( source + gp.Pad(raw,amount_size) + gp.Normalize(raw) + # raw: (d, h, w) gp.Stack(1) + # raw: (1, d, h, w) AddChannelDim(raw) + # raw: (1, 1, d, h, w) gp_torch.Predict( model, inputs={'x': raw}, outputs={0: affs_predicted}, checkpoint=f'C:/Users/filip/spine_yodl/model_checkpoint_{iteration}') + RemoveChannelDim(raw) + RemoveChannelDim(raw) + RemoveChannelDim(affs_predicted) + # raw: (d, h, w) # affs_predicted: (3, d, h, w) gp.Scan(reference_request) ) with gp.build(pipeline): prediction = pipeline.request_batch(request) return prediction[raw].data, prediction[affs_predicted].data
seg: gp.ArraySpec(interpolatable=False) }) sourceC = gp.ZarrSource('../data/cropped_sample_C.zarr', { raw: 'raw', seg: 'segmentation' }, { raw: gp.ArraySpec(interpolatable=True), seg: gp.ArraySpec(interpolatable=False) }) source = (sourceA, sourceB, sourceC) + gp.MergeProvider() print(source) normalize = gp.Normalize(raw) simulate_cages = SimulateCages(raw, seg, out_cage_map, out_density_map, psf, (min_density, max_density), [cage1], 0.5) add_channel_dim = gp.Stack(1) stack = gp.Stack(5) prepare_data = PrepareTrainingData() train = gp.torch.Train(model, loss, optimizer, inputs={'input': raw}, loss_inputs={ 0: prediction, 1: out_cage_map }, outputs={0: prediction}) pipeline = (source + normalize + gp.RandomLocation() + simulate_cages + add_channel_dim + stack + prepare_data + gp.PreCache(num_workers=40) + train + gp.PrintProfilingStats(every=1))
def create_train_pipeline(self, model): print(f"Creating training pipeline with batch size \ {self.params['batch_size']}") filename = self.params['data_file'] raw_dataset = self.params['dataset']['train']['raw'] gt_dataset = self.params['dataset']['train']['gt'] optimizer = self.params['optimizer'](model.parameters(), **self.params['optimizer_kwargs']) raw = gp.ArrayKey('RAW') gt_labels = gp.ArrayKey('LABELS') points = gp.GraphKey("POINTS") locations = gp.ArrayKey("LOCATIONS") predictions = gp.ArrayKey('PREDICTIONS') emb = gp.ArrayKey('EMBEDDING') raw_data = daisy.open_ds(filename, raw_dataset) source_roi = gp.Roi(raw_data.roi.get_offset(), raw_data.roi.get_shape()) source_voxel_size = gp.Coordinate(raw_data.voxel_size) out_voxel_size = gp.Coordinate(raw_data.voxel_size) # Get in and out shape in_shape = gp.Coordinate(model.in_shape) out_roi = gp.Coordinate(model.base_encoder.out_shape[2:]) is_2d = in_shape.dims() == 2 in_shape = in_shape * out_voxel_size out_roi = out_roi * out_voxel_size out_shape = gp.Coordinate( (self.params["num_points"], *model.out_shape[2:])) context = (in_shape - out_roi) / 2 gt_labels_out_shape = out_roi # Add fake 3rd dim if is_2d: source_voxel_size = gp.Coordinate((1, *source_voxel_size)) source_roi = gp.Roi((0, *source_roi.get_offset()), (raw_data.shape[0], *source_roi.get_shape())) context = gp.Coordinate((0, *context)) gt_labels_out_shape = (1, *gt_labels_out_shape) points_roi = out_voxel_size * tuple((*self.params["point_roi"], )) points_pad = (0, *points_roi) context = gp.Coordinate((0, None, None)) else: points_roi = source_voxel_size * tuple(self.params["point_roi"]) points_pad = points_roi context = gp.Coordinate((None, None, None)) logger.info(f"source roi: {source_roi}") logger.info(f"in_shape: {in_shape}") logger.info(f"out_shape: {out_shape}") logger.info(f"voxel_size: {out_voxel_size}") logger.info(f"context: {context}") logger.info(f"out_voxel_size: {out_voxel_size}") request = gp.BatchRequest() request.add(raw, in_shape) request.add(points, points_roi) request.add(gt_labels, out_roi) request[locations] = gp.ArraySpec(nonspatial=True) request[predictions] = gp.ArraySpec(nonspatial=True) snapshot_request = gp.BatchRequest() snapshot_request[emb] = gp.ArraySpec( roi=gp.Roi((0, ) * in_shape.dims(), gp.Coordinate((*model.base_encoder.out_shape[2:], )) * out_voxel_size)) source = ( (gp.ZarrSource(filename, { raw: raw_dataset, gt_labels: gt_dataset }, array_specs={ raw: gp.ArraySpec(roi=source_roi, voxel_size=source_voxel_size, interpolatable=True), gt_labels: gp.ArraySpec(roi=source_roi, voxel_size=source_voxel_size) }), PointsLabelsSource(points, self.data, scale=source_voxel_size)) + gp.MergeProvider() + gp.Pad(raw, context) + gp.Pad(gt_labels, context) + gp.Pad(points, points_pad) + gp.RandomLocation(ensure_nonempty=points) + gp.Normalize(raw, self.params['norm_factor']) # raw : (source_roi) # gt_labels: (source_roi) # points : (c=1, source_locations_shape) # If 2d then source_roi = (1, input_shape) in order to select a RL ) source = self._augmentation_pipeline(raw, source) pipeline = ( source + # Batches seem to be rejected because points are chosen near the # edge of the points ROI and the augmentations remove them. # TODO: Figure out if this is an actual issue, and if anything can # be done. gp.Reject(ensure_nonempty=points) + SetDtype(gt_labels, np.int64) + # raw : (source_roi) # gt_labels: (source_roi) # points : (c=1, source_locations_shape) AddChannelDim(raw) + AddChannelDim(gt_labels) # raw : (c=1, source_roi) # gt_labels: (c=2, source_roi) # points : (c=1, source_locations_shape) ) if is_2d: pipeline = ( # Remove extra dim the 2d roi had pipeline + RemoveSpatialDim(raw) + RemoveSpatialDim(gt_labels) + RemoveSpatialDim(points) # raw : (c=1, roi) # gt_labels: (c=1, roi) # points : (c=1, locations_shape) ) pipeline = ( pipeline + FillLocations(raw, points, locations, is_2d=False, max_points=1) + gp.Stack(self.params['batch_size']) + gp.PreCache() + # raw : (b, c=1, roi) # gt_labels: (b, c=1, roi) # locations: (b, c=1, locations_shape) # (which is what train requires) gp.torch.Train( model, self.loss, optimizer, inputs={ 'raw': raw, 'points': locations }, loss_inputs={ 0: predictions, 1: gt_labels, 2: locations }, outputs={ 0: predictions, 1: emb }, array_specs={ predictions: gp.ArraySpec(nonspatial=True), emb: gp.ArraySpec(voxel_size=out_voxel_size) }, checkpoint_basename=self.logdir + '/checkpoints/model', save_every=self.params['save_every'], log_dir=self.logdir, log_every=self.log_every) + # everything is 2D at this point, plus extra dimensions for # channels and batch # raw : (b, c=1, roi) # gt_labels : (b, c=1, roi) # predictions: (b, num_points) gp.Snapshot(output_dir=self.logdir + '/snapshots', output_filename='it{iteration}.hdf', dataset_names={ raw: 'raw', gt_labels: 'gt_labels', predictions: 'predictions', emb: 'emb' }, additional_request=snapshot_request, every=self.params['save_every']) + InspectBatch('END') + gp.PrintProfilingStats(every=500)) return pipeline, request
def create_train_pipeline(self, model): print( f"Creating training pipeline with batch size {self.params['batch_size']}" ) filename = self.params['data_file'] raw_dataset = self.params['dataset']['train']['raw'] gt_dataset = self.params['dataset']['train']['gt'] optimizer = self.params['optimizer'](model.parameters(), **self.params['optimizer_kwargs']) raw = gp.ArrayKey('RAW') gt_labels = gp.ArrayKey('LABELS') gt_aff = gp.ArrayKey('AFFINITIES') predictions = gp.ArrayKey('PREDICTIONS') emb = gp.ArrayKey('EMBEDDING') raw_data = daisy.open_ds(filename, raw_dataset) source_roi = gp.Roi(raw_data.roi.get_offset(), raw_data.roi.get_shape()) source_voxel_size = gp.Coordinate(raw_data.voxel_size) out_voxel_size = gp.Coordinate(raw_data.voxel_size) # Get in and out shape in_shape = gp.Coordinate(model.in_shape) out_shape = gp.Coordinate(model.out_shape[2:]) is_2d = in_shape.dims() == 2 in_shape = in_shape * out_voxel_size out_shape = out_shape * out_voxel_size context = (in_shape - out_shape) / 2 gt_labels_out_shape = out_shape # Add fake 3rd dim if is_2d: source_voxel_size = gp.Coordinate((1, *source_voxel_size)) source_roi = gp.Roi((0, *source_roi.get_offset()), (raw_data.shape[0], *source_roi.get_shape())) context = gp.Coordinate((0, *context)) aff_neighborhood = [[0, -1, 0], [0, 0, -1]] gt_labels_out_shape = (1, *gt_labels_out_shape) else: aff_neighborhood = [[-1, 0, 0], [0, -1, 0], [0, 0, -1]] logger.info(f"source roi: {source_roi}") logger.info(f"in_shape: {in_shape}") logger.info(f"out_shape: {out_shape}") logger.info(f"voxel_size: {out_voxel_size}") logger.info(f"context: {context}") request = gp.BatchRequest() request.add(raw, in_shape) request.add(gt_aff, out_shape) request.add(predictions, out_shape) snapshot_request = gp.BatchRequest() snapshot_request[emb] = gp.ArraySpec( roi=gp.Roi((0, ) * in_shape.dims(), gp.Coordinate((*model.base_encoder.out_shape[2:], )) * out_voxel_size)) snapshot_request[gt_labels] = gp.ArraySpec( roi=gp.Roi(context, gt_labels_out_shape)) source = ( gp.ZarrSource(filename, { raw: raw_dataset, gt_labels: gt_dataset }, array_specs={ raw: gp.ArraySpec(roi=source_roi, voxel_size=source_voxel_size, interpolatable=True), gt_labels: gp.ArraySpec(roi=source_roi, voxel_size=source_voxel_size) }) + gp.Normalize(raw, self.params['norm_factor']) + gp.Pad(raw, context) + gp.Pad(gt_labels, context) + gp.RandomLocation() # raw : (l=1, h, w) # gt_labels: (l=1, h, w) ) source = self._augmentation_pipeline(raw, source) pipeline = ( source + # raw : (l=1, h, w) # gt_labels: (l=1, h, w) gp.AddAffinities(aff_neighborhood, gt_labels, gt_aff) + SetDtype(gt_aff, np.float32) + # raw : (l=1, h, w) # gt_aff : (c=2, l=1, h, w) AddChannelDim(raw) # raw : (c=1, l=1, h, w) # gt_aff : (c=2, l=1, h, w) ) if is_2d: pipeline = ( pipeline + RemoveSpatialDim(raw) + RemoveSpatialDim(gt_aff) # raw : (c=1, h, w) # gt_aff : (c=2, h, w) ) pipeline = ( pipeline + gp.Stack(self.params['batch_size']) + gp.PreCache() + # raw : (b, c=1, h, w) # gt_aff : (b, c=2, h, w) # (which is what train requires) gp.torch.Train( model, self.loss, optimizer, inputs={'raw': raw}, loss_inputs={ 0: predictions, 1: gt_aff }, outputs={ 0: predictions, 1: emb }, array_specs={ predictions: gp.ArraySpec(voxel_size=out_voxel_size), }, checkpoint_basename=self.logdir + '/checkpoints/model', save_every=self.params['save_every'], log_dir=self.logdir, log_every=self.log_every) + # everything is 2D at this point, plus extra dimensions for # channels and batch # raw : (b, c=1, h, w) # gt_aff : (b, c=2, h, w) # predictions: (b, c=2, h, w) # Crop GT to look at labels gp.Crop(gt_labels, gp.Roi(context, gt_labels_out_shape)) + gp.Snapshot(output_dir=self.logdir + '/snapshots', output_filename='it{iteration}.hdf', dataset_names={ raw: 'raw', gt_labels: 'gt_labels', predictions: 'predictions', gt_aff: 'gt_aff', emb: 'emb' }, additional_request=snapshot_request, every=self.params['save_every']) + gp.PrintProfilingStats(every=500)) return pipeline, request
def train(until): model = SpineUNet() loss = torch.nn.BCELoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.001) input_size = (8, 96, 96) raw = gp.ArrayKey('RAW') labels = gp.ArrayKey('LABELS') affs = gp.ArrayKey('AFFS') affs_predicted = gp.ArrayKey('AFFS_PREDICTED') pipeline = ( ( gp.ZarrSource( 'data/20200201.zarr', { raw: 'train/sample1/raw', labels: 'train/sample1/labels' }), gp.ZarrSource( 'data/20200201.zarr', { raw: 'train/sample2/raw', labels: 'train/sample2/labels' }), gp.ZarrSource( 'data/20200201.zarr', { raw: 'train/sample3/raw', labels: 'train/sample3/labels' }) ) + gp.RandomProvider() + gp.Normalize(raw) + gp.RandomLocation() + gp.SimpleAugment(transpose_only=(1, 2)) + gp.ElasticAugment((2, 10, 10), (0.0, 0.5, 0.5), [0, math.pi]) + gp.AddAffinities( [(1, 0, 0), (0, 1, 0), (0, 0, 1)], labels, affs) + gp.Normalize(affs, factor=1.0) + #gp.PreCache(num_workers=1) + # raw: (d, h, w) # affs: (3, d, h, w) gp.Stack(1) + # raw: (1, d, h, w) # affs: (1, 3, d, h, w) AddChannelDim(raw) + # raw: (1, 1, d, h, w) # affs: (1, 3, d, h, w) gp_torch.Train( model, loss, optimizer, inputs={'x': raw}, outputs={0: affs_predicted}, loss_inputs={0: affs_predicted, 1: affs}, save_every=10000) + RemoveChannelDim(raw) + RemoveChannelDim(raw) + RemoveChannelDim(affs) + RemoveChannelDim(affs_predicted) + # raw: (d, h, w) # affs: (3, d, h, w) # affs_predicted: (3, d, h, w) gp.Snapshot( { raw: 'raw', labels: 'labels', affs: 'affs', affs_predicted: 'affs_predicted' }, every=500, output_filename='iteration_{iteration}.hdf') ) request = gp.BatchRequest() request.add(raw, input_size) request.add(labels, input_size) request.add(affs, input_size) request.add(affs_predicted, input_size) with gp.build(pipeline): for i in range(until): pipeline.request_batch(request)
def create_pipeline_2d(task, predictor, optimizer, batch_size, outdir, snapshot_every): raw_channels = task.data.raw.num_channels filename = task.data.raw.train.filename input_shape = predictor.input_shape output_shape = predictor.output_shape dataset_shape = task.data.raw.train.shape dataset_roi = task.data.raw.train.roi voxel_size = task.data.raw.train.voxel_size # switch to world units input_size = voxel_size * input_shape output_size = voxel_size * output_shape raw = gp.ArrayKey('RAW') gt = gp.ArrayKey('GT') target = gp.ArrayKey('TARGET') weights = gp.ArrayKey('WEIGHTS') prediction = gp.ArrayKey('PREDICTION') channel_dims = 0 if raw_channels == 1 else 1 data_dims = len(dataset_shape) - channel_dims if data_dims == 3: num_samples = dataset_shape[0] sample_shape = dataset_shape[channel_dims + 1:] else: raise RuntimeError("For 2D training, please provide a 3D array where " "the first dimension indexes the samples.") sample_shape = gp.Coordinate(sample_shape) sample_size = sample_shape * voxel_size # overwrite source ROI to treat samples as z dimension spec = gp.ArraySpec(roi=gp.Roi((0, ) + dataset_roi.get_begin(), (num_samples, ) + sample_size), voxel_size=(1, ) + voxel_size) sources = (task.data.raw.train.get_source(raw, overwrite_spec=spec), task.data.gt.train.get_source(gt, overwrite_spec=spec)) pipeline = sources + gp.MergeProvider() pipeline += gp.Pad(raw, None) pipeline += gp.Normalize(raw) # raw: ([c,] d=1, h, w) # gt: ([c,] d=1, h, w) pipeline += gp.RandomLocation() # raw: ([c,] d=1, h, w) # gt: ([c,] d=1, h, w) for augmentation in eval(task.augmentations): pipeline += augmentation pipeline += predictor.add_target(gt, target) # (don't care about gt anymore) # raw: ([c,] d=1, h, w) # target: ([c,] d=1, h, w) weights_node = task.loss.add_weights(target, weights) if weights_node: pipeline += weights_node loss_inputs = {0: prediction, 1: target, 2: weights} else: loss_inputs = {0: prediction, 1: target} # raw: ([c,] d=1, h, w) # target: ([c,] d=1, h, w) # [weights: ([c,] d=1, h, w)] # get rid of z dim: pipeline += Squash(dim=-3) # raw: ([c,] h, w) # target: ([c,] h, w) # [weights: ([c,] h, w)] if channel_dims == 0: pipeline += AddChannelDim(raw) # raw: (c, h, w) # target: ([c,] h, w) # [weights: ([c,] h, w)] pipeline += gp.PreCache() pipeline += gp.Stack(batch_size) # raw: (b, c, h, w) # target: (b, [c,] h, w) # [weights: (b, [c,] h, w)] pipeline += gp_torch.Train(model=predictor, loss=task.loss, optimizer=optimizer, inputs={'x': raw}, loss_inputs=loss_inputs, outputs={0: prediction}, save_every=1e6) # raw: (b, c, h, w) # target: (b, [c,] h, w) # [weights: (b, [c,] h, w)] # prediction: (b, [c,] h, w) if snapshot_every > 0: # get channels first pipeline += TransposeDims(raw, (1, 0, 2, 3)) if predictor.target_channels > 0: pipeline += TransposeDims(target, (1, 0, 2, 3)) if weights_node: pipeline += TransposeDims(weights, (1, 0, 2, 3)) if predictor.prediction_channels > 0: pipeline += TransposeDims(prediction, (1, 0, 2, 3)) # raw: (c, b, h, w) # target: ([c,] b, h, w) # [weights: ([c,] b, h, w)] # prediction: ([c,] b, h, w) if channel_dims == 0: pipeline += RemoveChannelDim(raw) # raw: ([c,] b, h, w) # target: ([c,] b, h, w) # [weights: ([c,] b, h, w)] # prediction: ([c,] b, h, w) pipeline += gp.Snapshot(dataset_names={ raw: 'raw', target: 'target', prediction: 'prediction', weights: 'weights' }, every=snapshot_every, output_dir=os.path.join(outdir, 'snapshots'), output_filename="{iteration}.hdf") pipeline += gp.PrintProfilingStats(every=100) request = gp.BatchRequest() request.add(raw, input_size) request.add(gt, output_size) request.add(target, output_size) if weights_node: request.add(weights, output_size) request.add(prediction, output_size) return pipeline, request
def create_pipeline_3d(task, predictor, optimizer, batch_size, outdir, snapshot_every): raw_channels = max(1, task.data.raw.num_channels) input_shape = predictor.input_shape output_shape = predictor.output_shape voxel_size = task.data.raw.train.voxel_size # switch to world units input_size = voxel_size * input_shape output_size = voxel_size * output_shape raw = gp.ArrayKey('RAW') gt = gp.ArrayKey('GT') target = gp.ArrayKey('TARGET') weights = gp.ArrayKey('WEIGHTS') prediction = gp.ArrayKey('PREDICTION') channel_dims = 0 if raw_channels == 1 else 1 num_samples = task.data.raw.train.num_samples assert num_samples == 0, ( "Multiple samples for 3D training not yet implemented") sources = (task.data.raw.train.get_source(raw), task.data.gt.train.get_source(gt)) pipeline = sources + gp.MergeProvider() pipeline += gp.Pad(raw, None) # raw: ([c,] d, h, w) # gt: ([c,] d, h, w) pipeline += gp.Normalize(raw) # raw: ([c,] d, h, w) # gt: ([c,] d, h, w) pipeline += gp.RandomLocation() # raw: ([c,] d, h, w) # gt: ([c,] d, h, w) for augmentation in eval(task.augmentations): pipeline += augmentation pipeline += predictor.add_target(gt, target) # (don't care about gt anymore) # raw: ([c,] d, h, w) # target: ([c,] d, h, w) weights_node = task.loss.add_weights(target, weights) if weights_node: pipeline += weights_node loss_inputs = {0: prediction, 1: target, 2: weights} else: loss_inputs = {0: prediction, 1: target} # raw: ([c,] d, h, w) # target: ([c,] d, h, w) # [weights: ([c,] d, h, w)] if channel_dims == 0: pipeline += AddChannelDim(raw) # raw: (c, d, h, w) # target: ([c,] d, h, w) # [weights: ([c,] d, h, w)] pipeline += gp.PreCache() pipeline += gp.Stack(batch_size) # raw: (b, c, d, h, w) # target: (b, [c,] d, h, w) # [weights: (b, [c,] d, h, w)] pipeline += gp_torch.Train(model=predictor, loss=task.loss, optimizer=optimizer, inputs={'x': raw}, loss_inputs=loss_inputs, outputs={0: prediction}, save_every=1e6) # raw: (b, c, d, h, w) # target: (b, [c,] d, h, w) # [weights: (b, [c,] d, h, w)] # prediction: (b, [c,] d, h, w) if snapshot_every > 0: # get channels first pipeline += TransposeDims(raw, (1, 0, 2, 3, 4)) if predictor.target_channels > 0: pipeline += TransposeDims(target, (1, 0, 2, 3, 4)) if weights_node: pipeline += TransposeDims(weights, (1, 0, 2, 3, 4)) if predictor.prediction_channels > 0: pipeline += TransposeDims(prediction, (1, 0, 2, 3, 4)) # raw: (c, b, d, h, w) # target: ([c,] b, d, h, w) # [weights: ([c,] b, d, h, w)] # prediction: ([c,] b, d, h, w) if channel_dims == 0: pipeline += RemoveChannelDim(raw) # raw: ([c,] b, d, h, w) # target: (c, b, d, h, w) # [weights: ([c,] b, d, h, w)] # prediction: (c, b, d, h, w) pipeline += gp.Snapshot(dataset_names={ raw: 'raw', target: 'target', prediction: 'prediction', weights: 'weights' }, every=snapshot_every, output_dir=os.path.join(outdir, 'snapshots'), output_filename="{iteration}.hdf") pipeline += gp.PrintProfilingStats(every=100) request = gp.BatchRequest() request.add(raw, input_size) request.add(gt, output_size) request.add(target, output_size) if weights_node: request.add(weights, output_size) request.add(prediction, output_size) return pipeline, request
def build_pipeline( data_dir, model, save_every, batch_size, input_size, output_size, raw, labels, affs, affs_predicted, lr=1e-5): dataset_shape = zarr.open(str(data_dir))['train/raw'].shape num_samples = dataset_shape[0] sample_size = dataset_shape[1:] loss = torch.nn.MSELoss() optimizer = RAdam(model.parameters(), lr=lr) pipeline = ( gp.ZarrSource( data_dir, { raw: 'train/raw', labels: 'train/gt' }, array_specs={ raw: gp.ArraySpec( roi=gp.Roi((0, 0, 0), (num_samples,) + sample_size), voxel_size=(1, 1, 1)), labels: gp.ArraySpec( roi=gp.Roi((0, 0, 0), (num_samples,) + sample_size), voxel_size=(1, 1, 1)) }) + # raw: (d=1, h, w) # labels: (d=1, fmap_inc_factors=5h, w) gp.RandomLocation() + # raw: (d=1, h, w) # labels: (d=1, h, w) gp.AddAffinities( affinity_neighborhood=[(0, 1, 0), (0, 0, 1)], labels=labels, affinities=affs) + gp.Normalize(affs, factor=1.0) + # raw: (d=1, h, w) # affs: (c=2, d=1, h, w) Squash(dim=-3) + # get rid of z dim # raw: (h, w) # affs: (c=2, h, w) AddChannelDim(raw) + # raw: (c=1, h, w) # affs: (c=2, h, w) gp.PreCache() + gp.Stack(batch_size) + # raw: (b=10, c=1, h, w) # affs: (b=10, c=2, h, w) Train( model=model, loss=loss, optimizer=optimizer, inputs={'x': raw}, target=affs, output=affs_predicted, save_every=save_every, log_dir='log') + # raw: (b=10, c=1, h, w) # affs: (b=10, c=2, h, w) # affs_predicted: (b=10, c=2, h, w) TransposeDims(raw,(1, 0, 2, 3)) + TransposeDims(affs,(1, 0, 2, 3)) + TransposeDims(affs_predicted,(1, 0, 2, 3)) + # raw: (c=1, b=10, h, w) # affs: (c=2, b=10, h, w) # affs_predicted: (c=2, b=10, h, w) RemoveChannelDim(raw) + # raw: (b=10, h, w) # affs: (c=2, b=10, h, w) # affs_predicted: (c=2, b=10, h, w) gp.Snapshot( dataset_names={ raw: 'raw', labels: 'labels', affs: 'affs', affs_predicted: 'affs_predicted' }, every=100) + gp.PrintProfilingStats(every=100) ) return pipeline
def random_point_pairs_pipeline(model, loss, optimizer, dataset, augmentation_parameters, point_density, out_dir, normalize_factor=None, checkpoint_interval=5000, snapshot_interval=5000): raw_0 = gp.ArrayKey('RAW_0') points_0 = gp.GraphKey('POINTS_0') locations_0 = gp.ArrayKey('LOCATIONS_0') emb_0 = gp.ArrayKey('EMBEDDING_0') raw_1 = gp.ArrayKey('RAW_1') points_1 = gp.GraphKey('POINTS_1') locations_1 = gp.ArrayKey('LOCATIONS_1') emb_1 = gp.ArrayKey('EMBEDDING_1') # TODO parse this key from somewhere key = 'train/raw/0' data = daisy.open_ds(dataset.filename, key) source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape()) voxel_size = gp.Coordinate(data.voxel_size) emb_voxel_size = voxel_size # Get in and out shape in_shape = gp.Coordinate(model.in_shape) out_shape = gp.Coordinate(model.out_shape) logger.info(f"source roi: {source_roi}") logger.info(f"in_shape: {in_shape}") logger.info(f"out_shape: {out_shape}") logger.info(f"voxel_size: {voxel_size}") request = gp.BatchRequest() request.add(raw_0, in_shape) request.add(raw_1, in_shape) request.add(points_0, out_shape) request.add(points_1, out_shape) request[locations_0] = gp.ArraySpec(nonspatial=True) request[locations_1] = gp.ArraySpec(nonspatial=True) snapshot_request = gp.BatchRequest() snapshot_request[emb_0] = gp.ArraySpec(roi=request[points_0].roi) snapshot_request[emb_1] = gp.ArraySpec(roi=request[points_1].roi) # Let's hardcode this for now # TODO read actual number from zarr file keys n_samples = 447 batch_size = 1 dim = 2 padding = (100, 100) sources = [] for i in range(n_samples): ds_key = f'train/raw/{i}' image_sources = tuple( gp.ZarrSource( dataset.filename, {raw: ds_key}, {raw: gp.ArraySpec(interpolatable=True, voxel_size=(1, 1))}) + gp.Pad(raw, None) for raw in [raw_0, raw_1]) random_point_generator = RandomPointGenerator(density=point_density, repetitions=2) point_sources = tuple( (RandomPointSource(points_0, dim, random_point_generator=random_point_generator), RandomPointSource(points_1, dim, random_point_generator=random_point_generator))) # TODO: get augmentation parameters from some config file! points_and_image_sources = tuple( (img_source, point_source) + gp.MergeProvider() + \ gp.SimpleAugment() + \ gp.ElasticAugment( spatial_dims=2, control_point_spacing=(10, 10), jitter_sigma=(0.0, 0.0), rotation_interval=(0, math.pi/2)) + \ gp.IntensityAugment(r, scale_min=0.8, scale_max=1.2, shift_min=-0.2, shift_max=0.2, clip=False) + \ gp.NoiseAugment(r, var=0.01, clip=False) for r, img_source, point_source in zip([raw_0, raw_1], image_sources, point_sources)) sample_source = points_and_image_sources + gp.MergeProvider() data = daisy.open_ds(dataset.filename, ds_key) source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape()) sample_source += gp.Crop(raw_0, source_roi) sample_source += gp.Crop(raw_1, source_roi) sample_source += gp.Pad(raw_0, padding) sample_source += gp.Pad(raw_1, padding) sample_source += gp.RandomLocation() sources.append(sample_source) sources = tuple(sources) pipeline = sources + gp.RandomProvider() pipeline += gp.Unsqueeze([raw_0, raw_1]) pipeline += PrepareBatch(raw_0, raw_1, points_0, points_1, locations_0, locations_1) # How does prepare batch relate to Stack????? pipeline += RejectArray(ensure_nonempty=locations_1) pipeline += RejectArray(ensure_nonempty=locations_0) # batch content # raw_0: (1, h, w) # raw_1: (1, h, w) # locations_0: (n, 2) # locations_1: (n, 2) pipeline += gp.Stack(batch_size) # batch content # raw_0: (b, 1, h, w) # raw_1: (b, 1, h, w) # locations_0: (b, n, 2) # locations_1: (b, n, 2) pipeline += gp.PreCache(num_workers=10) pipeline += gp.torch.Train( model, loss, optimizer, inputs={ 'raw_0': raw_0, 'raw_1': raw_1 }, loss_inputs={ 'emb_0': emb_0, 'emb_1': emb_1, 'locations_0': locations_0, 'locations_1': locations_1 }, outputs={ 2: emb_0, 3: emb_1 }, array_specs={ emb_0: gp.ArraySpec(voxel_size=emb_voxel_size), emb_1: gp.ArraySpec(voxel_size=emb_voxel_size) }, checkpoint_basename=os.path.join(out_dir, 'model'), save_every=checkpoint_interval) pipeline += gp.Snapshot( { raw_0: 'raw_0', raw_1: 'raw_1', emb_0: 'emb_0', emb_1: 'emb_1', # locations_0 : 'locations_0', # locations_1 : 'locations_1', }, every=snapshot_interval, additional_request=snapshot_request) return pipeline, request
def build_batch_provider(self, datasets, model, task, snapshot_container=None): input_shape = Coordinate(model.input_shape) output_shape = Coordinate(model.output_shape) # get voxel sizes raw_voxel_size = datasets[0].raw.voxel_size prediction_voxel_size = model.scale(raw_voxel_size) # define input and output size: # switch to world units input_size = raw_voxel_size * input_shape output_size = prediction_voxel_size * output_shape # padding of groundtruth/mask gt_mask_padding = output_size + task.predictor.padding(prediction_voxel_size) # define keys: raw_key = gp.ArrayKey("RAW") gt_key = gp.ArrayKey("GT") mask_key = gp.ArrayKey("MASK") target_key = gp.ArrayKey("TARGET") weight_key = gp.ArrayKey("WEIGHT") # Get source nodes dataset_sources = [] for dataset in datasets: raw_source = DaCapoArraySource(dataset.raw, raw_key) raw_source += gp.Pad(raw_key, None, 0) gt_source = DaCapoArraySource(dataset.gt, gt_key) gt_source += gp.Pad(gt_key, gt_mask_padding, 0) if dataset.mask is not None: mask_source = DaCapoArraySource(dataset.mask, mask_key) else: # Always provide a mask. By default it is simply an array # of ones with the same shape/roi as gt. Avoids making us # specially handle no mask case and allows padding of the # ground truth without worrying about training on incorrect # data. mask_source = DaCapoArraySource(OnesArray.like(dataset.gt), mask_key) mask_source += gp.Pad(mask_key, gt_mask_padding, 0) array_sources = [raw_source, gt_source, mask_source] dataset_source = ( tuple(array_sources) + gp.MergeProvider() + gp.RandomLocation() ) dataset_sources.append(dataset_source) pipeline = tuple(dataset_sources) + gp.RandomProvider() for augment in self.augments: pipeline += augment.node(raw_key, gt_key, mask_key) pipeline += gp.Reject(mask_key, min_masked=self.min_masked) # Add predictor nodes to pipeline pipeline += DaCapoTargetFilter( task.predictor, gt_key=gt_key, target_key=target_key, weights_key=weight_key, mask_key=mask_key, ) # Trainer attributes: if self.num_data_fetchers > 1: pipeline += gp.PreCache(num_workers=self.num_data_fetchers) # stack to create a batch dimension pipeline += gp.Stack(self.batch_size) # print profiling stats pipeline += gp.PrintProfilingStats(every=self.print_profiling) # generate request for all necessary inputs to training request = gp.BatchRequest() request.add(raw_key, input_size) request.add(target_key, output_size) request.add(weight_key, output_size) # request additional keys for snapshots request.add(gt_key, output_size) request.add(mask_key, output_size) self._request = request self._pipeline = pipeline self._raw_key = raw_key self._gt_key = gt_key self._mask_key = mask_key self._weight_key = weight_key self._target_key = target_key self._loss = task.loss self.snapshot_container = snapshot_container
def create_pipeline_3d( task, data, predictor, optimizer, batch_size, outdir, snapshot_every ): raw_channels = max(1, data.raw.num_channels) input_shape = gp.Coordinate(task.model.input_shape) output_shape = gp.Coordinate(task.model.output_shape) voxel_size = data.raw.train.voxel_size task.predictor = task.predictor.to("cuda") # switch to world units input_size = voxel_size * input_shape output_size = voxel_size * output_shape raw = gp.ArrayKey("RAW") gt = gp.ArrayKey("GT") mask = gp.ArrayKey("MASK") target = gp.ArrayKey("TARGET") weights = gp.ArrayKey("WEIGHTS") model_outputs = gp.ArrayKey("MODEL_OUTPUTS") model_output_grads = gp.ArrayKey("MODEL_OUT_GRAD") prediction = gp.ArrayKey("PREDICTION") pred_gradients = gp.ArrayKey("PRED_GRADIENTS") snapshot_dataset_names = { raw: "raw", model_outputs: "model_outputs", model_output_grads: "model_out_grad", target: "target", prediction: "prediction", pred_gradients: "pred_gradients", weights: "weights", } aux_keys = {} aux_grad_keys = {} for name, _, _ in task.aux_tasks: aux_keys[name] = ( gp.ArrayKey(f"{name.upper()}_PREDICTION"), gp.ArrayKey(f"{name.upper()}_TARGET"), None, ) aux_grad_keys[name] = gp.ArrayKey(f"{name.upper()}_PRED_GRAD") aux_pred, aux_target, _ = aux_keys[name] snapshot_dataset_names[aux_pred] = f"{name}_pred" snapshot_dataset_names[aux_target] = f"{name}_target" aux_grad = aux_grad_keys[name] snapshot_dataset_names[aux_grad] = f"{name}_aux_grad" channel_dims = 0 if raw_channels == 1 else 1 num_samples = data.raw.train.num_samples assert num_samples == 0, "Multiple samples for 3D training not yet implemented" sources = (data.raw.train.get_source(raw), data.gt.train.get_source(gt)) pipeline = sources + gp.MergeProvider() pipeline += gp.Pad(raw, input_shape / 2 * voxel_size) # pipeline += gp.Pad(gt, input_shape / 2 * voxel_size) # raw: ([c,] d, h, w) # gt: ([c,] d, h, w) pipeline += gp.Normalize(raw) mask_node = task.loss.add_mask(gt, mask) if mask_node is not None: pipeline += mask_node pipeline += gp.RandomLocation(min_masked=1e-6, mask=mask) else: # raw: ([c,] d, h, w) # gt: ([c,] d, h, w) pipeline += gp.RandomLocation() # raw: ([c,] d, h, w) # gt: ([c,] d, h, w) for augmentation in eval(task.augmentations): pipeline += augmentation pipeline += predictor.add_target(gt, target) # (don't care about gt anymore) # raw: ([c,] d, h, w) # target: ([c,] d, h, w) weights_node = task.loss.add_weights(target, weights) loss_inputs = [] if weights_node: pipeline += weights_node loss_inputs.append({0: prediction, 1: target, 2: weights}) else: loss_inputs.append({0: prediction, 1: target}) head_outputs = [] head_gradients = [] for name, aux_predictor, aux_loss in task.aux_tasks: aux_prediction, aux_target, aux_weights = aux_keys[name] pipeline += aux_predictor.add_target(gt, aux_target) aux_weights_node = aux_loss.add_weights(aux_target, aux_weights) if aux_weights_node: aux_weights = gp.ArrayKey(f"{name.upper()}_WEIGHTS") aux_keys[name] = ( aux_prediction, aux_target, aux_weights, ) pipeline += aux_weights_node loss_inputs.append({0: aux_prediction, 1: aux_target, 2: aux_weights}) snapshot_dataset_names[aux_weights] = f"{name}_weights" else: loss_inputs.append({0: aux_prediction, 1: aux_target}) head_outputs.append({0: aux_prediction}) aux_pred_gradient = aux_grad_keys[name] head_gradients.append({0: aux_pred_gradient}) # raw: ([c,] d, h, w) # target: ([c,] d, h, w) # [weights: ([c,] d, h, w)] if channel_dims == 0: pipeline += AddChannelDim(raw) # raw: (c, d, h, w) # target: ([c,] d, h, w) # [weights: ([c,] d, h, w)] pipeline += gp.PreCache() pipeline += gp.Stack(batch_size) # raw: (b, c, d, h, w) # target: (b, [c,] d, h, w) # [weights: (b, [c,] d, h, w)] pipeline += Train( model=task.model, heads=[("opt", predictor)] + [(name, aux_pred) for name, aux_pred, _ in task.aux_tasks], losses=[task.loss] + [loss for _, _, loss in task.aux_tasks], optimizer=optimizer, inputs={"x": raw}, outputs={0: model_outputs}, head_outputs=[{0: prediction}] + head_outputs, loss_inputs=loss_inputs, gradients=[{0: model_output_grads}, {0: pred_gradients}] + head_gradients, save_every=1e6, ) # raw: (b, c, d, h, w) # target: (b, [c,] d, h, w) # [weights: (b, [c,] d, h, w)] # prediction: (b, [c,] d, h, w) if snapshot_every > 0: # get channels first pipeline += TransposeDims(raw, (1, 0, 2, 3, 4)) if predictor.target_channels > 0: pipeline += TransposeDims(target, (1, 0, 2, 3, 4)) if weights_node: pipeline += TransposeDims(weights, (1, 0, 2, 3, 4)) if predictor.prediction_channels > 0: pipeline += TransposeDims(prediction, (1, 0, 2, 3, 4)) # raw: (c, b, d, h, w) # target: ([c,] b, d, h, w) # [weights: ([c,] b, d, h, w)] # prediction: ([c,] b, d, h, w) if channel_dims == 0: pipeline += RemoveChannelDim(raw) # raw: ([c,] b, d, h, w) # target: (c, b, d, h, w) # [weights: ([c,] b, d, h, w)] # prediction: (c, b, d, h, w) pipeline += gp.Snapshot( dataset_names=snapshot_dataset_names, every=snapshot_every, output_dir=os.path.join(outdir, "snapshots"), output_filename="{iteration}.hdf", ) pipeline += gp.PrintProfilingStats(every=10) request = gp.BatchRequest() request.add(raw, input_size) request.add(gt, output_size) if mask_node is not None: request.add(mask, output_size) request.add(target, output_size) for name, _, _ in task.aux_tasks: aux_pred, aux_target, aux_weight = aux_keys[name] request.add(aux_pred, output_size) request.add(aux_target, output_size) if aux_weight is not None: request.add(aux_weight, output_size) aux_pred_grad = aux_grad_keys[name] request.add(aux_pred_grad, output_size) if weights_node: request.add(weights, output_size) request.add(prediction, output_size) request.add(pred_gradients, output_size) return pipeline, request