def create_train_pipeline(self, model): optimizer = self.params['optimizer'](model.parameters(), **self.params['optimizer_kwargs']) points = gp.ArrayKey('POINTS') predictions = gp.ArrayKey("PREDICTIONS") gt_labels = gp.ArrayKey('LABELS') request = gp.BatchRequest() # Because of PointsLabelsSource we can keep everything as nonspatial request[points] = gp.ArraySpec(nonspatial=True) request[predictions] = gp.ArraySpec(nonspatial=True) request[gt_labels] = gp.ArraySpec(nonspatial=True) pipeline = ( PointsLabelsSource(points, self.data, gt_labels, self.labels, 1) + gp.Stack(self.params['batch_size']) + gp.torch.Train( model, self.loss, optimizer, inputs={'points': points}, loss_inputs={ 0: predictions, 1: gt_labels }, outputs={0: predictions}, checkpoint_basename=self.logdir + '/checkpoints/model', save_every=self.params['save_every'], log_dir=self.logdir, log_every=self.log_every)) return pipeline, request
def build_source(self): data = daisy.open_ds(filename, key) if self.time_window is None: source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape()) else: offs = list(data.roi.get_offset()) offs[1] += self.time_window[0] sh = list(data.roi.get_shape()) offs[1] = self.time_window[1] - self.time_window[0] source_roi = gp.Roi(tuple(offs), tuple(sh)) voxel_size = gp.Coordinate(data.voxel_size) return gp.ZarrSource(filename, { self.raw_0: key, self.raw_1: key }, array_specs={ self.raw_0: gp.ArraySpec( roi=source_roi, voxel_size=voxel_size, interpolatable=True), self.raw_1: gp.ArraySpec( roi=source_roi, voxel_size=voxel_size, interpolatable=True) })
def create_source(sample, raw, presyn, postsyn, dummypostsyn, parameter, gt_neurons): data_sources = tuple(( Hdf5PointsSource(os.path.join(data_dir_syn, sample + '.hdf'), datasets={ presyn: 'annotations', postsyn: 'annotations' }, rois={ presyn: cremi_roi, postsyn: cremi_roi }), Hdf5PointsSource( os.path.join(data_dir_syn, sample + '.hdf'), datasets={dummypostsyn: 'annotations'}, rois={ # presyn: cremi_roi, dummypostsyn: cremi_roi }, kind='postsyn'), gp.Hdf5Source(os.path.join(data_dir, sample + '.hdf'), datasets={ raw: 'volumes/raw', gt_neurons: 'volumes/labels/neuron_ids', }, array_specs={ raw: gp.ArraySpec(interpolatable=True), gt_neurons: gp.ArraySpec(interpolatable=False), }))) source_pip = data_sources + gp.MergeProvider() + gp.Normalize( raw) + gp.RandomLocation(ensure_nonempty=dummypostsyn, p_nonempty=parameter['reject_probability']) return source_pip
def setup(self): self.ndims = self.data.shape[1] if self.points_spec is not None: self.provides(self.points, self.points_spec) elif isinstance(self.points, gp.ArrayKey): self.provides(self.points, gp.ArraySpec(voxel_size=((1, )))) elif isinstance(self.points, gp.GraphKey): print(self.ndims) min_bb = gp.Coordinate( np.floor(np.amin(self.data[:, :self.ndims], 0))) max_bb = gp.Coordinate( np.ceil(np.amax(self.data[:, :self.ndims], 0)) + 1) roi = gp.Roi(min_bb, max_bb - min_bb) logger.debug(f"Bounding Box: {roi}") self.provides(self.points, gp.GraphSpec(roi=roi)) if self.labels is not None: assert isinstance(self.labels, gp.ArrayKey), \ f"Label key must be an ArrayKey, \ was given {type(self.labels)}" if self.labels_spec is not None: self.provides(self.labels, self.labels_spec) else: self.provides(self.labels, gp.ArraySpec(voxel_size=((1, ))))
def predict(iteration,path_to_dataGP): input_size = (8, 96, 96) output_size = (4, 64, 64) amount_size = gp.Coordinate((2, 16, 16)) model = SpineUNet(crop_output='output_size') raw = gp.ArrayKey('RAW') affs_predicted = gp.ArrayKey('AFFS_PREDICTED') reference_request = gp.BatchRequest() reference_request.add(raw, input_size) reference_request.add(affs_predicted, output_size) source = gp.ZarrSource( path_to_dataGP, { raw: 'validate/sample1/raw' } ) with gp.build(source): source_roi = source.spec[raw].roi request = gp.BatchRequest() request[raw] = gp.ArraySpec(roi=source_roi) request[affs_predicted] = gp.ArraySpec(roi=source_roi) pipeline = ( source + gp.Pad(raw,amount_size) + gp.Normalize(raw) + # raw: (d, h, w) gp.Stack(1) + # raw: (1, d, h, w) AddChannelDim(raw) + # raw: (1, 1, d, h, w) gp_torch.Predict( model, inputs={'x': raw}, outputs={0: affs_predicted}, checkpoint=f'C:/Users/filip/spine_yodl/model_checkpoint_{iteration}') + RemoveChannelDim(raw) + RemoveChannelDim(raw) + RemoveChannelDim(affs_predicted) + # raw: (d, h, w) # affs_predicted: (3, d, h, w) gp.Scan(reference_request) ) with gp.build(pipeline): prediction = pipeline.request_batch(request) return prediction[raw].data, prediction[affs_predicted].data
def setup(self): self.provides( gp.ArrayKeys.M_PRED, gp.ArraySpec(roi=gp.Roi((0, 0, 0), (200, 200, 200)), voxel_size=self.voxel_size, interpolatable=False)) self.provides( gp.ArrayKeys.D_PRED, gp.ArraySpec(roi=gp.Roi((0, 0, 0), (200, 200, 200)), voxel_size=self.voxel_size, interpolatable=False))
def provide(self, request): roi_array = request[gp.ArrayKeys.M_PRED].roi batch = gp.Batch() batch.arrays[gp.ArrayKeys.M_PRED] = gp.Array( self.m_pred[(roi_array / self.voxel_size).to_slices()], spec=gp.ArraySpec(roi=roi_array, voxel_size=self.voxel_size)) slices = (roi_array / self.voxel_size).to_slices() batch.arrays[gp.ArrayKeys.D_PRED] = gp.Array( self.d_pred[:, slices[0], slices[1], slices[2]], spec=gp.ArraySpec(roi=roi_array, voxel_size=self.voxel_size)) return batch
def setup(self): # we provide cage maps everywhere where we have a segmentation: roi = self.spec[self.seg].roi.copy() voxel_size = self.spec[self.seg].voxel_size self.provides( self.cage_map, gp.ArraySpec(roi=roi, dtype=np.uint16, voxel_size=voxel_size)) # same for the density map roi = self.spec[self.seg].roi.copy() self.provides( self.density_map, gp.ArraySpec(roi=roi, dtype=np.float32, voxel_size=voxel_size))
def setup(self): self.provides( self.raw, gp.ArraySpec(roi=gp.Roi((0, 0), (1000, 1000)), dtype=np.uint8, interpolatable=True, voxel_size=(1, 1))) self.provides( self.gt, gp.ArraySpec(roi=gp.Roi((0, 0), (1000, 1000)), dtype=np.uint64, interpolatable=False, voxel_size=(1, 1)))
def setup(self): provided_spec = gp.ArraySpec( roi=self.spec[self.gt_key].roi, voxel_size=self.spec[self.gt_key].voxel_size, interpolatable=self.predictor.output_array_type.interpolatable, ) self.provides(self.target_key, provided_spec) provided_spec = gp.ArraySpec( roi=self.spec[self.gt_key].roi, voxel_size=self.spec[self.gt_key].voxel_size, interpolatable=True, ) self.provides(self.weights_key, provided_spec)
def __init__(self, voxel_size): self.voxel_size = gp.Coordinate(voxel_size) self.roi = gp.Roi((0, 0, 0), (10, 10, 10)) * self.voxel_size self.raw = gp.ArrayKey("RAW") self.labels = gp.ArrayKey("LABELS") self.array_spec_raw = gp.ArraySpec(roi=self.roi, voxel_size=self.voxel_size, dtype='uint8', interpolatable=True) self.array_spec_labels = gp.ArraySpec(roi=self.roi, voxel_size=self.voxel_size, dtype='uint64', interpolatable=False)
def setup(self): self.enable_autoskip() self.provides(self.output, gp.ArraySpec(nonspatial=True)) if self.details is not None: self.provides(self.details, self.spec[self.mst].copy()) if self.output_graph is not None: self.provides(self.output_graph, self.spec[self.mst].copy())
def setup(self): self.provides( self.array_key, gp.ArraySpec(roi=gp.Roi(offset=gp.Coordinate( (-10000, -10000, -10000)), shape=gp.Coordinate( (20000, 20000, 20000))), voxel_size=(1, 1, 1)))
def __init__(self, filename, key, density=None, channels=0, shape=(16, 256, 256), time_window=None, add_sparse_mosaic_channel=True, random_rot=False): self.filename = filename self.key = key self.shape = shape self.density = density self.raw = gp.ArrayKey('RAW_0') self.add_sparse_mosaic_channel = add_sparse_mosaic_channel self.random_rot = random_rot self.channels = channels data = daisy.open_ds(filename, key) if time_window is None: source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape()) else: offs = list(data.roi.get_offset()) offs[1] += time_window[0] sh = list(data.roi.get_shape()) offs[1] = time_window[1] - time_window[0] source_roi = gp.Roi(tuple(offs), tuple(sh)) voxel_size = gp.Coordinate(data.voxel_size) self.pipeline = gp.ZarrSource( filename, { self.raw: key }, array_specs={ self.raw: gp.ArraySpec( roi=source_roi, voxel_size=voxel_size, interpolatable=True) }) + gp.RandomLocation() + IntensityDiffFilter(self.raw, 0, min_distance=0.1, channels=Slice(None)) # add augmentations self.pipeline = self.pipeline + gp.ElasticAugment([40, 40], [2, 2], [0, math.pi / 2.0], prob_slip=-1, spatial_dims=2) self.pipeline.setup() np.random.seed(os.getpid() + int(time.time()))
def prepare(self, request): context = self.context dims = request[self.srcpoints].roi.dims() assert type(context) == list if len(context) == 1: context = context * dims # request array in a larger area to get predictions from outside # write roi m_roi = request[self.srcpoints].roi.grow(gp.Coordinate(context), gp.Coordinate(context)) # however, restrict the request to the array actually provided # m_roi = m_roi.intersect(self.spec[self.m_array].roi) request[self.m_array] = gp.ArraySpec(roi=m_roi) # Do the same for the direction vector array. request[self.d_array] = gp.ArraySpec(roi=m_roi)
def get_requests(config, blocks, raw, emb_pred, labels, gt): voxel_size = gp.Coordinate(config["VOXEL_SIZE"]) input_shape = gp.Coordinate(config["INPUT_SHAPE"]) output_shape = gp.Coordinate(config["OUTPUT_SHAPE"]) input_size = voxel_size * input_shape output_size = voxel_size * output_shape diff = input_size - output_size cube_rois = [get_cube_roi(config, block) for block in blocks] requests = [] for cube_roi in cube_rois: context_roi = cube_roi.grow(diff // 2, diff // 2) request = gp.BatchRequest() request[raw] = gp.ArraySpec(roi=context_roi) request[emb_pred] = gp.ArraySpec(roi=cube_roi) request[labels] = gp.ArraySpec(roi=cube_roi) request[gt] = gp.GraphSpec(roi=cube_roi) requests.append(request) return requests
def process(self, batch, request): spec = self.spec[self.fg].copy() voxel_size = (1, ) + spec.voxel_size merged = np.stack([batch[self.fg].data, batch[self.bg].data], axis=0) batch[self.raw] = gp.Array( data=merged.astype(spec.dtype), spec=gp.ArraySpec(dtype=spec.dtype, roi=Roi((0, 0, 0, 0), merged.shape) * voxel_size, interpolatable=True, voxel_size=voxel_size))
def validation_data_sources_from_snapshots(config, blocks): validation_blocks = Path(config["VALIDATION_BLOCKS"]) raw = gp.ArrayKey("RAW") ground_truth = gp.GraphKey("GROUND_TRUTH") labels = gp.ArrayKey("LABELS") voxel_size = gp.Coordinate(config["VOXEL_SIZE"]) input_shape = gp.Coordinate(config["INPUT_SHAPE"]) output_shape = gp.Coordinate(config["OUTPUT_SHAPE"]) input_size = voxel_size * input_shape output_size = voxel_size * output_shape block_pipelines = [] for block in blocks: pipelines = ( SnapshotSource( validation_blocks / f"block_{block}.hdf", { labels: "volumes/labels", ground_truth: "points/gt" }, directed={ground_truth: True}, ), SnapshotSource(validation_blocks / f"block_{block}.hdf", {raw: "volumes/raw"}), ) cube_roi = get_cube_roi(config, block) request = gp.BatchRequest() input_roi = cube_roi.grow((input_size - output_size) // 2, (input_size - output_size) // 2) request[raw] = gp.ArraySpec(input_roi) request[ground_truth] = gp.GraphSpec(cube_roi) request[labels] = gp.ArraySpec(cube_roi) block_pipelines.append((pipelines, request)) return block_pipelines, (raw, labels, ground_truth)
def process(self, batch, request): final_scores = {} for key, array in batch.items(): if "SCORE" in str(key): block = int(str(key).split("_")[1]) final_scores[block] = array.data final_scores = [ final_scores[block] for block in range(1, 26) if block in final_scores ] outputs = gp.Batch() outputs[self.output] = gp.Array(np.array(final_scores), gp.ArraySpec(nonspatial=True)) return outputs
def evaluate_affs(pred_labels, gt_labels, return_results=False): results = rand_voi(gt_labels.data, pred_labels.data) results["voi_sum"] = results["voi_split"] + results["voi_merge"] scores = {"sample": results, "average": results} if return_results: results = { "pred_labels": gp.Array( pred_labels.data.astype(np.uint64), gp.ArraySpec(roi=pred_labels.spec.roi, voxel_size=pred_labels.spec.voxel_size)), "gt_labels": gp.Array( gt_labels.data.astype(np.uint64), gp.ArraySpec(roi=gt_labels.spec.roi, voxel_size=gt_labels.spec.voxel_size)), } return scores, results return scores
def test_gp_dacapo_array_source(array_config): # Create Array from config array = array_config.array_type(array_config) # Make sure the DaCapoArraySource can properly read # the data in `array` key = gp.ArrayKey("TEST") source_node = DaCapoArraySource(array, key) with gp.build(source_node): request = gp.BatchRequest() request[key] = gp.ArraySpec(roi=array.roi) batch = source_node.request_batch(request) data = batch[key].data assert (data - array[array.roi]).sum() == 0
def test_pipeline3(self): array_key = gp.ArrayKey("TEST_ARRAY") points_key = gp.PointsKey("TEST_POINTS") voxel_size = gp.Coordinate((1, 1)) spec = gp.ArraySpec(voxel_size=voxel_size, interpolatable=True) hdf5_source = gp.Hdf5Source(self.fake_data_file, {array_key: 'testdata'}, array_specs={array_key: spec}) csv_source = gp.CsvPointsSource( self.fake_points_file, points_key, gp.PointsSpec( roi=gp.Roi(shape=gp.Coordinate((100, 100)), offset=(0, 0)))) request = gp.BatchRequest() shape = gp.Coordinate((60, 60)) request.add(array_key, shape, voxel_size=gp.Coordinate((1, 1))) request.add(points_key, shape) shift_node = gp.ShiftAugment(prob_slip=0.2, prob_shift=0.2, sigma=5, shift_axis=0) pipeline = ((hdf5_source, csv_source) + gp.MergeProvider() + gp.RandomLocation(ensure_nonempty=points_key) + shift_node) with gp.build(pipeline) as b: request = b.request_batch(request) # print(request[points_key]) target_vals = [ self.fake_data[point[0]][point[1]] for point in self.fake_points ] result_data = request[array_key].data result_points = request[points_key].data result_vals = [ result_data[int(point.location[0])][int(point.location[1])] for point in result_points.values() ] for result_val in result_vals: self.assertTrue( result_val in target_vals, msg= "result value {} at points {} not in target values {} at points {}" .format(result_val, list(result_points.values()), target_vals, self.fake_points))
def test_prepare1(self): key = gp.ArrayKey("TEST_ARRAY") spec = gp.ArraySpec(voxel_size=gp.Coordinate((1, 1)), interpolatable=True) hdf5_source = gp.Hdf5Source(self.fake_data_file, {key: 'testdata'}, array_specs={key: spec}) request = gp.BatchRequest() shape = gp.Coordinate((3, 3)) request.add(key, shape, voxel_size=gp.Coordinate((1, 1))) shift_node = gp.ShiftAugment(sigma=1, shift_axis=0) with gp.build((hdf5_source + shift_node)): shift_node.prepare(request) self.assertTrue(shift_node.ndim == 2) self.assertTrue(shift_node.shift_sigmas == tuple([0.0, 1.0]))
def test_pipeline2(self): key = gp.ArrayKey("TEST_ARRAY") spec = gp.ArraySpec(voxel_size=gp.Coordinate((3, 1)), interpolatable=True) hdf5_source = gp.Hdf5Source(self.fake_data_file, {key: 'testdata'}, array_specs={key: spec}) request = gp.BatchRequest() shape = gp.Coordinate((3, 3)) request.add(key, shape, voxel_size=gp.Coordinate((3, 1))) shift_node = gp.ShiftAugment(prob_slip=0.2, prob_shift=0.2, sigma=1, shift_axis=0) with gp.build((hdf5_source + shift_node)) as b: b.request_batch(request)
def __read_spec(self, array_key): if array_key in self.array_specs: spec = self.array_specs[array_key].copy() else: spec = gp.ArraySpec() assert spec.voxel_size is not None, "Voxel size needs to be given" self.ndims = len(spec.voxel_size) if spec.roi is None: roi = gp.Roi(gp.Coordinate((0, ) * self.ndims), shape=gp.Coordinate((1, ) * self.ndims)) roi.set_shape(None) spec.roi = roi arr = self.func((2, ) * self.ndims) if spec.dtype is not None: assert spec.dtype == arr.dtype, ( "dtype %s provided in array_specs for %s, " "but differs from function output %s dtype %s" % (self.array_specs[array_key].dtype, array_key, self.func, arr.dtype)) else: spec.dtype = arr.dtype if spec.interpolatable is None: spec.interpolatable = spec.dtype in [ np.float, np.float32, np.float64, np.float128, np.uint8 # assuming this is not used for labels ] logger.warning( "WARNING: You didn't set 'interpolatable' for %s " "(func %s) . Based on the dtype %s, it has been " "set to %s. This might not be what you want.", array_key, self.func, spec.dtype, spec.interpolatable) return spec
def train_until(max_iteration, name='train_net', output_folder='.', clip_max=2000): # get the latest checkpoint if tf.train.latest_checkpoint(output_folder): trained_until = int(tf.train.latest_checkpoint(output_folder).split('_')[-1]) else: trained_until = 0 if trained_until >= max_iteration: return with open(os.path.join(output_folder, name + '_config.json'), 'r') as f: net_config = json.load(f) with open(os.path.join(output_folder, name + '_names.json'), 'r') as f: net_names = json.load(f) # array keys raw = gp.ArrayKey('RAW') gt_mask = gp.ArrayKey('GT_MASK') gt_dt = gp.ArrayKey('GT_DT') pred_dt = gp.ArrayKey('PRED_DT') loss_gradient = gp.ArrayKey('LOSS_GRADIENT') voxel_size = gp.Coordinate((1, 1, 1)) input_shape = gp.Coordinate(net_config['input_shape']) output_shape = gp.Coordinate(net_config['output_shape']) context = gp.Coordinate(input_shape - output_shape) / 2 request = gp.BatchRequest() request.add(raw, input_shape) request.add(gt_mask, output_shape) request.add(gt_dt, output_shape) snapshot_request = gp.BatchRequest() snapshot_request.add(raw, input_shape) snapshot_request.add(gt_mask, output_shape) snapshot_request.add(gt_dt, output_shape) snapshot_request.add(pred_dt, output_shape) snapshot_request.add(loss_gradient, output_shape) # specify data source data_sources = tuple() for data_file in data_files: current_path = os.path.join(data_dir, data_file) with h5py.File(current_path, 'r') as f: data_sources += tuple( gp.Hdf5Source( current_path, datasets={ raw: sample + '/raw', gt_mask: sample + '/fg' }, array_specs={ raw: gp.ArraySpec(interpolatable=True, dtype=np.uint16, voxel_size=voxel_size), gt_mask: gp.ArraySpec(interpolatable=False, dtype=np.bool, voxel_size=voxel_size), } ) + Convert(gt_mask, np.uint8) + gp.Pad(raw, context) + gp.Pad(gt_mask, context) + gp.RandomLocation() for sample in f) pipeline = ( data_sources + gp.RandomProvider() + gp.Reject(gt_mask, min_masked=0.005, reject_probability=1.) + DistanceTransform(gt_mask, gt_dt, 3) + nl.Clip(raw, 0, clip_max) + gp.Normalize(raw, factor=1.0/clip_max) + gp.ElasticAugment( control_point_spacing=[20, 20, 20], jitter_sigma=[1, 1, 1], rotation_interval=[0, math.pi/2.0], subsample=4) + gp.SimpleAugment(mirror_only=[1,2], transpose_only=[1,2]) + gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1) + gp.IntensityScaleShift(raw, 2,-1) + # train gp.PreCache( cache_size=40, num_workers=5) + gp.tensorflow.Train( os.path.join(output_folder, name), optimizer=net_names['optimizer'], loss=net_names['loss'], inputs={ net_names['raw']: raw, net_names['gt_dt']: gt_dt, }, outputs={ net_names['pred_dt']: pred_dt, }, gradients={ net_names['pred_dt']: loss_gradient, }, save_every=5000) + # visualize gp.Snapshot({ raw: 'volumes/raw', gt_mask: 'volumes/gt_mask', gt_dt: 'volumes/gt_dt', pred_dt: 'volumes/pred_dt', loss_gradient: 'volumes/gradient', }, output_filename=os.path.join(output_folder, 'snapshots', 'batch_{iteration}.hdf'), additional_request=snapshot_request, every=2000) + gp.PrintProfilingStats(every=500) ) with gp.build(pipeline): print("Starting training...") for i in range(max_iteration - trained_until): pipeline.request_batch(request)
def predict_2d(raw_data, gt_data, predictor): raw_channels = max(1, raw_data.num_channels) input_shape = predictor.input_shape output_shape = predictor.output_shape dataset_shape = raw_data.shape dataset_roi = raw_data.roi voxel_size = raw_data.voxel_size # switch to world units input_size = voxel_size * input_shape output_size = voxel_size * output_shape raw = gp.ArrayKey('RAW') gt = gp.ArrayKey('GT') target = gp.ArrayKey('TARGET') prediction = gp.ArrayKey('PREDICTION') channel_dims = 0 if raw_channels == 1 else 1 data_dims = len(dataset_shape) - channel_dims if data_dims == 3: num_samples = dataset_shape[0] sample_shape = dataset_shape[channel_dims + 1:] else: raise RuntimeError( "For 2D validation, please provide a 3D array where the first " "dimension indexes the samples.") num_samples = raw_data.num_samples sample_shape = gp.Coordinate(sample_shape) sample_size = sample_shape * voxel_size scan_request = gp.BatchRequest() scan_request.add(raw, input_size) scan_request.add(prediction, output_size) if gt_data: scan_request.add(gt, output_size) scan_request.add(target, output_size) # overwrite source ROI to treat samples as z dimension spec = gp.ArraySpec(roi=gp.Roi((0, ) + dataset_roi.get_begin(), (num_samples, ) + sample_size), voxel_size=(1, ) + voxel_size) if gt_data: sources = (raw_data.get_source(raw, overwrite_spec=spec), gt_data.get_source(gt, overwrite_spec=spec)) pipeline = sources + gp.MergeProvider() else: pipeline = raw_data.get_source(raw, overwrite_spec=spec) pipeline += gp.Pad(raw, None) if gt_data: pipeline += gp.Pad(gt, None) # raw: ([c,] s, h, w) # gt: ([c,] s, h, w) pipeline += gp.Normalize(raw) # raw: ([c,] s, h, w) # gt: ([c,] s, h, w) if gt_data: pipeline += predictor.add_target(gt, target) # raw: ([c,] s, h, w) # gt: ([c,] s, h, w) # target: ([c,] s, h, w) if channel_dims == 0: pipeline += AddChannelDim(raw) if gt_data and predictor.target_channels == 0: pipeline += AddChannelDim(target) # raw: (c, s, h, w) # gt: ([c,] s, h, w) # target: (c, s, h, w) pipeline += TransposeDims(raw, (1, 0, 2, 3)) if gt_data: pipeline += TransposeDims(target, (1, 0, 2, 3)) # raw: (s, c, h, w) # gt: ([c,] s, h, w) # target: (s, c, h, w) pipeline += gp_torch.Predict(model=predictor, inputs={'x': raw}, outputs={0: prediction}) # raw: (s, c, h, w) # gt: ([c,] s, h, w) # target: (s, c, h, w) # prediction: (s, c, h, w) pipeline += gp.Scan(scan_request) total_request = gp.BatchRequest() total_request.add(raw, sample_size) total_request.add(prediction, sample_size) if gt_data: total_request.add(gt, sample_size) total_request.add(target, sample_size) with gp.build(pipeline): batch = pipeline.request_batch(total_request) ret = {'raw': batch[raw], 'prediction': batch[prediction]} if gt_data: ret.update({'gt': batch[gt], 'target': batch[target]}) return ret
def predict_3d(raw_data, gt_data, predictor): raw_channels = max(1, raw_data.num_channels) input_shape = predictor.input_shape output_shape = predictor.output_shape voxel_size = raw_data.voxel_size # switch to world units input_size = voxel_size * input_shape output_size = voxel_size * output_shape raw = gp.ArrayKey('RAW') gt = gp.ArrayKey('GT') target = gp.ArrayKey('TARGET') prediction = gp.ArrayKey('PREDICTION') channel_dims = 0 if raw_channels == 1 else 1 num_samples = raw_data.num_samples assert num_samples == 0, ( "Multiple samples for 3D validation not yet implemented") scan_request = gp.BatchRequest() scan_request.add(raw, input_size) scan_request.add(prediction, output_size) if gt_data: scan_request.add(gt, output_size) scan_request.add(target, output_size) if gt_data: sources = (raw_data.get_source(raw), gt_data.get_source(gt)) pipeline = sources + gp.MergeProvider() else: pipeline = raw_data.get_source(raw) pipeline += gp.Pad(raw, None) if gt_data: pipeline += gp.Pad(gt, None) # raw: ([c,] d, h, w) # gt: ([c,] d, h, w) pipeline += gp.Normalize(raw) # raw: ([c,] d, h, w) # gt: ([c,] d, h, w) if gt_data: pipeline += predictor.add_target(gt, target) # raw: ([c,] d, h, w) # gt: ([c,] d, h, w) # target: ([c,] d, h, w) if channel_dims == 0: pipeline += AddChannelDim(raw) # raw: (c, d, h, w) # gt: ([c,] d, h, w) # target: ([c,] d, h, w) # add a "batch" dimension pipeline += AddChannelDim(raw) # raw: (1, c, d, h, w) # gt: ([c,] d, h, w) # target: ([c,] d, h, w) pipeline += gp_torch.Predict(model=predictor, inputs={'x': raw}, outputs={0: prediction}) # remove "batch" dimension pipeline += RemoveChannelDim(raw) pipeline += RemoveChannelDim(prediction) # raw: (c, d, h, w) # gt: ([c,] d, h, w) # target: ([c,] d, h, w) # prediction: ([c,] d, h, w) if channel_dims == 0: pipeline += RemoveChannelDim(raw) # raw: ([c,] d, h, w) # gt: ([c,] d, h, w) # target: ([c,] d, h, w) # prediction: ([c,] d, h, w) pipeline += gp.Scan(scan_request) # ensure validation ROI is at least the size of the network input roi = raw_data.roi.grow(input_size / 2, input_size / 2) total_request = gp.BatchRequest() total_request[raw] = gp.ArraySpec(roi=roi) total_request[prediction] = gp.ArraySpec(roi=roi) if gt_data: total_request[gt] = gp.ArraySpec(roi=roi) total_request[target] = gp.ArraySpec(roi=roi) with gp.build(pipeline): batch = pipeline.request_batch(total_request) ret = {'raw': batch[raw], 'prediction': batch[prediction]} if gt_data: ret.update({'gt': batch[gt], 'target': batch[target]}) return ret
prediction = gp.ArrayKey('PREDICTION') class PrepareTrainingData(gp.BatchFilter): def process(self, batch, request): batch[out_cage_map].data = batch[out_cage_map].data.astype(np.float32) batch[out_cage_map].spec.dtype = np.float32 # assemble pipeline sourceA = gp.ZarrSource('../data/cropped_sample_A.zarr', { raw: 'raw', seg: 'segmentation' }, { raw: gp.ArraySpec(interpolatable=True), seg: gp.ArraySpec(interpolatable=False) }) sourceB = gp.ZarrSource('../data/cropped_sample_B.zarr', { raw: 'raw', seg: 'segmentation' }, { raw: gp.ArraySpec(interpolatable=True), seg: gp.ArraySpec(interpolatable=False) }) sourceC = gp.ZarrSource('../data/cropped_sample_C.zarr', { raw: 'raw', seg: 'segmentation' }, { raw: gp.ArraySpec(interpolatable=True), seg: gp.ArraySpec(interpolatable=False)
def train_simple_pipeline(n_iterations, setup_config, mknet_tensor_names, loss_tensor_names): input_shape = gp.Coordinate(setup_config["INPUT_SHAPE"]) output_shape = gp.Coordinate(setup_config["OUTPUT_SHAPE"]) voxel_size = gp.Coordinate(setup_config["VOXEL_SIZE"]) num_iterations = setup_config["NUM_ITERATIONS"] cache_size = setup_config["CACHE_SIZE"] num_workers = setup_config["NUM_WORKERS"] snapshot_every = setup_config["SNAPSHOT_EVERY"] checkpoint_every = setup_config["CHECKPOINT_EVERY"] profile_every = setup_config["PROFILE_EVERY"] seperate_by = setup_config["SEPERATE_BY"] gap_crossing_dist = setup_config["GAP_CROSSING_DIST"] match_distance_threshold = setup_config["MATCH_DISTANCE_THRESHOLD"] point_balance_radius = setup_config["POINT_BALANCE_RADIUS"] neuron_radius = setup_config["NEURON_RADIUS"] samples_path = Path(setup_config["SAMPLES_PATH"]) mongo_url = setup_config["MONGO_URL"] input_size = input_shape * voxel_size output_size = output_shape * voxel_size # voxels have size ~= 1 micron on z axis # use this value to scale anything that depends on world unit distance micron_scale = voxel_size[0] seperate_distance = (np.array(seperate_by)).tolist() # array keys for data sources raw = gp.ArrayKey("RAW") consensus = gp.PointsKey("CONSENSUS") skeletonization = gp.PointsKey("SKELETONIZATION") matched = gp.PointsKey("MATCHED") labels = gp.ArrayKey("LABELS") labels_fg = gp.ArrayKey("LABELS_FG") labels_fg_bin = gp.ArrayKey("LABELS_FG_BIN") loss_weights = gp.ArrayKey("LOSS_WEIGHTS") # tensorflow tensors gt_fg = gp.ArrayKey("GT_FG") fg_pred = gp.ArrayKey("FG_PRED") embedding = gp.ArrayKey("EMBEDDING") fg = gp.ArrayKey("FG") maxima = gp.ArrayKey("MAXIMA") gradient_embedding = gp.ArrayKey("GRADIENT_EMBEDDING") gradient_fg = gp.ArrayKey("GRADIENT_FG") emst = gp.ArrayKey("EMST") edges_u = gp.ArrayKey("EDGES_U") edges_v = gp.ArrayKey("EDGES_V") ratio_pos = gp.ArrayKey("RATIO_POS") ratio_neg = gp.ArrayKey("RATIO_NEG") dist = gp.ArrayKey("DIST") num_pos_pairs = gp.ArrayKey("NUM_POS") num_neg_pairs = gp.ArrayKey("NUM_NEG") # add request request = gp.BatchRequest() request.add(labels_fg, output_size) request.add(labels_fg_bin, output_size) request.add(loss_weights, output_size) request.add(raw, input_size) request.add(labels, input_size) request.add(matched, input_size) request.add(skeletonization, input_size) request.add(consensus, input_size) # add snapshot request snapshot_request = gp.BatchRequest() request.add(labels_fg, output_size) # tensorflow requests # snapshot_request.add(raw, input_size) # input_size request for positioning # snapshot_request.add(embedding, output_size, voxel_size=voxel_size) # snapshot_request.add(fg, output_size, voxel_size=voxel_size) # snapshot_request.add(gt_fg, output_size, voxel_size=voxel_size) # snapshot_request.add(fg_pred, output_size, voxel_size=voxel_size) # snapshot_request.add(maxima, output_size, voxel_size=voxel_size) # snapshot_request.add(gradient_embedding, output_size, voxel_size=voxel_size) # snapshot_request.add(gradient_fg, output_size, voxel_size=voxel_size) # snapshot_request[emst] = gp.ArraySpec() # snapshot_request[edges_u] = gp.ArraySpec() # snapshot_request[edges_v] = gp.ArraySpec() # snapshot_request[ratio_pos] = gp.ArraySpec() # snapshot_request[ratio_neg] = gp.ArraySpec() # snapshot_request[dist] = gp.ArraySpec() # snapshot_request[num_pos_pairs] = gp.ArraySpec() # snapshot_request[num_neg_pairs] = gp.ArraySpec() data_sources = tuple( ( gp.N5Source( filename=str((sample / "fluorescence-near-consensus.n5").absolute()), datasets={raw: "volume"}, array_specs={ raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size, dtype=np.uint16) }, ), gp.DaisyGraphProvider( f"mouselight-{sample.name}-consensus", mongo_url, points=[consensus], directed=True, node_attrs=[], edge_attrs=[], ), gp.DaisyGraphProvider( f"mouselight-{sample.name}-skeletonization", mongo_url, points=[skeletonization], directed=False, node_attrs=[], edge_attrs=[], ), ) + gp.MergeProvider() + gp.RandomLocation( ensure_nonempty=consensus, ensure_centered=True, point_balance_radius=point_balance_radius * micron_scale, ) + TopologicalMatcher( skeletonization, consensus, matched, failures=Path("matching_failures_slow"), match_distance_threshold=match_distance_threshold * micron_scale, max_gap_crossing=gap_crossing_dist * micron_scale, try_complete=False, use_gurobi=True, ) + RejectIfEmpty(matched) + RasterizeSkeleton( points=matched, array=labels, array_spec=gp.ArraySpec( interpolatable=False, voxel_size=voxel_size, dtype=np.uint32), ) + GrowLabels(labels, radii=[neuron_radius * micron_scale]) # TODO: Do these need to be scaled by world units? + gp.ElasticAugment( [40, 10, 10], [0.25, 1, 1], [0, math.pi / 2.0], subsample=4, use_fast_points_transform=True, recompute_missing_points=False, ) # + gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2]) + gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001) for sample in samples_path.iterdir() if sample.name in ("2018-07-02", "2018-08-01")) pipeline = ( data_sources + gp.RandomProvider() + Crop(labels, labels_fg) + BinarizeGt(labels_fg, labels_fg_bin) + gp.BalanceLabels(labels_fg_bin, loss_weights) + gp.PreCache(cache_size=cache_size, num_workers=num_workers) + gp.tensorflow.Train( "train_net", optimizer=create_custom_loss(mknet_tensor_names, setup_config), loss=None, inputs={ mknet_tensor_names["loss_weights"]: loss_weights, mknet_tensor_names["raw"]: raw, mknet_tensor_names["gt_labels"]: labels_fg, }, outputs={ mknet_tensor_names["embedding"]: embedding, mknet_tensor_names["fg"]: fg, loss_tensor_names["fg_pred"]: fg_pred, loss_tensor_names["maxima"]: maxima, loss_tensor_names["gt_fg"]: gt_fg, loss_tensor_names["emst"]: emst, loss_tensor_names["edges_u"]: edges_u, loss_tensor_names["edges_v"]: edges_v, loss_tensor_names["ratio_pos"]: ratio_pos, loss_tensor_names["ratio_neg"]: ratio_neg, loss_tensor_names["dist"]: dist, loss_tensor_names["num_pos_pairs"]: num_pos_pairs, loss_tensor_names["num_neg_pairs"]: num_neg_pairs, }, gradients={ mknet_tensor_names["embedding"]: gradient_embedding, mknet_tensor_names["fg"]: gradient_fg, }, save_every=checkpoint_every, summary="Merge/MergeSummary:0", log_dir="tensorflow_logs", ) + gp.PrintProfilingStats(every=profile_every) + gp.Snapshot( additional_request=snapshot_request, output_filename="snapshot_{}_{}.hdf".format( int(np.min(seperate_distance)), "{id}"), dataset_names={ # raw data raw: "volumes/raw", # labeled data labels: "volumes/labels", # trees skeletonization: "points/skeletonization", consensus: "points/consensus", matched: "points/matched", # output volumes embedding: "volumes/embedding", fg: "volumes/fg", maxima: "volumes/maxima", gt_fg: "volumes/gt_fg", fg_pred: "volumes/fg_pred", gradient_embedding: "volumes/gradient_embedding", gradient_fg: "volumes/gradient_fg", # output trees emst: "emst", edges_u: "edges_u", edges_v: "edges_v", # output debug data ratio_pos: "ratio_pos", ratio_neg: "ratio_neg", dist: "dist", num_pos_pairs: "num_pos_pairs", num_neg_pairs: "num_neg_pairs", loss_weights: "volumes/loss_weights", }, every=snapshot_every, )) with gp.build(pipeline): for _ in range(num_iterations): pipeline.request_batch(request)