def test_shift_points5(self): data = { 0: gp.Point([3, 0]), 1: gp.Point([3, 2]), 2: gp.Point([3, 4]), 3: gp.Point([3, 6]), 4: gp.Point([3, 8]) } spec = gp.PointsSpec(gp.Roi(offset=(0, 0), shape=(15, 10))) points = gp.Points(data, spec) request_roi = gp.Roi(offset=(3, 0), shape=(9, 10)) shift_array = np.array([[3, 0], [-3, 0], [0, 0], [-3, 0], [3, 0]], dtype=int) lcm_voxel_size = gp.Coordinate((3, 2)) shifted_data = { 0: gp.Point([6, 0]), 2: gp.Point([3, 4]), 4: gp.Point([6, 8]) } result = gp.ShiftAugment.shift_points(points, request_roi, shift_array, shift_axis=1, lcm_voxel_size=lcm_voxel_size) # print("test 4", result.data, shifted_data) self.assertTrue(self.points_equal(result.data, shifted_data)) self.assertTrue(result.spec == gp.PointsSpec(request_roi))
def setup(self): self.spec_src = gp.PointsSpec() self.spec_trg = gp.PointsSpec() self.provides(self.srcpoints, self.spec_src) self.provides(self.trgpoints, self.spec_trg) self.enable_autoskip()
def test_context(self): d_pred = gp.ArrayKeys.D_PRED m_pred = gp.ArrayKeys.M_PRED presyn = gp.PointsKeys.PRESYN postsyn = gp.PointsKeys.POSTSYN outdir = tempfile.mkdtemp() voxel_size = gp.Coordinate((10, 10, 10)) size = ((200, 200, 200)) # Check whether the score of the entire cube is measured, although # cube of borderpoint partially outside request ROI. context = 40 shape = gp.Coordinate(size) / voxel_size m_predar = np.zeros(shape, dtype=np.float32) outsidepoint = gp.Coordinate((13, 13, 13)) borderpoint = (4, 4, 4) m_predar[3:5, 3:5, 3:5] = 1 m_predar[outsidepoint] = 1 d_predar = np.ones((3, shape[0], shape[1], shape[2])) * 0 pipeline = (TestSource(m_predar, d_predar, voxel_size=voxel_size) + ExtractSynapses(m_pred, d_pred, presyn, postsyn, out_dir=outdir, settings=parameters, context=context) + gp.PrintProfilingStats()) request = gp.BatchRequest() roi = gp.Roi((40, 40, 40), (80, 80, 80)) request[presyn] = gp.PointsSpec(roi=roi) request[postsyn] = gp.PointsSpec(roi=roi) with gp.build(pipeline): batch = pipeline.request_batch(request) synapsefile = os.path.join(outdir, "40", "40", "40.npz") with np.load(synapsefile) as data: data = dict(data) self.assertTrue(len(data['ids']) == 1) self.assertEqual(data['scores'][0], 2.0**3) # Size of the cube. for ii in range(len(voxel_size)): self.assertEqual(data['positions'][0][0][ii], borderpoint[ii] * voxel_size[ii]) for ii in range(len(voxel_size)): self.assertEqual(data['positions'][0][1][ii], borderpoint[ii] * voxel_size[ii] + 0) shutil.rmtree(outdir)
def test_output_basics(self): d_pred = gp.ArrayKeys.D_PRED m_pred = gp.ArrayKeys.M_PRED presyn = gp.PointsKeys.PRESYN postsyn = gp.PointsKeys.POSTSYN voxel_size = gp.Coordinate((10, 10, 10)) size = ((200, 200, 200)) context = 40 shape = gp.Coordinate(size) / voxel_size m_predar = np.zeros(shape, dtype=np.float32) insidepoint = gp.Coordinate((10, 10, 10)) outsidepoint = gp.Coordinate((15, 15, 15)) m_predar[insidepoint] = 1 m_predar[outsidepoint] = 1 d_predar = np.ones((3, shape[0], shape[1], shape[2])) * 10 outdir = tempfile.mkdtemp() pipeline = (TestSource(m_predar, d_predar, voxel_size=voxel_size) + ExtractSynapses(m_pred, d_pred, presyn, postsyn, out_dir=outdir, settings=parameters, context=context)) request = gp.BatchRequest() roi = gp.Roi((40, 40, 40), (80, 80, 80)) request[presyn] = gp.PointsSpec(roi=roi) request[postsyn] = gp.PointsSpec(roi=roi) with gp.build(pipeline): batch = pipeline.request_batch(request) print(outdir, "outdir") synapsefile = os.path.join(outdir, "40", "40", "40.npz") with np.load(synapsefile) as data: data = dict(data) self.assertTrue(len(data['ids']) == 1) self.assertEqual(data['scores'][0], 1.0) # Size of the cube. for ii in range(len(voxel_size)): self.assertEqual(data['positions'][0][1][ii], insidepoint[ii] * voxel_size[ii]) for ii in range(len(voxel_size)): self.assertEqual(data['positions'][0][0][ii], insidepoint[ii] * voxel_size[ii] + 10) shutil.rmtree(outdir)
def test_shift_points2(self): data = {1: gp.Point([0, 1])} spec = gp.PointsSpec(gp.Roi(offset=(0, 0), shape=(5, 5))) points = gp.Points(data, spec) request_roi = gp.Roi(offset=(0, 1), shape=(5, 3)) shift_array = np.array([[0, 0], [0, -1], [0, 0], [0, 0], [0, 1]], dtype=int) lcm_voxel_size = gp.Coordinate((1, 1)) result = gp.ShiftAugment.shift_points(points, request_roi, shift_array, shift_axis=0, lcm_voxel_size=lcm_voxel_size) # print("test 2", result.data, data) self.assertTrue(self.points_equal(result.data, data)) self.assertTrue(result.spec == gp.PointsSpec(request_roi))
def test_pipeline3(self): array_key = gp.ArrayKey("TEST_ARRAY") points_key = gp.PointsKey("TEST_POINTS") voxel_size = gp.Coordinate((1, 1)) spec = gp.ArraySpec(voxel_size=voxel_size, interpolatable=True) hdf5_source = gp.Hdf5Source(self.fake_data_file, {array_key: 'testdata'}, array_specs={array_key: spec}) csv_source = gp.CsvPointsSource( self.fake_points_file, points_key, gp.PointsSpec( roi=gp.Roi(shape=gp.Coordinate((100, 100)), offset=(0, 0)))) request = gp.BatchRequest() shape = gp.Coordinate((60, 60)) request.add(array_key, shape, voxel_size=gp.Coordinate((1, 1))) request.add(points_key, shape) shift_node = gp.ShiftAugment(prob_slip=0.2, prob_shift=0.2, sigma=5, shift_axis=0) pipeline = ((hdf5_source, csv_source) + gp.MergeProvider() + gp.RandomLocation(ensure_nonempty=points_key) + shift_node) with gp.build(pipeline) as b: request = b.request_batch(request) # print(request[points_key]) target_vals = [ self.fake_data[point[0]][point[1]] for point in self.fake_points ] result_data = request[array_key].data result_points = request[points_key].data result_vals = [ result_data[int(point.location[0])][int(point.location[1])] for point in result_points.values() ] for result_val in result_vals: self.assertTrue( result_val in target_vals, msg= "result value {} at points {} not in target values {} at points {}" .format(result_val, list(result_points.values()), target_vals, self.fake_points))
def validation_data_sources_recomputed(config, blocks): benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"]) sample = config["VALIDATION_SAMPLES"][0] sample_dir = Path(config["SAMPLES_PATH"]) raw_n5 = config["RAW_N5"] transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt" neuron_width = int(config["NEURON_RADIUS"]) voxel_size = gp.Coordinate(config["VOXEL_SIZE"]) input_shape = gp.Coordinate(config["INPUT_SHAPE"]) output_shape = gp.Coordinate(config["OUTPUT_SHAPE"]) input_size = voxel_size * input_shape output_size = voxel_size * output_shape validation_dirs = {} for group in benchmark_datasets_path.iterdir(): if "validation" in group.name and group.is_dir(): for validation_dir in group.iterdir(): validation_num = int(validation_dir.name.split("_")[-1]) if validation_num in blocks: validation_dirs[validation_num] = validation_dir validation_dirs = [validation_dirs[block] for block in blocks] raw = gp.ArrayKey("RAW") ground_truth = gp.GraphKey("GROUND_TRUTH") labels = gp.ArrayKey("LABELS") validation_pipelines = [] for validation_dir in validation_dirs: trees = [] cube = None for gt_file in validation_dir.iterdir(): if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc": trees.append(gt_file) if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc": cube = gt_file assert cube.exists() cube_roi = get_roi_from_swc( cube, Path(transform_template.format(sample=sample)), np.array([300, 300, 1000]), ) pipeline = (( gp.ZarrSource( filename=str(Path(sample_dir, sample, raw_n5).absolute()), datasets={raw: "volume-rechunked"}, array_specs={ raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size) }, ), nl.gunpowder.nodes.MouselightSwcFileSource( validation_dir, [ground_truth], transform_file=transform_template.format(sample=sample), ignore_human_nodes=False, scale=voxel_size, transpose=[2, 1, 0], points_spec=[ gp.PointsSpec(roi=gp.Roi( gp.Coordinate([None, None, None]), gp.Coordinate([None, None, None]), )) ], ), ) + gp.nodes.MergeProvider() + gp.Normalize( raw, dtype=np.float32) + nl.gunpowder.RasterizeSkeleton( ground_truth, labels, connected_component_labeling=True, array_spec=gp.ArraySpec( voxel_size=voxel_size, dtype=np.int64, roi=gp.Roi( gp.Coordinate([None, None, None]), gp.Coordinate([None, None, None]), ), ), ) + nl.gunpowder.GrowLabels(labels, radii=[neuron_width * 1000])) request = gp.BatchRequest() input_roi = cube_roi.grow((input_size - output_size) // 2, (input_size - output_size) // 2) print(f"input_roi has shape: {input_roi.get_shape()}") print(f"cube_roi has shape: {cube_roi.get_shape()}") request[raw] = gp.ArraySpec(input_roi) request[ground_truth] = gp.GraphSpec(cube_roi) request[labels] = gp.ArraySpec(cube_roi) validation_pipelines.append((pipeline, request)) return validation_pipelines, (raw, labels, ground_truth)
def validation_pipeline(config): """ Per block { Raw -> predict -> scan gt -> rasterize -> merge -> candidates -> trees } -> merge -> comatch + evaluate """ blocks = config["BLOCKS"] benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"]) sample = config["VALIDATION_SAMPLES"][0] sample_dir = Path(config["SAMPLES_PATH"]) raw_n5 = config["RAW_N5"] transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt" neuron_width = int(config["NEURON_RADIUS"]) voxel_size = gp.Coordinate(config["VOXEL_SIZE"]) micron_scale = max(voxel_size) input_shape = gp.Coordinate(config["INPUT_SHAPE"]) output_shape = gp.Coordinate(config["OUTPUT_SHAPE"]) input_size = voxel_size * input_shape output_size = voxel_size * output_shape distance_attr = config["DISTANCE_ATTR"] validation_pipelines = [] specs = {} for block in blocks: validation_dir = get_validation_dir(benchmark_datasets_path, block) trees = [] cube = None for gt_file in validation_dir.iterdir(): if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc": trees.append(gt_file) if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc": cube = gt_file assert cube.exists() cube_roi = get_roi_from_swc( cube, Path(transform_template.format(sample=sample)), np.array([300, 300, 1000]), ) raw = gp.ArrayKey(f"RAW_{block}") raw_clahed = gp.ArrayKey(f"RAW_CLAHED_{block}") ground_truth = gp.GraphKey(f"GROUND_TRUTH_{block}") labels = gp.ArrayKey(f"LABELS_{block}") raw_source = (gp.ZarrSource( filename=str(Path(sample_dir, sample, raw_n5).absolute()), datasets={ raw: "volume-rechunked", raw_clahed: "volume-rechunked" }, array_specs={ raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size), raw_clahed: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size), }, ) + gp.Normalize(raw, dtype=np.float32) + gp.Normalize(raw_clahed, dtype=np.float32) + scipyCLAHE([raw_clahed], [20, 64, 64])) swc_source = nl.gunpowder.nodes.MouselightSwcFileSource( validation_dir, [ground_truth], transform_file=transform_template.format(sample=sample), ignore_human_nodes=False, scale=voxel_size, transpose=[2, 1, 0], points_spec=[ gp.PointsSpec(roi=gp.Roi( gp.Coordinate([None, None, None]), gp.Coordinate([None, None, None]), )) ], ) additional_request = BatchRequest() input_roi = cube_roi.grow((input_size - output_size) // 2, (input_size - output_size) // 2) cube_roi_shifted = gp.Roi((0, ) * len(cube_roi.get_shape()), cube_roi.get_shape()) input_roi = cube_roi_shifted.grow((input_size - output_size) // 2, (input_size - output_size) // 2) block_spec = specs.setdefault(block, {}) block_spec[raw] = gp.ArraySpec(input_roi) additional_request[raw] = gp.ArraySpec(roi=input_roi) block_spec[raw_clahed] = gp.ArraySpec(input_roi) additional_request[raw_clahed] = gp.ArraySpec(roi=input_roi) block_spec[ground_truth] = gp.GraphSpec(cube_roi_shifted) additional_request[ground_truth] = gp.GraphSpec(roi=cube_roi_shifted) block_spec[labels] = gp.ArraySpec(cube_roi_shifted) additional_request[labels] = gp.ArraySpec(roi=cube_roi_shifted) pipeline = ((swc_source, raw_source) + gp.nodes.MergeProvider() + gp.SpecifiedLocation(locations=[cube_roi.get_center()]) + gp.Crop(raw, roi=input_roi) + gp.Crop(raw_clahed, roi=input_roi) + gp.Crop(ground_truth, roi=cube_roi_shifted) + nl.gunpowder.RasterizeSkeleton( ground_truth, labels, connected_component_labeling=True, array_spec=gp.ArraySpec( voxel_size=voxel_size, dtype=np.int64, roi=gp.Roi( gp.Coordinate([None, None, None]), gp.Coordinate([None, None, None]), ), ), ) + nl.gunpowder.GrowLabels( labels, radii=[neuron_width * micron_scale]) + gp.Crop(labels, roi=cube_roi_shifted) + gp.Snapshot( { raw: f"volumes/{block}/raw", raw_clahed: f"volumes/{block}/raw_clahe", ground_truth: f"points/{block}/ground_truth", labels: f"volumes/{block}/labels", }, additional_request=additional_request, output_dir="validations", output_filename="validations.hdf", )) validation_pipelines.append(pipeline) validation_pipeline = (tuple(pipeline for pipeline in validation_pipelines) + gp.MergeProvider() + gp.PrintProfilingStats()) return validation_pipeline, specs
def train_until(**kwargs): if tf.train.latest_checkpoint(kwargs['output_folder']): trained_until = int( tf.train.latest_checkpoint(kwargs['output_folder']).split('_')[-1]) else: trained_until = 0 if trained_until >= kwargs['max_iteration']: return anchor = gp.ArrayKey('ANCHOR') raw = gp.ArrayKey('RAW') raw_cropped = gp.ArrayKey('RAW_CROPPED') gt_labels = gp.ArrayKey('GT_LABELS') gt_affs = gp.ArrayKey('GT_AFFS') gt_fgbg = gp.ArrayKey('GT_FGBG') gt_cpv = gp.ArrayKey('GT_CPV') gt_points = gp.PointsKey('GT_CPV_POINTS') loss_weights_affs = gp.ArrayKey('LOSS_WEIGHTS_AFFS') loss_weights_fgbg = gp.ArrayKey('LOSS_WEIGHTS_FGBG') # loss_weights_cpv = gp.ArrayKey('LOSS_WEIGHTS_CPV') pred_affs = gp.ArrayKey('PRED_AFFS') pred_fgbg = gp.ArrayKey('PRED_FGBG') pred_cpv = gp.ArrayKey('PRED_CPV') pred_affs_gradients = gp.ArrayKey('PRED_AFFS_GRADIENTS') pred_fgbg_gradients = gp.ArrayKey('PRED_FGBG_GRADIENTS') pred_cpv_gradients = gp.ArrayKey('PRED_CPV_GRADIENTS') with open( os.path.join(kwargs['output_folder'], kwargs['name'] + '_config.json'), 'r') as f: net_config = json.load(f) with open( os.path.join(kwargs['output_folder'], kwargs['name'] + '_names.json'), 'r') as f: net_names = json.load(f) voxel_size = gp.Coordinate(kwargs['voxel_size']) input_shape_world = gp.Coordinate(net_config['input_shape']) * voxel_size output_shape_world = gp.Coordinate(net_config['output_shape']) * voxel_size # formulate the request for what a batch should (at least) contain request = gp.BatchRequest() request.add(raw, input_shape_world) request.add(raw_cropped, output_shape_world) request.add(gt_labels, output_shape_world) request.add(gt_fgbg, output_shape_world) request.add(anchor, output_shape_world) request.add(gt_cpv, output_shape_world) request.add(gt_affs, output_shape_world) request.add(loss_weights_affs, output_shape_world) request.add(loss_weights_fgbg, output_shape_world) # when we make a snapshot for inspection (see below), we also want to # request the predicted affinities and gradients of the loss wrt the # affinities snapshot_request = gp.BatchRequest() snapshot_request.add(raw_cropped, output_shape_world) snapshot_request.add(pred_affs, output_shape_world) # snapshot_request.add(pred_affs_gradients, output_shape_world) snapshot_request.add(gt_fgbg, output_shape_world) snapshot_request.add(pred_fgbg, output_shape_world) # snapshot_request.add(pred_fgbg_gradients, output_shape_world) snapshot_request.add(pred_cpv, output_shape_world) # snapshot_request.add(pred_cpv_gradients, output_shape_world) if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr": raise NotImplementedError("train node for {} not implemented".format( kwargs['input_format'])) fls = [] shapes = [] for f in kwargs['data_files']: fls.append(os.path.splitext(f)[0]) if kwargs['input_format'] == "hdf": vol = h5py.File(f, 'r')['volumes/raw'] elif kwargs['input_format'] == "zarr": vol = zarr.open(f, 'r')['volumes/raw'] print(f, vol.shape, vol.dtype) shapes.append(vol.shape) if vol.dtype != np.float32: print("please convert to float32") ln = len(fls) print("first 5 files: ", fls[0:4]) # padR = 46 # padGT = 32 if kwargs['input_format'] == "hdf": sourceNode = gp.Hdf5Source elif kwargs['input_format'] == "zarr": sourceNode = gp.ZarrSource augmentation = kwargs['augmentation'] pipeline = ( tuple( # read batches from the HDF5 file ( sourceNode( fls[t] + "." + kwargs['input_format'], datasets={ raw: 'volumes/raw', gt_labels: 'volumes/gt_labels', gt_fgbg: 'volumes/gt_fgbg', anchor: 'volumes/gt_fgbg', }, array_specs={ raw: gp.ArraySpec(interpolatable=True), gt_labels: gp.ArraySpec(interpolatable=False), gt_fgbg: gp.ArraySpec(interpolatable=False), anchor: gp.ArraySpec(interpolatable=False) } ), gp.CsvIDPointsSource( fls[t] + ".csv", gt_points, points_spec=gp.PointsSpec(roi=gp.Roi( gp.Coordinate((0, 0, 0)), gp.Coordinate(shapes[t]))) ) ) + gp.MergeProvider() + gp.Pad(raw, None) + gp.Pad(gt_points, None) + gp.Pad(gt_labels, None) + gp.Pad(gt_fgbg, None) # chose a random location for each requested batch + gp.RandomLocation() for t in range(ln) ) + # chose a random source (i.e., sample) from the above gp.RandomProvider() + # elastically deform the batch (gp.ElasticAugment( augmentation['elastic']['control_point_spacing'], augmentation['elastic']['jitter_sigma'], [augmentation['elastic']['rotation_min']*np.pi/180.0, augmentation['elastic']['rotation_max']*np.pi/180.0], subsample=augmentation['elastic'].get('subsample', 1)) \ if augmentation.get('elastic') is not None else NoOp()) + # apply transpose and mirror augmentations gp.SimpleAugment(mirror_only=augmentation['simple'].get("mirror"), transpose_only=augmentation['simple'].get("transpose")) + # # scale and shift the intensity of the raw array gp.IntensityAugment( raw, scale_min=augmentation['intensity']['scale'][0], scale_max=augmentation['intensity']['scale'][1], shift_min=augmentation['intensity']['shift'][0], shift_max=augmentation['intensity']['shift'][1], z_section_wise=False) + # grow a boundary between labels gp.GrowBoundary( gt_labels, steps=1, only_xy=False) + # convert labels into affinities between voxels gp.AddAffinities( [[-1, 0, 0], [0, -1, 0], [0, 0, -1]], gt_labels, gt_affs) + gp.AddCPV( gt_points, gt_labels, gt_cpv) + # create a weight array that balances positive and negative samples in # the affinity array gp.BalanceLabels( gt_affs, loss_weights_affs) + gp.BalanceLabels( gt_fgbg, loss_weights_fgbg) + # pre-cache batches from the point upstream gp.PreCache( cache_size=kwargs['cache_size'], num_workers=kwargs['num_workers']) + # perform one training iteration for each passing batch (here we use # the tensor names earlier stored in train_net.config) gp.tensorflow.Train( os.path.join(kwargs['output_folder'], kwargs['name']), optimizer=net_names['optimizer'], summary=net_names['summaries'], log_dir=kwargs['output_folder'], loss=net_names['loss'], inputs={ net_names['raw']: raw, net_names['gt_affs']: gt_affs, net_names['gt_fgbg']: gt_fgbg, net_names['anchor']: anchor, net_names['gt_cpv']: gt_cpv, net_names['gt_labels']: gt_labels, net_names['loss_weights_affs']: loss_weights_affs, net_names['loss_weights_fgbg']: loss_weights_fgbg }, outputs={ net_names['pred_affs']: pred_affs, net_names['pred_fgbg']: pred_fgbg, net_names['pred_cpv']: pred_cpv, net_names['raw_cropped']: raw_cropped, }, gradients={ net_names['pred_affs']: pred_affs_gradients, net_names['pred_fgbg']: pred_fgbg_gradients, net_names['pred_cpv']: pred_cpv_gradients }, save_every=kwargs['checkpoints']) + # save the passing batch as an HDF5 file for inspection gp.Snapshot( { raw: '/volumes/raw', raw_cropped: 'volumes/raw_cropped', gt_labels: '/volumes/gt_labels', gt_affs: '/volumes/gt_affs', gt_fgbg: '/volumes/gt_fgbg', gt_cpv: '/volumes/gt_cpv', pred_affs: '/volumes/pred_affs', pred_affs_gradients: '/volumes/pred_affs_gradients', pred_fgbg: '/volumes/pred_fgbg', pred_fgbg_gradients: '/volumes/pred_fgbg_gradients', pred_cpv: '/volumes/pred_cpv', pred_cpv_gradients: '/volumes/pred_cpv_gradients' }, output_dir=os.path.join(kwargs['output_folder'], 'snapshots'), output_filename='batch_{iteration}.hdf', every=kwargs['snapshots'], additional_request=snapshot_request, compression_type='gzip') + # show a summary of time spend in each node every 10 iterations gp.PrintProfilingStats(every=kwargs['profiling']) ) ######### # TRAIN # ######### print("Starting training...") with gp.build(pipeline): print(pipeline) for i in range(trained_until, kwargs['max_iteration']): # print("request", request) start = time.time() pipeline.request_batch(request) time_of_iteration = time.time() - start logger.info("Batch: iteration=%d, time=%f", i, time_of_iteration) # exit() print("Training finished")
def validation_pipeline(config): """ Per block { Raw -> predict -> scan gt -> rasterize -> merge -> candidates -> trees } -> merge -> comatch + evaluate """ blocks = config["BLOCKS"] benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"]) sample = config["VALIDATION_SAMPLES"][0] sample_dir = Path(config["SAMPLES_PATH"]) raw_n5 = config["RAW_N5"] transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt" neuron_width = int(config["NEURON_RADIUS"]) voxel_size = gp.Coordinate(config["VOXEL_SIZE"]) micron_scale = max(voxel_size) input_shape = gp.Coordinate(config["INPUT_SHAPE"]) output_shape = gp.Coordinate(config["OUTPUT_SHAPE"]) input_size = voxel_size * input_shape output_size = voxel_size * output_shape distance_attr = config["DISTANCE_ATTR"] candidate_threshold = config["NMS_THRESHOLD"] candidate_spacing = min(config["NMS_WINDOW_SIZE"]) * micron_scale coordinate_scale = config["COORDINATE_SCALE"] * np.array( voxel_size) / micron_scale emb_model = get_emb_model(config) fg_model = get_fg_model(config) validation_pipelines = [] specs = {} for block in blocks: validation_dir = get_validation_dir(benchmark_datasets_path, block) trees = [] cube = None for gt_file in validation_dir.iterdir(): if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc": trees.append(gt_file) if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc": cube = gt_file assert cube.exists() cube_roi = get_roi_from_swc( cube, Path(transform_template.format(sample=sample)), np.array([300, 300, 1000]), ) raw = gp.ArrayKey(f"RAW_{block}") ground_truth = gp.GraphKey(f"GROUND_TRUTH_{block}") labels = gp.ArrayKey(f"LABELS_{block}") candidates = gp.ArrayKey(f"CANDIDATES_{block}") mst = gp.GraphKey(f"MST_{block}") raw_source = (gp.ZarrSource( filename=str(Path(sample_dir, sample, raw_n5).absolute()), datasets={raw: "volume-rechunked"}, array_specs={ raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size) }, ) + gp.Normalize(raw, dtype=np.float32) + mCLAHE([raw], [20, 64, 64])) emb_source, emb = add_emb_pred(config, raw_source, raw, block, emb_model) pred_source, fg = add_fg_pred(config, emb_source, raw, block, fg_model) pred_source = add_scan(pred_source, { raw: input_size, emb: output_size, fg: output_size }) swc_source = nl.gunpowder.nodes.MouselightSwcFileSource( validation_dir, [ground_truth], transform_file=transform_template.format(sample=sample), ignore_human_nodes=False, scale=voxel_size, transpose=[2, 1, 0], points_spec=[ gp.PointsSpec(roi=gp.Roi( gp.Coordinate([None, None, None]), gp.Coordinate([None, None, None]), )) ], ) additional_request = BatchRequest() input_roi = cube_roi.grow((input_size - output_size) // 2, (input_size - output_size) // 2) block_spec = specs.setdefault(block, {}) block_spec["raw"] = (raw, gp.ArraySpec(input_roi)) additional_request[raw] = gp.ArraySpec(roi=input_roi) block_spec["ground_truth"] = (ground_truth, gp.GraphSpec(cube_roi)) additional_request[ground_truth] = gp.GraphSpec(roi=cube_roi) block_spec["labels"] = (labels, gp.ArraySpec(cube_roi)) additional_request[labels] = gp.ArraySpec(roi=cube_roi) block_spec["fg_pred"] = (fg, gp.ArraySpec(cube_roi)) additional_request[fg] = gp.ArraySpec(roi=cube_roi) block_spec["emb_pred"] = (emb, gp.ArraySpec(cube_roi)) additional_request[emb] = gp.ArraySpec(roi=cube_roi) block_spec["candidates"] = (candidates, gp.ArraySpec(cube_roi)) additional_request[candidates] = gp.ArraySpec(roi=cube_roi) block_spec["mst_pred"] = (mst, gp.GraphSpec(cube_roi)) additional_request[mst] = gp.GraphSpec(roi=cube_roi) pipeline = ((swc_source, pred_source) + gp.nodes.MergeProvider() + nl.gunpowder.RasterizeSkeleton( ground_truth, labels, connected_component_labeling=True, array_spec=gp.ArraySpec( voxel_size=voxel_size, dtype=np.int64, roi=gp.Roi( gp.Coordinate([None, None, None]), gp.Coordinate([None, None, None]), ), ), ) + nl.gunpowder.GrowLabels( labels, radii=[neuron_width * micron_scale]) + Skeletonize(fg, candidates, candidate_spacing, candidate_threshold) + EMST( emb, candidates, mst, distance_attr=distance_attr, coordinate_scale=coordinate_scale, ) + gp.Snapshot( { raw: f"volumes/{raw}", ground_truth: f"points/{ground_truth}", labels: f"volumes/{labels}", fg: f"volumes/{fg}", emb: f"volumes/{emb}", candidates: f"volumes/{candidates}", mst: f"points/{mst}", }, additional_request=additional_request, output_dir="snapshots", output_filename="{id}.hdf", edge_attrs={mst: [distance_attr]}, )) validation_pipelines.append(pipeline) full_gt = gp.GraphKey("FULL_GT") full_mst = gp.GraphKey("FULL_MST") score = gp.ArrayKey("SCORE") validation_pipeline = ( tuple(pipeline for pipeline in validation_pipelines) + gp.MergeProvider() + MergeGraphs(specs, full_gt, full_mst) + Evaluate(full_gt, full_mst, score, edge_threshold_attr=distance_attr) + gp.PrintProfilingStats()) return validation_pipeline, score
def train_until(**kwargs): if tf.train.latest_checkpoint(kwargs['output_folder']): trained_until = int( tf.train.latest_checkpoint(kwargs['output_folder']).split('_')[-1]) else: trained_until = 0 if trained_until >= kwargs['max_iteration']: return anchor = gp.ArrayKey('ANCHOR') raw = gp.ArrayKey('RAW') raw_cropped = gp.ArrayKey('RAW_CROPPED') points = gp.PointsKey('POINTS') gt_cp = gp.ArrayKey('GT_CP') pred_cp = gp.ArrayKey('PRED_CP') pred_cp_gradients = gp.ArrayKey('PRED_CP_GRADIENTS') with open( os.path.join(kwargs['output_folder'], kwargs['name'] + '_config.json'), 'r') as f: net_config = json.load(f) with open( os.path.join(kwargs['output_folder'], kwargs['name'] + '_names.json'), 'r') as f: net_names = json.load(f) voxel_size = gp.Coordinate(kwargs['voxel_size']) input_shape_world = gp.Coordinate(net_config['input_shape']) * voxel_size output_shape_world = gp.Coordinate(net_config['output_shape']) * voxel_size # formulate the request for what a batch should (at least) contain request = gp.BatchRequest() request.add(raw, input_shape_world) request.add(raw_cropped, output_shape_world) request.add(gt_cp, output_shape_world) request.add(anchor, output_shape_world) # when we make a snapshot for inspection (see below), we also want to # request the predicted affinities and gradients of the loss wrt the # affinities snapshot_request = gp.BatchRequest() snapshot_request.add(raw_cropped, output_shape_world) snapshot_request.add(gt_cp, output_shape_world) snapshot_request.add(pred_cp, output_shape_world) # snapshot_request.add(pred_cp_gradients, output_shape_world) if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr": raise NotImplementedError("train node for %s not implemented yet", kwargs['input_format']) fls = [] shapes = [] mn = [] mx = [] for f in kwargs['data_files']: fls.append(os.path.splitext(f)[0]) if kwargs['input_format'] == "hdf": vol = h5py.File(f, 'r')['volumes/raw'] elif kwargs['input_format'] == "zarr": vol = zarr.open(f, 'r')['volumes/raw'] print(f, vol.shape, vol.dtype) shapes.append(vol.shape) mn.append(np.min(vol)) mx.append(np.max(vol)) if vol.dtype != np.float32: print("please convert to float32") ln = len(fls) print("first 5 files: ", fls[0:4]) if kwargs['input_format'] == "hdf": sourceNode = gp.Hdf5Source elif kwargs['input_format'] == "zarr": sourceNode = gp.ZarrSource augmentation = kwargs['augmentation'] sources = tuple( (sourceNode(fls[t] + "." + kwargs['input_format'], datasets={ raw: 'volumes/raw', anchor: 'volumes/gt_fgbg', }, array_specs={ raw: gp.ArraySpec(interpolatable=True), anchor: gp.ArraySpec(interpolatable=False) }), gp.CsvIDPointsSource(fls[t] + ".csv", points, points_spec=gp.PointsSpec( roi=gp.Roi(gp.Coordinate(( 0, 0, 0)), gp.Coordinate(shapes[t]))))) + gp.MergeProvider() # + Clip(raw, mn=mn[t], mx=mx[t]) # + NormalizeMinMax(raw, mn=mn[t], mx=mx[t]) + gp.Pad(raw, None) + gp.Pad(points, None) # chose a random location for each requested batch + gp.RandomLocation() for t in range(ln)) pipeline = ( sources + # chose a random source (i.e., sample) from the above gp.RandomProvider() + # elastically deform the batch (gp.ElasticAugment( augmentation['elastic']['control_point_spacing'], augmentation['elastic']['jitter_sigma'], [augmentation['elastic']['rotation_min']*np.pi/180.0, augmentation['elastic']['rotation_max']*np.pi/180.0], subsample=augmentation['elastic'].get('subsample', 1)) \ if augmentation.get('elastic') is not None else NoOp()) + # apply transpose and mirror augmentations gp.SimpleAugment(mirror_only=augmentation['simple'].get("mirror"), transpose_only=augmentation['simple'].get("transpose")) + # (gp.SimpleAugment( # mirror_only=augmentation['simple'].get("mirror"), # transpose_only=augmentation['simple'].get("transpose")) \ # if augmentation.get('simple') is not None and \ # augmentation.get('simple') != {} else NoOp()) + # # scale and shift the intensity of the raw array (gp.IntensityAugment( raw, scale_min=augmentation['intensity']['scale'][0], scale_max=augmentation['intensity']['scale'][1], shift_min=augmentation['intensity']['shift'][0], shift_max=augmentation['intensity']['shift'][1], z_section_wise=False) \ if augmentation.get('intensity') is not None and \ augmentation.get('intensity') != {} else NoOp()) + gp.RasterizePoints( points, gt_cp, array_spec=gp.ArraySpec(voxel_size=voxel_size), settings=gp.RasterizationSettings( radius=(2, 2, 2), mode='peak')) + # pre-cache batches from the point upstream gp.PreCache( cache_size=kwargs['cache_size'], num_workers=kwargs['num_workers']) + # perform one training iteration for each passing batch (here we use # the tensor names earlier stored in train_net.config) gp.tensorflow.Train( os.path.join(kwargs['output_folder'], kwargs['name']), optimizer=net_names['optimizer'], summary=net_names['summaries'], log_dir=kwargs['output_folder'], loss=net_names['loss'], inputs={ net_names['raw']: raw, net_names['gt_cp']: gt_cp, net_names['anchor']: anchor, }, outputs={ net_names['pred_cp']: pred_cp, net_names['raw_cropped']: raw_cropped, }, gradients={ # net_names['pred_cp']: pred_cp_gradients, }, save_every=kwargs['checkpoints']) + # save the passing batch as an HDF5 file for inspection gp.Snapshot( { raw: '/volumes/raw', raw_cropped: 'volumes/raw_cropped', gt_cp: '/volumes/gt_cp', pred_cp: '/volumes/pred_cp', # pred_cp_gradients: '/volumes/pred_cp_gradients', }, output_dir=os.path.join(kwargs['output_folder'], 'snapshots'), output_filename='batch_{iteration}.hdf', every=kwargs['snapshots'], additional_request=snapshot_request, compression_type='gzip') + # show a summary of time spend in each node every 10 iterations gp.PrintProfilingStats(every=kwargs['profiling']) ) ######### # TRAIN # ######### print("Starting training...") with gp.build(pipeline): print(pipeline) for i in range(trained_until, kwargs['max_iteration']): # print("request", request) start = time.time() pipeline.request_batch(request) time_of_iteration = time.time() - start logger.info("Batch: iteration=%d, time=%f", i, time_of_iteration) # exit() print("Training finished")