def test_merge_basics(self): voxel_size = (1, 1, 1) GraphKey("PRESYN") ArrayKey("GT_LABELS") graphsource = GraphTestSource(voxel_size) arraysource = ArrayTestSoure(voxel_size) pipeline = (graphsource, arraysource) + MergeProvider() + RandomLocation() window_request = Coordinate((50, 50, 50)) with build(pipeline): # Check basic merging. request = BatchRequest() request.add((GraphKeys.PRESYN), window_request) request.add((ArrayKeys.GT_LABELS), window_request) batch_res = pipeline.request_batch(request) self.assertTrue(ArrayKeys.GT_LABELS in batch_res.arrays) self.assertTrue(GraphKeys.PRESYN in batch_res.graphs) # Check that request of only one source also works. request = BatchRequest() request.add((GraphKeys.PRESYN), window_request) batch_res = pipeline.request_batch(request) self.assertFalse(ArrayKeys.GT_LABELS in batch_res.arrays) self.assertTrue(GraphKeys.PRESYN in batch_res.graphs) # Check that it fails, when having two sources that provide the same type. arraysource2 = ArrayTestSoure(voxel_size) pipeline_fail = (arraysource, arraysource2) + MergeProvider() + RandomLocation() with self.assertRaises(PipelineSetupError): with build(pipeline_fail): pass
def test_ensure_center_non_zero(self): path = Path(self.path_to("test_swc_source.swc")) # write test swc self._write_swc(path, self._toy_swc_points().to_nx_graph()) # read arrays swc = PointsKey("SWC") img = ArrayKey("IMG") pipeline = (SwcFileSource( path, [swc], [PointsSpec(roi=Roi((0, 0, 0), (11, 11, 11)))]) + RandomLocation(ensure_nonempty=swc, ensure_centered=True) + RasterizeSkeleton( points=swc, array=img, array_spec=ArraySpec( interpolatable=False, dtype=np.uint32, voxel_size=Coordinate((1, 1, 1)), ), )) request = BatchRequest() request.add(img, Coordinate((5, 5, 5))) request.add(swc, Coordinate((5, 5, 5))) with build(pipeline): batch = pipeline.request_batch(request) data = batch[img].data g = batch[swc] assert g.num_vertices() > 0 self.assertNotEqual(data[tuple(np.array(data.shape) // 2)], 0)
def test_placeholder(self): test_labels = ArrayKey("TEST_LABELS") test_points = GraphKey("TEST_POINTS") pipeline = ( PointTestSource3D() + RandomLocation(ensure_nonempty=test_points) + ElasticAugment([10, 10, 10], [0.1, 0.1, 0.1], [0, 2.0 * math.pi]) + Snapshot( {test_labels: "volumes/labels"}, output_dir=self.path_to(), output_filename="elastic_augment_test{id}-{iteration}.hdf", )) with build(pipeline): for i in range(2): request_size = Coordinate((40, 40, 40)) request_a = BatchRequest(random_seed=i) request_a.add(test_points, request_size) request_a.add(test_labels, request_size, placeholder=True) request_b = BatchRequest(random_seed=i) request_b.add(test_points, request_size) request_b.add(test_labels, request_size) batch_a = pipeline.request_batch(request_a) batch_b = pipeline.request_batch(request_b) points_a = batch_a[test_points].nodes points_b = batch_b[test_points].nodes for a, b in zip(points_a, points_b): assert all(np.isclose(a.location, b.location))
def test_without_placeholder(self): test_labels = ArrayKey("TEST_LABELS") test_points = GraphKey("TEST_POINTS") pipeline = ( PointTestSource3D() + RandomLocation(ensure_nonempty=test_points) + ElasticAugment([10, 10, 10], [0.1, 0.1, 0.1], [0, 2.0 * math.pi]) + Snapshot( {test_labels: "volumes/labels"}, output_dir=self.path_to(), output_filename="elastic_augment_test{id}-{iteration}.hdf", )) with build(pipeline): for i in range(2): request_size = Coordinate((40, 40, 40)) request_a = BatchRequest(random_seed=i) request_a.add(test_points, request_size) request_b = BatchRequest(random_seed=i) request_b.add(test_points, request_size) request_b.add(test_labels, request_size) # No array to provide a voxel size to ElasticAugment with pytest.raises(PipelineRequestError): pipeline.request_batch(request_a) batch_b = pipeline.request_batch(request_b) self.assertIn(test_labels, batch_b)
def test_pipeline3(self): array_key = ArrayKey("TEST_ARRAY") points_key = GraphKey("TEST_POINTS") voxel_size = Coordinate((1, 1)) spec = ArraySpec(voxel_size=voxel_size, interpolatable=True) hdf5_source = Hdf5Source(self.fake_data_file, {array_key: "testdata"}, array_specs={array_key: spec}) csv_source = CsvPointsSource( self.fake_points_file, points_key, GraphSpec(roi=Roi(shape=Coordinate((100, 100)), offset=(0, 0))), ) request = BatchRequest() shape = Coordinate((60, 60)) request.add(array_key, shape, voxel_size=Coordinate((1, 1))) request.add(points_key, shape) shift_node = ShiftAugment(prob_slip=0.2, prob_shift=0.2, sigma=5, shift_axis=0) pipeline = ((hdf5_source, csv_source) + MergeProvider() + RandomLocation(ensure_nonempty=points_key) + shift_node) with build(pipeline) as b: request = b.request_batch(request) # print(request[points_key]) target_vals = [ self.fake_data[point[0]][point[1]] for point in self.fake_points ] result_data = request[array_key].data result_points = list(request[points_key].nodes) result_vals = [ result_data[int(point.location[0])][int(point.location[1])] for point in result_points ] for result_val in result_vals: self.assertTrue( result_val in target_vals, msg= "result value {} at points {} not in target values {} at points {}" .format( result_val, list(result_points), target_vals, self.fake_points, ), )
def test_output(self): """ Fails due to probabilities being calculated in advance, rather than after creating each roi. The new approach does not account for all possible roi's containing each point, some of which may not contain its nearest neighbors. """ GraphKey('TEST_POINTS') pipeline = (ExampleSourceRandomLocation() + RandomLocation( ensure_nonempty=GraphKeys.TEST_POINTS, point_balance_radius=100)) # count the number of times we get each point histogram = {} with build(pipeline): for i in range(5000): batch = pipeline.request_batch( BatchRequest({ GraphKeys.TEST_POINTS: GraphSpec(roi=Roi((0, 0, 0), (100, 100, 100))) })) points = { node.id: node for node in batch[GraphKeys.TEST_POINTS].nodes } self.assertTrue(len(points) > 0) self.assertTrue((1 in points) != (2 in points or 3 in points), points) for node in batch[GraphKeys.TEST_POINTS].nodes: if node.id not in histogram: histogram[node.id] = 1 else: histogram[node.id] += 1 total = sum(histogram.values()) for k, v in histogram.items(): histogram[k] = float(v) / total # we should get roughly the same count for each point for i in histogram.keys(): for j in histogram.keys(): self.assertAlmostEqual(histogram[i], histogram[j], 1)
def test_req_full_roi(self): GraphKey("TEST_GRAPH") possible_roi = Roi((0, 0, 0), (1000, 1000, 1000)) pipeline = (SourceGraphLocation() + BatchTester(possible_roi, exact=False) + RandomLocation(ensure_nonempty=GraphKeys.TEST_GRAPH)) with build(pipeline): batch = pipeline.request_batch( BatchRequest({ GraphKeys.TEST_GRAPH: GraphSpec(roi=Roi((0, 0, 0), (1000, 1000, 1000))) })) assert len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1
def test_roi_one_point(self): GraphKey("TEST_GRAPH") upstream_roi = Roi((500, 500, 500), (1, 1, 1)) pipeline = (SourceGraphLocation() + BatchTester(upstream_roi, exact=True) + RandomLocation(ensure_nonempty=GraphKeys.TEST_GRAPH)) with build(pipeline): for i in range(500): batch = pipeline.request_batch( BatchRequest({ GraphKeys.TEST_GRAPH: GraphSpec(roi=Roi((0, 0, 0), (1, 1, 1))) })) assert len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1
def test_dim_size_1(self): GraphKey("TEST_GRAPH") upstream_roi = Roi((500, 401, 401), (1, 200, 200)) pipeline = (SourceGraphLocation() + BatchTester(upstream_roi, exact=False) + RandomLocation(ensure_nonempty=GraphKeys.TEST_GRAPH)) # count the number of times we get each node with build(pipeline): for i in range(500): batch = pipeline.request_batch( BatchRequest({ GraphKeys.TEST_GRAPH: GraphSpec(roi=Roi((0, 0, 0), (1, 100, 100))) })) assert len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1
def test_output(self): GraphKey("TEST_GRAPH") pipeline = TestSourceRandomLocation() + RandomLocation( ensure_nonempty=GraphKeys.TEST_GRAPH) # count the number of times we get each node histogram = {} with build(pipeline): for i in range(5000): batch = pipeline.request_batch( BatchRequest({ GraphKeys.TEST_GRAPH: GraphSpec(roi=Roi((0, 0, 0), (100, 100, 100))) })) nodes = list(batch[GraphKeys.TEST_GRAPH].nodes) node_ids = [v.id for v in nodes] self.assertTrue(len(nodes) > 0) self.assertTrue( (1 in node_ids) != (2 in node_ids or 3 in node_ids), node_ids, ) for node in batch[GraphKeys.TEST_GRAPH].nodes: if node.id not in histogram: histogram[node.id] = 1 else: histogram[node.id] += 1 total = sum(histogram.values()) for k, v in histogram.items(): histogram[k] = float(v) / total # we should get roughly the same count for each point for i in histogram.keys(): for j in histogram.keys(): self.assertAlmostEqual(histogram[i], histogram[j], 1)
def test_equal_probability(self): GraphKey('TEST_POINTS') pipeline = (ExampleSourceRandomLocation() + RandomLocation(ensure_nonempty=GraphKeys.TEST_POINTS)) # count the number of times we get each point histogram = {} with build(pipeline): for i in range(5000): batch = pipeline.request_batch( BatchRequest({ GraphKeys.TEST_POINTS: GraphSpec(roi=Roi((0, 0, 0), (10, 10, 10))) })) points = { node.id: node for node in batch[GraphKeys.TEST_POINTS].nodes } self.assertTrue(len(points) > 0) self.assertTrue((1 in points) != (2 in points or 3 in points), points) for point in batch[GraphKeys.TEST_POINTS].nodes: if point.id not in histogram: histogram[point.id] = 1 else: histogram[point.id] += 1 total = sum(histogram.values()) for k, v in histogram.items(): histogram[k] = float(v) / total # we should get roughly the same count for each point for i in histogram.keys(): for j in histogram.keys(): self.assertAlmostEqual(histogram[i], histogram[j], 1)
def train(): gunpowder.set_verbose(False) affinity_neighborhood = malis.mknhood3d() solver_parameters = gunpowder.caffe.SolverParameters() solver_parameters.train_net = 'net.prototxt' solver_parameters.base_lr = 1e-4 solver_parameters.momentum = 0.95 solver_parameters.momentum2 = 0.999 solver_parameters.delta = 1e-8 solver_parameters.weight_decay = 0.000005 solver_parameters.lr_policy = 'inv' solver_parameters.gamma = 0.0001 solver_parameters.power = 0.75 solver_parameters.snapshot = 10000 solver_parameters.snapshot_prefix = 'net' solver_parameters.type = 'Adam' solver_parameters.resume_from = None solver_parameters.train_state.add_stage('euclid') request = BatchRequest() request.add_volume_request(VolumeTypes.RAW, constants.input_shape) request.add_volume_request(VolumeTypes.GT_LABELS, constants.output_shape) request.add_volume_request(VolumeTypes.GT_MASK, constants.output_shape) request.add_volume_request(VolumeTypes.GT_AFFINITIES, constants.output_shape) request.add_volume_request(VolumeTypes.LOSS_SCALE, constants.output_shape) data_providers = list() fibsem_dir = "/groups/turaga/turagalab/data/FlyEM/fibsem_medulla_7col" for volume_name in ("tstvol-520-1-h5",): h5_filepath = "./{}.h5".format(volume_name) path_to_labels = os.path.join(fibsem_dir, volume_name, "groundtruth_seg.h5") with h5py.File(path_to_labels, "r") as f_labels: mask_shape = f_labels["main"].shape with h5py.File(h5_filepath, "w") as h5: h5['volumes/raw'] = h5py.ExternalLink(os.path.join(fibsem_dir, volume_name, "im_uint8.h5"), "main") h5['volumes/labels/neuron_ids'] = h5py.ExternalLink(path_to_labels, "main") h5.create_dataset( name="volumes/labels/mask", dtype="uint8", shape=mask_shape, fillvalue=1, ) data_providers.append( gunpowder.Hdf5Source( h5_filepath, datasets={ VolumeTypes.RAW: 'volumes/raw', VolumeTypes.GT_LABELS: 'volumes/labels/neuron_ids', VolumeTypes.GT_MASK: 'volumes/labels/mask', }, resolution=(8, 8, 8), ) ) dvid_source = DvidSource( hostname='slowpoke3', port=32788, uuid='341', raw_array_name='grayscale', gt_array_name='groundtruth', gt_mask_roi_name="seven_column_eroded7_z_lt_5024", resolution=(8, 8, 8), ) data_providers.extend([dvid_source]) data_providers = tuple( provider + RandomLocation() + Reject(min_masked=0.5) + Normalize() for provider in data_providers ) # create a batch provider by concatenation of filters batch_provider = ( data_providers + RandomProvider() + ElasticAugment([20, 20, 20], [0, 0, 0], [0, math.pi / 2.0]) + SimpleAugment(transpose_only_xy=False) + GrowBoundary(steps=2, only_xy=False) + AddGtAffinities(affinity_neighborhood) + BalanceAffinityLabels() + SplitAndRenumberSegmentationLabels() + IntensityAugment(0.9, 1.1, -0.1, 0.1, z_section_wise=False) + PreCache( request, cache_size=11, num_workers=10) + Train(solver_parameters, use_gpu=0) + Typecast(volume_dtypes={ VolumeTypes.GT_LABELS: np.dtype("uint32"), VolumeTypes.GT_MASK: np.dtype("uint8"), VolumeTypes.LOSS_SCALE: np.dtype("float32"), }, safe=True) + Snapshot(every=50, output_filename='batch_{id}.hdf') + PrintProfilingStats(every=50) ) n = 500000 print("Training for", n, "iterations") with gunpowder.build(batch_provider) as pipeline: for i in range(n): pipeline.request_batch(request) print("Finished")
def visualize_augmentations_paintera(args=None): data_providers = [] data_dir = '/groups/saalfeld/home/hanslovskyp/experiments/quasi-isotropic/data' # data_dir = os.path.expanduser('~/Dropbox/cremi-upsampled/') file_pattern = 'sample_A_padded_20160501-2-additional-sections-fixed-offset.h5' file_pattern = 'sample_B_padded_20160501-2-additional-sections-fixed-offset.h5' file_pattern = 'sample_C_padded_20160501-2-additional-sections-fixed-offset.h5' defect_dir = '/groups/saalfeld/home/hanslovskyp/experiments/quasi-isotropic/data/defects' artifact_source = ( Hdf5Source( os.path.join(defect_dir, 'sample_ABC_padded_20160501.defects.hdf'), datasets={ RAW: 'defect_sections/raw', ALPHA_MASK: 'defect_sections/mask', }, array_specs={ RAW: ArraySpec(voxel_size=tuple(d * 9 for d in (40, 4, 4))), ALPHA_MASK: ArraySpec(voxel_size=tuple(d * 9 for d in (40, 4, 4))), }) + RandomLocation(min_masked=0.05, mask=ALPHA_MASK) + Normalize(RAW) + IntensityAugment(RAW, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) + ElasticAugment(voxel_size=(360, 36, 36), control_point_spacing=(4, 40, 40), control_point_displacement_sigma=(0, 2 * 36, 2 * 36), rotation_interval=(0, np.pi / 2.0), subsample=8) + SimpleAugment(transpose_only=[1, 2])) for data in glob.glob(os.path.join(data_dir, file_pattern)): h5_source = Hdf5Source(data, datasets={ RAW: 'volumes/raw', GT_LABELS: 'volumes/labels/neuron_ids-downsampled', }) data_providers.append(h5_source) input_resolution = (360.0, 36.0, 36.0) output_resolution = Coordinate((120.0, 108.0, 108.0)) offset = (13640, 10932, 10932) output_shape = Coordinate((60.0, 100.0, 100.0)) * output_resolution output_offset = (13320 + 3600, 32796 + 36 + 10800, 32796 + 36 + 10800) overhang = Coordinate((360.0, 108.0, 108.0)) * 16 input_shape = output_shape + overhang * 2 input_offset = Coordinate(output_offset) - overhang output_roi = Roi(offset=output_offset, shape=output_shape) input_roi = Roi(offset=input_offset, shape=input_shape) augmentations = ( Snapshot(dataset_names={ GT_LABELS: 'volumes/gt', RAW: 'volumes/raw' }, output_dir='.', output_filename='snapshot-before.h5', attributes_callback=Snapshot.default_attributes_callback()), ElasticAugment(voxel_size=(360.0, 36.0, 36.0), control_point_spacing=(4, 40, 40), control_point_displacement_sigma=(0, 5 * 2 * 36, 5 * 2 * 36), rotation_interval=(0 * np.pi / 8, 0 * 2 * np.pi), subsample=8, augmentation_probability=1.0, seed=None), Misalign(z_resolution=360, prob_slip=0.2, prob_shift=0.5, max_misalign=(3600, 0), seed=100, ignore_keys_for_slip=(GT_LABELS, )), Snapshot(dataset_names={ GT_LABELS: 'volumes/gt', RAW: 'volumes/raw' }, output_dir='.', output_filename='snapshot-after.h5', attributes_callback=Snapshot.default_attributes_callback())) keys = (RAW, GT_LABELS)[:] batch, snapshot = run_augmentations( data_providers=data_providers, roi=lambda key: output_roi.copy() if key == GT_LABELS else input_roi.copy(), augmentations=augmentations, keys=keys, voxel_size=lambda key: input_resolution if key == RAW else (output_resolution if key == GT_LABELS else None)) args = get_parser().parse_args() if args is None else args jnius_config.add_options('-Xmx{}'.format(args.max_heap_size)) import payntera.jfx from jnius import autoclass payntera.jfx.init_platform() PainteraBaseView = autoclass( 'org.janelia.saalfeldlab.paintera.PainteraBaseView') viewer = PainteraBaseView.defaultView() pbv = viewer.baseView scene, stage = payntera.jfx.start_stage(viewer.paneWithStatus.getPane()) screen_scale_setter = lambda: pbv.orthogonalViews().setScreenScales( [0.3, 0.1, 0.03]) # payntera.jfx.invoke_on_jfx_application_thread(screen_scale_setter) snapshot_states = add_to_viewer( snapshot, keys=keys, name=lambda key: '%s-snapshot' % key.identifier) states = add_to_viewer(batch, keys=keys) viewer.keyTracker.installInto(scene) scene.addEventFilter( autoclass('javafx.scene.input.MouseEvent').ANY, viewer.mouseTracker) while stage.isShowing(): time.sleep(0.1)
def train_until( data_providers, affinity_neighborhood, meta_graph_filename, stop, input_shape, output_shape, loss, optimizer, tensor_affinities, tensor_affinities_mask, tensor_glia, tensor_glia_mask, summary, save_checkpoint_every, pre_cache_size, pre_cache_num_workers, snapshot_every, balance_labels, renumber_connected_components, network_inputs, ignore_labels_for_slip, grow_boundaries, mask_out_labels, snapshot_dir): ignore_keys_for_slip = (LABELS_KEY, GT_MASK_KEY, GT_GLIA_KEY, GLIA_MASK_KEY, UNLABELED_KEY) if ignore_labels_for_slip else () defect_dir = '/groups/saalfeld/home/hanslovskyp/experiments/quasi-isotropic/data/defects' if tf.train.latest_checkpoint('.'): trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1]) print('Resuming training from', trained_until) else: trained_until = 0 print('Starting fresh training') input_voxel_size = Coordinate((120, 12, 12)) * 3 output_voxel_size = Coordinate((40, 36, 36)) * 3 input_size = Coordinate(input_shape) * input_voxel_size output_size = Coordinate(output_shape) * output_voxel_size num_affinities = sum(len(nh) for nh in affinity_neighborhood) gt_affinities_size = Coordinate((num_affinities,) + tuple(s for s in output_size)) print("gt affinities size", gt_affinities_size) # TODO why is GT_AFFINITIES three-dimensional? compare to # TODO https://github.com/funkey/gunpowder/blob/master/examples/cremi/train.py#L35 # TODO Use glia scale somehow, probably not possible with tensorflow 1.3 because it does not know uint64... # specifiy which Arrays should be requested for each batch request = BatchRequest() request.add(RAW_KEY, input_size, voxel_size=input_voxel_size) request.add(LABELS_KEY, output_size, voxel_size=output_voxel_size) request.add(GT_AFFINITIES_KEY, output_size, voxel_size=output_voxel_size) request.add(AFFINITIES_MASK_KEY, output_size, voxel_size=output_voxel_size) request.add(GT_MASK_KEY, output_size, voxel_size=output_voxel_size) request.add(GLIA_MASK_KEY, output_size, voxel_size=output_voxel_size) request.add(GLIA_KEY, output_size, voxel_size=output_voxel_size) request.add(GT_GLIA_KEY, output_size, voxel_size=output_voxel_size) request.add(UNLABELED_KEY, output_size, voxel_size=output_voxel_size) if balance_labels: request.add(AFFINITIES_SCALE_KEY, output_size, voxel_size=output_voxel_size) # always balance glia labels! request.add(GLIA_SCALE_KEY, output_size, voxel_size=output_voxel_size) network_inputs[tensor_affinities_mask] = AFFINITIES_SCALE_KEY if balance_labels else AFFINITIES_MASK_KEY network_inputs[tensor_glia_mask] = GLIA_SCALE_KEY#GLIA_SCALE_KEY if balance_labels else GLIA_MASK_KEY # create a tuple of data sources, one for each HDF file data_sources = tuple( provider + Normalize(RAW_KEY) + # ensures RAW is in float in [0, 1] # zero-pad provided RAW and GT_MASK to be able to draw batches close to # the boundary of the available data # size more or less irrelevant as followed by Reject Node Pad(RAW_KEY, None) + Pad(GT_MASK_KEY, None) + Pad(GLIA_MASK_KEY, None) + Pad(LABELS_KEY, size=NETWORK_OUTPUT_SHAPE / 2, value=np.uint64(-3)) + Pad(GT_GLIA_KEY, size=NETWORK_OUTPUT_SHAPE / 2) + # Pad(LABELS_KEY, None) + # Pad(GT_GLIA_KEY, None) + RandomLocation() + # chose a random location inside the provided arrays Reject(mask=GT_MASK_KEY, min_masked=0.5) + Reject(mask=GLIA_MASK_KEY, min_masked=0.5) + MapNumpyArray(lambda array: np.require(array, dtype=np.int64), GT_GLIA_KEY) # this is necessary because gunpowder 1.3 only understands int64, not uint64 for provider in data_providers) # TODO figure out what this is for snapshot_request = BatchRequest({ LOSS_GRADIENT_KEY : request[GT_AFFINITIES_KEY], AFFINITIES_KEY : request[GT_AFFINITIES_KEY], }) # no need to do anything here. random sections will be replaced with sections from this source (only raw) artifact_source = ( Hdf5Source( os.path.join(defect_dir, 'sample_ABC_padded_20160501.defects.hdf'), datasets={ RAW_KEY : 'defect_sections/raw', DEFECT_MASK_KEY : 'defect_sections/mask', }, array_specs={ RAW_KEY : ArraySpec(voxel_size=input_voxel_size), DEFECT_MASK_KEY : ArraySpec(voxel_size=input_voxel_size), } ) + RandomLocation(min_masked=0.05, mask=DEFECT_MASK_KEY) + Normalize(RAW_KEY) + IntensityAugment(RAW_KEY, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) + ElasticAugment( voxel_size=(360, 36, 36), control_point_spacing=(4, 40, 40), control_point_displacement_sigma=(0, 2 * 36, 2 * 36), rotation_interval=(0, math.pi / 2.0), subsample=8 ) + SimpleAugment(transpose_only=[1,2]) ) train_pipeline = data_sources train_pipeline += RandomProvider() train_pipeline += ElasticAugment( voxel_size=(360, 36, 36), control_point_spacing=(4, 40, 40), control_point_displacement_sigma=(0, 2 * 36, 2 * 36), rotation_interval=(0, math.pi / 2.0), augmentation_probability=0.5, subsample=8 ) # train_pipeline += Log.log_numpy_array_stats_after_process(GT_MASK_KEY, 'min', 'max', 'dtype', logging_prefix='%s: before misalign: ' % GT_MASK_KEY) train_pipeline += Misalign(z_resolution=360, prob_slip=0.05, prob_shift=0.05, max_misalign=(360,) * 2, ignore_keys_for_slip=ignore_keys_for_slip) # train_pipeline += Log.log_numpy_array_stats_after_process(GT_MASK_KEY, 'min', 'max', 'dtype', logging_prefix='%s: after misalign: ' % GT_MASK_KEY) train_pipeline += SimpleAugment(transpose_only=[1,2]) train_pipeline += IntensityAugment(RAW_KEY, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) train_pipeline += DefectAugment(RAW_KEY, prob_missing=0.03, prob_low_contrast=0.01, prob_artifact=0.03, artifact_source=artifact_source, artifacts=RAW_KEY, artifacts_mask=DEFECT_MASK_KEY, contrast_scale=0.5) train_pipeline += IntensityScaleShift(RAW_KEY, 2, -1) train_pipeline += ZeroOutConstSections(RAW_KEY) if grow_boundaries > 0: train_pipeline += GrowBoundary(LABELS_KEY, GT_MASK_KEY, steps=grow_boundaries, only_xy=True) _logger.info("Renumbering connected components? %s", renumber_connected_components) if renumber_connected_components: train_pipeline += RenumberConnectedComponents(labels=LABELS_KEY) train_pipeline += NewKeyFromNumpyArray(lambda array: 1 - array, GT_GLIA_KEY, UNLABELED_KEY) if len(mask_out_labels) > 0: train_pipeline += MaskOutLabels(label_key=LABELS_KEY, mask_key=GT_MASK_KEY, ids_to_be_masked=mask_out_labels) # labels_mask: anything that connects into labels_mask will be zeroed out # unlabelled: anyhing that points into unlabeled will have zero affinity; # affinities within unlabelled will be masked out train_pipeline += AddAffinities( affinity_neighborhood=affinity_neighborhood, labels=LABELS_KEY, labels_mask=GT_MASK_KEY, affinities=GT_AFFINITIES_KEY, affinities_mask=AFFINITIES_MASK_KEY, unlabelled=UNLABELED_KEY ) snapshot_datasets = { RAW_KEY: 'volumes/raw', LABELS_KEY: 'volumes/labels/neuron_ids', GT_AFFINITIES_KEY: 'volumes/affinities/gt', GT_GLIA_KEY: 'volumes/labels/glia_gt', UNLABELED_KEY: 'volumes/labels/unlabeled', AFFINITIES_KEY: 'volumes/affinities/prediction', LOSS_GRADIENT_KEY: 'volumes/loss_gradient', AFFINITIES_MASK_KEY: 'masks/affinities', GLIA_KEY: 'volumes/labels/glia_pred', GT_MASK_KEY: 'masks/gt', GLIA_MASK_KEY: 'masks/glia'} if balance_labels: train_pipeline += BalanceLabels(labels=GT_AFFINITIES_KEY, scales=AFFINITIES_SCALE_KEY, mask=AFFINITIES_MASK_KEY) snapshot_datasets[AFFINITIES_SCALE_KEY] = 'masks/affinity-scale' train_pipeline += BalanceLabels(labels=GT_GLIA_KEY, scales=GLIA_SCALE_KEY, mask=GLIA_MASK_KEY) snapshot_datasets[GLIA_SCALE_KEY] = 'masks/glia-scale' if (pre_cache_size > 0 and pre_cache_num_workers > 0): train_pipeline += PreCache(cache_size=pre_cache_size, num_workers=pre_cache_num_workers) train_pipeline += Train( summary=summary, graph=meta_graph_filename, save_every=save_checkpoint_every, optimizer=optimizer, loss=loss, inputs=network_inputs, log_dir='log', outputs={tensor_affinities: AFFINITIES_KEY, tensor_glia: GLIA_KEY}, gradients={tensor_affinities: LOSS_GRADIENT_KEY}, array_specs={ AFFINITIES_KEY : ArraySpec(voxel_size=output_voxel_size), LOSS_GRADIENT_KEY : ArraySpec(voxel_size=output_voxel_size), AFFINITIES_MASK_KEY : ArraySpec(voxel_size=output_voxel_size), GT_MASK_KEY : ArraySpec(voxel_size=output_voxel_size), AFFINITIES_SCALE_KEY : ArraySpec(voxel_size=output_voxel_size), GLIA_MASK_KEY : ArraySpec(voxel_size=output_voxel_size), GLIA_SCALE_KEY : ArraySpec(voxel_size=output_voxel_size), GLIA_KEY : ArraySpec(voxel_size=output_voxel_size) } ) train_pipeline += Snapshot( snapshot_datasets, every=snapshot_every, output_filename='batch_{iteration}.hdf', output_dir=snapshot_dir, additional_request=snapshot_request, attributes_callback=Snapshot.default_attributes_callback()) train_pipeline += PrintProfilingStats(every=50) print("Starting training...") with build(train_pipeline) as b: for i in range(trained_until, stop): b.request_batch(request) print("Training finished")
def test_get_neuron_pair(self): path = Path(self.path_to("test_swc_source.swc")) # write test swc self._write_swc(path, self._toy_swc_points().to_nx_graph()) # read arrays swc_source = PointsKey("SWC_SOURCE") ensure_nonempty = PointsKey("ENSURE_NONEMPTY") labels_source = ArrayKey("LABELS_SOURCE") img_source = ArrayKey("IMG_SOURCE") img_swc = PointsKey("IMG_SWC") label_swc = PointsKey("LABEL_SWC") imgs = ArrayKey("IMGS") labels = ArrayKey("LABELS") points_a = PointsKey("SKELETON_A") points_b = PointsKey("SKELETON_B") img_a = ArrayKey("VOLUME_A") img_b = ArrayKey("VOLUME_B") labels_a = ArrayKey("LABELS_A") labels_b = ArrayKey("LABELS_B") data_shape = 5 output_shape = Coordinate((data_shape, data_shape, data_shape)) # Get points from test swc swc_file_source = SwcFileSource( path, [swc_source, ensure_nonempty], [ PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31))), PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31))), ], ) # Create an artificial image source by rasterizing the points image_source = (SwcFileSource( path, [img_swc], [PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))]) + RasterizeSkeleton( points=img_swc, array=img_source, array_spec=ArraySpec( interpolatable=True, dtype=np.uint32, voxel_size=Coordinate((1, 1, 1)), ), ) + BinarizeLabels(labels=img_source, labels_binary=imgs) + GrowLabels(array=imgs, radius=0)) # Create an artificial label source by rasterizing the points label_source = (SwcFileSource(path, [label_swc], [ PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31))) ]) + RasterizeSkeleton( points=label_swc, array=labels_source, array_spec=ArraySpec( interpolatable=True, dtype=np.uint32, voxel_size=Coordinate((1, 1, 1)), ), ) + BinarizeLabels(labels=labels_source, labels_binary=labels) + GrowLabels(array=labels, radius=1)) skeleton = tuple() skeleton += ((swc_file_source, image_source, label_source) + MergeProvider() + RandomLocation(ensure_nonempty=ensure_nonempty, ensure_centered=True)) pipeline = skeleton + GetNeuronPair( point_source=swc_source, nonempty_placeholder=ensure_nonempty, array_source=imgs, label_source=labels, points=(points_a, points_b), arrays=(img_a, img_b), labels=(labels_a, labels_b), seperate_by=(1, 3), shift_attempts=100, request_attempts=10, output_shape=output_shape, ) request = BatchRequest() request.add(points_a, output_shape) request.add(points_b, output_shape) request.add(img_a, output_shape) request.add(img_b, output_shape) request.add(labels_a, output_shape) request.add(labels_b, output_shape) with build(pipeline): for i in range(10): batch = pipeline.request_batch(request) assert all([ x in batch for x in [points_a, points_b, img_a, img_b, labels_a, labels_b] ]) min_dist = 5 for a, b in itertools.product( batch[points_a].nodes, batch[points_b].nodes, ): min_dist = min( min_dist, np.linalg.norm(a.location - b.location), ) self.assertLessEqual(min_dist, 3) self.assertGreaterEqual(min_dist, 1)
file_pattern = 'sample_B_padded_20160501-2-additional-sections-fixed-offset.h5' file_pattern = 'sample_C_padded_20160501-2-additional-sections-fixed-offset.h5' defect_dir = '/groups/saalfeld/home/hanslovskyp/experiments/quasi-isotropic/data/defects' artifact_source = ( Hdf5Source( os.path.join(defect_dir, 'sample_ABC_padded_20160501.defects.hdf'), datasets={ RAW: 'defect_sections/raw', ALPHA_MASK: 'defect_sections/mask', }, array_specs={ RAW: ArraySpec(voxel_size=tuple(d * 9 for d in (40, 4, 4))), ALPHA_MASK: ArraySpec(voxel_size=tuple(d * 9 for d in (40, 4, 4))), }) + RandomLocation(min_masked=0.05, mask=ALPHA_MASK) + Normalize(RAW) + IntensityAugment(RAW, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) + ElasticAugment(voxel_size=(360, 36, 36), control_point_spacing=(4, 40, 40), control_point_displacement_sigma=(0, 2 * 36, 2 * 36), rotation_interval=(0, np.pi / 2.0), subsample=8) + SimpleAugment(transpose_only=[1, 2])) for data in glob.glob(os.path.join(data_dir, file_pattern)): h5_source = Hdf5Source(data, datasets={ RAW: 'volumes/raw', GT_LABELS: 'volumes/labels/neuron_ids-downsampled', })
def train_until( data_providers, affinity_neighborhood, meta_graph_filename, stop, input_shape, output_shape, loss, optimizer, tensor_affinities, tensor_affinities_nn, tensor_affinities_mask, summary, save_checkpoint_every, pre_cache_size, pre_cache_num_workers, snapshot_every, balance_labels, renumber_connected_components, network_inputs, ignore_labels_for_slip, grow_boundaries): ignore_keys_for_slip = (GT_LABELS_KEY, GT_MASK_KEY) if ignore_labels_for_slip else () defect_dir = '/groups/saalfeld/home/hanslovskyp/experiments/quasi-isotropic/data/defects' if tf.train.latest_checkpoint('.'): trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1]) print('Resuming training from', trained_until) else: trained_until = 0 print('Starting fresh training') input_voxel_size = Coordinate((120, 12, 12)) * 3 output_voxel_size = Coordinate((40, 36, 36)) * 3 input_size = Coordinate(input_shape) * input_voxel_size output_size = Coordinate(output_shape) * output_voxel_size output_size_nn = Coordinate(s - 2 for s in output_shape) * output_voxel_size num_affinities = sum(len(nh) for nh in affinity_neighborhood) gt_affinities_size = Coordinate((num_affinities,) + tuple(s for s in output_size)) print("gt affinities size", gt_affinities_size) # TODO why is GT_AFFINITIES three-dimensional? compare to # TODO https://github.com/funkey/gunpowder/blob/master/examples/cremi/train.py#L35 # specifiy which Arrays should be requested for each batch request = BatchRequest() request.add(RAW_KEY, input_size, voxel_size=input_voxel_size) request.add(GT_LABELS_KEY, output_size, voxel_size=output_voxel_size) request.add(GT_AFFINITIES_KEY, output_size, voxel_size=output_voxel_size) request.add(AFFINITIES_MASK_KEY, output_size, voxel_size=output_voxel_size) request.add(GT_MASK_KEY, output_size, voxel_size=output_voxel_size) request.add(AFFINITIES_NN_KEY, output_size_nn, voxel_size=output_voxel_size) if balance_labels: request.add(AFFINITIES_SCALE_KEY, output_size, voxel_size=output_voxel_size) network_inputs[tensor_affinities_mask] = AFFINITIES_SCALE_KEY if balance_labels else AFFINITIES_MASK_KEY # create a tuple of data sources, one for each HDF file data_sources = tuple( provider + Normalize(RAW_KEY) + # ensures RAW is in float in [0, 1] # zero-pad provided RAW and GT_MASK to be able to draw batches close to # the boundary of the available data # size more or less irrelevant as followed by Reject Node Pad(RAW_KEY, None) + Pad(GT_MASK_KEY, None) + RandomLocation() + # chose a random location inside the provided arrays Reject(GT_MASK_KEY) + # reject batches wich do contain less than 50% labelled data Reject(GT_LABELS_KEY, min_masked=0.0, reject_probability=0.95) for provider in data_providers) # TODO figure out what this is for snapshot_request = BatchRequest({ LOSS_GRADIENT_KEY : request[GT_AFFINITIES_KEY], AFFINITIES_KEY : request[GT_AFFINITIES_KEY], AFFINITIES_NN_KEY : request[AFFINITIES_NN_KEY] }) # no need to do anything here. random sections will be replaced with sections from this source (only raw) artifact_source = ( Hdf5Source( os.path.join(defect_dir, 'sample_ABC_padded_20160501.defects.hdf'), datasets={ RAW_KEY : 'defect_sections/raw', ALPHA_MASK_KEY : 'defect_sections/mask', }, array_specs={ RAW_KEY : ArraySpec(voxel_size=input_voxel_size), ALPHA_MASK_KEY : ArraySpec(voxel_size=input_voxel_size), } ) + RandomLocation(min_masked=0.05, mask=ALPHA_MASK_KEY) + Normalize(RAW_KEY) + IntensityAugment(RAW_KEY, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) + ElasticAugment( voxel_size=(360, 36, 36), control_point_spacing=(4, 40, 40), control_point_displacement_sigma=(0, 2 * 36, 2 * 36), rotation_interval=(0, math.pi / 2.0), subsample=8 ) + SimpleAugment(transpose_only=[1,2]) ) train_pipeline = data_sources train_pipeline += RandomProvider() train_pipeline += ElasticAugment( voxel_size=(360, 36, 36), control_point_spacing=(4, 40, 40), control_point_displacement_sigma=(0, 2 * 36, 2 * 36), rotation_interval=(0, math.pi / 2.0), augmentation_probability=0.5, subsample=8 ) train_pipeline += Misalign(z_resolution=360, prob_slip=0.05, prob_shift=0.05, max_misalign=(360,) * 2, ignore_keys_for_slip=ignore_keys_for_slip) train_pipeline += SimpleAugment(transpose_only=[1,2]) train_pipeline += IntensityAugment(RAW_KEY, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) train_pipeline += DefectAugment(RAW_KEY, prob_missing=0.03, prob_low_contrast=0.01, prob_artifact=0.03, artifact_source=artifact_source, artifacts=RAW_KEY, artifacts_mask=ALPHA_MASK_KEY, contrast_scale=0.5) train_pipeline += IntensityScaleShift(RAW_KEY, 2, -1) train_pipeline += ZeroOutConstSections(RAW_KEY) if grow_boundaries > 0: train_pipeline += GrowBoundary(GT_LABELS_KEY, GT_MASK_KEY, steps=grow_boundaries, only_xy=True) if renumber_connected_components: train_pipeline += RenumberConnectedComponents(labels=GT_LABELS_KEY) train_pipeline += AddAffinities( affinity_neighborhood=affinity_neighborhood, labels=GT_LABELS_KEY, labels_mask=GT_MASK_KEY, affinities=GT_AFFINITIES_KEY, affinities_mask=AFFINITIES_MASK_KEY ) if balance_labels: train_pipeline += BalanceLabels(labels=GT_AFFINITIES_KEY, scales=AFFINITIES_SCALE_KEY, mask=AFFINITIES_MASK_KEY) train_pipeline += PreCache(cache_size=pre_cache_size, num_workers=pre_cache_num_workers) train_pipeline += Train( summary=summary, graph=meta_graph_filename, save_every=save_checkpoint_every, optimizer=optimizer, loss=loss, inputs=network_inputs, log_dir='log', outputs={tensor_affinities: AFFINITIES_KEY, tensor_affinities_nn: AFFINITIES_NN_KEY}, gradients={tensor_affinities: LOSS_GRADIENT_KEY}, array_specs={ AFFINITIES_KEY : ArraySpec(voxel_size=output_voxel_size), LOSS_GRADIENT_KEY : ArraySpec(voxel_size=output_voxel_size), AFFINITIES_MASK_KEY : ArraySpec(voxel_size=output_voxel_size), GT_MASK_KEY : ArraySpec(voxel_size=output_voxel_size), AFFINITIES_SCALE_KEY : ArraySpec(voxel_size=output_voxel_size), AFFINITIES_NN_KEY : ArraySpec(voxel_size=output_voxel_size) } ) train_pipeline += Snapshot( dataset_names={ RAW_KEY : 'volumes/raw', GT_LABELS_KEY : 'volumes/labels/neuron_ids', GT_AFFINITIES_KEY : 'volumes/affinities/gt', AFFINITIES_KEY : 'volumes/affinities/prediction', LOSS_GRADIENT_KEY : 'volumes/loss_gradient', AFFINITIES_MASK_KEY : 'masks/affinities', AFFINITIES_NN_KEY : 'volumes/affinities/prediction-nn' }, every=snapshot_every, output_filename='batch_{iteration}.hdf', output_dir='snapshots/', additional_request=snapshot_request, attributes_callback=Snapshot.default_attributes_callback()) train_pipeline += PrintProfilingStats(every=50) print("Starting training...") with build(train_pipeline) as b: for i in range(trained_until, stop): b.request_batch(request) print("Training finished")
def test_recenter(): path = Path(self.path_to("test_swc_source.swc")) # write test swc self._write_swc(path, self._toy_swc_points()) # read arrays swc_source = PointsKey("SWC_SOURCE") labels_source = ArrayKey("LABELS_SOURCE") img_source = ArrayKey("IMG_SOURCE") img_swc = PointsKey("IMG_SWC") label_swc = PointsKey("LABEL_SWC") imgs = ArrayKey("IMGS") labels = ArrayKey("LABELS") points_a = PointsKey("SKELETON_A") points_b = PointsKey("SKELETON_B") img_a = ArrayKey("VOLUME_A") img_b = ArrayKey("VOLUME_B") labels_a = ArrayKey("LABELS_A") labels_b = ArrayKey("LABELS_B") # Get points from test swc swc_file_source = SwcFileSource( path, [swc_source], [PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))] ) # Create an artificial image source by rasterizing the points image_source = ( SwcFileSource( path, [img_swc], [PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))] ) + RasterizeSkeleton( points=img_swc, array=img_source, array_spec=ArraySpec( interpolatable=True, dtype=np.uint32, voxel_size=Coordinate((1, 1, 1)) ), ) + BinarizeLabels(labels=img_source, labels_binary=imgs) + GrowLabels(array=imgs, radius=0) ) # Create an artificial label source by rasterizing the points label_source = ( SwcFileSource( path, [label_swc], [PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))] ) + RasterizeSkeleton( points=label_swc, array=labels_source, array_spec=ArraySpec( interpolatable=True, dtype=np.uint32, voxel_size=Coordinate((1, 1, 1)) ), ) + BinarizeLabels(labels=labels_source, labels_binary=labels) + GrowLabels(array=labels, radius=1) ) skeleton = tuple() skeleton += ( (swc_file_source, image_source, label_source) + MergeProvider() + RandomLocation(ensure_nonempty=swc_source, ensure_centered=True) ) pipeline = ( skeleton + GetNeuronPair( point_source=swc_source, array_source=imgs, label_source=labels, points=(points_a, points_b), arrays=(img_a, img_b), labels=(labels_a, labels_b), seperate_by=4, shift_attempts=100, ) + Recenter(points_a, img_a, max_offset=4) + Recenter(points_b, img_b, max_offset=4) ) request = BatchRequest() data_shape = 9 request.add(points_a, Coordinate((data_shape, data_shape, data_shape))) request.add(points_b, Coordinate((data_shape, data_shape, data_shape))) request.add(img_a, Coordinate((data_shape, data_shape, data_shape))) request.add(img_b, Coordinate((data_shape, data_shape, data_shape))) request.add(labels_a, Coordinate((data_shape, data_shape, data_shape))) request.add(labels_b, Coordinate((data_shape, data_shape, data_shape))) with build(pipeline): batch = pipeline.request_batch(request) data_a = batch[img_a].data assert data_a[tuple(np.array(data_a.shape) // 2)] == 1 data_a = np.pad(data_a, (1,), "constant", constant_values=(0,)) data_b = batch[img_b].data assert data_b[tuple(np.array(data_b.shape) // 2)] == 1 data_b = np.pad(data_b, (1,), "constant", constant_values=(0,)) data_c = data_a + data_b data = np.array((data_a, data_b, data_c)) for _, point in batch[points_a].data.items(): assert ( data[(0,) + tuple(int(x) + 1 for x in point.location)] == 1 ), "data at {} is not 1, its {}".format( point.location, data[(0,) + tuple(int(x) for x in point.location)] ) for _, point in batch[points_b].data.items(): assert ( data[(1,) + tuple(int(x) + 1 for x in point.location)] == 1 ), "data at {} is not 1".format(point.location)
def test_ensure_centered(self): """ Expected failure due to emergent behavior of two desired rules: 1) Points on the upper bound of Roi are not considered contained 2) When considering a point as a center of a random location, scale by the number of points within some delta distance if two points are equally likely to be chosen, and centering a roi on either of them means the other is on the bounding box of the roi, then it can be the case that if the roi is centered one of them, the roi contains only that one, but if the roi is centered on the second, then both are considered contained, breaking the equal likelihood of picking each point. """ GraphKey("TEST_POINTS") pipeline = ExampleSourceRandomLocation() + RandomLocation( ensure_nonempty=GraphKeys.TEST_POINTS, ensure_centered=True) # count the number of times we get each point histogram = {} with build(pipeline): for i in range(5000): batch = pipeline.request_batch( BatchRequest({ GraphKeys.TEST_POINTS: GraphSpec(roi=Roi((0, 0, 0), (100, 100, 100))) })) points = batch[GraphKeys.TEST_POINTS].data roi = batch[GraphKeys.TEST_POINTS].spec.roi locations = tuple( [Coordinate(point.location) for point in points.values()]) self.assertTrue( Coordinate([50, 50, 50]) in locations, f"locations: {tuple([point.location for point in points.values()])}" ) self.assertTrue(len(points) > 0) self.assertTrue((1 in points) != (2 in points or 3 in points), points) for point_id in batch[GraphKeys.TEST_POINTS].data.keys(): if point_id not in histogram: histogram[point_id] = 1 else: histogram[node.id] += 1 total = sum(histogram.values()) for k, v in histogram.items(): histogram[k] = float(v) / total # we should get roughly the same count for each point for i in histogram.keys(): for j in histogram.keys(): self.assertAlmostEqual(histogram[i], histogram[j], 1, histogram)