def train_until(max_iteration, data_sources, input_shape, output_shape, dt_scaling_factor, loss_name): ArrayKey("RAW") ArrayKey("ALPHA_MASK") ArrayKey("GT_LABELS") ArrayKey("GT_MASK") ArrayKey("TRAINING_MASK") ArrayKey("GT_SCALE") ArrayKey("LOSS_GRADIENT") ArrayKey("GT_DIST") ArrayKey("PREDICTED_DIST_LABELS") data_providers = [] if cremi_version == "2016": cremi_dir = "/groups/saalfeld/saalfeldlab/larissa/data/cremi-2016/" filename = "sample_{0:}_padded_20160501." elif cremi_version == "2017": cremi_dir = "/groups/saalfeld/saalfeldlab/larissa/data/cremi-2017/" filename = "sample_{0:}_padded_20170424." if aligned: filename += "aligned." filename += "0bg.hdf" if tf.train.latest_checkpoint("."): trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1]) print("Resuming training from", trained_until) else: trained_until = 0 print("Starting fresh training") for sample in data_sources: print(sample) h5_source = Hdf5Source( os.path.join(cremi_dir, filename.format(sample)), datasets={ ArrayKeys.RAW: "volumes/raw", ArrayKeys.GT_LABELS: "volumes/labels/clefts", ArrayKeys.GT_MASK: "volumes/masks/groundtruth", ArrayKeys.TRAINING_MASK: "volumes/masks/validation", }, array_specs={ArrayKeys.GT_MASK: ArraySpec(interpolatable=False)}, ) data_providers.append(h5_source) with open("net_io_names.json", "r") as f: net_io_names = json.load(f) voxel_size = Coordinate((40, 4, 4)) input_size = Coordinate(input_shape) * voxel_size output_size = Coordinate(output_shape) * voxel_size context = input_size - output_size # specifiy which Arrays should be requested for each batch request = BatchRequest() request.add(ArrayKeys.RAW, input_size) request.add(ArrayKeys.GT_LABELS, output_size) request.add(ArrayKeys.GT_MASK, output_size) request.add(ArrayKeys.TRAINING_MASK, output_size) request.add(ArrayKeys.GT_SCALE, output_size) request.add(ArrayKeys.GT_DIST, output_size) # create a tuple of data sources, one for each HDF file data_sources = tuple( provider + Normalize(ArrayKeys.RAW) + IntensityScaleShift( # ensures RAW is in float in [0, 1] ArrayKeys.TRAINING_MASK, -1, 1) + # zero-pad provided RAW and GT_MASK to be able to draw batches close to # the boundary of the available data # size more or less irrelevant as followed by Reject Node Pad(ArrayKeys.RAW, None) + Pad(ArrayKeys.GT_MASK, None) + Pad(ArrayKeys.TRAINING_MASK, context) + RandomLocation(min_masked=0.99, mask=ArrayKeys.TRAINING_MASK) + Reject(ArrayKeys.GT_MASK) + Reject( # reject batches wich do contain less than 50% labelled data ArrayKeys.GT_LABELS, min_masked=0.0, reject_probability=0.95) for provider in data_providers) snapshot_request = BatchRequest({ ArrayKeys.LOSS_GRADIENT: request[ArrayKeys.GT_LABELS], ArrayKeys.PREDICTED_DIST_LABELS: request[ArrayKeys.GT_LABELS], ArrayKeys.LOSS_GRADIENT: request[ArrayKeys.GT_DIST], }) train_pipeline = ( data_sources + RandomProvider() + ElasticAugment( (4, 40, 40), (0.0, 0.0, 0.0), (0, math.pi / 2.0), prob_slip=0.0, prob_shift=0.0, max_misalign=0, subsample=8, ) + SimpleAugment(transpose_only=[1, 2], mirror_only=[1, 2]) + IntensityAugment( ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1, z_section_wise=False) + IntensityScaleShift(ArrayKeys.RAW, 2, -1) + ZeroOutConstSections(ArrayKeys.RAW) + AddDistance( label_array_key=ArrayKeys.GT_LABELS, distance_array_key=ArrayKeys.GT_DIST, normalize="tanh", normalize_args=dt_scaling_factor, ) + BalanceByThreshold( ArrayKeys.GT_LABELS, ArrayKeys.GT_SCALE, mask=ArrayKeys.GT_MASK) + PreCache(cache_size=40, num_workers=10) + Train( "unet", optimizer=net_io_names["optimizer"], loss=net_io_names[loss_name], inputs={ net_io_names["raw"]: ArrayKeys.RAW, net_io_names["gt_dist"]: ArrayKeys.GT_DIST, net_io_names["loss_weights"]: ArrayKeys.GT_SCALE, net_io_names["mask"]: ArrayKeys.GT_MASK, }, summary=net_io_names["summary"], log_dir="log", outputs={net_io_names["dist"]: ArrayKeys.PREDICTED_DIST_LABELS}, gradients={net_io_names["dist"]: ArrayKeys.LOSS_GRADIENT}, ) + Snapshot( { ArrayKeys.RAW: "volumes/raw", ArrayKeys.GT_LABELS: "volumes/labels/gt_clefts", ArrayKeys.GT_DIST: "volumes/labels/gt_clefts_dist", ArrayKeys.PREDICTED_DIST_LABELS: "volumes/labels/pred_clefts_dist", ArrayKeys.LOSS_GRADIENT: "volumes/loss_gradient", }, every=500, output_filename="batch_{iteration}.hdf", output_dir="snapshots/", additional_request=snapshot_request, ) + PrintProfilingStats(every=50)) print("Starting training...") with build(train_pipeline) as b: for i in range(max_iteration): b.request_batch(request) print("Training finished")
def train_until(max_iteration, data_sources, input_shape, output_shape, dt_scaling_factor, loss_name): ArrayKey('RAW') ArrayKey('ALPHA_MASK') ArrayKey('GT_LABELS') ArrayKey('GT_MASK') ArrayKey('TRAINING_MASK') ArrayKey('GT_SCALE') ArrayKey('LOSS_GRADIENT') ArrayKey('GT_DIST') ArrayKey('PREDICTED_DIST_LABELS') data_providers = [] cremi_dir = "/groups/saalfeld/saalfeldlab/larissa/data/cremi-2017/" if tf.train.latest_checkpoint('.'): trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1]) print('Resuming training from', trained_until) else: trained_until = 0 print('Starting fresh training') for sample in data_sources: print(sample) h5_source = Hdf5Source( os.path.join(cremi_dir, 'sample_'+sample+'_cleftsorig.hdf'), datasets={ ArrayKeys.RAW: 'volumes/raw', ArrayKeys.GT_LABELS: 'volumes/labels/clefts', ArrayKeys.GT_MASK: 'volumes/masks/groundtruth', ArrayKeys.TRAINING_MASK: 'volumes/masks/training' }, array_specs={ ArrayKeys.GT_MASK: ArraySpec(interpolatable=False) } ) data_providers.append(h5_source) #todo: dvid source with open('net_io_names.json', 'r') as f: net_io_names = json.load(f) voxel_size = Coordinate((40, 4, 4)) input_size = Coordinate(input_shape) * voxel_size output_size = Coordinate(output_shape) * voxel_size # input_size = Coordinate((132,)*3) * voxel_size # output_size = Coordinate((44,)*3) * voxel_size # specifiy which Arrays should be requested for each batch request = BatchRequest() request.add(ArrayKeys.RAW, input_size) request.add(ArrayKeys.GT_LABELS, output_size) request.add(ArrayKeys.GT_MASK, output_size) request.add(ArrayKeys.TRAINING_MASK, output_size) request.add(ArrayKeys.GT_SCALE, output_size) request.add(ArrayKeys.GT_DIST, output_size) # create a tuple of data sources, one for each HDF file data_sources = tuple( provider + Normalize(ArrayKeys.RAW) + # ensures RAW is in float in [0, 1] # zero-pad provided RAW and GT_MASK to be able to draw batches close to # the boundary of the available data # size more or less irrelevant as followed by Reject Node Pad( { ArrayKeys.RAW: Coordinate((8, 8, 8)) * voxel_size, ArrayKeys.GT_MASK: Coordinate((8, 8, 8)) * voxel_size, ArrayKeys.TRAINING_MASK: Coordinate((8, 8, 8)) * voxel_size #ArrayKeys.GT_LABELS: Coordinate((100, 100, 100)) * voxel_size # added later } ) + RandomLocation() + # chose a random location inside the provided arrays Reject(ArrayKeys.GT_MASK) + # reject batches wich do contain less than 50% labelled data Reject(ArrayKeys.TRAINING_MASK, min_masked=0.99) + Reject(ArrayKeys.GT_LABELS, min_masked=0.0, reject_probability=0.95) for provider in data_providers) snapshot_request = BatchRequest({ ArrayKeys.LOSS_GRADIENT: request[ArrayKeys.GT_LABELS], ArrayKeys.PREDICTED_DIST_LABELS: request[ArrayKeys.GT_LABELS], ArrayKeys.LOSS_GRADIENT: request[ArrayKeys.GT_DIST], }) artifact_source = ( Hdf5Source( os.path.join(cremi_dir, 'sample_ABC_padded_20160501.defects.hdf'), datasets={ ArrayKeys.RAW: 'defect_sections/raw', ArrayKeys.ALPHA_MASK: 'defect_sections/mask', }, array_specs={ ArrayKeys.RAW: ArraySpec(voxel_size=(40, 4, 4)), ArrayKeys.ALPHA_MASK: ArraySpec(voxel_size=(40, 4, 4)), } ) + RandomLocation(min_masked=0.05, mask=ArrayKeys.ALPHA_MASK) + Normalize(ArrayKeys.RAW) + IntensityAugment(ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) + ElasticAugment((4, 40, 40), (0, 2, 2), (0, math.pi/2.0), subsample=8) + SimpleAugment(transpose_only_xy=True) ) train_pipeline = ( data_sources + RandomProvider() + ElasticAugment((4, 40, 40), (0., 2., 2.), (0, math.pi/2.0), prob_slip=0.05, prob_shift=0.05, max_misalign=10, subsample=8) + SimpleAugment(transpose_only_xy=True) + IntensityAugment(ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) + DefectAugment(ArrayKeys.RAW, prob_missing=0.03, prob_low_contrast=0.01, prob_artifact=0.03, artifact_source=artifact_source, artifacts=ArrayKeys.RAW, artifacts_mask=ArrayKeys.ALPHA_MASK, contrast_scale=0.5) + IntensityScaleShift(ArrayKeys.RAW, 2, -1) + ZeroOutConstSections(ArrayKeys.RAW) + #GrowBoundary(steps=1) + #SplitAndRenumberSegmentationLabels() + #AddGtAffinities(malis.mknhood3d()) + AddBoundaryDistance(label_array_key=ArrayKeys.GT_LABELS, distance_array_key=ArrayKeys.GT_DIST, normalize='tanh', normalize_args=dt_scaling_factor ) + BalanceLabels(ArrayKeys.GT_LABELS, ArrayKeys.GT_SCALE, ArrayKeys.GT_MASK) + #BalanceByThreshold( # labels=ArrayKeys.GT_DIST, # scales= ArrayKeys.GT_SCALE) + #{ # ArrayKeys.GT_AFFINITIES: ArrayKeys.GT_SCALE # }, # { # ArrayKeys.GT_AFFINITIES: ArrayKeys.GT_MASK # }) + PreCache( cache_size=40, num_workers=10)+ Train( 'unet', optimizer=net_io_names['optimizer'], loss=net_io_names[loss_name], inputs={ net_io_names['raw']: ArrayKeys.RAW, net_io_names['gt_dist']: ArrayKeys.GT_DIST, net_io_names['loss_weights']: ArrayKeys.GT_SCALE }, summary=net_io_names['summary'], log_dir='log', outputs={ net_io_names['dist']: ArrayKeys.PREDICTED_DIST_LABELS }, gradients={ net_io_names['dist']: ArrayKeys.LOSS_GRADIENT }) + Snapshot({ ArrayKeys.RAW: 'volumes/raw', ArrayKeys.GT_LABELS: 'volumes/labels/gt_clefts', ArrayKeys.GT_DIST: 'volumes/labels/gt_clefts_dist', ArrayKeys.PREDICTED_DIST_LABELS: 'volumes/labels/pred_clefts_dist', ArrayKeys.LOSS_GRADIENT: 'volumes/loss_gradient', }, every=500, output_filename='batch_{iteration}.hdf', output_dir='snapshots/', additional_request=snapshot_request) + PrintProfilingStats(every=50)) print("Starting training...") with build(train_pipeline) as b: for i in range(max_iteration): b.request_batch(request) print("Training finished")
def train_until(max_iteration, data_sources, input_shape, output_shape, dt_scaling_factor, loss_name, labels): ArrayKey('RAW') ArrayKey('ALPHA_MASK') ArrayKey('GT_LABELS') ArrayKey('MASK') data_providers = [] data_dir = "/groups/saalfeld/saalfeldlab/larissa/data/cell/{0:}.n5" voxel_size = Coordinate((2, 2, 2)) input_size = Coordinate(input_shape) * voxel_size output_size = Coordinate(output_shape) * voxel_size if tf.train.latest_checkpoint('.'): trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1]) print('Resuming training from', trained_until) else: trained_until = 0 print('Starting fresh training') for src in data_sources: n5_source = N5Source( os.path.join(data_dir.format(src)), datasets={ ArrayKeys.RAW: 'volumes/raw', ArrayKeys.GT_LABELS: 'volumes/labels/all', ArrayKeys.MASK: 'volumes/mask' }, array_specs={ArrayKeys.MASK: ArraySpec(interpolatable=False)}) data_providers.append(n5_source) with open('net_io_names.json', 'r') as f: net_io_names = json.load(f) inputs = dict() inputs[net_io_names['raw']] = ArrayKeys.RAW outputs = dict() snapshot = dict() snapshot[ArrayKeys.RAW] = 'volumes/raw' snapshot[ArrayKeys.GT_LABELS] = 'volumes/labels/gt_labels' for label in labels: inputs[net_io_names['gt_' + label.labelname]] = label.gt_dist_key if label.scale_loss or label.scale_key is not None: inputs[net_io_names['w_' + label.labelname]] = label.scale_key outputs[net_io_names[label.labelname]] = label.pred_dist_key snapshot[ label.gt_dist_key] = 'volumes/labels/gt_dist_' + label.labelname snapshot[label. pred_dist_key] = 'volumes/labels/pred_dist_' + label.labelname # specifiy which Arrays should be requested for each batch request = BatchRequest() snapshot_request = BatchRequest() request.add(ArrayKeys.RAW, input_size, voxel_size=voxel_size) request.add(ArrayKeys.GT_LABELS, output_size, voxel_size=voxel_size) request.add(ArrayKeys.MASK, output_size, voxel_size=voxel_size) for label in labels: request.add(label.gt_dist_key, output_size, voxel_size=voxel_size) snapshot_request.add(label.pred_dist_key, output_size, voxel_size=voxel_size) if label.scale_loss: request.add(label.scale_key, output_size, voxel_size=voxel_size) # create a tuple of data sources, one for each HDF file data_sources = tuple( provider + Normalize(ArrayKeys.RAW) + # ensures RAW is in float in [0, 1] # zero-pad provided RAW and MASK to be able to draw batches close to # the boundary of the available data # size more or less irrelevant as followed by Reject Node Pad(ArrayKeys.RAW, None) + RandomLocation(min_masked=0.5, mask=ArrayKeys.MASK ) # chose a random location inside the provided arrays #Reject(ArrayKeys.MASK) # reject batches wich do contain less than 50% labelled data for provider in data_providers) train_pipeline = ( data_sources + RandomProvider() + ElasticAugment( (100, 100, 100), (10., 10., 10.), (0, math.pi / 2.0), prob_slip=0, prob_shift=0, max_misalign=0, subsample=8) + SimpleAugment() + #ElasticAugment((40, 1000, 1000), (10., 0., 0.), (0, 0), subsample=8) + IntensityAugment(ArrayKeys.RAW, 0.95, 1.05, -0.05, 0.05) + IntensityScaleShift(ArrayKeys.RAW, 2, -1) + ZeroOutConstSections(ArrayKeys.RAW)) for label in labels: train_pipeline += AddDistance(label_array_key=ArrayKeys.GT_LABELS, distance_array_key=label.gt_dist_key, normalize='tanh', normalize_args=dt_scaling_factor, label_id=label.labelid) train_pipeline = (train_pipeline) for label in labels: if label.scale_loss: train_pipeline += BalanceByThreshold(label.gt_dist_key, label.scale_key, mask=ArrayKeys.MASK) train_pipeline = (train_pipeline + PreCache(cache_size=40, num_workers=10) + Train('build', optimizer=net_io_names['optimizer'], loss=net_io_names[loss_name], inputs=inputs, summary=net_io_names['summary'], log_dir='log', outputs=outputs, gradients={}) + Snapshot(snapshot, every=500, output_filename='batch_{iteration}.hdf', output_dir='snapshots/', additional_request=snapshot_request) + PrintProfilingStats(every=50)) print("Starting training...") with build(train_pipeline) as b: for i in range(max_iteration): b.request_batch(request) print("Training finished")
def train_until( max_iteration, cremi_dir, data_sources, input_shape, output_shape, dt_scaling_factor, loss_name, cache_size=10, num_workers=10, ): ArrayKey("RAW") ArrayKey("ALPHA_MASK") ArrayKey("GT_SYN_LABELS") ArrayKey("GT_LABELS") ArrayKey("GT_MASK") ArrayKey("TRAINING_MASK") ArrayKey("GT_SYN_SCALE") ArrayKey("LOSS_GRADIENT") ArrayKey("GT_SYN_DIST") ArrayKey("PREDICTED_SYN_DIST") ArrayKey("GT_BDY_DIST") ArrayKey("PREDICTED_BDY_DIST") data_providers = [] if tf.train.latest_checkpoint("."): trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1]) print("Resuming training from", trained_until) else: trained_until = 0 print("Starting fresh training") for sample in data_sources: print(sample) h5_source = Hdf5Source( os.path.join(cremi_dir, "sample_" + sample + "_cleftsorig.hdf"), datasets={ ArrayKeys.RAW: "volumes/raw", ArrayKeys.GT_SYN_LABELS: "volumes/labels/clefts", ArrayKeys.GT_MASK: "volumes/masks/groundtruth", ArrayKeys.TRAINING_MASK: "volumes/masks/validation", ArrayKeys.GT_LABELS: "volumes/labels/neuron_ids", }, array_specs={ArrayKeys.GT_MASK: ArraySpec(interpolatable=False)}, ) data_providers.append(h5_source) # todo: dvid source with open("net_io_names.json", "r") as f: net_io_names = json.load(f) voxel_size = Coordinate((40, 4, 4)) input_size = Coordinate(input_shape) * voxel_size output_size = Coordinate(output_shape) * voxel_size # input_size = Coordinate((132,)*3) * voxel_size # output_size = Coordinate((44,)*3) * voxel_size # specifiy which volumes should be requested for each batch request = BatchRequest() request.add(ArrayKeys.RAW, input_size) request.add(ArrayKeys.GT_SYN_LABELS, output_size) request.add(ArrayKeys.GT_LABELS, output_size) request.add(ArrayKeys.GT_BDY_DIST, output_size) request.add(ArrayKeys.GT_MASK, output_size) request.add(ArrayKeys.TRAINING_MASK, output_size) request.add(ArrayKeys.GT_SYN_SCALE, output_size) request.add(ArrayKeys.GT_SYN_DIST, output_size) # create a tuple of data sources, one for each HDF file data_sources = tuple( provider + Normalize(ArrayKeys.RAW) + # ensures RAW is in float in [0, 1] # zero-pad provided RAW and GT_MASK to be able to draw batches close to # the boundary of the available data # size more or less irrelevant as followed by Reject Node Pad(ArrayKeys.RAW, None) + Pad(ArrayKeys.GT_MASK, None) + Pad(ArrayKeys.TRAINING_MASK, None) + RandomLocation() + Reject( # chose a random location inside the provided arrays ArrayKeys.GT_MASK ) + Reject( # reject batches wich do contain less than 50% labelled data ArrayKeys.TRAINING_MASK, min_masked=0.99 ) + Reject(ArrayKeys.GT_LABELS, min_masked=0.0, reject_probability=0.95) for provider in data_providers ) snapshot_request = BatchRequest( { ArrayKeys.LOSS_GRADIENT: request[ArrayKeys.GT_SYN_LABELS], ArrayKeys.PREDICTED_SYN_DIST: request[ArrayKeys.GT_SYN_LABELS], ArrayKeys.PREDICTED_BDY_DIST: request[ArrayKeys.GT_LABELS], ArrayKeys.LOSS_GRADIENT: request[ArrayKeys.GT_SYN_DIST], } ) artifact_source = ( Hdf5Source( os.path.join(cremi_dir, "sample_ABC_padded_20160501.defects.hdf"), datasets={ ArrayKeys.RAW: "defect_sections/raw", ArrayKeys.ALPHA_MASK: "defect_sections/mask", }, array_specs={ ArrayKeys.RAW: ArraySpec(voxel_size=(40, 4, 4)), ArrayKeys.ALPHA_MASK: ArraySpec(voxel_size=(40, 4, 4)), }, ) + RandomLocation(min_masked=0.05, mask=ArrayKeys.ALPHA_MASK) + Normalize(ArrayKeys.RAW) + IntensityAugment(ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) + ElasticAugment((4, 40, 40), (0, 2, 2), (0, math.pi / 2.0), subsample=8) + SimpleAugment(transpose_only=[1, 2]) ) train_pipeline = ( data_sources + RandomProvider() + ElasticAugment( (4, 40, 40), (0.0, 2.0, 2.0), (0, math.pi / 2.0), prob_slip=0.05, prob_shift=0.05, max_misalign=10, subsample=8, ) + SimpleAugment(transpose_only=[1, 2]) + IntensityAugment(ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) + DefectAugment( ArrayKeys.RAW, prob_missing=0.03, prob_low_contrast=0.01, prob_artifact=0.03, artifact_source=artifact_source, artifacts=ArrayKeys.RAW, artifacts_mask=ArrayKeys.ALPHA_MASK, contrast_scale=0.5, ) + IntensityScaleShift(ArrayKeys.RAW, 2, -1) + ZeroOutConstSections(ArrayKeys.RAW) + GrowBoundary(ArrayKeys.GT_LABELS, ArrayKeys.GT_MASK, steps=1, only_xy=True) + AddBoundaryDistance( label_array_key=ArrayKeys.GT_LABELS, distance_array_key=ArrayKeys.GT_BDY_DIST, normalize="tanh", normalize_args=100, ) + AddDistance( label_array_key=ArrayKeys.GT_SYN_LABELS, distance_array_key=ArrayKeys.GT_SYN_DIST, normalize="tanh", normalize_args=dt_scaling_factor, ) + BalanceLabels( ArrayKeys.GT_SYN_LABELS, ArrayKeys.GT_SYN_SCALE, ArrayKeys.GT_MASK ) + PreCache(cache_size=cache_size, num_workers=num_workers) + Train( "unet", optimizer=net_io_names["optimizer"], loss=net_io_names[loss_name], inputs={ net_io_names["raw"]: ArrayKeys.RAW, net_io_names["gt_syn_dist"]: ArrayKeys.GT_SYN_DIST, net_io_names["gt_bdy_dist"]: ArrayKeys.GT_BDY_DIST, net_io_names["loss_weights"]: ArrayKeys.GT_SYN_SCALE, net_io_names["mask"]: ArrayKeys.GT_MASK, }, summary=net_io_names["summary"], log_dir="log", outputs={ net_io_names["syn_dist"]: ArrayKeys.PREDICTED_SYN_DIST, net_io_names["bdy_dist"]: ArrayKeys.PREDICTED_BDY_DIST, }, gradients={net_io_names["syn_dist"]: ArrayKeys.LOSS_GRADIENT}, ) + Snapshot( { ArrayKeys.RAW: "volumes/raw", ArrayKeys.GT_SYN_LABELS: "volumes/labels/gt_clefts", ArrayKeys.GT_SYN_DIST: "volumes/labels/gt_clefts_dist", ArrayKeys.PREDICTED_SYN_DIST: "volumes/labels/pred_clefts_dist", ArrayKeys.LOSS_GRADIENT: "volumes/loss_gradient", ArrayKeys.GT_LABELS: "volumes/labels/neuron_ids", ArrayKeys.PREDICTED_BDY_DIST: "volumes/labels/pred_bdy_dist", ArrayKeys.GT_BDY_DIST: "volumes/labels/gt_bdy_dist", }, every=500, output_filename="batch_{iteration}.hdf", output_dir="snapshots/", additional_request=snapshot_request, ) + PrintProfilingStats(every=50) ) print("Starting training...") with build(train_pipeline) as b: for i in range(max_iteration): b.request_batch(request) print("Training finished")
def train_until(max_iteration, data_sources, input_shape, output_shape, dt_scaling_factor, loss_name): ArrayKey("RAW") ArrayKey("ALPHA_MASK") ArrayKey("GT_LABELS") ArrayKey("GT_SCALE") ArrayKey("LOSS_GRADIENT") ArrayKey("GT_DIST") ArrayKey("PREDICTED_DIST") data_providers = [] fib25_dir = "/groups/saalfeld/saalfeldlab/larissa/data/gunpowder/fib25/" if "fib25h5" in data_sources: for volume_name in ( "tstvol-520-1", "tstvol-520-2", "trvol-250-1", "trvol-250-2", ): h5_source = Hdf5Source( os.path.join(fib25_dir, volume_name + ".hdf"), datasets={ ArrayKeys.RAW: "volumes/raw", ArrayKeys.GT_LABELS: "volumes/labels/clefts", ArrayKeys.GT_MASK: "volumes/masks/groundtruth", }, volume_specs={Array.GT_MASK: ArraySpec(interpolatable=False)}, ) data_providers.append(h5_source) fib19_dir = "/groups/saalfeld/saalfeldlab/larissa/fib19" # if 'fib19h5' in data_sources: # for volume_name in ("trvol-250", "trvol-600"): # h5_source = prepare_h5source(fib19_dir, volume_name) # data_providers.append(h5_source) # todo: dvid source with open("net_io_names.json", "r") as f: net_io_names = json.load(f) voxel_size = Coordinate((8, 8, 8)) input_size = Coordinate((196, ) * 3) * voxel_size output_size = Coordinate((92, ) * 3) * voxel_size # input_size = Coordinate((132,)*3) * voxel_size # output_size = Coordinate((44,)*3) * voxel_size # specifiy which volumes should be requested for each batch request = BatchRequest() request.add(ArrayKeys.RAW, input_size) request.add(ArrayKeys.GT_LABELS, output_size) request.add(ArrayKeys.GT_MASK, output_size) # request.add(VolumeTypes.GT_SCALE, output_size) request.add(ArrayKeys.GT_DIST, output_size) # create a tuple of data sources, one for each HDF file data_sources = tuple( provider + Normalize() + # ensures RAW is in float in [0, 1] # zero-pad provided RAW and GT_MASK to be able to draw batches close to # the boundary of the available data # size more or less irrelevant as followed by Reject Node Pad(ArrayKeys.RAW, None) + Pad(ArrayKeys.GT_MASK, None) + RandomLocation() + Reject( ) # chose a random location inside the provided volumes # reject batches wich do contain less than 50% labelled data for provider in data_providers) snapshot_request = BatchRequest({ ArrayKeys.LOSS_GRADIENT: request[ArrayKeys.GT_LABELS], ArrayKeys.PREDICTED_DIST: request[ArrayKeys.GT_LABELS], ArrayKeys.LOSS_GRADIENT: request[ArrayKeys.GT_DIST], }) # artifact_source = ( # Hdf5Source( # os.path.join(data_dir, 'sample_ABC_padded_20160501.defects.hdf'), # datasets = { # VolumeTypes.RAW: 'defect_sections/raw', # VolumeTypes.ALPHA_MASK: 'defect_sections/mask', # }, # volume_specs = { # VolumeTypes.RAW: VolumeSpec(voxel_size=(40, 4, 4)), # VolumeTypes.ALPHA_MASK: VolumeSpec(voxel_size=(40, 4, 4)), # } # ) + # RandomLocation(min_masked=0.05, mask_volume_type=VolumeTypes.ALPHA_MASK) + # Normalize() + # IntensityAugment(0.9, 1.1, -0.1, 0.1, z_section_wise=True) + # ElasticAugment([4,40,40], [0,2,2], [0,math.pi/2.0], subsample=8) + # SimpleAugment(transpose_only_xy=True) # ) train_pipeline = ( data_sources + RandomProvider() + ElasticAugment( [40, 40, 40], [2, 2, 2], [0, math.pi / 2.0], prob_slip=0.01, prob_shift=0.05, max_misalign=1, subsample=8, ) + SimpleAugment() + IntensityAugment(ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1) + IntensityScaleShift(ArrayKeys.RAW, 2, -1) + ZeroOutConstSections(ArrayKeys.RAW) + GrowBoundary(steps=1) + # SplitAndRenumberSegmentationLabels() + # AddGtAffinities(malis.mknhood3d()) + AddBoundaryDistance( label_volume_type=ArrayKeys.GT_LABELS, distance_volume_type=ArrayKeys.GT_DIST, normalize="tanh", normalize_args=dt_scaling_factor, ) + BalanceLabels(ArrayKeys.GT_LABELs, ArrayKeys.GT_SCALE, ArrayKeys.GT_MASK) + # BalanceByThreshold( # labels=VolumeTypes.GT_DIST, # scales= VolumeTypes.GT_SCALE) + # { # VolumeTypes.GT_AFFINITIES: VolumeTypes.GT_SCALE # }, # { # VolumeTypes.GT_AFFINITIES: VolumeTypes.GT_MASK # }) + PreCache(cache_size=40, num_workers=10) + # DefectAugment( # prob_missing=0.03, # prob_low_contrast=0.01, # prob_artifact=0.03, # artifact_source=artifact_source, # contrast_scale=0.5) + Train( "unet", optimizer=net_io_names["optimizer"], loss=net_io_names[loss_name], inputs={ net_io_names["raw"]: ArrayKeys.RAW, net_io_names["gt_dist"]: ArrayKeys.GT_DIST, # net_io_names['loss_weights']: VolumeTypes.GT_SCALE }, summary=net_io_names["summary"], log_dir="log", outputs={net_io_names["dist"]: ArrayKeys.PREDICTED_DIST}, gradients={net_io_names["dist"]: ArrayKeys.LOSS_GRADIENT}, ) + Snapshot( { ArrayKeys.RAW: "volumes/raw", ArrayKeys.GT_LABELS: "volumes/labels/neuron_ids", ArrayKeys.GT_DIST: "volumes/labels/distances", ArrayKeys.PREDICTED_DIST: "volumes/labels/pred_distances", ArrayKeys.LOSS_GRADIENT: "volumes/loss_gradient", }, every=1000, output_filename="batch_{iteration}.hdf", output_dir="snapshots/", additional_request=snapshot_request, ) + PrintProfilingStats(every=10)) print("Starting training...") with build(train_pipeline) as b: for i in range(max_iteration): b.request_batch(request) print("Training finished")
def train_until( max_iteration, data_sources, input_shape, output_shape, dt_scaling_factor, loss_name ): raw = ArrayKey("RAW") # ArrayKey('ALPHA_MASK') clefts = ArrayKey("GT_LABELS") mask = ArrayKey("GT_MASK") scale = ArrayKey("GT_SCALE") # grad = ArrayKey('LOSS_GRADIENT') gt_dist = ArrayKey("GT_DIST") pred_dist = ArrayKey("PREDICTED_DIST") data_providers = [] if tf.train.latest_checkpoint("."): trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1]) print("Resuming training from", trained_until) else: trained_until = 0 print("Starting fresh training") if trained_until >= max_iteration: return data_dir = "/groups/saalfeld/saalfeldlab/larissa/data/fib19/mine/" for sample in data_sources: print(sample) h5_source = Hdf5Source( os.path.join(data_dir, "cube{0:}.hdf".format(sample)), datasets={ raw: "volumes/raw", clefts: "volumes/labels/clefts", mask: "/volumes/masks/groundtruth", }, array_specs={mask: ArraySpec(interpolatable=False)}, ) data_providers.append(h5_source) with open("net_io_names.json", "r") as f: net_io_names = json.load(f) voxel_size = Coordinate((8, 8, 8)) input_size = Coordinate(input_shape) * voxel_size output_size = Coordinate(output_shape) * voxel_size # input_size = Coordinate((132,)*3) * voxel_size # output_size = Coordinate((44,)*3) * voxel_size # specifiy which Arrays should be requested for each batch request = BatchRequest() request.add(raw, input_size) request.add(clefts, output_size) request.add(mask, output_size) # request.add(ArrayKeys.TRAINING_MASK, output_size) request.add(scale, output_size) request.add(gt_dist, output_size) # create a tuple of data sources, one for each HDF file data_sources = tuple( provider + Normalize(ArrayKeys.RAW) + # ensures RAW is in float in [0, 1] # zero-pad provided RAW and GT_MASK to be able to draw batches close to # the boundary of the available data # size more or less irrelevant as followed by Reject Node Pad(raw, None) + RandomLocation() + # chose a random location inside the provided arrays # Reject(ArrayKeys.GT_MASK) + # reject batches wich do contain less than 50% labelled data # Reject(ArrayKeys.TRAINING_MASK, min_masked=0.99) + Reject(mask=mask) + Reject(clefts, min_masked=0.0, reject_probability=0.95) for provider in data_providers ) snapshot_request = BatchRequest({pred_dist: request[clefts]}) train_pipeline = ( data_sources + RandomProvider() + ElasticAugment( (40, 40, 40), (2.0, 2.0, 2.0), (0, math.pi / 2.0), prob_slip=0.01, prob_shift=0.01, max_misalign=1, subsample=8, ) + SimpleAugment() + IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1) + IntensityScaleShift(raw, 2, -1) + ZeroOutConstSections(raw) + # GrowBoundary(steps=1) + # SplitAndRenumberSegmentationLabels() + # AddGtAffinities(malis.mknhood3d()) + AddDistance( label_array_key=clefts, distance_array_key=gt_dist, normalize="tanh", normalize_args=dt_scaling_factor, ) + # BalanceLabels(clefts, scale, mask) + BalanceByThreshold(labels=ArrayKeys.GT_DIST, scales=ArrayKeys.GT_SCALE) + # { # ArrayKeys.GT_AFFINITIES: ArrayKeys.GT_SCALE # }, # { # ArrayKeys.GT_AFFINITIES: ArrayKeys.GT_MASK # }) + PreCache(cache_size=40, num_workers=10) + Train( "unet", optimizer=net_io_names["optimizer"], loss=net_io_names[loss_name], inputs={ net_io_names["raw"]: raw, net_io_names["gt_dist"]: gt_dist, net_io_names["loss_weights"]: scale, }, summary=net_io_names["summary"], log_dir="log", outputs={net_io_names["dist"]: pred_dist}, gradients={}, ) + Snapshot( { raw: "volumes/raw", clefts: "volumes/labels/gt_clefts", gt_dist: "volumes/labels/gt_clefts_dist", pred_dist: "volumes/labels/pred_clefts_dist", }, dataset_dtypes={clefts: np.uint64}, every=500, output_filename="batch_{iteration}.hdf", output_dir="snapshots/", additional_request=snapshot_request, ) + PrintProfilingStats(every=50) ) print("Starting training...") with build(train_pipeline) as b: for i in range(max_iteration - trained_until): b.request_batch(request) print("Training finished")
def train_until( data_providers, affinity_neighborhood, meta_graph_filename, stop, input_shape, output_shape, loss, optimizer, tensor_affinities, tensor_affinities_mask, tensor_glia, tensor_glia_mask, summary, save_checkpoint_every, pre_cache_size, pre_cache_num_workers, snapshot_every, balance_labels, renumber_connected_components, network_inputs, ignore_labels_for_slip, grow_boundaries, mask_out_labels, snapshot_dir): ignore_keys_for_slip = (LABELS_KEY, GT_MASK_KEY, GT_GLIA_KEY, GLIA_MASK_KEY, UNLABELED_KEY) if ignore_labels_for_slip else () defect_dir = '/groups/saalfeld/home/hanslovskyp/experiments/quasi-isotropic/data/defects' if tf.train.latest_checkpoint('.'): trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1]) print('Resuming training from', trained_until) else: trained_until = 0 print('Starting fresh training') input_voxel_size = Coordinate((120, 12, 12)) * 3 output_voxel_size = Coordinate((40, 36, 36)) * 3 input_size = Coordinate(input_shape) * input_voxel_size output_size = Coordinate(output_shape) * output_voxel_size num_affinities = sum(len(nh) for nh in affinity_neighborhood) gt_affinities_size = Coordinate((num_affinities,) + tuple(s for s in output_size)) print("gt affinities size", gt_affinities_size) # TODO why is GT_AFFINITIES three-dimensional? compare to # TODO https://github.com/funkey/gunpowder/blob/master/examples/cremi/train.py#L35 # TODO Use glia scale somehow, probably not possible with tensorflow 1.3 because it does not know uint64... # specifiy which Arrays should be requested for each batch request = BatchRequest() request.add(RAW_KEY, input_size, voxel_size=input_voxel_size) request.add(LABELS_KEY, output_size, voxel_size=output_voxel_size) request.add(GT_AFFINITIES_KEY, output_size, voxel_size=output_voxel_size) request.add(AFFINITIES_MASK_KEY, output_size, voxel_size=output_voxel_size) request.add(GT_MASK_KEY, output_size, voxel_size=output_voxel_size) request.add(GLIA_MASK_KEY, output_size, voxel_size=output_voxel_size) request.add(GLIA_KEY, output_size, voxel_size=output_voxel_size) request.add(GT_GLIA_KEY, output_size, voxel_size=output_voxel_size) request.add(UNLABELED_KEY, output_size, voxel_size=output_voxel_size) if balance_labels: request.add(AFFINITIES_SCALE_KEY, output_size, voxel_size=output_voxel_size) # always balance glia labels! request.add(GLIA_SCALE_KEY, output_size, voxel_size=output_voxel_size) network_inputs[tensor_affinities_mask] = AFFINITIES_SCALE_KEY if balance_labels else AFFINITIES_MASK_KEY network_inputs[tensor_glia_mask] = GLIA_SCALE_KEY#GLIA_SCALE_KEY if balance_labels else GLIA_MASK_KEY # create a tuple of data sources, one for each HDF file data_sources = tuple( provider + Normalize(RAW_KEY) + # ensures RAW is in float in [0, 1] # zero-pad provided RAW and GT_MASK to be able to draw batches close to # the boundary of the available data # size more or less irrelevant as followed by Reject Node Pad(RAW_KEY, None) + Pad(GT_MASK_KEY, None) + Pad(GLIA_MASK_KEY, None) + Pad(LABELS_KEY, size=NETWORK_OUTPUT_SHAPE / 2, value=np.uint64(-3)) + Pad(GT_GLIA_KEY, size=NETWORK_OUTPUT_SHAPE / 2) + # Pad(LABELS_KEY, None) + # Pad(GT_GLIA_KEY, None) + RandomLocation() + # chose a random location inside the provided arrays Reject(mask=GT_MASK_KEY, min_masked=0.5) + Reject(mask=GLIA_MASK_KEY, min_masked=0.5) + MapNumpyArray(lambda array: np.require(array, dtype=np.int64), GT_GLIA_KEY) # this is necessary because gunpowder 1.3 only understands int64, not uint64 for provider in data_providers) # TODO figure out what this is for snapshot_request = BatchRequest({ LOSS_GRADIENT_KEY : request[GT_AFFINITIES_KEY], AFFINITIES_KEY : request[GT_AFFINITIES_KEY], }) # no need to do anything here. random sections will be replaced with sections from this source (only raw) artifact_source = ( Hdf5Source( os.path.join(defect_dir, 'sample_ABC_padded_20160501.defects.hdf'), datasets={ RAW_KEY : 'defect_sections/raw', DEFECT_MASK_KEY : 'defect_sections/mask', }, array_specs={ RAW_KEY : ArraySpec(voxel_size=input_voxel_size), DEFECT_MASK_KEY : ArraySpec(voxel_size=input_voxel_size), } ) + RandomLocation(min_masked=0.05, mask=DEFECT_MASK_KEY) + Normalize(RAW_KEY) + IntensityAugment(RAW_KEY, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) + ElasticAugment( voxel_size=(360, 36, 36), control_point_spacing=(4, 40, 40), control_point_displacement_sigma=(0, 2 * 36, 2 * 36), rotation_interval=(0, math.pi / 2.0), subsample=8 ) + SimpleAugment(transpose_only=[1,2]) ) train_pipeline = data_sources train_pipeline += RandomProvider() train_pipeline += ElasticAugment( voxel_size=(360, 36, 36), control_point_spacing=(4, 40, 40), control_point_displacement_sigma=(0, 2 * 36, 2 * 36), rotation_interval=(0, math.pi / 2.0), augmentation_probability=0.5, subsample=8 ) # train_pipeline += Log.log_numpy_array_stats_after_process(GT_MASK_KEY, 'min', 'max', 'dtype', logging_prefix='%s: before misalign: ' % GT_MASK_KEY) train_pipeline += Misalign(z_resolution=360, prob_slip=0.05, prob_shift=0.05, max_misalign=(360,) * 2, ignore_keys_for_slip=ignore_keys_for_slip) # train_pipeline += Log.log_numpy_array_stats_after_process(GT_MASK_KEY, 'min', 'max', 'dtype', logging_prefix='%s: after misalign: ' % GT_MASK_KEY) train_pipeline += SimpleAugment(transpose_only=[1,2]) train_pipeline += IntensityAugment(RAW_KEY, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) train_pipeline += DefectAugment(RAW_KEY, prob_missing=0.03, prob_low_contrast=0.01, prob_artifact=0.03, artifact_source=artifact_source, artifacts=RAW_KEY, artifacts_mask=DEFECT_MASK_KEY, contrast_scale=0.5) train_pipeline += IntensityScaleShift(RAW_KEY, 2, -1) train_pipeline += ZeroOutConstSections(RAW_KEY) if grow_boundaries > 0: train_pipeline += GrowBoundary(LABELS_KEY, GT_MASK_KEY, steps=grow_boundaries, only_xy=True) _logger.info("Renumbering connected components? %s", renumber_connected_components) if renumber_connected_components: train_pipeline += RenumberConnectedComponents(labels=LABELS_KEY) train_pipeline += NewKeyFromNumpyArray(lambda array: 1 - array, GT_GLIA_KEY, UNLABELED_KEY) if len(mask_out_labels) > 0: train_pipeline += MaskOutLabels(label_key=LABELS_KEY, mask_key=GT_MASK_KEY, ids_to_be_masked=mask_out_labels) # labels_mask: anything that connects into labels_mask will be zeroed out # unlabelled: anyhing that points into unlabeled will have zero affinity; # affinities within unlabelled will be masked out train_pipeline += AddAffinities( affinity_neighborhood=affinity_neighborhood, labels=LABELS_KEY, labels_mask=GT_MASK_KEY, affinities=GT_AFFINITIES_KEY, affinities_mask=AFFINITIES_MASK_KEY, unlabelled=UNLABELED_KEY ) snapshot_datasets = { RAW_KEY: 'volumes/raw', LABELS_KEY: 'volumes/labels/neuron_ids', GT_AFFINITIES_KEY: 'volumes/affinities/gt', GT_GLIA_KEY: 'volumes/labels/glia_gt', UNLABELED_KEY: 'volumes/labels/unlabeled', AFFINITIES_KEY: 'volumes/affinities/prediction', LOSS_GRADIENT_KEY: 'volumes/loss_gradient', AFFINITIES_MASK_KEY: 'masks/affinities', GLIA_KEY: 'volumes/labels/glia_pred', GT_MASK_KEY: 'masks/gt', GLIA_MASK_KEY: 'masks/glia'} if balance_labels: train_pipeline += BalanceLabels(labels=GT_AFFINITIES_KEY, scales=AFFINITIES_SCALE_KEY, mask=AFFINITIES_MASK_KEY) snapshot_datasets[AFFINITIES_SCALE_KEY] = 'masks/affinity-scale' train_pipeline += BalanceLabels(labels=GT_GLIA_KEY, scales=GLIA_SCALE_KEY, mask=GLIA_MASK_KEY) snapshot_datasets[GLIA_SCALE_KEY] = 'masks/glia-scale' if (pre_cache_size > 0 and pre_cache_num_workers > 0): train_pipeline += PreCache(cache_size=pre_cache_size, num_workers=pre_cache_num_workers) train_pipeline += Train( summary=summary, graph=meta_graph_filename, save_every=save_checkpoint_every, optimizer=optimizer, loss=loss, inputs=network_inputs, log_dir='log', outputs={tensor_affinities: AFFINITIES_KEY, tensor_glia: GLIA_KEY}, gradients={tensor_affinities: LOSS_GRADIENT_KEY}, array_specs={ AFFINITIES_KEY : ArraySpec(voxel_size=output_voxel_size), LOSS_GRADIENT_KEY : ArraySpec(voxel_size=output_voxel_size), AFFINITIES_MASK_KEY : ArraySpec(voxel_size=output_voxel_size), GT_MASK_KEY : ArraySpec(voxel_size=output_voxel_size), AFFINITIES_SCALE_KEY : ArraySpec(voxel_size=output_voxel_size), GLIA_MASK_KEY : ArraySpec(voxel_size=output_voxel_size), GLIA_SCALE_KEY : ArraySpec(voxel_size=output_voxel_size), GLIA_KEY : ArraySpec(voxel_size=output_voxel_size) } ) train_pipeline += Snapshot( snapshot_datasets, every=snapshot_every, output_filename='batch_{iteration}.hdf', output_dir=snapshot_dir, additional_request=snapshot_request, attributes_callback=Snapshot.default_attributes_callback()) train_pipeline += PrintProfilingStats(every=50) print("Starting training...") with build(train_pipeline) as b: for i in range(trained_until, stop): b.request_batch(request) print("Training finished")
def train_until( data_providers, affinity_neighborhood, meta_graph_filename, stop, input_shape, output_shape, loss, optimizer, tensor_affinities, tensor_affinities_nn, tensor_affinities_mask, summary, save_checkpoint_every, pre_cache_size, pre_cache_num_workers, snapshot_every, balance_labels, renumber_connected_components, network_inputs, ignore_labels_for_slip, grow_boundaries): ignore_keys_for_slip = (GT_LABELS_KEY, GT_MASK_KEY) if ignore_labels_for_slip else () defect_dir = '/groups/saalfeld/home/hanslovskyp/experiments/quasi-isotropic/data/defects' if tf.train.latest_checkpoint('.'): trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1]) print('Resuming training from', trained_until) else: trained_until = 0 print('Starting fresh training') input_voxel_size = Coordinate((120, 12, 12)) * 3 output_voxel_size = Coordinate((40, 36, 36)) * 3 input_size = Coordinate(input_shape) * input_voxel_size output_size = Coordinate(output_shape) * output_voxel_size output_size_nn = Coordinate(s - 2 for s in output_shape) * output_voxel_size num_affinities = sum(len(nh) for nh in affinity_neighborhood) gt_affinities_size = Coordinate((num_affinities,) + tuple(s for s in output_size)) print("gt affinities size", gt_affinities_size) # TODO why is GT_AFFINITIES three-dimensional? compare to # TODO https://github.com/funkey/gunpowder/blob/master/examples/cremi/train.py#L35 # specifiy which Arrays should be requested for each batch request = BatchRequest() request.add(RAW_KEY, input_size, voxel_size=input_voxel_size) request.add(GT_LABELS_KEY, output_size, voxel_size=output_voxel_size) request.add(GT_AFFINITIES_KEY, output_size, voxel_size=output_voxel_size) request.add(AFFINITIES_MASK_KEY, output_size, voxel_size=output_voxel_size) request.add(GT_MASK_KEY, output_size, voxel_size=output_voxel_size) request.add(AFFINITIES_NN_KEY, output_size_nn, voxel_size=output_voxel_size) if balance_labels: request.add(AFFINITIES_SCALE_KEY, output_size, voxel_size=output_voxel_size) network_inputs[tensor_affinities_mask] = AFFINITIES_SCALE_KEY if balance_labels else AFFINITIES_MASK_KEY # create a tuple of data sources, one for each HDF file data_sources = tuple( provider + Normalize(RAW_KEY) + # ensures RAW is in float in [0, 1] # zero-pad provided RAW and GT_MASK to be able to draw batches close to # the boundary of the available data # size more or less irrelevant as followed by Reject Node Pad(RAW_KEY, None) + Pad(GT_MASK_KEY, None) + RandomLocation() + # chose a random location inside the provided arrays Reject(GT_MASK_KEY) + # reject batches wich do contain less than 50% labelled data Reject(GT_LABELS_KEY, min_masked=0.0, reject_probability=0.95) for provider in data_providers) # TODO figure out what this is for snapshot_request = BatchRequest({ LOSS_GRADIENT_KEY : request[GT_AFFINITIES_KEY], AFFINITIES_KEY : request[GT_AFFINITIES_KEY], AFFINITIES_NN_KEY : request[AFFINITIES_NN_KEY] }) # no need to do anything here. random sections will be replaced with sections from this source (only raw) artifact_source = ( Hdf5Source( os.path.join(defect_dir, 'sample_ABC_padded_20160501.defects.hdf'), datasets={ RAW_KEY : 'defect_sections/raw', ALPHA_MASK_KEY : 'defect_sections/mask', }, array_specs={ RAW_KEY : ArraySpec(voxel_size=input_voxel_size), ALPHA_MASK_KEY : ArraySpec(voxel_size=input_voxel_size), } ) + RandomLocation(min_masked=0.05, mask=ALPHA_MASK_KEY) + Normalize(RAW_KEY) + IntensityAugment(RAW_KEY, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) + ElasticAugment( voxel_size=(360, 36, 36), control_point_spacing=(4, 40, 40), control_point_displacement_sigma=(0, 2 * 36, 2 * 36), rotation_interval=(0, math.pi / 2.0), subsample=8 ) + SimpleAugment(transpose_only=[1,2]) ) train_pipeline = data_sources train_pipeline += RandomProvider() train_pipeline += ElasticAugment( voxel_size=(360, 36, 36), control_point_spacing=(4, 40, 40), control_point_displacement_sigma=(0, 2 * 36, 2 * 36), rotation_interval=(0, math.pi / 2.0), augmentation_probability=0.5, subsample=8 ) train_pipeline += Misalign(z_resolution=360, prob_slip=0.05, prob_shift=0.05, max_misalign=(360,) * 2, ignore_keys_for_slip=ignore_keys_for_slip) train_pipeline += SimpleAugment(transpose_only=[1,2]) train_pipeline += IntensityAugment(RAW_KEY, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) train_pipeline += DefectAugment(RAW_KEY, prob_missing=0.03, prob_low_contrast=0.01, prob_artifact=0.03, artifact_source=artifact_source, artifacts=RAW_KEY, artifacts_mask=ALPHA_MASK_KEY, contrast_scale=0.5) train_pipeline += IntensityScaleShift(RAW_KEY, 2, -1) train_pipeline += ZeroOutConstSections(RAW_KEY) if grow_boundaries > 0: train_pipeline += GrowBoundary(GT_LABELS_KEY, GT_MASK_KEY, steps=grow_boundaries, only_xy=True) if renumber_connected_components: train_pipeline += RenumberConnectedComponents(labels=GT_LABELS_KEY) train_pipeline += AddAffinities( affinity_neighborhood=affinity_neighborhood, labels=GT_LABELS_KEY, labels_mask=GT_MASK_KEY, affinities=GT_AFFINITIES_KEY, affinities_mask=AFFINITIES_MASK_KEY ) if balance_labels: train_pipeline += BalanceLabels(labels=GT_AFFINITIES_KEY, scales=AFFINITIES_SCALE_KEY, mask=AFFINITIES_MASK_KEY) train_pipeline += PreCache(cache_size=pre_cache_size, num_workers=pre_cache_num_workers) train_pipeline += Train( summary=summary, graph=meta_graph_filename, save_every=save_checkpoint_every, optimizer=optimizer, loss=loss, inputs=network_inputs, log_dir='log', outputs={tensor_affinities: AFFINITIES_KEY, tensor_affinities_nn: AFFINITIES_NN_KEY}, gradients={tensor_affinities: LOSS_GRADIENT_KEY}, array_specs={ AFFINITIES_KEY : ArraySpec(voxel_size=output_voxel_size), LOSS_GRADIENT_KEY : ArraySpec(voxel_size=output_voxel_size), AFFINITIES_MASK_KEY : ArraySpec(voxel_size=output_voxel_size), GT_MASK_KEY : ArraySpec(voxel_size=output_voxel_size), AFFINITIES_SCALE_KEY : ArraySpec(voxel_size=output_voxel_size), AFFINITIES_NN_KEY : ArraySpec(voxel_size=output_voxel_size) } ) train_pipeline += Snapshot( dataset_names={ RAW_KEY : 'volumes/raw', GT_LABELS_KEY : 'volumes/labels/neuron_ids', GT_AFFINITIES_KEY : 'volumes/affinities/gt', AFFINITIES_KEY : 'volumes/affinities/prediction', LOSS_GRADIENT_KEY : 'volumes/loss_gradient', AFFINITIES_MASK_KEY : 'masks/affinities', AFFINITIES_NN_KEY : 'volumes/affinities/prediction-nn' }, every=snapshot_every, output_filename='batch_{iteration}.hdf', output_dir='snapshots/', additional_request=snapshot_request, attributes_callback=Snapshot.default_attributes_callback()) train_pipeline += PrintProfilingStats(every=50) print("Starting training...") with build(train_pipeline) as b: for i in range(trained_until, stop): b.request_batch(request) print("Training finished")
def train_until(max_iteration, data_sources, input_shape, output_shape): ArrayKey('RAW') ArrayKey('PRED_RAW') data_providers = [] data_dir = "/groups/saalfeld/saalfeldlab/larissa/data/cell/superresolution/{0:}.n5" voxel_size = Coordinate((4,4,4)) input_size = Coordinate(input_shape)*voxel_size output_size = Coordinate(output_shape)*voxel_size with open('net_io_names.json', 'r') as f: net_io_names = json.load(f) request = BatchRequest() request.add(ArrayKeys.RAW, input_size, voxel_size=voxel_size) snapshot_request = BatchRequest() snapshot_request.add(ArrayKeys.PRED_RAW, output_size, voxel_size=voxel_size) # load latest ckpt for weights if available if tf.train.latest_checkpoint('.'): trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1]) print('Resuming training from', trained_until) else: trained_until = 0 print('Starting fresh training') # construct DAG for src in data_sources: n5_source = N5Source( data_dir.format(src), datasets={ ArrayKeys.RAW: 'volumes/raw' } ) data_providers.append(n5_source) data_sources = tuple( provider + Normalize(ArrayKeys.RAW) + Pad(ArrayKeys.RAW, Coordinate((400,400,400))) + RandomLocation() for provider in data_providers ) train_pipeline = ( data_sources + ElasticAugment((100,100,100), (10., 10., 10.), (0, math.pi/2.0), prob_slip=0, prob_shift=0, max_misalign=0, subsample=8) + SimpleAugment() + ElasticAugment((40, 1000, 1000), (10., 0., 0.), (0, 0), subsample=8) + IntensityAugment(ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1) + IntensityScaleShift(ArrayKeys.RAW, 2, -1) + ZeroOutConstSections(ArrayKeys.RAW) + PreCache(cache_size=40, num_workers=10) + Train('unet', optimizer=net_io_names['optimizer'], loss=net_io_names['loss'], inputs={ net_io_names['raw']:ArrayKeys.RAW }, summary=net_io_names['summary'], log_dir='log', outputs={ net_io_names['pred_raw']:ArrayKeys.PRED_RAW }, gradients={} )+ Snapshot({ArrayKeys.RAW: 'volumes/raw', ArrayKeys.PRED_RAW: 'volumes/pred_raw'}, every=500, output_filename='batch_{iteration}.hdf', output_dir='snapshots/', additional_request=snapshot_request) + PrintProfilingStats(every=50) ) # no intensity augment cause currently can't apply the same to both in and out print("Starting training...") with build(train_pipeline) as b: for i in range(max_iteration): b.request_batch(request) print("Training finished")
def train_until(max_iteration, data_sources, input_shape, output_shape, dt_scaling_factor, loss_name): ArrayKey("RAW") ArrayKey("RAW_UP") ArrayKey("ALPHA_MASK") ArrayKey("GT_LABELS") ArrayKey("MASK") ArrayKey("MASK_UP") ArrayKey("GT_DIST_CENTROSOME") ArrayKey("GT_DIST_GOLGI") ArrayKey("GT_DIST_GOLGI_MEM") ArrayKey("GT_DIST_ER") ArrayKey("GT_DIST_ER_MEM") ArrayKey("GT_DIST_MVB") ArrayKey("GT_DIST_MVB_MEM") ArrayKey("GT_DIST_MITO") ArrayKey("GT_DIST_MITO_MEM") ArrayKey("GT_DIST_LYSOSOME") ArrayKey("GT_DIST_LYSOSOME_MEM") ArrayKey("PRED_DIST_CENTROSOME") ArrayKey("PRED_DIST_GOLGI") ArrayKey("PRED_DIST_GOLGI_MEM") ArrayKey("PRED_DIST_ER") ArrayKey("PRED_DIST_ER_MEM") ArrayKey("PRED_DIST_MVB") ArrayKey("PRED_DIST_MVB_MEM") ArrayKey("PRED_DIST_MITO") ArrayKey("PRED_DIST_MITO_MEM") ArrayKey("PRED_DIST_LYSOSOME") ArrayKey("PRED_DIST_LYSOSOME_MEM") ArrayKey("SCALE_CENTROSOME") ArrayKey("SCALE_GOLGI") ArrayKey("SCALE_GOLGI_MEM") ArrayKey("SCALE_ER") ArrayKey("SCALE_ER_MEM") ArrayKey("SCALE_MVB") ArrayKey("SCALE_MVB_MEM") ArrayKey("SCALE_MITO") ArrayKey("SCALE_MITO_MEM") ArrayKey("SCALE_LYSOSOME") ArrayKey("SCALE_LYSOSOME_MEM") data_providers = [] data_dir = "/groups/saalfeld/saalfeldlab/larissa/data/cell/{0:}.n5" voxel_size_up = Coordinate((4, 4, 4)) voxel_size_orig = Coordinate((8, 8, 8)) input_size = Coordinate(input_shape) * voxel_size_orig output_size = Coordinate(output_shape) * voxel_size_orig if tf.train.latest_checkpoint("."): trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1]) print("Resuming training from", trained_until) else: trained_until = 0 print("Starting fresh training") for src in data_sources: n5_source = N5Source( os.path.join(data_dir.format(src)), datasets={ ArrayKeys.RAW_UP: "volumes/raw", ArrayKeys.GT_LABELS: "volumes/labels/all", ArrayKeys.MASK_UP: "volumes/mask", }, array_specs={ArrayKeys.MASK_UP: ArraySpec(interpolatable=False)}, ) data_providers.append(n5_source) with open("net_io_names.json", "r") as f: net_io_names = json.load(f) # specifiy which Arrays should be requested for each batch request = BatchRequest() request.add(ArrayKeys.RAW, input_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.RAW_UP, input_size, voxel_size=voxel_size_up) request.add(ArrayKeys.GT_LABELS, output_size, voxel_size=voxel_size_up) request.add(ArrayKeys.MASK_UP, output_size, voxel_size=voxel_size_up) request.add(ArrayKeys.MASK, output_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.GT_DIST_CENTROSOME, output_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.GT_DIST_GOLGI, output_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.GT_DIST_GOLGI_MEM, output_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.GT_DIST_ER, output_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.GT_DIST_ER_MEM, output_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.GT_DIST_MVB, output_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.GT_DIST_MVB_MEM, output_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.GT_DIST_MITO, output_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.GT_DIST_MITO_MEM, output_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.GT_DIST_LYSOSOME, output_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.GT_DIST_LYSOSOME_MEM, output_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.SCALE_CENTROSOME, output_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.SCALE_GOLGI, output_size, voxel_size=voxel_size_orig) # request.add(ArrayKeys.SCALE_GOLGI_MEM, output_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.SCALE_ER, output_size, voxel_size=voxel_size_orig) # request.add(ArrayKeys.SCALE_ER_MEM, output_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.SCALE_MVB, output_size, voxel_size=voxel_size_orig) # request.add(ArrayKeys.SCALE_MVB_MEM, output_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.SCALE_MITO, output_size, voxel_size=voxel_size_orig) # request.add(ArrayKeys.SCALE_MITO_MEM, output_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.SCALE_LYSOSOME, output_size, voxel_size=voxel_size_orig) # request.add(ArrayKeys.SCALE_LYSOSOME_MEM, output_size, voxel_size=voxel_size_orig) # create a tuple of data sources, one for each HDF file data_sources = tuple( provider + Normalize(ArrayKeys.RAW_UP) + # ensures RAW is in float in [0, 1] # zero-pad provided RAW and MASK to be able to draw batches close to # the boundary of the available data # size more or less irrelevant as followed by Reject Node Pad(ArrayKeys.RAW_UP, None) + RandomLocation(min_masked=0.5, mask=ArrayKeys.MASK_UP ) # chose a random location inside the provided arrays # Reject(ArrayKeys.MASK) # reject batches wich do contain less than 50% labelled data for provider in data_providers) snapshot_request = BatchRequest() snapshot_request.add(ArrayKeys.PRED_DIST_CENTROSOME, output_size) snapshot_request.add(ArrayKeys.PRED_DIST_GOLGI, output_size) snapshot_request.add(ArrayKeys.PRED_DIST_GOLGI_MEM, output_size) snapshot_request.add(ArrayKeys.PRED_DIST_ER, output_size) snapshot_request.add(ArrayKeys.PRED_DIST_ER_MEM, output_size) snapshot_request.add(ArrayKeys.PRED_DIST_MVB, output_size) snapshot_request.add(ArrayKeys.PRED_DIST_MVB_MEM, output_size) snapshot_request.add(ArrayKeys.PRED_DIST_MITO, output_size) snapshot_request.add(ArrayKeys.PRED_DIST_MITO_MEM, output_size) snapshot_request.add(ArrayKeys.PRED_DIST_LYSOSOME, output_size) snapshot_request.add(ArrayKeys.PRED_DIST_LYSOSOME_MEM, output_size) train_pipeline = ( data_sources + RandomProvider() + ElasticAugment( (100, 100, 100), (10.0, 10.0, 10.0), (0, math.pi / 2.0), prob_slip=0, prob_shift=0, max_misalign=0, subsample=8, ) + SimpleAugment() + ElasticAugment( (40, 1000, 1000), (10.0, 0.0, 0.0), (0, 0), subsample=8) + IntensityAugment(ArrayKeys.RAW_UP, 0.9, 1.1, -0.1, 0.1) + IntensityScaleShift(ArrayKeys.RAW_UP, 2, -1) + ZeroOutConstSections(ArrayKeys.RAW_UP) + # GrowBoundary(steps=1) + # SplitAndRenumberSegmentationLabels() + # AddGtAffinities(malis.mknhood3d()) + AddDistance( label_array_key=ArrayKeys.GT_LABELS, distance_array_key=ArrayKeys.GT_DIST_CENTROSOME, normalize="tanh", normalize_args=dt_scaling_factor, label_id=1, factor=2, ) + AddDistance( label_array_key=ArrayKeys.GT_LABELS, distance_array_key=ArrayKeys.GT_DIST_GOLGI, normalize="tanh", normalize_args=dt_scaling_factor, label_id=(2, 11), factor=2, ) + AddDistance( label_array_key=ArrayKeys.GT_LABELS, distance_array_key=ArrayKeys.GT_DIST_GOLGI_MEM, normalize="tanh", normalize_args=dt_scaling_factor, label_id=11, factor=2, ) + AddDistance( label_array_key=ArrayKeys.GT_LABELS, distance_array_key=ArrayKeys.GT_DIST_ER, normalize="tanh", normalize_args=dt_scaling_factor, label_id=(3, 10), factor=2, ) + AddDistance( label_array_key=ArrayKeys.GT_LABELS, distance_array_key=ArrayKeys.GT_DIST_ER_MEM, normalize="tanh", normalize_args=dt_scaling_factor, label_id=10, factor=2, ) + AddDistance( label_array_key=ArrayKeys.GT_LABELS, distance_array_key=ArrayKeys.GT_DIST_MVB, normalize="tanh", normalize_args=dt_scaling_factor, label_id=(4, 9), factor=2, ) + AddDistance( label_array_key=ArrayKeys.GT_LABELS, distance_array_key=ArrayKeys.GT_DIST_MVB_MEM, normalize="tanh", normalize_args=dt_scaling_factor, label_id=9, factor=2, ) + AddDistance( label_array_key=ArrayKeys.GT_LABELS, distance_array_key=ArrayKeys.GT_DIST_MITO, normalize="tanh", normalize_args=dt_scaling_factor, label_id=(5, 8), factor=2, ) + AddDistance( label_array_key=ArrayKeys.GT_LABELS, distance_array_key=ArrayKeys.GT_DIST_MITO_MEM, normalize="tanh", normalize_args=dt_scaling_factor, label_id=8, factor=2, ) + AddDistance( label_array_key=ArrayKeys.GT_LABELS, distance_array_key=ArrayKeys.GT_DIST_LYSOSOME, normalize="tanh", normalize_args=dt_scaling_factor, label_id=(6, 7), factor=2, ) + AddDistance( label_array_key=ArrayKeys.GT_LABELS, distance_array_key=ArrayKeys.GT_DIST_LYSOSOME_MEM, normalize="tanh", normalize_args=dt_scaling_factor, label_id=7, factor=2, ) + DownSample(ArrayKeys.MASK_UP, 2, ArrayKeys.MASK) + BalanceByThreshold( ArrayKeys.GT_DIST_CENTROSOME, ArrayKeys.SCALE_CENTROSOME, mask=ArrayKeys.MASK, ) + BalanceByThreshold(ArrayKeys.GT_DIST_GOLGI, ArrayKeys.SCALE_GOLGI, mask=ArrayKeys.MASK) + # BalanceByThreshold(ArrayKeys.GT_DIST_GOLGI_MEM, ArrayKeys.SCALE_GOLGI_MEM, mask=ArrayKeys.MASK) + BalanceByThreshold( ArrayKeys.GT_DIST_ER, ArrayKeys.SCALE_ER, mask=ArrayKeys.MASK) + # BalanceByThreshold(ArrayKeys.GT_DIST_ER_MEM, ArrayKeys.SCALE_ER_MEM, mask=ArrayKeys.MASK) + BalanceByThreshold( ArrayKeys.GT_DIST_MVB, ArrayKeys.SCALE_MVB, mask=ArrayKeys.MASK) + # BalanceByThreshold(ArrayKeys.GT_DIST_MVB_MEM, ArrayKeys.SCALE_MVB_MEM, mask=ArrayKeys.MASK) + BalanceByThreshold(ArrayKeys.GT_DIST_MITO, ArrayKeys.SCALE_MITO, mask=ArrayKeys.MASK) + # BalanceByThreshold(ArrayKeys.GT_DIST_MITO_MEM, ArrayKeys.SCALE_MITO_MEM, mask=ArrayKeys.MASK) + BalanceByThreshold(ArrayKeys.GT_DIST_LYSOSOME, ArrayKeys.SCALE_LYSOSOME, mask=ArrayKeys.MASK) + # BalanceByThreshold(ArrayKeys.GT_DIST_LYSOSOME_MEM, ArrayKeys.SCALE_LYSOSOME_MEM, mask=ArrayKeys.MASK) + # BalanceByThreshold( # labels=ArrayKeys.GT_DIST, # scales= ArrayKeys.GT_SCALE) + # { # ArrayKeys.GT_AFFINITIES: ArrayKeys.GT_SCALE # }, # { # ArrayKeys.GT_AFFINITIES: ArrayKeys.MASK # }) + DownSample(ArrayKeys.RAW_UP, 2, ArrayKeys.RAW) + PreCache(cache_size=40, num_workers=10) + Train( "build", optimizer=net_io_names["optimizer"], loss=net_io_names[loss_name], inputs={ net_io_names["raw"]: ArrayKeys.RAW, net_io_names["gt_centrosome"]: ArrayKeys.GT_DIST_CENTROSOME, net_io_names["gt_golgi"]: ArrayKeys.GT_DIST_GOLGI, net_io_names["gt_golgi_mem"]: ArrayKeys.GT_DIST_GOLGI_MEM, net_io_names["gt_er"]: ArrayKeys.GT_DIST_ER, net_io_names["gt_er_mem"]: ArrayKeys.GT_DIST_ER_MEM, net_io_names["gt_mvb"]: ArrayKeys.GT_DIST_MVB, net_io_names["gt_mvb_mem"]: ArrayKeys.GT_DIST_MVB_MEM, net_io_names["gt_mito"]: ArrayKeys.GT_DIST_MITO, net_io_names["gt_mito_mem"]: ArrayKeys.GT_DIST_MITO_MEM, net_io_names["gt_lysosome"]: ArrayKeys.GT_DIST_LYSOSOME, net_io_names["gt_lysosome_mem"]: ArrayKeys.GT_DIST_LYSOSOME_MEM, net_io_names["w_centrosome"]: ArrayKeys.SCALE_CENTROSOME, net_io_names["w_golgi"]: ArrayKeys.SCALE_GOLGI, net_io_names["w_golgi_mem"]: ArrayKeys.SCALE_GOLGI, net_io_names["w_er"]: ArrayKeys.SCALE_ER, net_io_names["w_er_mem"]: ArrayKeys.SCALE_ER, net_io_names["w_mvb"]: ArrayKeys.SCALE_MVB, net_io_names["w_mvb_mem"]: ArrayKeys.SCALE_MVB, net_io_names["w_mito"]: ArrayKeys.SCALE_MITO, net_io_names["w_mito_mem"]: ArrayKeys.SCALE_MITO, net_io_names["w_lysosome"]: ArrayKeys.SCALE_LYSOSOME, net_io_names["w_lysosome_mem"]: ArrayKeys.SCALE_LYSOSOME, }, summary=net_io_names["summary"], log_dir="log", outputs={ net_io_names["centrosome"]: ArrayKeys.PRED_DIST_CENTROSOME, net_io_names["golgi"]: ArrayKeys.PRED_DIST_GOLGI, net_io_names["golgi_mem"]: ArrayKeys.PRED_DIST_GOLGI_MEM, net_io_names["er"]: ArrayKeys.PRED_DIST_ER, net_io_names["er_mem"]: ArrayKeys.PRED_DIST_ER_MEM, net_io_names["mvb"]: ArrayKeys.PRED_DIST_MVB, net_io_names["mvb_mem"]: ArrayKeys.PRED_DIST_MVB_MEM, net_io_names["mito"]: ArrayKeys.PRED_DIST_MITO, net_io_names["mito_mem"]: ArrayKeys.PRED_DIST_MITO_MEM, net_io_names["lysosome"]: ArrayKeys.PRED_DIST_LYSOSOME, net_io_names["lysosome_mem"]: ArrayKeys.PRED_DIST_LYSOSOME_MEM, }, gradients={}, ) + Snapshot( { ArrayKeys.RAW: "volumes/raw", ArrayKeys.GT_LABELS: "volumes/labels/gt_labels", ArrayKeys.GT_DIST_CENTROSOME: "volumes/labels/gt_dist_centrosome", ArrayKeys.PRED_DIST_CENTROSOME: "volumes/labels/pred_dist_centrosome", ArrayKeys.GT_DIST_GOLGI: "volumes/labels/gt_dist_golgi", ArrayKeys.PRED_DIST_GOLGI: "volumes/labels/pred_dist_golgi", ArrayKeys.GT_DIST_GOLGI_MEM: "volumes/labels/gt_dist_golgi_mem", ArrayKeys.PRED_DIST_GOLGI_MEM: "volumes/labels/pred_dist_golgi_mem", ArrayKeys.GT_DIST_ER: "volumes/labels/gt_dist_er", ArrayKeys.PRED_DIST_ER: "volumes/labels/pred_dist_er", ArrayKeys.GT_DIST_ER_MEM: "volumes/labels/gt_dist_er_mem", ArrayKeys.PRED_DIST_ER_MEM: "volumes/labels/pred_dist_er_mem", ArrayKeys.GT_DIST_MVB: "volumes/labels/gt_dist_mvb", ArrayKeys.PRED_DIST_MVB: "volumes/labels/pred_dist_mvb", ArrayKeys.GT_DIST_MVB_MEM: "volumes/labels/gt_dist_mvb_mem", ArrayKeys.PRED_DIST_MVB_MEM: "volumes/labels/pred_dist_mvb_mem", ArrayKeys.GT_DIST_MITO: "volumes/labels/gt_dist_mito", ArrayKeys.PRED_DIST_MITO: "volumes/labels/pred_dist_mito", ArrayKeys.GT_DIST_MITO_MEM: "volumes/labels/gt_dist_mito_mem", ArrayKeys.PRED_DIST_MITO_MEM: "volumes/labels/pred_dist_mito_mem", ArrayKeys.GT_DIST_LYSOSOME: "volumes/labels/gt_dist_lysosome", ArrayKeys.PRED_DIST_LYSOSOME: "volumes/labels/pred_dist_lysosome", ArrayKeys.GT_DIST_LYSOSOME_MEM: "volumes/labels/gt_dist_lysosome_mem", ArrayKeys.PRED_DIST_LYSOSOME_MEM: "volumes/labels/pred_dist_lysosome_mem", }, every=500, output_filename="batch_{iteration}.hdf", output_dir="snapshots/", additional_request=snapshot_request, ) + PrintProfilingStats(every=50)) print("Starting training...") with build(train_pipeline) as b: for i in range(max_iteration): b.request_batch(request) print("Training finished")
def train_until( max_iteration, cremi_dir, samples, n5_filename_format, csv_filename_format, filter_comments_pre, filter_comments_post, labels, net_name, input_shape, output_shape, loss_name, aug_mode, include_cleft=False, dt_scaling_factor=50, cache_size=5, num_workers=10, min_masked_voxels=17561.0, voxel_size=Coordinate((40, 4, 4)), ): ''' Trains a network to predict signed distance boundaries of synapses. Args: max_iteration(int): The number of iterations to train the network. cremi_dir(str): The path to the directory containing n5 files for training. samples (:obj:`list` of :obj:`str`): The names of samples to train on. This is used as input to format the `n5_filename_format` and `csv_filename_format`. n5_filename_format(str): The format string for n5 files. csv_filename_format (str): The format string for n5 files. filter_comments_pre (:obj:`list` of :obj: `str`): A list of pre- or postsynaptic comments that should be excluded from the mapping of cleft ids to presynaptic neuron ids. filter_comments_post (:obj:`list` of :obj: `str`): A list of pre- or postsynaptic comments that should be excluded from the mapping of cleft ids to postsynaptic neuron ids. labels(:obj:`list` of :class:`Label`): The list of labels to be trained for. net_name(str): The name of the network, referring to the .meta file. input_shape(:obj:`tuple` of int): The shape of input arrays of the network. output_shape(:obj:`tuple` of int): The shape of output arrays of the network. loss_name (str): The name of the loss function as saved in the net_io_names. aug_mode (str): The augmentation mode ("deluxe", "classic" or "lite"). include_cleft (boolean, optional): whether to include the whole cleft as part of the label when calculating the masked distance transform for pre-and postsynaptic sites dt_scaling_factor (int, optional): The factor for scaling the signed distance transform before applying tanh using formula tanh(distance_transform/dt_scaling_factor), default:50. cache_size (int, optional): The size of the cache for pulling batches, default: 5. num_workers(int, optional): The number of workers for pulling batches, default: 10. min_masked_voxels(Union(int,float), optional): The number of voxels that need to be contained in the groundtruth mask for a batch to be viable, default: 17561. voxel_size(:class:`Coordinate`): The voxel size of the input and output of the network. Returns: None. ''' def label_filter(cond_f): return [ll for ll in labels if cond_f(ll)] def get_label(name): filter = label_filter(lambda l: l.labelname == name) if len(filter) > 0: return filter[0] else: return None def network_setup(): # load net_io_names.json with open("net_io_names.json", "r") as f: net_io_names = json.load(f) # find checkpoint from previous training, start a new one if not found if tf.train.latest_checkpoint("."): start_iteration = int( tf.train.latest_checkpoint(".").split("_")[-1]) if start_iteration >= max_iteration: logging.info( "Network has already been trained for {0:} iterations". format(start_iteration)) else: logging.info( "Resuming training from {0:}".format(start_iteration)) else: start_iteration = 0 logging.info("Starting fresh training") # define network inputs inputs = dict() inputs[net_io_names["raw"]] = ak_raw inputs[net_io_names["mask"]] = ak_training for label in labels: inputs[net_io_names["mask_" + label.labelname]] = label.mask_key inputs[net_io_names["gt_" + label.labelname]] = label.gt_dist_key if label.scale_loss or label.scale_key is not None: inputs[net_io_names["w_" + label.labelname]] = label.scale_key # define network outputs outputs = dict() for label in labels: outputs[net_io_names[label.labelname]] = label.pred_dist_key return net_io_names, start_iteration, inputs, outputs keep_thr = float(min_masked_voxels) / np.prod(output_shape) max_distance = 2.76 * dt_scaling_factor ak_raw = ArrayKey("RAW") ak_alpha = ArrayKey("ALPHA_MASK") ak_neurons = ArrayKey("GT_NEURONS") ak_training = ArrayKey("TRAINING_MASK") ak_integral = ArrayKey("INTEGRAL_MASK") ak_clefts = ArrayKey("GT_CLEFTS") input_size = Coordinate(input_shape) * voxel_size output_size = Coordinate(output_shape) * voxel_size pad_width = input_size - output_size + voxel_size * Coordinate( (20, 20, 20)) crop_width = Coordinate((max_distance, ) * len(voxel_size)) crop_width = crop_width // voxel_size if crop_width == 0: crop_width *= voxel_size else: crop_width = (crop_width + (1, ) * len(crop_width)) * voxel_size net_io_names, start_iteration, inputs, outputs = network_setup() # specifiy which Arrays should be requested for each batch request = BatchRequest() request.add(ak_raw, input_size, voxel_size=voxel_size) request.add(ak_neurons, output_size, voxel_size=voxel_size) request.add(ak_clefts, output_size, voxel_size=voxel_size) request.add(ak_training, output_size, voxel_size=voxel_size) request.add(ak_integral, output_size, voxel_size=voxel_size) for l in labels: request.add(l.mask_key, output_size, voxel_size=voxel_size) request.add(l.scale_key, output_size, voxel_size=voxel_size) request.add(l.gt_dist_key, output_size, voxel_size=voxel_size) arrays_that_need_to_be_cropped = [] arrays_that_need_to_be_cropped.append(ak_neurons) arrays_that_need_to_be_cropped.append(ak_clefts) for l in labels: arrays_that_need_to_be_cropped.append(l.mask_key) arrays_that_need_to_be_cropped.append(l.gt_dist_key) # specify specs for output array_specs_pred = dict() for l in labels: array_specs_pred[l.pred_dist_key] = ArraySpec(voxel_size=voxel_size, interpolatable=True) snapshot_data = { ak_raw: "volumes/raw", ak_training: "volumes/masks/training", ak_clefts: "volumes/labels/gt_clefts", ak_neurons: "volumes/labels/gt_neurons", ak_integral: "volumes/masks/gt_integral" } # specify snapshot data layout for l in labels: snapshot_data[l.mask_key] = "volumes/masks/" + l.labelname snapshot_data[ l.pred_dist_key] = "volumes/labels/pred_dist_" + l.labelname snapshot_data[l.gt_dist_key] = "volumes/labels/gt_dist_" + l.labelname # specify snapshot request snapshot_request_dict = {} for l in labels: snapshot_request_dict[l.pred_dist_key] = request[l.gt_dist_key] snapshot_request = BatchRequest(snapshot_request_dict) csv_files = [ os.path.join(cremi_dir, csv_filename_format.format(sample)) for sample in samples ] cleft_to_pre, cleft_to_post, cleft_to_pre_filtered, cleft_to_post_filtered = \ make_cleft_to_prepostsyn_neuron_id_dict(csv_files, filter_comments_pre, filter_comments_post) data_providers = [] for sample in samples: logging.info("Adding sample {0:}".format(sample)) datasets = { ak_raw: "volumes/raw", ak_training: "volumes/masks/validation", ak_integral: "volumes/masks/groundtruth_integral", ak_clefts: "volumes/labels/clefts", ak_neurons: "volumes/labels/neuron_ids", } specs = { ak_clefts: ArraySpec(interpolatable=False), ak_training: ArraySpec(interpolatable=False), ak_integral: ArraySpec(interpolatable=False), } for l in labels: datasets[l.mask_key] = "volumes/masks/groundtruth" specs[l.mask_key] = ArraySpec(interpolatable=False) n5_source = ZarrSource( os.path.join(cremi_dir, n5_filename_format.format(sample)), datasets=datasets, array_specs=specs, ) data_providers.append(n5_source) data_sources = [] for provider in data_providers: provider += Normalize(ak_raw) provider += Pad(ak_training, pad_width) provider += Pad(ak_neurons, pad_width) for l in labels: provider += Pad(l.mask_key, pad_width) provider += IntensityScaleShift(ak_training, -1, 1) provider += RandomLocationWithIntegralMask(integral_mask=ak_integral, min_masked=keep_thr) provider += Reject(ak_training, min_masked=0.999) provider += Reject(ak_clefts, min_masked=0.0, reject_probability=0.95) data_sources.append(provider) artifact_source = ( Hdf5Source( os.path.join(cremi_dir, "sample_ABC_padded_20160501.defects.hdf"), datasets={ ArrayKeys.RAW: "defect_sections/raw", ArrayKeys.ALPHA_MASK: "defect_sections/mask", }, array_specs={ ArrayKeys.RAW: ArraySpec(voxel_size=(40, 4, 4)), ArrayKeys.ALPHA_MASK: ArraySpec(voxel_size=(40, 4, 4)), }, ) + RandomLocation(min_masked=0.05, mask=ak_alpha) + Normalize(ak_raw) + IntensityAugment( ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) + ElasticAugment((4, 40, 40), (0, 2, 2), (0, math.pi / 2.0), subsample=8) + SimpleAugment(transpose_only=[1, 2], mirror_only=[1, 2])) train_pipeline = tuple(data_sources) + RandomProvider() if aug_mode == "deluxe": slip_ignore = [ak_clefts, ak_training, ak_neurons, ak_integral] for l in labels: slip_ignore.append(l.mask_key) train_pipeline += fuse.ElasticAugment( (40, 4, 4), (4, 40, 40), (0.0, 2.0, 2.0), (0, math.pi / 2.0), spatial_dims=3, subsample=8, ) train_pipeline += SimpleAugment(transpose_only=[1, 2], mirror_only=[1, 2]) train_pipeline += fuse.Misalign( 40, prob_slip=0.05, prob_shift=0.05, max_misalign=(10, 10), ignore_keys_for_slip=tuple(slip_ignore), ) train_pipeline += IntensityAugment(ak_raw, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) train_pipeline += DefectAugment( ak_raw, prob_missing=0.03, prob_low_contrast=0.01, prob_artifact=0.03, artifact_source=artifact_source, artifacts=ak_raw, artifacts_mask=ak_alpha, contrast_scale=0.5, ) elif aug_mode == "classic": train_pipeline += fuse.ElasticAugment( (40, 4, 4), (4, 40, 40), (0.0, 0.0, 0.0), (0, math.pi / 2.0), spatial_dims=3, subsample=8, ) train_pipeline += fuse.SimpleAugment(transpose_only=[1, 2], mirror_only=[1, 2]) train_pipeline += IntensityAugment(ak_raw, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) train_pipeline += DefectAugment( ak_raw, prob_missing=0.03, prob_low_contrast=0.01, prob_artifact=0.03, artifact_source=artifact_source, artifacts=ak_raw, artifacts_mask=ak_alpha, contrast_scale=0.5, ) elif aug_mode == "lite": train_pipeline += fuse.ElasticAugment( (40, 4, 4), (4, 40, 40), (0.0, 0.0, 0.0), (0, math.pi / 2.0), spatial_dims=3, subsample=8, ) train_pipeline += fuse.SimpleAugment(transpose_only=[1, 2], mirror_only=[1, 2]) train_pipeline += IntensityAugment(ak_raw, 0.9, 1.1, -0.1, 0.1, z_section_wise=False) else: pass train_pipeline += IntensityScaleShift(ak_raw, 2, -1) train_pipeline += ZeroOutConstSections(ak_raw) clefts = get_label("clefts") pre = get_label("pre") post = get_label("post") if clefts is not None or pre is not None or post is not None: train_pipeline += AddPrePostCleftDistance( ak_clefts, ak_neurons, clefts.gt_dist_key if clefts is not None else None, pre.gt_dist_key if pre is not None else None, post.gt_dist_key if post is not None else None, clefts.mask_key if post is not None else None, pre.mask_key if pre is not None else None, post.mask_key if post is not None else None, cleft_to_pre, cleft_to_post, cleft_to_presyn_neuron_id_filtered=cleft_to_pre_filtered, cleft_to_postsyn_neuron_id_filtered=cleft_to_post_filtered, bg_value=(0, 18446744073709551613), include_cleft=include_cleft, max_distance=2.76 * dt_scaling_factor, ) for ak in arrays_that_need_to_be_cropped: train_pipeline += CropArray(ak, crop_width, crop_width) for l in labels: train_pipeline += TanhSaturate(l.gt_dist_key, dt_scaling_factor) for l in labels: train_pipeline += BalanceByThreshold( labels=l.gt_dist_key, scales=l.scale_key, mask=(l.mask_key, ak_training), threshold=l.thr, ) train_pipeline += PreCache(cache_size=cache_size, num_workers=num_workers) train_pipeline += Train( net_name, optimizer=net_io_names["optimizer"], loss=net_io_names[loss_name], inputs=inputs, summary=net_io_names["summary"], log_dir="log", save_every=500, log_every=5, outputs=outputs, gradients={}, array_specs=array_specs_pred, ) train_pipeline += Snapshot( snapshot_data, every=500, output_filename="batch_{iteration}.hdf", output_dir="snapshots/", additional_request=snapshot_request, ) train_pipeline += PrintProfilingStats(every=50) logging.info("Starting training...") with build(train_pipeline) as pp: for i in range(start_iteration, max_iteration + 1): start_it = time.time() pp.request_batch(request) time_it = time.time() - start_it logging.info("it{0:}: {1:}".format(i + 1, time_it)) logging.info("Training finished")
def train_until(max_iteration, data_sources, input_shape, output_shape, dt_scaling_factor, loss_name, cremi_version, aligned): ArrayKey('RAW') ArrayKey('ALPHA_MASK') ArrayKey('GT_LABELS') ArrayKey('GT_CLEFTS') ArrayKey('GT_MASK') ArrayKey('TRAINING_MASK') ArrayKey('CLEFT_SCALE') ArrayKey('PRE_SCALE') ArrayKey('POST_SCALE') ArrayKey('LOSS_GRADIENT') ArrayKey('GT_CLEFT_DIST') ArrayKey('PRED_CLEFT_DIST') ArrayKey('GT_PRE_DIST') ArrayKey('PRED_PRE_DIST') ArrayKey('GT_POST_DIST') ArrayKey('PRED_POST_DIST') ArrayKey('GT_POST_DIST') data_providers = [] if cremi_version == '2016': cremi_dir = "/groups/saalfeld/saalfeldlab/larissa/data/cremi-2016/" filename = 'sample_{0:}_padded_20160501.' elif cremi_version == '2017': cremi_dir = "/groups/saalfeld/saalfeldlab/larissa/data/cremi-2017/" filename = 'sample_{0:}_padded_20170424.' if aligned: filename += 'aligned.' filename += '0bg.hdf' if tf.train.latest_checkpoint('.'): trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1]) print('Resuming training from', trained_until) else: trained_until = 0 print('Starting fresh training') for sample in data_sources: print(sample) h5_source = Hdf5Source( os.path.join(cremi_dir, filename.format(sample)), datasets={ ArrayKeys.RAW: 'volumes/raw', ArrayKeys.GT_CLEFTS: 'volumes/labels/clefts', ArrayKeys.GT_MASK: 'volumes/masks/groundtruth', ArrayKeys.TRAINING_MASK: 'volumes/masks/validation', ArrayKeys.GT_LABELS: 'volumes/labels/neuron_ids' }, array_specs={ ArrayKeys.GT_MASK: ArraySpec(interpolatable=False), ArrayKeys.GT_CLEFTS: ArraySpec(interpolatable=False), ArrayKeys.TRAINING_MASK: ArraySpec(interpolatable=False) }) data_providers.append(h5_source) if cremi_version == '2017': csv_files = [ os.path.join(cremi_dir, 'cleft-partners_' + sample + '_2017.csv') for sample in data_sources ] elif cremi_version == '2016': csv_files = [ os.path.join( cremi_dir, 'cleft-partners-' + sample + '-20160501.aligned.corrected.csv') for sample in data_sources ] cleft_to_pre, cleft_to_post = make_cleft_to_prepostsyn_neuron_id_dict( csv_files) print(cleft_to_pre, cleft_to_post) with open('net_io_names.json', 'r') as f: net_io_names = json.load(f) voxel_size = Coordinate((40, 4, 4)) input_size = Coordinate(input_shape) * voxel_size output_size = Coordinate(output_shape) * voxel_size context = input_size - output_size # specifiy which Arrays should be requested for each batch request = BatchRequest() request.add(ArrayKeys.RAW, input_size) request.add(ArrayKeys.GT_LABELS, output_size) request.add(ArrayKeys.GT_CLEFTS, output_size) request.add(ArrayKeys.GT_MASK, output_size) request.add(ArrayKeys.TRAINING_MASK, output_size) request.add(ArrayKeys.CLEFT_SCALE, output_size) request.add(ArrayKeys.GT_CLEFT_DIST, output_size) request.add(ArrayKeys.GT_PRE_DIST, output_size) request.add(ArrayKeys.GT_POST_DIST, output_size) # create a tuple of data sources, one for each HDF file data_sources = tuple( provider + Normalize(ArrayKeys.RAW) + # ensures RAW is in float in [0, 1] IntensityScaleShift(ArrayKeys.TRAINING_MASK, -1, 1) + # zero-pad provided RAW and GT_MASK to be able to draw batches close to # the boundary of the available data # size more or less irrelevant as followed by Reject Node Pad(ArrayKeys.RAW, None) + Pad(ArrayKeys.GT_MASK, None) + Pad(ArrayKeys.TRAINING_MASK, context) + RandomLocation(min_masked=0.99, mask=ArrayKeys.TRAINING_MASK) + # chose a random location inside the provided arrays Reject(ArrayKeys.GT_MASK) + # reject batches which do contain less than 50% labelled data Reject(ArrayKeys.GT_CLEFTS, min_masked=0.0, reject_probability=0.95) for provider in data_providers) snapshot_request = BatchRequest({ ArrayKeys.LOSS_GRADIENT: request[ArrayKeys.GT_CLEFTS], ArrayKeys.PRED_CLEFT_DIST: request[ArrayKeys.GT_CLEFT_DIST], ArrayKeys.PRED_PRE_DIST: request[ArrayKeys.GT_PRE_DIST], ArrayKeys.PRED_POST_DIST: request[ArrayKeys.GT_POST_DIST], }) train_pipeline = ( data_sources + RandomProvider() + ElasticAugment( (4, 40, 40), (0., 0., 0.), (0, math.pi / 2.0), prob_slip=0.0, prob_shift=0.0, max_misalign=0, subsample=8) + SimpleAugment(transpose_only=[1, 2], mirror_only=[1, 2]) + IntensityAugment( ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1, z_section_wise=False) + IntensityScaleShift(ArrayKeys.RAW, 2, -1) + ZeroOutConstSections(ArrayKeys.RAW) + AddDistance(label_array_key=ArrayKeys.GT_CLEFTS, distance_array_key=ArrayKeys.GT_CLEFT_DIST, normalize='tanh', normalize_args=dt_scaling_factor) + AddPrePostCleftDistance(ArrayKeys.GT_CLEFTS, ArrayKeys.GT_LABELS, ArrayKeys.GT_PRE_DIST, ArrayKeys.GT_POST_DIST, cleft_to_pre, cleft_to_post, normalize='tanh', normalize_args=dt_scaling_factor, include_cleft=False) + BalanceByThreshold(labels=ArrayKeys.GT_CLEFT_DIST, scales=ArrayKeys.CLEFT_SCALE, mask=ArrayKeys.GT_MASK) + BalanceByThreshold(labels=ArrayKeys.GT_PRE_DIST, scales=ArrayKeys.PRE_SCALE, mask=ArrayKeys.GT_MASK, threshold=-0.5) + BalanceByThreshold(labels=ArrayKeys.GT_POST_DIST, scales=ArrayKeys.POST_SCALE, mask=ArrayKeys.GT_MASK, threshold=-0.5) + PreCache(cache_size=40, num_workers=10) + Train( 'unet', optimizer=net_io_names['optimizer'], loss=net_io_names[loss_name], inputs={ net_io_names['raw']: ArrayKeys.RAW, net_io_names['gt_cleft_dist']: ArrayKeys.GT_CLEFT_DIST, net_io_names['gt_pre_dist']: ArrayKeys.GT_PRE_DIST, net_io_names['gt_post_dist']: ArrayKeys.GT_POST_DIST, net_io_names['loss_weights_cleft']: ArrayKeys.CLEFT_SCALE, net_io_names['loss_weights_pre']: ArrayKeys.CLEFT_SCALE, net_io_names['loss_weights_post']: ArrayKeys.CLEFT_SCALE, net_io_names['mask']: ArrayKeys.GT_MASK }, summary=net_io_names['summary'], log_dir='log', outputs={ net_io_names['cleft_dist']: ArrayKeys.PRED_CLEFT_DIST, net_io_names['pre_dist']: ArrayKeys.PRED_PRE_DIST, net_io_names['post_dist']: ArrayKeys.PRED_POST_DIST }, gradients={net_io_names['cleft_dist']: ArrayKeys.LOSS_GRADIENT}) + Snapshot( { ArrayKeys.RAW: 'volumes/raw', ArrayKeys.GT_CLEFTS: 'volumes/labels/gt_clefts', ArrayKeys.GT_CLEFT_DIST: 'volumes/labels/gt_clefts_dist', ArrayKeys.PRED_CLEFT_DIST: 'volumes/labels/pred_clefts_dist', ArrayKeys.LOSS_GRADIENT: 'volumes/loss_gradient', ArrayKeys.PRED_PRE_DIST: 'volumes/labels/pred_pre_dist', ArrayKeys.PRED_POST_DIST: 'volumes/labels/pred_post_dist', ArrayKeys.GT_PRE_DIST: 'volumes/labels/gt_pre_dist', ArrayKeys.GT_POST_DIST: 'volumes/labels/gt_post_dist' }, every=500, output_filename='batch_{iteration}.hdf', output_dir='snapshots/', additional_request=snapshot_request) + PrintProfilingStats(every=50)) print("Starting training...") with build(train_pipeline) as b: for i in range(max_iteration): b.request_batch(request) print("Training finished")