def get_snapshot_source(setup_config: Dict[str, Any], source_samples: List[str]): snapshot = setup_config.get("SNAPSHOT_SOURCE", "snapshots/snapshot_1.hdf") # Data Properties voxel_size = gp.Coordinate(setup_config["VOXEL_SIZE"]) # New array keys # Note: These are intended to be requested with size input_size raw = ArrayKey("RAW") consensus = gp.PointsKey("CONSENSUS") skeletonization = gp.PointsKey("SKELETONIZATION") matched = gp.PointsKey("MATCHED") nonempty_placeholder = gp.PointsKey("NONEMPTY") labels = ArrayKey("LABELS") data_sources = SnapshotSource( snapshot=snapshot, outputs={ "volumes/raw": raw, "points/consensus": consensus, "points/skeletonization": skeletonization, "points/matched": matched, "points/matched": nonempty_placeholder, "points/labels": labels, }, voxel_size=voxel_size, ) return ( data_sources, raw, labels, consensus, nonempty_placeholder, skeletonization, matched, )
def test_csv_header(self): points = gp.PointsKey("POINTS") tswh = TracksSource(TEST_FILE_WITH_HEADER, points) request = gp.BatchRequest() request.add(points, gp.Coordinate((5, 5, 5, 5))) tswh.setup() b = tswh.provide(request) points = b[points].data self.assertListEqual([0.0, 0.0, 0.0, 0.0], list(points[1].location)) self.assertListEqual([1.0, 0.0, 0.0, 0.0], list(points[2].location)) self.assertListEqual([1.0, 1.0, 2.0, 3.0], list(points[3].location)) self.assertListEqual([2.0, 2.0, 2.0, 2.0], list(points[4].location))
def test_delete_points_in_context(self): points = gp.PointsKey("POINTS") pv_array = gp.ArrayKey("PARENT_VECTORS") mask = gp.ArrayKey("MASK") radius = [0.1, 0.1, 0.1, 0.1] ts = TracksSource(TEST_FILE, points) apv = AddParentVectors(points, pv_array, mask, radius) request = gp.BatchRequest() request.add(points, gp.Coordinate((1, 4, 4, 4))) request.add(pv_array, gp.Coordinate((1, 4, 4, 4))) request.add(mask, gp.Coordinate((1, 4, 4, 4))) pipeline = (ts + gp.Pad(points, None) + apv) with gp.build(pipeline): pipeline.request_batch(request)
def test_pipeline3(self): array_key = gp.ArrayKey("TEST_ARRAY") points_key = gp.PointsKey("TEST_POINTS") voxel_size = gp.Coordinate((1, 1)) spec = gp.ArraySpec(voxel_size=voxel_size, interpolatable=True) hdf5_source = gp.Hdf5Source(self.fake_data_file, {array_key: 'testdata'}, array_specs={array_key: spec}) csv_source = gp.CsvPointsSource( self.fake_points_file, points_key, gp.PointsSpec( roi=gp.Roi(shape=gp.Coordinate((100, 100)), offset=(0, 0)))) request = gp.BatchRequest() shape = gp.Coordinate((60, 60)) request.add(array_key, shape, voxel_size=gp.Coordinate((1, 1))) request.add(points_key, shape) shift_node = gp.ShiftAugment(prob_slip=0.2, prob_shift=0.2, sigma=5, shift_axis=0) pipeline = ((hdf5_source, csv_source) + gp.MergeProvider() + gp.RandomLocation(ensure_nonempty=points_key) + shift_node) with gp.build(pipeline) as b: request = b.request_batch(request) # print(request[points_key]) target_vals = [ self.fake_data[point[0]][point[1]] for point in self.fake_points ] result_data = request[array_key].data result_points = request[points_key].data result_vals = [ result_data[int(point.location[0])][int(point.location[1])] for point in result_points.values() ] for result_val in result_vals: self.assertTrue( result_val in target_vals, msg= "result value {} at points {} not in target values {} at points {}" .format(result_val, list(result_points.values()), target_vals, self.fake_points))
def test_add_parent_vectors(self): points = gp.PointsKey("POINTS") pv_array = gp.ArrayKey("PARENT_VECTORS") mask = gp.ArrayKey("MASK") radius = [0.1, 0.1, 0.1, 0.1] ts = TracksSource(TEST_FILE, points) apv = AddParentVectors(points, pv_array, mask, radius) request = gp.BatchRequest() request.add(points, gp.Coordinate((3, 4, 4, 4))) request.add(pv_array, gp.Coordinate((1, 4, 4, 4))) request.add(mask, gp.Coordinate((1, 4, 4, 4))) pipeline = (ts + gp.Pad(points, None) + apv) with gp.build(pipeline): batch = pipeline.request_batch(request) points = batch[points].data expected_mask = np.zeros(shape=(1, 4, 4, 4)) expected_mask[0, 0, 0, 0] = 1 expected_mask[0, 1, 2, 3] = 1 expected_parent_vectors_z = np.zeros(shape=(1, 4, 4, 4)) expected_parent_vectors_z[0, 1, 2, 3] = -1.0 expected_parent_vectors_y = np.zeros(shape=(1, 4, 4, 4)) expected_parent_vectors_y[0, 1, 2, 3] = -2.0 expected_parent_vectors_x = np.zeros(shape=(1, 4, 4, 4)) expected_parent_vectors_x[0, 1, 2, 3] = -3.0 # print("MASK") # print(batch[mask].data) self.assertListEqual(expected_mask.tolist(), batch[mask].data.tolist()) parent_vectors = batch[pv_array].data self.assertListEqual(expected_parent_vectors_z.tolist(), parent_vectors[0].tolist()) self.assertListEqual(expected_parent_vectors_y.tolist(), parent_vectors[1].tolist()) self.assertListEqual(expected_parent_vectors_x.tolist(), parent_vectors[2].tolist())
def train_until(**kwargs): if tf.train.latest_checkpoint(kwargs['output_folder']): trained_until = int( tf.train.latest_checkpoint(kwargs['output_folder']).split('_')[-1]) else: trained_until = 0 if trained_until >= kwargs['max_iteration']: return anchor = gp.ArrayKey('ANCHOR') raw = gp.ArrayKey('RAW') raw_cropped = gp.ArrayKey('RAW_CROPPED') gt_labels = gp.ArrayKey('GT_LABELS') gt_affs = gp.ArrayKey('GT_AFFS') gt_fgbg = gp.ArrayKey('GT_FGBG') gt_cpv = gp.ArrayKey('GT_CPV') gt_points = gp.PointsKey('GT_CPV_POINTS') loss_weights_affs = gp.ArrayKey('LOSS_WEIGHTS_AFFS') loss_weights_fgbg = gp.ArrayKey('LOSS_WEIGHTS_FGBG') # loss_weights_cpv = gp.ArrayKey('LOSS_WEIGHTS_CPV') pred_affs = gp.ArrayKey('PRED_AFFS') pred_fgbg = gp.ArrayKey('PRED_FGBG') pred_cpv = gp.ArrayKey('PRED_CPV') pred_affs_gradients = gp.ArrayKey('PRED_AFFS_GRADIENTS') pred_fgbg_gradients = gp.ArrayKey('PRED_FGBG_GRADIENTS') pred_cpv_gradients = gp.ArrayKey('PRED_CPV_GRADIENTS') with open( os.path.join(kwargs['output_folder'], kwargs['name'] + '_config.json'), 'r') as f: net_config = json.load(f) with open( os.path.join(kwargs['output_folder'], kwargs['name'] + '_names.json'), 'r') as f: net_names = json.load(f) voxel_size = gp.Coordinate(kwargs['voxel_size']) input_shape_world = gp.Coordinate(net_config['input_shape']) * voxel_size output_shape_world = gp.Coordinate(net_config['output_shape']) * voxel_size # formulate the request for what a batch should (at least) contain request = gp.BatchRequest() request.add(raw, input_shape_world) request.add(raw_cropped, output_shape_world) request.add(gt_labels, output_shape_world) request.add(gt_fgbg, output_shape_world) request.add(anchor, output_shape_world) request.add(gt_cpv, output_shape_world) request.add(gt_affs, output_shape_world) request.add(loss_weights_affs, output_shape_world) request.add(loss_weights_fgbg, output_shape_world) # when we make a snapshot for inspection (see below), we also want to # request the predicted affinities and gradients of the loss wrt the # affinities snapshot_request = gp.BatchRequest() snapshot_request.add(raw_cropped, output_shape_world) snapshot_request.add(pred_affs, output_shape_world) # snapshot_request.add(pred_affs_gradients, output_shape_world) snapshot_request.add(gt_fgbg, output_shape_world) snapshot_request.add(pred_fgbg, output_shape_world) # snapshot_request.add(pred_fgbg_gradients, output_shape_world) snapshot_request.add(pred_cpv, output_shape_world) # snapshot_request.add(pred_cpv_gradients, output_shape_world) if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr": raise NotImplementedError("train node for {} not implemented".format( kwargs['input_format'])) fls = [] shapes = [] for f in kwargs['data_files']: fls.append(os.path.splitext(f)[0]) if kwargs['input_format'] == "hdf": vol = h5py.File(f, 'r')['volumes/raw'] elif kwargs['input_format'] == "zarr": vol = zarr.open(f, 'r')['volumes/raw'] print(f, vol.shape, vol.dtype) shapes.append(vol.shape) if vol.dtype != np.float32: print("please convert to float32") ln = len(fls) print("first 5 files: ", fls[0:4]) # padR = 46 # padGT = 32 if kwargs['input_format'] == "hdf": sourceNode = gp.Hdf5Source elif kwargs['input_format'] == "zarr": sourceNode = gp.ZarrSource augmentation = kwargs['augmentation'] pipeline = ( tuple( # read batches from the HDF5 file ( sourceNode( fls[t] + "." + kwargs['input_format'], datasets={ raw: 'volumes/raw', gt_labels: 'volumes/gt_labels', gt_fgbg: 'volumes/gt_fgbg', anchor: 'volumes/gt_fgbg', }, array_specs={ raw: gp.ArraySpec(interpolatable=True), gt_labels: gp.ArraySpec(interpolatable=False), gt_fgbg: gp.ArraySpec(interpolatable=False), anchor: gp.ArraySpec(interpolatable=False) } ), gp.CsvIDPointsSource( fls[t] + ".csv", gt_points, points_spec=gp.PointsSpec(roi=gp.Roi( gp.Coordinate((0, 0, 0)), gp.Coordinate(shapes[t]))) ) ) + gp.MergeProvider() + gp.Pad(raw, None) + gp.Pad(gt_points, None) + gp.Pad(gt_labels, None) + gp.Pad(gt_fgbg, None) # chose a random location for each requested batch + gp.RandomLocation() for t in range(ln) ) + # chose a random source (i.e., sample) from the above gp.RandomProvider() + # elastically deform the batch (gp.ElasticAugment( augmentation['elastic']['control_point_spacing'], augmentation['elastic']['jitter_sigma'], [augmentation['elastic']['rotation_min']*np.pi/180.0, augmentation['elastic']['rotation_max']*np.pi/180.0], subsample=augmentation['elastic'].get('subsample', 1)) \ if augmentation.get('elastic') is not None else NoOp()) + # apply transpose and mirror augmentations gp.SimpleAugment(mirror_only=augmentation['simple'].get("mirror"), transpose_only=augmentation['simple'].get("transpose")) + # # scale and shift the intensity of the raw array gp.IntensityAugment( raw, scale_min=augmentation['intensity']['scale'][0], scale_max=augmentation['intensity']['scale'][1], shift_min=augmentation['intensity']['shift'][0], shift_max=augmentation['intensity']['shift'][1], z_section_wise=False) + # grow a boundary between labels gp.GrowBoundary( gt_labels, steps=1, only_xy=False) + # convert labels into affinities between voxels gp.AddAffinities( [[-1, 0, 0], [0, -1, 0], [0, 0, -1]], gt_labels, gt_affs) + gp.AddCPV( gt_points, gt_labels, gt_cpv) + # create a weight array that balances positive and negative samples in # the affinity array gp.BalanceLabels( gt_affs, loss_weights_affs) + gp.BalanceLabels( gt_fgbg, loss_weights_fgbg) + # pre-cache batches from the point upstream gp.PreCache( cache_size=kwargs['cache_size'], num_workers=kwargs['num_workers']) + # perform one training iteration for each passing batch (here we use # the tensor names earlier stored in train_net.config) gp.tensorflow.Train( os.path.join(kwargs['output_folder'], kwargs['name']), optimizer=net_names['optimizer'], summary=net_names['summaries'], log_dir=kwargs['output_folder'], loss=net_names['loss'], inputs={ net_names['raw']: raw, net_names['gt_affs']: gt_affs, net_names['gt_fgbg']: gt_fgbg, net_names['anchor']: anchor, net_names['gt_cpv']: gt_cpv, net_names['gt_labels']: gt_labels, net_names['loss_weights_affs']: loss_weights_affs, net_names['loss_weights_fgbg']: loss_weights_fgbg }, outputs={ net_names['pred_affs']: pred_affs, net_names['pred_fgbg']: pred_fgbg, net_names['pred_cpv']: pred_cpv, net_names['raw_cropped']: raw_cropped, }, gradients={ net_names['pred_affs']: pred_affs_gradients, net_names['pred_fgbg']: pred_fgbg_gradients, net_names['pred_cpv']: pred_cpv_gradients }, save_every=kwargs['checkpoints']) + # save the passing batch as an HDF5 file for inspection gp.Snapshot( { raw: '/volumes/raw', raw_cropped: 'volumes/raw_cropped', gt_labels: '/volumes/gt_labels', gt_affs: '/volumes/gt_affs', gt_fgbg: '/volumes/gt_fgbg', gt_cpv: '/volumes/gt_cpv', pred_affs: '/volumes/pred_affs', pred_affs_gradients: '/volumes/pred_affs_gradients', pred_fgbg: '/volumes/pred_fgbg', pred_fgbg_gradients: '/volumes/pred_fgbg_gradients', pred_cpv: '/volumes/pred_cpv', pred_cpv_gradients: '/volumes/pred_cpv_gradients' }, output_dir=os.path.join(kwargs['output_folder'], 'snapshots'), output_filename='batch_{iteration}.hdf', every=kwargs['snapshots'], additional_request=snapshot_request, compression_type='gzip') + # show a summary of time spend in each node every 10 iterations gp.PrintProfilingStats(every=kwargs['profiling']) ) ######### # TRAIN # ######### print("Starting training...") with gp.build(pipeline): print(pipeline) for i in range(trained_until, kwargs['max_iteration']): # print("request", request) start = time.time() pipeline.request_batch(request) time_of_iteration = time.time() - start logger.info("Batch: iteration=%d, time=%f", i, time_of_iteration) # exit() print("Training finished")
def train_until(**kwargs): if tf.train.latest_checkpoint(kwargs['output_folder']): trained_until = int( tf.train.latest_checkpoint(kwargs['output_folder']).split('_')[-1]) else: trained_until = 0 if trained_until >= kwargs['max_iteration']: return anchor = gp.ArrayKey('ANCHOR') raw = gp.ArrayKey('RAW') raw_cropped = gp.ArrayKey('RAW_CROPPED') points = gp.PointsKey('POINTS') gt_cp = gp.ArrayKey('GT_CP') pred_cp = gp.ArrayKey('PRED_CP') pred_cp_gradients = gp.ArrayKey('PRED_CP_GRADIENTS') with open( os.path.join(kwargs['output_folder'], kwargs['name'] + '_config.json'), 'r') as f: net_config = json.load(f) with open( os.path.join(kwargs['output_folder'], kwargs['name'] + '_names.json'), 'r') as f: net_names = json.load(f) voxel_size = gp.Coordinate(kwargs['voxel_size']) input_shape_world = gp.Coordinate(net_config['input_shape']) * voxel_size output_shape_world = gp.Coordinate(net_config['output_shape']) * voxel_size # formulate the request for what a batch should (at least) contain request = gp.BatchRequest() request.add(raw, input_shape_world) request.add(raw_cropped, output_shape_world) request.add(gt_cp, output_shape_world) request.add(anchor, output_shape_world) # when we make a snapshot for inspection (see below), we also want to # request the predicted affinities and gradients of the loss wrt the # affinities snapshot_request = gp.BatchRequest() snapshot_request.add(raw_cropped, output_shape_world) snapshot_request.add(gt_cp, output_shape_world) snapshot_request.add(pred_cp, output_shape_world) # snapshot_request.add(pred_cp_gradients, output_shape_world) if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr": raise NotImplementedError("train node for %s not implemented yet", kwargs['input_format']) fls = [] shapes = [] mn = [] mx = [] for f in kwargs['data_files']: fls.append(os.path.splitext(f)[0]) if kwargs['input_format'] == "hdf": vol = h5py.File(f, 'r')['volumes/raw'] elif kwargs['input_format'] == "zarr": vol = zarr.open(f, 'r')['volumes/raw'] print(f, vol.shape, vol.dtype) shapes.append(vol.shape) mn.append(np.min(vol)) mx.append(np.max(vol)) if vol.dtype != np.float32: print("please convert to float32") ln = len(fls) print("first 5 files: ", fls[0:4]) if kwargs['input_format'] == "hdf": sourceNode = gp.Hdf5Source elif kwargs['input_format'] == "zarr": sourceNode = gp.ZarrSource augmentation = kwargs['augmentation'] sources = tuple( (sourceNode(fls[t] + "." + kwargs['input_format'], datasets={ raw: 'volumes/raw', anchor: 'volumes/gt_fgbg', }, array_specs={ raw: gp.ArraySpec(interpolatable=True), anchor: gp.ArraySpec(interpolatable=False) }), gp.CsvIDPointsSource(fls[t] + ".csv", points, points_spec=gp.PointsSpec( roi=gp.Roi(gp.Coordinate(( 0, 0, 0)), gp.Coordinate(shapes[t]))))) + gp.MergeProvider() # + Clip(raw, mn=mn[t], mx=mx[t]) # + NormalizeMinMax(raw, mn=mn[t], mx=mx[t]) + gp.Pad(raw, None) + gp.Pad(points, None) # chose a random location for each requested batch + gp.RandomLocation() for t in range(ln)) pipeline = ( sources + # chose a random source (i.e., sample) from the above gp.RandomProvider() + # elastically deform the batch (gp.ElasticAugment( augmentation['elastic']['control_point_spacing'], augmentation['elastic']['jitter_sigma'], [augmentation['elastic']['rotation_min']*np.pi/180.0, augmentation['elastic']['rotation_max']*np.pi/180.0], subsample=augmentation['elastic'].get('subsample', 1)) \ if augmentation.get('elastic') is not None else NoOp()) + # apply transpose and mirror augmentations gp.SimpleAugment(mirror_only=augmentation['simple'].get("mirror"), transpose_only=augmentation['simple'].get("transpose")) + # (gp.SimpleAugment( # mirror_only=augmentation['simple'].get("mirror"), # transpose_only=augmentation['simple'].get("transpose")) \ # if augmentation.get('simple') is not None and \ # augmentation.get('simple') != {} else NoOp()) + # # scale and shift the intensity of the raw array (gp.IntensityAugment( raw, scale_min=augmentation['intensity']['scale'][0], scale_max=augmentation['intensity']['scale'][1], shift_min=augmentation['intensity']['shift'][0], shift_max=augmentation['intensity']['shift'][1], z_section_wise=False) \ if augmentation.get('intensity') is not None and \ augmentation.get('intensity') != {} else NoOp()) + gp.RasterizePoints( points, gt_cp, array_spec=gp.ArraySpec(voxel_size=voxel_size), settings=gp.RasterizationSettings( radius=(2, 2, 2), mode='peak')) + # pre-cache batches from the point upstream gp.PreCache( cache_size=kwargs['cache_size'], num_workers=kwargs['num_workers']) + # perform one training iteration for each passing batch (here we use # the tensor names earlier stored in train_net.config) gp.tensorflow.Train( os.path.join(kwargs['output_folder'], kwargs['name']), optimizer=net_names['optimizer'], summary=net_names['summaries'], log_dir=kwargs['output_folder'], loss=net_names['loss'], inputs={ net_names['raw']: raw, net_names['gt_cp']: gt_cp, net_names['anchor']: anchor, }, outputs={ net_names['pred_cp']: pred_cp, net_names['raw_cropped']: raw_cropped, }, gradients={ # net_names['pred_cp']: pred_cp_gradients, }, save_every=kwargs['checkpoints']) + # save the passing batch as an HDF5 file for inspection gp.Snapshot( { raw: '/volumes/raw', raw_cropped: 'volumes/raw_cropped', gt_cp: '/volumes/gt_cp', pred_cp: '/volumes/pred_cp', # pred_cp_gradients: '/volumes/pred_cp_gradients', }, output_dir=os.path.join(kwargs['output_folder'], 'snapshots'), output_filename='batch_{iteration}.hdf', every=kwargs['snapshots'], additional_request=snapshot_request, compression_type='gzip') + # show a summary of time spend in each node every 10 iterations gp.PrintProfilingStats(every=kwargs['profiling']) ) ######### # TRAIN # ######### print("Starting training...") with gp.build(pipeline): print(pipeline) for i in range(trained_until, kwargs['max_iteration']): # print("request", request) start = time.time() pipeline.request_batch(request) time_of_iteration = time.time() - start logger.info("Batch: iteration=%d, time=%f", i, time_of_iteration) # exit() print("Training finished")
def build_pipeline(parameter, augment=True): voxel_size = gp.Coordinate(parameter['voxel_size']) # Array Specifications. raw = gp.ArrayKey('RAW') gt_neurons = gp.ArrayKey('GT_NEURONS') gt_postpre_vectors = gp.ArrayKey('GT_POSTPRE_VECTORS') gt_post_indicator = gp.ArrayKey('GT_POST_INDICATOR') post_loss_weight = gp.ArrayKey('POST_LOSS_WEIGHT') vectors_mask = gp.ArrayKey('VECTORS_MASK') pred_postpre_vectors = gp.ArrayKey('PRED_POSTPRE_VECTORS') pred_post_indicator = gp.ArrayKey('PRED_POST_INDICATOR') grad_syn_indicator = gp.ArrayKey('GRAD_SYN_INDICATOR') grad_partner_vectors = gp.ArrayKey('GRAD_PARTNER_VECTORS') # Points specifications dummypostsyn = gp.PointsKey('DUMMYPOSTSYN') postsyn = gp.PointsKey('POSTSYN') presyn = gp.PointsKey('PRESYN') trg_context = 140 # AddPartnerVectorMap context in nm - pre-post distance with open('train_net_config.json', 'r') as f: net_config = json.load(f) input_size = gp.Coordinate(net_config['input_shape']) * voxel_size output_size = gp.Coordinate(net_config['output_shape']) * voxel_size request = gp.BatchRequest() request.add(raw, input_size) request.add(gt_neurons, output_size) request.add(gt_postpre_vectors, output_size) request.add(gt_post_indicator, output_size) request.add(post_loss_weight, output_size) request.add(vectors_mask, output_size) request.add(dummypostsyn, output_size) for (key, request_spec) in request.items(): print(key) print(request_spec.roi) request_spec.roi.contains(request_spec.roi) # slkfdms snapshot_request = gp.BatchRequest({ pred_post_indicator: request[gt_postpre_vectors], pred_postpre_vectors: request[gt_postpre_vectors], grad_syn_indicator: request[gt_postpre_vectors], grad_partner_vectors: request[gt_postpre_vectors], vectors_mask: request[gt_postpre_vectors] }) postsyn_rastersetting = gp.RasterizationSettings( radius=parameter['blob_radius'], mask=gt_neurons, mode=parameter['blob_mode']) pipeline = tuple([ create_source(sample, raw, presyn, postsyn, dummypostsyn, parameter, gt_neurons) for sample in samples ]) pipeline += gp.RandomProvider() if augment: pipeline += gp.ElasticAugment([4, 40, 40], [0, 2, 2], [0, math.pi / 2.0], prob_slip=0.05, prob_shift=0.05, max_misalign=10, subsample=8) pipeline += gp.SimpleAugment(transpose_only=[1, 2], mirror_only=[1, 2]) pipeline += gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) pipeline += gp.IntensityScaleShift(raw, 2, -1) pipeline += gp.RasterizePoints( postsyn, gt_post_indicator, gp.ArraySpec(voxel_size=voxel_size, dtype=np.int32), postsyn_rastersetting) spec = gp.ArraySpec(voxel_size=voxel_size) pipeline += AddPartnerVectorMap( src_points=postsyn, trg_points=presyn, array=gt_postpre_vectors, radius=parameter['d_blob_radius'], trg_context=trg_context, # enlarge array_spec=spec, mask=gt_neurons, pointmask=vectors_mask) pipeline += gp.BalanceLabels(labels=gt_post_indicator, scales=post_loss_weight, slab=(-1, -1, -1), clipmin=parameter['cliprange'][0], clipmax=parameter['cliprange'][1]) if parameter['d_scale'] != 1: pipeline += gp.IntensityScaleShift(gt_postpre_vectors, scale=parameter['d_scale'], shift=0) pipeline += gp.PreCache(cache_size=40, num_workers=10) pipeline += gp.tensorflow.Train( './train_net', optimizer=net_config['optimizer'], loss=net_config['loss'], summary=net_config['summary'], log_dir='./tensorboard/', save_every=30000, # 10000 log_every=100, inputs={ net_config['raw']: raw, net_config['gt_partner_vectors']: gt_postpre_vectors, net_config['gt_syn_indicator']: gt_post_indicator, net_config['vectors_mask']: vectors_mask, # Loss weights --> mask net_config['indicator_weight']: post_loss_weight, # Loss weights }, outputs={ net_config['pred_partner_vectors']: pred_postpre_vectors, net_config['pred_syn_indicator']: pred_post_indicator, }, gradients={ net_config['pred_partner_vectors']: grad_partner_vectors, net_config['pred_syn_indicator']: grad_syn_indicator, }, ) # Visualize. pipeline += gp.IntensityScaleShift(raw, 0.5, 0.5) pipeline += gp.Snapshot( { raw: 'volumes/raw', gt_neurons: 'volumes/labels/neuron_ids', gt_post_indicator: 'volumes/gt_post_indicator', gt_postpre_vectors: 'volumes/gt_postpre_vectors', pred_postpre_vectors: 'volumes/pred_postpre_vectors', pred_post_indicator: 'volumes/pred_post_indicator', post_loss_weight: 'volumes/post_loss_weight', grad_syn_indicator: 'volumes/post_indicator_gradients', grad_partner_vectors: 'volumes/partner_vectors_gradients', vectors_mask: 'volumes/vectors_mask' }, every=1000, output_filename='batch_{iteration}.hdf', compression_type='gzip', additional_request=snapshot_request) pipeline += gp.PrintProfilingStats(every=100) print("Starting training...") max_iteration = parameter['max_iteration'] with gp.build(pipeline) as b: for i in range(max_iteration): b.request_batch(request)
"score_thr": 0, "score_type": "sum", "nms_radius": None } parameters = synful.detection.SynapseExtractionParameters( extract_type=parameter_dic['extract_type'], cc_threshold=parameter_dic['cc_threshold'], loc_type=parameter_dic['loc_type'], score_thr=parameter_dic['score_thr'], score_type=parameter_dic['score_type'], nms_radius=parameter_dic['nms_radius']) gp.ArrayKey('M_PRED') gp.ArrayKey('D_PRED') gp.PointsKey('PRESYN') gp.PointsKey('POSTSYN') class TestSource(gp.BatchProvider): def __init__(self, m_pred, d_pred, voxel_size): self.voxel_size = voxel_size self.m_pred = m_pred self.d_pred = d_pred def setup(self): self.provides( gp.ArrayKeys.M_PRED, gp.ArraySpec(roi=gp.Roi((0, 0, 0), (200, 200, 200)), voxel_size=self.voxel_size, interpolatable=False))
def train_until(max_iteration): # get the latest checkpoint if tf.train.latest_checkpoint('.'): trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1]) else: trained_until = 0 if trained_until >= max_iteration: return # array keys for fused volume raw = gp.ArrayKey('RAW') labels = gp.ArrayKey('LABELS') labels_fg = gp.ArrayKey('LABELS_FG') # array keys for base volume raw_base = gp.ArrayKey('RAW_BASE') labels_base = gp.ArrayKey('LABELS_BASE') swc_base = gp.PointsKey('SWC_BASE') swc_center_base = gp.PointsKey('SWC_CENTER_BASE') # array keys for add volume raw_add = gp.ArrayKey('RAW_ADD') labels_add = gp.ArrayKey('LABELS_ADD') swc_add = gp.PointsKey('SWC_ADD') swc_center_add = gp.PointsKey('SWC_CENTER_ADD') # output data fg = gp.ArrayKey('FG') gradient_fg = gp.ArrayKey('GRADIENT_FG') loss_weights = gp.ArrayKey('LOSS_WEIGHTS') voxel_size = gp.Coordinate((3, 3, 3)) input_size = gp.Coordinate(net_config['input_shape']) * voxel_size output_size = gp.Coordinate(net_config['output_shape']) * voxel_size # add request request = gp.BatchRequest() request.add(raw, input_size) request.add(labels, output_size) request.add(labels_fg, output_size) request.add(loss_weights, output_size) request.add(swc_center_base, output_size) request.add(swc_base, input_size) request.add(swc_center_add, output_size) request.add(swc_add, input_size) # add snapshot request snapshot_request = gp.BatchRequest() snapshot_request.add(fg, output_size) snapshot_request.add(labels_fg, output_size) snapshot_request.add(gradient_fg, output_size) snapshot_request.add(raw_base, input_size) snapshot_request.add(raw_add, input_size) snapshot_request.add(labels_base, input_size) snapshot_request.add(labels_add, input_size) # data source for "base" volume data_sources_base = tuple() data_sources_base += tuple( (gp.Hdf5Source(file, datasets={ raw_base: '/volume', }, array_specs={ raw_base: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size, dtype=np.uint16), }, channels_first=False), SwcSource(filename=file, dataset='/reconstruction', points=(swc_center_base, swc_base), scale=voxel_size)) + gp.MergeProvider() + gp.RandomLocation(ensure_nonempty=swc_center_base) + RasterizeSkeleton( points=swc_base, array=labels_base, array_spec=gp.ArraySpec( interpolatable=False, voxel_size=voxel_size, dtype=np.uint32), iteration=10) for file in files) data_sources_base += gp.RandomProvider() # data source for "add" volume data_sources_add = tuple() data_sources_add += tuple( (gp.Hdf5Source(file, datasets={ raw_add: '/volume', }, array_specs={ raw_add: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size, dtype=np.uint16), }, channels_first=False), SwcSource(filename=file, dataset='/reconstruction', points=(swc_center_add, swc_add), scale=voxel_size)) + gp.MergeProvider() + gp.RandomLocation(ensure_nonempty=swc_center_add) + RasterizeSkeleton( points=swc_add, array=labels_add, array_spec=gp.ArraySpec( interpolatable=False, voxel_size=voxel_size, dtype=np.uint32), iteration=1) for file in files) data_sources_add += gp.RandomProvider() data_sources = tuple([data_sources_base, data_sources_add ]) + gp.MergeProvider() pipeline = ( data_sources + FusionAugment(raw_base, raw_add, labels_base, labels_add, raw, labels, blend_mode='labels_mask', blend_smoothness=10, num_blended_objects=0) + # augment gp.ElasticAugment([10, 10, 10], [1, 1, 1], [0, math.pi / 2.0], subsample=8) + gp.SimpleAugment(mirror_only=[2], transpose_only=[]) + gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001) + BinarizeGt(labels, labels_fg) + gp.BalanceLabels(labels_fg, loss_weights) + # train gp.PreCache(cache_size=40, num_workers=10) + gp.tensorflow.Train('./train_net', optimizer=net_names['optimizer'], loss=net_names['loss'], inputs={ net_names['raw']: raw, net_names['labels_fg']: labels_fg, net_names['loss_weights']: loss_weights, }, outputs={ net_names['fg']: fg, }, gradients={ net_names['fg']: gradient_fg, }, save_every=100) + # visualize gp.Snapshot(output_filename='snapshot_{iteration}.hdf', dataset_names={ raw: 'volumes/raw', raw_base: 'volumes/raw_base', raw_add: 'volumes/raw_add', labels: 'volumes/labels', labels_base: 'volumes/labels_base', labels_add: 'volumes/labels_add', fg: 'volumes/fg', labels_fg: 'volumes/labels_fg', gradient_fg: 'volumes/gradient_fg', }, additional_request=snapshot_request, every=10) + gp.PrintProfilingStats(every=100)) with gp.build(pipeline): print("Starting training...") for i in range(max_iteration - trained_until): pipeline.request_batch(request)
def train_until(max_iteration): # get the latest checkpoint if tf.train.latest_checkpoint("."): trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1]) else: trained_until = 0 if trained_until >= max_iteration: return # array keys for data sources raw = gp.ArrayKey("RAW") swcs = gp.PointsKey("SWCS") labels = gp.ArrayKey("LABELS") # array keys for base volume raw_base = gp.ArrayKey("RAW_BASE") labels_base = gp.ArrayKey("LABELS_BASE") swc_base = gp.PointsKey("SWC_BASE") # array keys for add volume raw_add = gp.ArrayKey("RAW_ADD") labels_add = gp.ArrayKey("LABELS_ADD") swc_add = gp.PointsKey("SWC_ADD") # array keys for fused volume raw_fused = gp.ArrayKey("RAW_FUSED") labels_fused = gp.ArrayKey("LABELS_FUSED") swc_fused = gp.PointsKey("SWC_FUSED") # output data fg = gp.ArrayKey("FG") labels_fg = gp.ArrayKey("LABELS_FG") labels_fg_bin = gp.ArrayKey("LABELS_FG_BIN") gradient_fg = gp.ArrayKey("GRADIENT_FG") loss_weights = gp.ArrayKey("LOSS_WEIGHTS") voxel_size = gp.Coordinate((10, 3, 3)) input_size = gp.Coordinate(net_config["input_shape"]) * voxel_size output_size = gp.Coordinate(net_config["output_shape"]) * voxel_size # add request request = gp.BatchRequest() request.add(raw_fused, input_size) request.add(labels_fused, input_size) request.add(swc_fused, input_size) request.add(labels_fg, output_size) request.add(labels_fg_bin, output_size) request.add(loss_weights, output_size) # add snapshot request # request.add(fg, output_size) # request.add(labels_fg, output_size) request.add(gradient_fg, output_size) request.add(raw_base, input_size) request.add(raw_add, input_size) request.add(labels_base, input_size) request.add(labels_add, input_size) request.add(swc_base, input_size) request.add(swc_add, input_size) data_sources = tuple( ( gp.N5Source( filename=str( ( filename / "consensus-neurons-with-machine-centerpoints-labelled-as-swcs-carved.n5" ).absolute() ), datasets={raw: "volume"}, array_specs={ raw: gp.ArraySpec( interpolatable=True, voxel_size=voxel_size, dtype=np.uint16 ) }, ), MouselightSwcFileSource( filename=str( ( filename / "consensus-neurons-with-machine-centerpoints-labelled-as-swcs" ).absolute() ), points=(swcs,), scale=voxel_size, transpose=(2, 1, 0), transform_file=str((filename / "transform.txt").absolute()), ignore_human_nodes=True ), ) + gp.MergeProvider() + gp.RandomLocation( ensure_nonempty=swcs, ensure_centered=True ) + RasterizeSkeleton( points=swcs, array=labels, array_spec=gp.ArraySpec( interpolatable=False, voxel_size=voxel_size, dtype=np.uint32 ), ) + GrowLabels(labels, radius=10) # augment + gp.ElasticAugment( [40, 10, 10], [0.25, 1, 1], [0, math.pi / 2.0], subsample=4, ) + gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2]) + gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001) for filename in Path(sample_dir).iterdir() if "2018-08-01" in filename.name # or "2018-07-02" in filename.name ) pipeline = ( data_sources + gp.RandomProvider() + GetNeuronPair( swcs, raw, labels, (swc_base, swc_add), (raw_base, raw_add), (labels_base, labels_add), seperate_by=150, shift_attempts=50, request_attempts=10, ) + FusionAugment( raw_base, raw_add, labels_base, labels_add, swc_base, swc_add, raw_fused, labels_fused, swc_fused, blend_mode="labels_mask", blend_smoothness=10, num_blended_objects=0, ) + Crop(labels_fused, labels_fg) + BinarizeGt(labels_fg, labels_fg_bin) + gp.BalanceLabels(labels_fg_bin, loss_weights) # train + gp.PreCache(cache_size=40, num_workers=10) + gp.tensorflow.Train( "./train_net", optimizer=net_names["optimizer"], loss=net_names["loss"], inputs={ net_names["raw"]: raw_fused, net_names["labels_fg"]: labels_fg_bin, net_names["loss_weights"]: loss_weights, }, outputs={net_names["fg"]: fg}, gradients={net_names["fg"]: gradient_fg}, save_every=100000, ) + gp.Snapshot( output_filename="snapshot_{iteration}.hdf", dataset_names={ raw_fused: "volumes/raw_fused", raw_base: "volumes/raw_base", raw_add: "volumes/raw_add", labels_fused: "volumes/labels_fused", labels_base: "volumes/labels_base", labels_add: "volumes/labels_add", labels_fg_bin: "volumes/labels_fg_bin", fg: "volumes/pred_fg", gradient_fg: "volumes/gradient_fg", }, every=100, ) + gp.PrintProfilingStats(every=10) ) with gp.build(pipeline): logging.info("Starting training...") for i in range(max_iteration - trained_until): logging.info("requesting batch {}".format(i)) batch = pipeline.request_batch(request) """
def get_neuron_pair( pipeline, setup_config, raw: ArrayKey, labels: ArrayKey, matched: PointsKey, nonempty_placeholder: PointsKey, ): # Data Properties output_shape = gp.Coordinate(setup_config["OUTPUT_SHAPE"]) voxel_size = gp.Coordinate(setup_config["VOXEL_SIZE"]) output_size = output_shape * voxel_size micron_scale = voxel_size[0] # Somewhat arbitrary hyperparameters blend_mode = setup_config["BLEND_MODE"] shift_attempts = setup_config["SHIFT_ATTEMPTS"] request_attempts = setup_config["REQUEST_ATTEMPTS"] blend_smoothness = setup_config["BLEND_SMOOTHNESS"] seperate_by = setup_config["SEPERATE_BY"] seperate_distance = (np.array(seperate_by)).tolist() # array keys for fused volume raw_fused = ArrayKey("RAW_FUSED") labels_fused = ArrayKey("LABELS_FUSED") matched_fused = gp.PointsKey("MATCHED_FUSED") # array keys for base volume raw_base = ArrayKey("RAW_BASE") labels_base = ArrayKey("LABELS_BASE") matched_base = gp.PointsKey("MATCHED_BASE") # array keys for add volume raw_add = ArrayKey("RAW_ADD") labels_add = ArrayKey("LABELS_ADD") matched_add = gp.PointsKey("MATCHED_ADD") # debug array keys soft_mask = gp.ArrayKey("SOFT_MASK") masked_base = gp.ArrayKey("MASKED_BASE") masked_add = gp.ArrayKey("MASKED_ADD") mask_maxed = gp.ArrayKey("MASK_MAXED") pipeline = pipeline + GetNeuronPair( matched, raw, labels, (matched_base, matched_add), (raw_base, raw_add), (labels_base, labels_add), output_shape=output_size, seperate_by=seperate_distance, shift_attempts=shift_attempts, request_attempts=request_attempts, # nonempty_placeholder=nonempty_placeholder, nonempty_placeholder=nonempty_placeholder, ) if blend_mode == "add": if setup_config["PRE_CLAHE"]: pipeline = pipeline + scipyCLAHE( [raw_add, raw_base], gp.Coordinate([20, 64, 64]) * voxel_size, clip_limit=float(setup_config["CLIP_LIMIT"]), normalize=setup_config["CLAHE_NORMALIZE"], ) pipeline = pipeline + FusionAugment( raw_base, raw_add, labels_base, labels_add, matched_base, matched_add, raw_fused, labels_fused, matched_fused, soft_mask=soft_mask, masked_base=masked_base, masked_add=masked_add, mask_maxed=mask_maxed, blend_mode=blend_mode, blend_smoothness=blend_smoothness * micron_scale, gaussian_smooth_mode="mirror", # TODO: Config this num_blended_objects=0, # TODO: Config this ) if blend_mode == "add": if setup_config["POST_CLAHE"]: pipeline = pipeline + scipyCLAHE( [raw_add, raw_base], gp.Coordinate([20, 64, 64]) * voxel_size, clip_limit=float(setup_config["CLIP_LIMIT"]), normalize=setup_config["CLAHE_NORMALIZE"], ) return ( pipeline, raw_fused, labels_fused, matched_fused, raw_base, labels_base, matched_base, raw_add, labels_add, matched_add, soft_mask, masked_base, masked_add, mask_maxed, )
def get_mouselight_data_sources(setup_config: Dict[str, Any], source_samples: List[str], locations=False): # Source Paths and accessibility raw_n5 = setup_config["RAW_N5"] mongo_url = setup_config["MONGO_URL"] samples_path = Path(setup_config["SAMPLES_PATH"]) # specified_locations = setup_config.get("SPECIFIED_LOCATIONS") # Graph matching parameters point_balance_radius = setup_config["POINT_BALANCE_RADIUS"] matching_failures_dir = setup_config["MATCHING_FAILURES_DIR"] matching_failures_dir = (matching_failures_dir if matching_failures_dir is None else Path(matching_failures_dir)) # Data Properties voxel_size = gp.Coordinate(setup_config["VOXEL_SIZE"]) output_shape = gp.Coordinate(setup_config["OUTPUT_SHAPE"]) output_size = output_shape * voxel_size micron_scale = voxel_size[0] distance_attr = setup_config["DISTANCE_ATTRIBUTE"] target_distance = float(setup_config["MIN_DIST_TO_FALLBACK"]) max_nonempty_points = int(setup_config["MAX_RANDOM_LOCATION_POINTS"]) mongo_db_template = setup_config["MONGO_DB_TEMPLATE"] matched_source = setup_config.get("MATCHED_SOURCE", "matched") # New array keys # Note: These are intended to be requested with size input_size raw = ArrayKey("RAW") matched = gp.PointsKey("MATCHED") nonempty_placeholder = gp.PointsKey("NONEMPTY") labels = ArrayKey("LABELS") ensure_nonempty = nonempty_placeholder node_offset = { sample.name: (daisy.persistence.MongoDbGraphProvider( mongo_db_template.format(sample=sample.name, source="skeletonization"), mongo_url, ).num_nodes(None) + 1) for sample in samples_path.iterdir() if sample.name in source_samples } # if specified_locations is not None: # centers = pickle.load(open(specified_locations, "rb")) # random = gp.SpecifiedLocation # kwargs = {"locations": centers, "choose_randomly": True} # logger.info(f"Using specified locations from {specified_locations}") # elif locations: # random = RandomLocations # kwargs = { # "ensure_nonempty": ensure_nonempty, # "ensure_centered": True, # "point_balance_radius": point_balance_radius * micron_scale, # "loc": gp.ArrayKey("RANDOM_LOCATION"), # } # else: random = RandomLocation kwargs = { "ensure_nonempty": ensure_nonempty, "ensure_centered": True, "point_balance_radius": point_balance_radius * micron_scale, } data_sources = (tuple( ( gp.ZarrSource( filename=str((sample / raw_n5).absolute()), datasets={raw: "volume-rechunked"}, array_specs={ raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size, dtype=np.uint16) }, ), DaisyGraphProvider( mongo_db_template.format(sample=sample.name, source=matched_source), mongo_url, points=[matched], directed=True, node_attrs=[], edge_attrs=[], ), FilteredDaisyGraphProvider( mongo_db_template.format(sample=sample.name, source=matched_source), mongo_url, points=[nonempty_placeholder], directed=True, node_attrs=["distance_to_fallback"], edge_attrs=[], num_nodes=max_nonempty_points, dist_attribute=distance_attr, min_dist=target_distance, ), ) + gp.MergeProvider() + random(**kwargs) + gp.Normalize(raw) + FilterComponents( matched, node_offset[sample.name], centroid_size=output_size) + RasterizeSkeleton( points=matched, array=labels, array_spec=gp.ArraySpec( interpolatable=False, voxel_size=voxel_size, dtype=np.int64), ) for sample in samples_path.iterdir() if sample.name in source_samples) + gp.RandomProvider()) return (data_sources, raw, labels, nonempty_placeholder, matched)
def train_until(max_iteration): # get the latest checkpoint if tf.train.latest_checkpoint("."): trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1]) else: trained_until = 0 if trained_until >= max_iteration: return # array keys for fused volume raw = gp.ArrayKey("RAW") labels = gp.ArrayKey("LABELS") labels_fg = gp.ArrayKey("LABELS_FG") # array keys for base volume raw_base = gp.ArrayKey("RAW_BASE") labels_base = gp.ArrayKey("LABELS_BASE") swc_base = gp.PointsKey("SWC_BASE") swc_center_base = gp.PointsKey("SWC_CENTER_BASE") # array keys for add volume raw_add = gp.ArrayKey("RAW_ADD") labels_add = gp.ArrayKey("LABELS_ADD") swc_add = gp.PointsKey("SWC_ADD") swc_center_add = gp.PointsKey("SWC_CENTER_ADD") # output data fg = gp.ArrayKey("FG") gradient_fg = gp.ArrayKey("GRADIENT_FG") loss_weights = gp.ArrayKey("LOSS_WEIGHTS") voxel_size = gp.Coordinate((4, 1, 1)) input_size = gp.Coordinate(net_config["input_shape"]) * voxel_size output_size = gp.Coordinate(net_config["output_shape"]) * voxel_size # add request request = gp.BatchRequest() request.add(raw, input_size) request.add(labels, output_size) request.add(labels_fg, output_size) request.add(loss_weights, output_size) request.add(swc_center_base, output_size) request.add(swc_center_add, output_size) # add snapshot request snapshot_request = gp.BatchRequest() snapshot_request.add(fg, output_size) snapshot_request.add(labels_fg, output_size) snapshot_request.add(gradient_fg, output_size) snapshot_request.add(raw_base, input_size) snapshot_request.add(raw_add, input_size) snapshot_request.add(labels_base, input_size) snapshot_request.add(labels_add, input_size) # data source for "base" volume data_sources_base = tuple( ( gp.Hdf5Source( filename, datasets={raw_base: "/volume"}, array_specs={ raw_base: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size, dtype=np.uint16) }, channels_first=False, ), SwcSource( filename=filename, dataset="/reconstruction", points=(swc_center_base, swc_base), scale=voxel_size, ), ) + gp.MergeProvider() + gp.RandomLocation(ensure_nonempty=swc_center_base) + RasterizeSkeleton( points=swc_base, array=labels_base, array_spec=gp.ArraySpec( interpolatable=False, voxel_size=voxel_size, dtype=np.uint32), radius=5.0, ) for filename in files) # data source for "add" volume data_sources_add = tuple( ( gp.Hdf5Source( file, datasets={raw_add: "/volume"}, array_specs={ raw_add: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size, dtype=np.uint16) }, channels_first=False, ), SwcSource( filename=file, dataset="/reconstruction", points=(swc_center_add, swc_add), scale=voxel_size, ), ) + gp.MergeProvider() + gp.RandomLocation(ensure_nonempty=swc_center_add) + RasterizeSkeleton( points=swc_add, array=labels_add, array_spec=gp.ArraySpec( interpolatable=False, voxel_size=voxel_size, dtype=np.uint32), radius=5.0, ) for file in files) data_sources = ( (data_sources_base + gp.RandomProvider()), (data_sources_add + gp.RandomProvider()), ) + gp.MergeProvider() pipeline = ( data_sources + FusionAugment( raw_base, raw_add, labels_base, labels_add, raw, labels, blend_mode="labels_mask", blend_smoothness=10, num_blended_objects=0, ) + # augment gp.ElasticAugment([40, 10, 10], [0.25, 1, 1], [0, math.pi / 2.0], subsample=4) + gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2]) + gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001) + BinarizeGt(labels, labels_fg) + gp.BalanceLabels(labels_fg, loss_weights) + # train gp.PreCache(cache_size=40, num_workers=10) + gp.tensorflow.Train( "./train_net", optimizer=net_names["optimizer"], loss=net_names["loss"], inputs={ net_names["raw"]: raw, net_names["labels_fg"]: labels_fg, net_names["loss_weights"]: loss_weights, }, outputs={net_names["fg"]: fg}, gradients={net_names["fg"]: gradient_fg}, save_every=100000, ) + # visualize gp.Snapshot( output_filename="snapshot_{iteration}.hdf", dataset_names={ raw: "volumes/raw", raw_base: "volumes/raw_base", raw_add: "volumes/raw_add", labels: "volumes/labels", labels_base: "volumes/labels_base", labels_add: "volumes/labels_add", fg: "volumes/fg", labels_fg: "volumes/labels_fg", gradient_fg: "volumes/gradient_fg", }, additional_request=snapshot_request, every=100, ) + gp.PrintProfilingStats(every=100)) with gp.build(pipeline): print("Starting training...") for i in range(max_iteration - trained_until): pipeline.request_batch(request)
def train(n_iterations, setup_config, mknet_tensor_names, loss_tensor_names): # Network hyperparams INPUT_SHAPE = setup_config["INPUT_SHAPE"] OUTPUT_SHAPE = setup_config["OUTPUT_SHAPE"] # Skeleton generation hyperparams SKEL_GEN_RADIUS = setup_config["SKEL_GEN_RADIUS"] THETAS = np.array(setup_config["THETAS"]) * math.pi SPLIT_PS = setup_config["SPLIT_PS"] NOISE_VAR = setup_config["NOISE_VAR"] N_OBJS = setup_config["N_OBJS"] # Skeleton variation hyperparams LABEL_RADII = setup_config["LABEL_RADII"] RAW_RADII = setup_config["RAW_RADII"] RAW_INTENSITIES = setup_config["RAW_INTENSITIES"] # Training hyperparams CACHE_SIZE = setup_config["CACHE_SIZE"] NUM_WORKERS = setup_config["NUM_WORKERS"] SNAPSHOT_EVERY = setup_config["SNAPSHOT_EVERY"] CHECKPOINT_EVERY = setup_config["CHECKPOINT_EVERY"] point_trees = gp.PointsKey("POINT_TREES") labels = gp.ArrayKey("LABELS") raw = gp.ArrayKey("RAW") gt_fg = gp.ArrayKey("GT_FG") embedding = gp.ArrayKey("EMBEDDING") fg = gp.ArrayKey("FG") maxima = gp.ArrayKey("MAXIMA") gradient_embedding = gp.ArrayKey("GRADIENT_EMBEDDING") gradient_fg = gp.ArrayKey("GRADIENT_FG") # tensorflow tensors emst = gp.ArrayKey("EMST") edges_u = gp.ArrayKey("EDGES_U") edges_v = gp.ArrayKey("EDGES_V") ratio_pos = gp.ArrayKey("RATIO_POS") ratio_neg = gp.ArrayKey("RATIO_NEG") dist = gp.ArrayKey("DIST") num_pos_pairs = gp.ArrayKey("NUM_POS") num_neg_pairs = gp.ArrayKey("NUM_NEG") request = gp.BatchRequest() request.add(raw, INPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) request.add(labels, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) request.add(point_trees, INPUT_SHAPE) snapshot_request = gp.BatchRequest() snapshot_request.add(raw, INPUT_SHAPE) snapshot_request.add(embedding, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) snapshot_request.add(fg, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) snapshot_request.add(gt_fg, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) snapshot_request.add(maxima, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) snapshot_request.add(gradient_embedding, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) snapshot_request.add(gradient_fg, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) snapshot_request[emst] = gp.ArraySpec() snapshot_request[edges_u] = gp.ArraySpec() snapshot_request[edges_v] = gp.ArraySpec() snapshot_request[ratio_pos] = gp.ArraySpec() snapshot_request[ratio_neg] = gp.ArraySpec() snapshot_request[dist] = gp.ArraySpec() snapshot_request[num_pos_pairs] = gp.ArraySpec() snapshot_request[num_neg_pairs] = gp.ArraySpec() pipeline = ( nl.SyntheticLightLike( point_trees, dims=2, r=SKEL_GEN_RADIUS, n_obj=N_OBJS, thetas=THETAS, split_ps=SPLIT_PS, ) # + gp.SimpleAugment() # + gp.ElasticAugment([10, 10], [0.1, 0.1], [0, 2.0 * math.pi], spatial_dims=2) + nl.RasterizeSkeleton( point_trees, raw, gp.ArraySpec( roi=gp.Roi((None, ) * 2, (None, ) * 2), voxel_size=gp.Coordinate((1, 1)), dtype=np.uint64, ), ) + nl.RasterizeSkeleton( point_trees, labels, gp.ArraySpec( roi=gp.Roi((None, ) * 2, (None, ) * 2), voxel_size=gp.Coordinate((1, 1)), dtype=np.uint64, ), use_component=True, n_objs=int(setup_config["HIDE_SIGNAL"]), ) + nl.GrowLabels(labels, radii=LABEL_RADII) + nl.GrowLabels(raw, radii=RAW_RADII) + LabelToFloat32(raw, intensities=RAW_INTENSITIES) + gp.NoiseAugment(raw, var=NOISE_VAR) + gp.PreCache(cache_size=CACHE_SIZE, num_workers=NUM_WORKERS) + gp.tensorflow.Train( "train_net", optimizer=create_custom_loss(mknet_tensor_names, setup_config), loss=None, inputs={ mknet_tensor_names["raw"]: raw, mknet_tensor_names["gt_labels"]: labels }, outputs={ mknet_tensor_names["embedding"]: embedding, mknet_tensor_names["fg"]: fg, "strided_slice_1:0": maxima, "gt_fg:0": gt_fg, loss_tensor_names["emst"]: emst, loss_tensor_names["edges_u"]: edges_u, loss_tensor_names["edges_v"]: edges_v, loss_tensor_names["ratio_pos"]: ratio_pos, loss_tensor_names["ratio_neg"]: ratio_neg, loss_tensor_names["dist"]: dist, loss_tensor_names["num_pos_pairs"]: num_pos_pairs, loss_tensor_names["num_neg_pairs"]: num_neg_pairs, }, gradients={ mknet_tensor_names["embedding"]: gradient_embedding, mknet_tensor_names["fg"]: gradient_fg, }, save_every=CHECKPOINT_EVERY, summary="Merge/MergeSummary:0", log_dir="tensorflow_logs", ) + gp.Snapshot( output_filename="{iteration}.hdf", dataset_names={ raw: "volumes/raw", labels: "volumes/labels", point_trees: "point_trees", embedding: "volumes/embedding", fg: "volumes/fg", maxima: "volumes/maxima", gt_fg: "volumes/gt_fg", gradient_embedding: "volumes/gradient_embedding", gradient_fg: "volumes/gradient_fg", emst: "emst", edges_u: "edges_u", edges_v: "edges_v", ratio_pos: "ratio_pos", ratio_neg: "ratio_neg", dist: "dist", num_pos_pairs: "num_pos_pairs", num_neg_pairs: "num_neg_pairs", }, dataset_dtypes={ maxima: np.float32, gt_fg: np.float32 }, every=SNAPSHOT_EVERY, additional_request=snapshot_request, ) # + gp.PrintProfilingStats(every=100) ) with gp.build(pipeline): for i in range(n_iterations + 1): pipeline.request_batch(request) request._update_random_seed()
def rasterize_graph( graph, position_attribute, radius_pos, radius_tolerance, roi, voxel_size): '''Rasterizes a geometric graph into a numpy array. For that, the nodes in the graph are assumed to have a position in 3D (see parameter ``position_attribute``). The created array will have edges painted with 1, background with 0, and (optionally) a tolerance region around each edge with -1. Args: graph (networkx graph): The graph to rasterize. Nodes need to have a position attribute. position_attribute (string): The name of the position attribute of the nodes. The attribute should contain tuples of the form ``(z, y, x)`` in world units. radius_pos (float): The radius of the lines to draw for each edge in the graph (in world units). radius_tolerance (float): The radius of a region around each edge line that will be labelled with ``np.uint64(-1)``. Should be larger than ``radius_pos``. If set to ``None``, no such label will be produced. roi (gp.Roi): The ROI of the area to rasterize. voxel_size (tuple of int): The size of a voxel in the array to create, in world units. ''' graph_key = gp.PointsKey('GRAPH') array = gp.ArrayKey('ARRAY') array_spec = gp.ArraySpec(voxel_size=voxel_size, dtype=np.uint64) pipeline_pos = ( NetworkxSource(graph, graph_key) + RasterizeSkeleton(graph_key, array, array_spec, radius_pos) + GrowLabels(array, tolerance, tolerance_spec, radius_tolerance)) request = gp.BatchRequest() request[array] = gp.ArraySpec(roi=roi) with gp.build(pipeline_pos): batch = pipeline_pos.request_batch(request) return batch[array].data
def train_until(max_iteration, return_intermediates=False): # get the latest checkpoint if tf.train.latest_checkpoint('.'): trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1]) else: trained_until = 0 if trained_until >= max_iteration: return # input data ch1 = gp.ArrayKey('CH1') ch2 = gp.ArrayKey('CH2') swc = gp.PointsKey('SWC') swc_env = gp.PointsKey('SWC_ENV') swc_center = gp.PointsKey('SWC_CENTER') gt = gp.ArrayKey('GT') gt_fg = gp.ArrayKey('GT_FG') # show fusion augment batches if return_intermediates: a_ch1 = gp.ArrayKey('A_CH1') a_ch2 = gp.ArrayKey('A_CH2') b_ch1 = gp.ArrayKey('B_CH1') b_ch2 = gp.ArrayKey('B_CH2') soft_mask = gp.ArrayKey('SOFT_MASK') # output data fg = gp.ArrayKey('FG') gradient_fg = gp.ArrayKey('GRADIENT_FG') loss_weights = gp.ArrayKey('LOSS_WEIGHTS') voxel_size = gp.Coordinate((4, 1, 1)) input_size = gp.Coordinate(net_config['input_shape']) * voxel_size output_size = gp.Coordinate(net_config['output_shape']) * voxel_size # add request request = gp.BatchRequest() request.add(ch1, input_size) request.add(ch2, input_size) request.add(swc, input_size) request.add(swc_center, output_size) request.add(gt, output_size) request.add(gt_fg, output_size) # request.add(loss_weights, output_size) if return_intermediates: request.add(a_ch1, input_size) request.add(a_ch2, input_size) request.add(b_ch1, input_size) request.add(b_ch2, input_size) request.add(soft_mask, input_size) # add snapshot request snapshot_request = gp.BatchRequest() # snapshot_request[fg] = request[gt] # snapshot_request[gt_fg] = request[gt] # snapshot_request[gradient_fg] = request[gt] data_sources = tuple() data_sources += tuple( (Hdf5ChannelSource(file, datasets={ ch1: '/volume', ch2: '/volume', }, channel_ids={ ch1: 0, ch2: 1, }, data_format='channels_last', array_specs={ ch1: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size, dtype=np.uint16), ch2: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size, dtype=np.uint16), }), SwcSource(filename=file, dataset='/reconstruction', points=(swc_center, swc), return_env=True, scale=voxel_size)) + gp.MergeProvider() + gp.RandomLocation(ensure_nonempty=swc_center) + RasterizeSkeleton( points=swc, array=gt, array_spec=gp.ArraySpec( interpolatable=False, voxel_size=voxel_size, dtype=np.uint32), points_env=swc_env, iteration=10) for file in files) snapshot_datasets = {} if return_intermediates: snapshot_datasets = { ch1: 'volumes/ch1', ch2: 'volumes/ch2', a_ch1: 'volumes/a_ch1', a_ch2: 'volumes/a_ch2', b_ch1: 'volumes/b_ch1', b_ch2: 'volumes/b_ch2', soft_mask: 'volumes/soft_mask', gt: 'volumes/gt', fg: 'volumes/fg', gt_fg: 'volumes/gt_fg', gradient_fg: 'volumes/gradient_fg', } else: snapshot_datasets = { ch1: 'volumes/ch1', ch2: 'volumes/ch2', gt: 'volumes/gt', fg: 'volumes/fg', gt_fg: 'volumes/gt_fg', gradient_fg: 'volumes/gradient_fg', } pipeline = ( data_sources + #gp.RandomProvider() + FusionAugment(ch1, ch2, gt, smoothness=1, return_intermediate=return_intermediates) + # augment #gp.ElasticAugment(...) + #gp.SimpleAugment() + gp.Normalize(ch1) + gp.Normalize(ch2) + gp.Normalize(a_ch1) + gp.Normalize(a_ch2) + gp.Normalize(b_ch1) + gp.Normalize(b_ch2) + gp.IntensityAugment(ch1, 0.9, 1.1, -0.001, 0.001) + gp.IntensityAugment(ch2, 0.9, 1.1, -0.001, 0.001) + BinarizeGt(gt, gt_fg) + # visualize gp.Snapshot(output_filename='snapshot_{iteration}.hdf', dataset_names=snapshot_datasets, additional_request=snapshot_request, every=20) + gp.PrintProfilingStats(every=1000)) with gp.build(pipeline): print("Starting training...") for i in range(max_iteration - trained_until): pipeline.request_batch(request)
def train(n_iterations): point_trees = gp.PointsKey("POINT_TREES") labels = gp.ArrayKey("LABELS") raw = gp.ArrayKey("RAW") # gt_fg = gp.ArrayKey("GT_FG") # embedding = gp.ArrayKey("EMBEDDING") # fg = gp.ArrayKey("FG") # maxima = gp.ArrayKey("MAXIMA") # gradient_embedding = gp.ArrayKey("GRADIENT_EMBEDDING") # gradient_fg = gp.ArrayKey("GRADIENT_FG") # emst = gp.ArrayKey("EMST") # edges_u = gp.ArrayKey("EDGES_U") # edges_v = gp.ArrayKey("EDGES_V") request = gp.BatchRequest() request.add(raw, INPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) request.add(labels, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) request.add(point_trees, INPUT_SHAPE) snapshot_request = gp.BatchRequest() snapshot_request.add(raw, INPUT_SHAPE) # snapshot_request.add(embedding, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) # snapshot_request.add(fg, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) # snapshot_request.add(gt_fg, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) # snapshot_request.add(maxima, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) # snapshot_request.add( # gradient_embedding, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1)) # ) # snapshot_request.add(gradient_fg, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) # snapshot_request[emst] = gp.ArraySpec() # snapshot_request[edges_u] = gp.ArraySpec() # snapshot_request[edges_v] = gp.ArraySpec() pipeline = ( nl.SyntheticLightLike( point_trees, dims=2, r=SKEL_GEN_RADIUS, n_obj=N_OBJS, thetas=THETAS, split_ps=SPLIT_PS, ) # + gp.SimpleAugment() # + gp.ElasticAugment([10, 10], [0.1, 0.1], [0, 2.0 * math.pi], spatial_dims=2) + nl.RasterizeSkeleton( point_trees, labels, gp.ArraySpec( roi=gp.Roi((None,) * 2, (None,) * 2), voxel_size=gp.Coordinate((1, 1)), dtype=np.uint64, ), ) + gp.Copy(labels, raw) + nl.GrowLabels(labels, radii=LABEL_RADII) + nl.GrowLabels(raw, radii=RAW_RADII) + LabelToFloat32(raw, intensities=RAW_INTENSITIES) + gp.NoiseAugment(raw, var=NOISE_VAR) # + gp.PreCache(cache_size=40, num_workers=10) # + gp.tensorflow.Train( # "train_net", # optimizer=add_loss, # loss=None, # inputs={tensor_names["raw"]: raw, tensor_names["gt_labels"]: labels}, # outputs={ # tensor_names["embedding"]: embedding, # tensor_names["fg"]: fg, # "maxima:0": maxima, # "gt_fg:0": gt_fg, # emst_name: emst, # edges_u_name: edges_u, # edges_v_name: edges_v, # }, # gradients={ # tensor_names["embedding"]: gradient_embedding, # tensor_names["fg"]: gradient_fg, # }, # ) + gp.Snapshot( output_filename="{iteration}.hdf", dataset_names={ raw: "volumes/raw", labels: "volumes/labels", point_trees: "point_trees", # embedding: "volumes/embedding", # fg: "volumes/fg", # maxima: "volumes/maxima", # gt_fg: "volumes/gt_fg", # gradient_embedding: "volumes/gradient_embedding", # gradient_fg: "volumes/gradient_fg", # emst: "emst", # edges_u: "edges_u", # edges_v: "edges_v", }, # dataset_dtypes={maxima: np.float32, gt_fg: np.float32}, every=100, additional_request=snapshot_request, ) + gp.PrintProfilingStats(every=10) ) with gp.build(pipeline): for i in range(n_iterations): pipeline.request_batch(request)
))] output_spec = copy.deepcopy(input_spec) output_spec.roi = output_roi output_array = gp.Array(output_data, output_spec) batch[self.output_array] = output_array input_size = Coordinate([74, 260, 260]) output_size = Coordinate([42, 168, 168]) path_to_data = Path("/nrs/funke/mouselight-v2") # array keys for data sources raw = gp.ArrayKey("RAW") swcs = gp.PointsKey("SWCS") labels = gp.ArrayKey("LABELS") # array keys for base volume raw_base = gp.ArrayKey("RAW_BASE") labels_base = gp.ArrayKey("LABELS_BASE") swc_base = gp.PointsKey("SWC_BASE") # array keys for add volume raw_add = gp.ArrayKey("RAW_ADD") labels_add = gp.ArrayKey("LABELS_ADD") swc_add = gp.PointsKey("SWC_ADD") # array keys for fused volume raw_fused = gp.ArrayKey("RAW_FUSED") labels_fused = gp.ArrayKey("LABELS_FUSED")
def train_distance_pipeline(n_iterations, setup_config, mknet_tensor_names, loss_tensor_names): input_shape = gp.Coordinate(setup_config["INPUT_SHAPE"]) output_shape = gp.Coordinate(setup_config["OUTPUT_SHAPE"]) voxel_size = gp.Coordinate(setup_config["VOXEL_SIZE"]) num_iterations = setup_config["NUM_ITERATIONS"] cache_size = setup_config["CACHE_SIZE"] num_workers = setup_config["NUM_WORKERS"] snapshot_every = setup_config["SNAPSHOT_EVERY"] checkpoint_every = setup_config["CHECKPOINT_EVERY"] profile_every = setup_config["PROFILE_EVERY"] seperate_by = setup_config["SEPERATE_BY"] gap_crossing_dist = setup_config["GAP_CROSSING_DIST"] match_distance_threshold = setup_config["MATCH_DISTANCE_THRESHOLD"] point_balance_radius = setup_config["POINT_BALANCE_RADIUS"] max_label_dist = setup_config["MAX_LABEL_DIST"] samples_path = Path(setup_config["SAMPLES_PATH"]) mongo_url = setup_config["MONGO_URL"] input_size = input_shape * voxel_size output_size = output_shape * voxel_size # voxels have size ~= 1 micron on z axis # use this value to scale anything that depends on world unit distance micron_scale = voxel_size[0] seperate_distance = (np.array(seperate_by)).tolist() # array keys for data sources raw = gp.ArrayKey("RAW") consensus = gp.PointsKey("CONSENSUS") skeletonization = gp.PointsKey("SKELETONIZATION") matched = gp.PointsKey("MATCHED") labels = gp.ArrayKey("LABELS") dist = gp.ArrayKey("DIST") dist_mask = gp.ArrayKey("DIST_MASK") dist_cropped = gp.ArrayKey("DIST_CROPPED") loss_weights = gp.ArrayKey("LOSS_WEIGHTS") # tensorflow tensors fg_dist = gp.ArrayKey("FG_DIST") gradient_fg = gp.ArrayKey("GRADIENT_FG") # add request request = gp.BatchRequest() request.add(dist_mask, output_size) request.add(dist_cropped, output_size) request.add(raw, input_size) request.add(labels, input_size) request.add(dist, input_size) request.add(matched, input_size) request.add(skeletonization, input_size) request.add(consensus, input_size) request.add(loss_weights, output_size) # add snapshot request snapshot_request = gp.BatchRequest() # tensorflow requests snapshot_request.add(raw, input_size) # input_size request for positioning snapshot_request.add(gradient_fg, output_size, voxel_size=voxel_size) snapshot_request.add(fg_dist, output_size, voxel_size=voxel_size) data_sources = tuple( ( gp.N5Source( filename=str((sample / "fluorescence-near-consensus.n5").absolute()), datasets={raw: "volume"}, array_specs={ raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size, dtype=np.uint16) }, ), gp.DaisyGraphProvider( f"mouselight-{sample.name}-consensus", mongo_url, points=[consensus], directed=True, node_attrs=[], edge_attrs=[], ), gp.DaisyGraphProvider( f"mouselight-{sample.name}-skeletonization", mongo_url, points=[skeletonization], directed=False, node_attrs=[], edge_attrs=[], ), ) + gp.MergeProvider() + gp.RandomLocation( ensure_nonempty=consensus, ensure_centered=True, point_balance_radius=point_balance_radius * micron_scale, ) + TopologicalMatcher( skeletonization, consensus, matched, failures=Path("matching_failures_slow"), match_distance_threshold=match_distance_threshold * micron_scale, max_gap_crossing=gap_crossing_dist * micron_scale, try_complete=False, use_gurobi=True, ) + RejectIfEmpty(matched, center_size=output_size) + RasterizeSkeleton( points=matched, array=labels, array_spec=gp.ArraySpec( interpolatable=False, voxel_size=voxel_size, dtype=np.uint32), ) + gp.contrib.nodes.add_distance.AddDistance( labels, dist, dist_mask, max_distance=max_label_dist * micron_scale) + gp.contrib.nodes. tanh_saturate.TanhSaturate(dist, scale=micron_scale, offset=1) + ThresholdMask(dist, loss_weights, 1e-4) # TODO: Do these need to be scaled by world units? + gp.ElasticAugment( [40, 10, 10], [0.25, 1, 1], [0, math.pi / 2.0], subsample=4, use_fast_points_transform=True, recompute_missing_points=False, ) # + gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2]) + gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001) for sample in samples_path.iterdir() if sample.name in ("2018-07-02", "2018-08-01")) pipeline = ( data_sources + gp.RandomProvider() + Crop(dist, dist_cropped) # + gp.PreCache(cache_size=cache_size, num_workers=num_workers) + gp.tensorflow.Train( "train_net_foreground", optimizer=mknet_tensor_names["optimizer"], loss=mknet_tensor_names["fg_loss"], inputs={ mknet_tensor_names["raw"]: raw, mknet_tensor_names["gt_distances"]: dist_cropped, mknet_tensor_names["loss_weights"]: loss_weights, }, outputs={mknet_tensor_names["fg_pred"]: fg_dist}, gradients={mknet_tensor_names["fg_pred"]: gradient_fg}, save_every=checkpoint_every, # summary=mknet_tensor_names["summaries"], log_dir="tensorflow_logs", ) + gp.PrintProfilingStats(every=profile_every) + gp.Snapshot( additional_request=snapshot_request, output_filename="snapshot_{}_{}.hdf".format( int(np.min(seperate_distance)), "{id}"), dataset_names={ # raw data raw: "volumes/raw", labels: "volumes/labels", # labeled data dist_cropped: "volumes/dist", # trees skeletonization: "points/skeletonization", consensus: "points/consensus", matched: "points/matched", # output volumes fg_dist: "volumes/fg_dist", gradient_fg: "volumes/gradient_fg", # output debug data dist_mask: "volumes/dist_mask", loss_weights: "volumes/loss_weights" }, every=snapshot_every, )) with gp.build(pipeline): for _ in range(num_iterations): pipeline.request_batch(request)
def train_simple_pipeline(n_iterations, setup_config, mknet_tensor_names, loss_tensor_names): input_shape = gp.Coordinate(setup_config["INPUT_SHAPE"]) output_shape = gp.Coordinate(setup_config["OUTPUT_SHAPE"]) voxel_size = gp.Coordinate(setup_config["VOXEL_SIZE"]) num_iterations = setup_config["NUM_ITERATIONS"] cache_size = setup_config["CACHE_SIZE"] num_workers = setup_config["NUM_WORKERS"] snapshot_every = setup_config["SNAPSHOT_EVERY"] checkpoint_every = setup_config["CHECKPOINT_EVERY"] profile_every = setup_config["PROFILE_EVERY"] seperate_by = setup_config["SEPERATE_BY"] gap_crossing_dist = setup_config["GAP_CROSSING_DIST"] match_distance_threshold = setup_config["MATCH_DISTANCE_THRESHOLD"] point_balance_radius = setup_config["POINT_BALANCE_RADIUS"] neuron_radius = setup_config["NEURON_RADIUS"] samples_path = Path(setup_config["SAMPLES_PATH"]) mongo_url = setup_config["MONGO_URL"] input_size = input_shape * voxel_size output_size = output_shape * voxel_size # voxels have size ~= 1 micron on z axis # use this value to scale anything that depends on world unit distance micron_scale = voxel_size[0] seperate_distance = (np.array(seperate_by)).tolist() # array keys for data sources raw = gp.ArrayKey("RAW") consensus = gp.PointsKey("CONSENSUS") skeletonization = gp.PointsKey("SKELETONIZATION") matched = gp.PointsKey("MATCHED") labels = gp.ArrayKey("LABELS") labels_fg = gp.ArrayKey("LABELS_FG") labels_fg_bin = gp.ArrayKey("LABELS_FG_BIN") loss_weights = gp.ArrayKey("LOSS_WEIGHTS") # tensorflow tensors gt_fg = gp.ArrayKey("GT_FG") fg_pred = gp.ArrayKey("FG_PRED") embedding = gp.ArrayKey("EMBEDDING") fg = gp.ArrayKey("FG") maxima = gp.ArrayKey("MAXIMA") gradient_embedding = gp.ArrayKey("GRADIENT_EMBEDDING") gradient_fg = gp.ArrayKey("GRADIENT_FG") emst = gp.ArrayKey("EMST") edges_u = gp.ArrayKey("EDGES_U") edges_v = gp.ArrayKey("EDGES_V") ratio_pos = gp.ArrayKey("RATIO_POS") ratio_neg = gp.ArrayKey("RATIO_NEG") dist = gp.ArrayKey("DIST") num_pos_pairs = gp.ArrayKey("NUM_POS") num_neg_pairs = gp.ArrayKey("NUM_NEG") # add request request = gp.BatchRequest() request.add(labels_fg, output_size) request.add(labels_fg_bin, output_size) request.add(loss_weights, output_size) request.add(raw, input_size) request.add(labels, input_size) request.add(matched, input_size) request.add(skeletonization, input_size) request.add(consensus, input_size) # add snapshot request snapshot_request = gp.BatchRequest() request.add(labels_fg, output_size) # tensorflow requests # snapshot_request.add(raw, input_size) # input_size request for positioning # snapshot_request.add(embedding, output_size, voxel_size=voxel_size) # snapshot_request.add(fg, output_size, voxel_size=voxel_size) # snapshot_request.add(gt_fg, output_size, voxel_size=voxel_size) # snapshot_request.add(fg_pred, output_size, voxel_size=voxel_size) # snapshot_request.add(maxima, output_size, voxel_size=voxel_size) # snapshot_request.add(gradient_embedding, output_size, voxel_size=voxel_size) # snapshot_request.add(gradient_fg, output_size, voxel_size=voxel_size) # snapshot_request[emst] = gp.ArraySpec() # snapshot_request[edges_u] = gp.ArraySpec() # snapshot_request[edges_v] = gp.ArraySpec() # snapshot_request[ratio_pos] = gp.ArraySpec() # snapshot_request[ratio_neg] = gp.ArraySpec() # snapshot_request[dist] = gp.ArraySpec() # snapshot_request[num_pos_pairs] = gp.ArraySpec() # snapshot_request[num_neg_pairs] = gp.ArraySpec() data_sources = tuple( ( gp.N5Source( filename=str((sample / "fluorescence-near-consensus.n5").absolute()), datasets={raw: "volume"}, array_specs={ raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size, dtype=np.uint16) }, ), gp.DaisyGraphProvider( f"mouselight-{sample.name}-consensus", mongo_url, points=[consensus], directed=True, node_attrs=[], edge_attrs=[], ), gp.DaisyGraphProvider( f"mouselight-{sample.name}-skeletonization", mongo_url, points=[skeletonization], directed=False, node_attrs=[], edge_attrs=[], ), ) + gp.MergeProvider() + gp.RandomLocation( ensure_nonempty=consensus, ensure_centered=True, point_balance_radius=point_balance_radius * micron_scale, ) + TopologicalMatcher( skeletonization, consensus, matched, failures=Path("matching_failures_slow"), match_distance_threshold=match_distance_threshold * micron_scale, max_gap_crossing=gap_crossing_dist * micron_scale, try_complete=False, use_gurobi=True, ) + RejectIfEmpty(matched) + RasterizeSkeleton( points=matched, array=labels, array_spec=gp.ArraySpec( interpolatable=False, voxel_size=voxel_size, dtype=np.uint32), ) + GrowLabels(labels, radii=[neuron_radius * micron_scale]) # TODO: Do these need to be scaled by world units? + gp.ElasticAugment( [40, 10, 10], [0.25, 1, 1], [0, math.pi / 2.0], subsample=4, use_fast_points_transform=True, recompute_missing_points=False, ) # + gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2]) + gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001) for sample in samples_path.iterdir() if sample.name in ("2018-07-02", "2018-08-01")) pipeline = ( data_sources + gp.RandomProvider() + Crop(labels, labels_fg) + BinarizeGt(labels_fg, labels_fg_bin) + gp.BalanceLabels(labels_fg_bin, loss_weights) + gp.PreCache(cache_size=cache_size, num_workers=num_workers) + gp.tensorflow.Train( "train_net", optimizer=create_custom_loss(mknet_tensor_names, setup_config), loss=None, inputs={ mknet_tensor_names["loss_weights"]: loss_weights, mknet_tensor_names["raw"]: raw, mknet_tensor_names["gt_labels"]: labels_fg, }, outputs={ mknet_tensor_names["embedding"]: embedding, mknet_tensor_names["fg"]: fg, loss_tensor_names["fg_pred"]: fg_pred, loss_tensor_names["maxima"]: maxima, loss_tensor_names["gt_fg"]: gt_fg, loss_tensor_names["emst"]: emst, loss_tensor_names["edges_u"]: edges_u, loss_tensor_names["edges_v"]: edges_v, loss_tensor_names["ratio_pos"]: ratio_pos, loss_tensor_names["ratio_neg"]: ratio_neg, loss_tensor_names["dist"]: dist, loss_tensor_names["num_pos_pairs"]: num_pos_pairs, loss_tensor_names["num_neg_pairs"]: num_neg_pairs, }, gradients={ mknet_tensor_names["embedding"]: gradient_embedding, mknet_tensor_names["fg"]: gradient_fg, }, save_every=checkpoint_every, summary="Merge/MergeSummary:0", log_dir="tensorflow_logs", ) + gp.PrintProfilingStats(every=profile_every) + gp.Snapshot( additional_request=snapshot_request, output_filename="snapshot_{}_{}.hdf".format( int(np.min(seperate_distance)), "{id}"), dataset_names={ # raw data raw: "volumes/raw", # labeled data labels: "volumes/labels", # trees skeletonization: "points/skeletonization", consensus: "points/consensus", matched: "points/matched", # output volumes embedding: "volumes/embedding", fg: "volumes/fg", maxima: "volumes/maxima", gt_fg: "volumes/gt_fg", fg_pred: "volumes/fg_pred", gradient_embedding: "volumes/gradient_embedding", gradient_fg: "volumes/gradient_fg", # output trees emst: "emst", edges_u: "edges_u", edges_v: "edges_v", # output debug data ratio_pos: "ratio_pos", ratio_neg: "ratio_neg", dist: "dist", num_pos_pairs: "num_pos_pairs", num_neg_pairs: "num_neg_pairs", loss_weights: "volumes/loss_weights", }, every=snapshot_every, )) with gp.build(pipeline): for _ in range(num_iterations): pipeline.request_batch(request)
def train_until(max_iteration): # get the latest checkpoint if tf.train.latest_checkpoint("."): trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1]) else: trained_until = 0 if trained_until >= max_iteration: return # array keys for data sources raw = gp.ArrayKey("RAW") swcs = gp.PointsKey("SWCS") voxel_size = gp.Coordinate((10, 3, 3)) input_size = gp.Coordinate(net_config["input_shape"]) * voxel_size * 2 # add request request = gp.BatchRequest() request.add(raw, input_size) request.add(swcs, input_size) data_sources = tuple(( gp.N5Source( filename=str(( filename / "consensus-neurons-with-machine-centerpoints-labelled-as-swcs-carved.n5" ).absolute()), datasets={raw: "volume"}, array_specs={ raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size, dtype=np.uint16) }, ), MouselightSwcFileSource( filename=str(( filename / "consensus-neurons-with-machine-centerpoints-labelled-as-swcs/G-002.swc" ).absolute()), points=(swcs, ), scale=voxel_size, transpose=(2, 1, 0), transform_file=str((filename / "transform.txt").absolute()), ), ) + gp.MergeProvider() + gp.RandomLocation(ensure_nonempty=swcs, ensure_centered=True) for filename in Path(sample_dir).iterdir() if "2018-08-01" in filename.name) pipeline = data_sources + gp.RandomProvider() with gp.build(pipeline): print("Starting training...") for i in range(max_iteration - trained_until): batch = pipeline.request_batch(request) vis_points_with_array(batch[raw].data, points_to_graph(batch[swcs].data), np.array(voxel_size))