def train_until(max_iteration, data_sources, input_shape, output_shape, dt_scaling_factor, loss_name): ArrayKey("RAW") ArrayKey("ALPHA_MASK") ArrayKey("GT_LABELS") ArrayKey("GT_MASK") ArrayKey("TRAINING_MASK") ArrayKey("GT_SCALE") ArrayKey("LOSS_GRADIENT") ArrayKey("GT_DIST") ArrayKey("PREDICTED_DIST_LABELS") data_providers = [] if cremi_version == "2016": cremi_dir = "/groups/saalfeld/saalfeldlab/larissa/data/cremi-2016/" filename = "sample_{0:}_padded_20160501." elif cremi_version == "2017": cremi_dir = "/groups/saalfeld/saalfeldlab/larissa/data/cremi-2017/" filename = "sample_{0:}_padded_20170424." if aligned: filename += "aligned." filename += "0bg.hdf" if tf.train.latest_checkpoint("."): trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1]) print("Resuming training from", trained_until) else: trained_until = 0 print("Starting fresh training") for sample in data_sources: print(sample) h5_source = Hdf5Source( os.path.join(cremi_dir, filename.format(sample)), datasets={ ArrayKeys.RAW: "volumes/raw", ArrayKeys.GT_LABELS: "volumes/labels/clefts", ArrayKeys.GT_MASK: "volumes/masks/groundtruth", ArrayKeys.TRAINING_MASK: "volumes/masks/validation", }, array_specs={ArrayKeys.GT_MASK: ArraySpec(interpolatable=False)}, ) data_providers.append(h5_source) with open("net_io_names.json", "r") as f: net_io_names = json.load(f) voxel_size = Coordinate((40, 4, 4)) input_size = Coordinate(input_shape) * voxel_size output_size = Coordinate(output_shape) * voxel_size context = input_size - output_size # specifiy which Arrays should be requested for each batch request = BatchRequest() request.add(ArrayKeys.RAW, input_size) request.add(ArrayKeys.GT_LABELS, output_size) request.add(ArrayKeys.GT_MASK, output_size) request.add(ArrayKeys.TRAINING_MASK, output_size) request.add(ArrayKeys.GT_SCALE, output_size) request.add(ArrayKeys.GT_DIST, output_size) # create a tuple of data sources, one for each HDF file data_sources = tuple( provider + Normalize(ArrayKeys.RAW) + IntensityScaleShift( # ensures RAW is in float in [0, 1] ArrayKeys.TRAINING_MASK, -1, 1) + # zero-pad provided RAW and GT_MASK to be able to draw batches close to # the boundary of the available data # size more or less irrelevant as followed by Reject Node Pad(ArrayKeys.RAW, None) + Pad(ArrayKeys.GT_MASK, None) + Pad(ArrayKeys.TRAINING_MASK, context) + RandomLocation(min_masked=0.99, mask=ArrayKeys.TRAINING_MASK) + Reject(ArrayKeys.GT_MASK) + Reject( # reject batches wich do contain less than 50% labelled data ArrayKeys.GT_LABELS, min_masked=0.0, reject_probability=0.95) for provider in data_providers) snapshot_request = BatchRequest({ ArrayKeys.LOSS_GRADIENT: request[ArrayKeys.GT_LABELS], ArrayKeys.PREDICTED_DIST_LABELS: request[ArrayKeys.GT_LABELS], ArrayKeys.LOSS_GRADIENT: request[ArrayKeys.GT_DIST], }) train_pipeline = ( data_sources + RandomProvider() + ElasticAugment( (4, 40, 40), (0.0, 0.0, 0.0), (0, math.pi / 2.0), prob_slip=0.0, prob_shift=0.0, max_misalign=0, subsample=8, ) + SimpleAugment(transpose_only=[1, 2], mirror_only=[1, 2]) + IntensityAugment( ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1, z_section_wise=False) + IntensityScaleShift(ArrayKeys.RAW, 2, -1) + ZeroOutConstSections(ArrayKeys.RAW) + AddDistance( label_array_key=ArrayKeys.GT_LABELS, distance_array_key=ArrayKeys.GT_DIST, normalize="tanh", normalize_args=dt_scaling_factor, ) + BalanceByThreshold( ArrayKeys.GT_LABELS, ArrayKeys.GT_SCALE, mask=ArrayKeys.GT_MASK) + PreCache(cache_size=40, num_workers=10) + Train( "unet", optimizer=net_io_names["optimizer"], loss=net_io_names[loss_name], inputs={ net_io_names["raw"]: ArrayKeys.RAW, net_io_names["gt_dist"]: ArrayKeys.GT_DIST, net_io_names["loss_weights"]: ArrayKeys.GT_SCALE, net_io_names["mask"]: ArrayKeys.GT_MASK, }, summary=net_io_names["summary"], log_dir="log", outputs={net_io_names["dist"]: ArrayKeys.PREDICTED_DIST_LABELS}, gradients={net_io_names["dist"]: ArrayKeys.LOSS_GRADIENT}, ) + Snapshot( { ArrayKeys.RAW: "volumes/raw", ArrayKeys.GT_LABELS: "volumes/labels/gt_clefts", ArrayKeys.GT_DIST: "volumes/labels/gt_clefts_dist", ArrayKeys.PREDICTED_DIST_LABELS: "volumes/labels/pred_clefts_dist", ArrayKeys.LOSS_GRADIENT: "volumes/loss_gradient", }, every=500, output_filename="batch_{iteration}.hdf", output_dir="snapshots/", additional_request=snapshot_request, ) + PrintProfilingStats(every=50)) print("Starting training...") with build(train_pipeline) as b: for i in range(max_iteration): b.request_batch(request) print("Training finished")
def train_until( max_iteration, data_sources, ribo_sources, input_shape, output_shape, dt_scaling_factor, loss_name, labels, net_name, min_masked_voxels=17561.0, mask_ds_name="volumes/masks/training_cropped", ): with open("net_io_names.json", "r") as f: net_io_names = json.load(f) ArrayKey("RAW") ArrayKey("ALPHA_MASK") ArrayKey("GT_LABELS") ArrayKey("MASK") ArrayKey("RIBO_GT") voxel_size_up = Coordinate((2, 2, 2)) voxel_size_orig = Coordinate((4, 4, 4)) input_size = Coordinate(input_shape) * voxel_size_orig output_size = Coordinate(output_shape) * voxel_size_orig # context = input_size-output_size keep_thr = float(min_masked_voxels) / np.prod(output_shape) data_providers = [] inputs = dict() outputs = dict() snapshot = dict() request = BatchRequest() snapshot_request = BatchRequest() datasets_ribo = { ArrayKeys.RAW: "volumes/raw/data/s0", ArrayKeys.GT_LABELS: "volumes/labels/all", ArrayKeys.MASK: mask_ds_name, ArrayKeys.RIBO_GT: "volumes/labels/ribosomes", } # for datasets without ribosome annotations volumes/labels/ribosomes doesn't exist, so use volumes/labels/all # instead (only one with the right resolution) datasets_no_ribo = { ArrayKeys.RAW: "volumes/raw/data/s0", ArrayKeys.GT_LABELS: "volumes/labels/all", ArrayKeys.MASK: mask_ds_name, ArrayKeys.RIBO_GT: "volumes/labels/all", } array_specs = { ArrayKeys.MASK: ArraySpec(interpolatable=False), ArrayKeys.RAW: ArraySpec(voxel_size=Coordinate(voxel_size_orig)), } array_specs_pred = {} inputs[net_io_names["raw"]] = ArrayKeys.RAW snapshot[ArrayKeys.RAW] = "volumes/raw" snapshot[ArrayKeys.GT_LABELS] = "volumes/labels/gt_labels" request.add(ArrayKeys.GT_LABELS, output_size, voxel_size=voxel_size_up) request.add(ArrayKeys.MASK, output_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.RIBO_GT, output_size, voxel_size=voxel_size_up) request.add(ArrayKeys.RAW, input_size, voxel_size=voxel_size_orig) for label in labels: datasets_no_ribo[label.mask_key] = "volumes/masks/" + label.labelname datasets_ribo[label.mask_key] = "volumes/masks/" + label.labelname array_specs[label.mask_key] = ArraySpec(interpolatable=False) array_specs_pred[label.pred_dist_key] = ArraySpec( voxel_size=voxel_size_orig, interpolatable=True) inputs[net_io_names["mask_" + label.labelname]] = label.mask_key inputs[net_io_names["gt_" + label.labelname]] = label.gt_dist_key if label.scale_loss or label.scale_key is not None: inputs[net_io_names["w_" + label.labelname]] = label.scale_key outputs[net_io_names[label.labelname]] = label.pred_dist_key snapshot[ label.gt_dist_key] = "volumes/labels/gt_dist_" + label.labelname snapshot[label. pred_dist_key] = "volumes/labels/pred_dist_" + label.labelname request.add(label.gt_dist_key, output_size, voxel_size=voxel_size_orig) request.add(label.pred_dist_key, output_size, voxel_size=voxel_size_orig) request.add(label.mask_key, output_size, voxel_size=voxel_size_orig) if label.scale_loss: request.add(label.scale_key, output_size, voxel_size=voxel_size_orig) snapshot_request.add(label.pred_dist_key, output_size, voxel_size=voxel_size_orig) if tf.train.latest_checkpoint("."): trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1]) print("Resuming training from", trained_until) else: trained_until = 0 print("Starting fresh training") for src in data_sources: if src not in ribo_sources: n5_source = N5Source(src.full_path, datasets=datasets_no_ribo, array_specs=array_specs) else: n5_source = N5Source(src.full_path, datasets=datasets_ribo, array_specs=array_specs) data_providers.append(n5_source) # create a tuple of data sources, one for each HDF file data_stream = tuple( provider + Normalize(ArrayKeys.RAW) + # ensures RAW is in float in [0, 1] # zero-pad provided RAW and MASK to be able to draw batches close to # the boundary of the available data # size more or less irrelevant as followed by Reject Node # Pad(ArrayKeys.RAW, context) + RandomLocation() + RejectEfficiently( # chose a random location inside the provided arrays ArrayKeys.MASK, min_masked=keep_thr) # Reject(ArrayKeys.MASK) # reject batches wich do contain less than 50% labelled data for provider in data_providers) train_pipeline = ( data_stream + RandomProvider(tuple([ds.labeled_voxels for ds in data_sources])) + gpn.SimpleAugment() + gpn.ElasticAugment( voxel_size_orig, (100, 100, 100), (10.0, 10.0, 10.0), (0, math.pi / 2.0), spatial_dims=3, subsample=8, ) + # ElasticAugment((40, 1000, 1000), (10., 0., 0.), (0, 0), subsample=8) + gpn.IntensityAugment(ArrayKeys.RAW, 0.25, 1.75, -0.5, 0.35) + GammaAugment(ArrayKeys.RAW, 0.5, 2.0) + IntensityScaleShift(ArrayKeys.RAW, 2, -1)) # ZeroOutConstSections(ArrayKeys.RAW)) for label in labels: if label.labelname != "ribosomes": train_pipeline += AddDistance( label_array_key=ArrayKeys.GT_LABELS, distance_array_key=label.gt_dist_key, normalize="tanh", normalize_args=dt_scaling_factor, label_id=label.labelid, factor=2, ) else: train_pipeline += AddDistance( label_array_key=ArrayKeys.RIBO_GT, distance_array_key=label.gt_dist_key, normalize="tanh+", normalize_args=(dt_scaling_factor, 8), label_id=label.labelid, factor=2, ) for label in labels: if label.scale_loss: train_pipeline += BalanceByThreshold(label.gt_dist_key, label.scale_key, mask=label.mask_key) train_pipeline = (train_pipeline + PreCache(cache_size=30, num_workers=30) + Train( net_name, optimizer=net_io_names["optimizer"], loss=net_io_names[loss_name], inputs=inputs, summary=net_io_names["summary"], log_dir="log", outputs=outputs, gradients={}, log_every=5, save_every=500, array_specs=array_specs_pred, ) + Snapshot( snapshot, every=500, output_filename="batch_{iteration}.hdf", output_dir="snapshots/", additional_request=snapshot_request, ) + PrintProfilingStats(every=500)) print("Starting training...") with build(train_pipeline) as b: for i in range(max_iteration): start_it = time.time() b.request_batch(request) time_it = time.time() - start_it logging.info("it {0:}: {1:}".format(i + 1, time_it)) print("Training finished")
def train_until(max_iteration, data_sources, input_shape, output_shape, dt_scaling_factor, loss_name, labels): ArrayKey('RAW') ArrayKey('ALPHA_MASK') ArrayKey('GT_LABELS') ArrayKey('MASK') data_providers = [] data_dir = "/groups/saalfeld/saalfeldlab/larissa/data/cell/{0:}.n5" voxel_size = Coordinate((2, 2, 2)) input_size = Coordinate(input_shape) * voxel_size output_size = Coordinate(output_shape) * voxel_size if tf.train.latest_checkpoint('.'): trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1]) print('Resuming training from', trained_until) else: trained_until = 0 print('Starting fresh training') for src in data_sources: n5_source = N5Source( os.path.join(data_dir.format(src)), datasets={ ArrayKeys.RAW: 'volumes/raw', ArrayKeys.GT_LABELS: 'volumes/labels/all', ArrayKeys.MASK: 'volumes/mask' }, array_specs={ArrayKeys.MASK: ArraySpec(interpolatable=False)}) data_providers.append(n5_source) with open('net_io_names.json', 'r') as f: net_io_names = json.load(f) inputs = dict() inputs[net_io_names['raw']] = ArrayKeys.RAW outputs = dict() snapshot = dict() snapshot[ArrayKeys.RAW] = 'volumes/raw' snapshot[ArrayKeys.GT_LABELS] = 'volumes/labels/gt_labels' for label in labels: inputs[net_io_names['gt_' + label.labelname]] = label.gt_dist_key if label.scale_loss or label.scale_key is not None: inputs[net_io_names['w_' + label.labelname]] = label.scale_key outputs[net_io_names[label.labelname]] = label.pred_dist_key snapshot[ label.gt_dist_key] = 'volumes/labels/gt_dist_' + label.labelname snapshot[label. pred_dist_key] = 'volumes/labels/pred_dist_' + label.labelname # specifiy which Arrays should be requested for each batch request = BatchRequest() snapshot_request = BatchRequest() request.add(ArrayKeys.RAW, input_size, voxel_size=voxel_size) request.add(ArrayKeys.GT_LABELS, output_size, voxel_size=voxel_size) request.add(ArrayKeys.MASK, output_size, voxel_size=voxel_size) for label in labels: request.add(label.gt_dist_key, output_size, voxel_size=voxel_size) snapshot_request.add(label.pred_dist_key, output_size, voxel_size=voxel_size) if label.scale_loss: request.add(label.scale_key, output_size, voxel_size=voxel_size) # create a tuple of data sources, one for each HDF file data_sources = tuple( provider + Normalize(ArrayKeys.RAW) + # ensures RAW is in float in [0, 1] # zero-pad provided RAW and MASK to be able to draw batches close to # the boundary of the available data # size more or less irrelevant as followed by Reject Node Pad(ArrayKeys.RAW, None) + RandomLocation(min_masked=0.5, mask=ArrayKeys.MASK ) # chose a random location inside the provided arrays #Reject(ArrayKeys.MASK) # reject batches wich do contain less than 50% labelled data for provider in data_providers) train_pipeline = ( data_sources + RandomProvider() + ElasticAugment( (100, 100, 100), (10., 10., 10.), (0, math.pi / 2.0), prob_slip=0, prob_shift=0, max_misalign=0, subsample=8) + SimpleAugment() + #ElasticAugment((40, 1000, 1000), (10., 0., 0.), (0, 0), subsample=8) + IntensityAugment(ArrayKeys.RAW, 0.95, 1.05, -0.05, 0.05) + IntensityScaleShift(ArrayKeys.RAW, 2, -1) + ZeroOutConstSections(ArrayKeys.RAW)) for label in labels: train_pipeline += AddDistance(label_array_key=ArrayKeys.GT_LABELS, distance_array_key=label.gt_dist_key, normalize='tanh', normalize_args=dt_scaling_factor, label_id=label.labelid) train_pipeline = (train_pipeline) for label in labels: if label.scale_loss: train_pipeline += BalanceByThreshold(label.gt_dist_key, label.scale_key, mask=ArrayKeys.MASK) train_pipeline = (train_pipeline + PreCache(cache_size=40, num_workers=10) + Train('build', optimizer=net_io_names['optimizer'], loss=net_io_names[loss_name], inputs=inputs, summary=net_io_names['summary'], log_dir='log', outputs=outputs, gradients={}) + Snapshot(snapshot, every=500, output_filename='batch_{iteration}.hdf', output_dir='snapshots/', additional_request=snapshot_request) + PrintProfilingStats(every=50)) print("Starting training...") with build(train_pipeline) as b: for i in range(max_iteration): b.request_batch(request) print("Training finished")
def train_until( max_iteration, cremi_dir, data_sources, input_shape, output_shape, dt_scaling_factor, loss_name, cache_size=10, num_workers=10, ): ArrayKey("RAW") ArrayKey("ALPHA_MASK") ArrayKey("GT_SYN_LABELS") ArrayKey("GT_LABELS") ArrayKey("GT_MASK") ArrayKey("TRAINING_MASK") ArrayKey("GT_SYN_SCALE") ArrayKey("LOSS_GRADIENT") ArrayKey("GT_SYN_DIST") ArrayKey("PREDICTED_SYN_DIST") ArrayKey("GT_BDY_DIST") ArrayKey("PREDICTED_BDY_DIST") data_providers = [] if tf.train.latest_checkpoint("."): trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1]) print("Resuming training from", trained_until) else: trained_until = 0 print("Starting fresh training") for sample in data_sources: print(sample) h5_source = Hdf5Source( os.path.join(cremi_dir, "sample_" + sample + "_cleftsorig.hdf"), datasets={ ArrayKeys.RAW: "volumes/raw", ArrayKeys.GT_SYN_LABELS: "volumes/labels/clefts", ArrayKeys.GT_MASK: "volumes/masks/groundtruth", ArrayKeys.TRAINING_MASK: "volumes/masks/validation", ArrayKeys.GT_LABELS: "volumes/labels/neuron_ids", }, array_specs={ArrayKeys.GT_MASK: ArraySpec(interpolatable=False)}, ) data_providers.append(h5_source) # todo: dvid source with open("net_io_names.json", "r") as f: net_io_names = json.load(f) voxel_size = Coordinate((40, 4, 4)) input_size = Coordinate(input_shape) * voxel_size output_size = Coordinate(output_shape) * voxel_size # input_size = Coordinate((132,)*3) * voxel_size # output_size = Coordinate((44,)*3) * voxel_size # specifiy which volumes should be requested for each batch request = BatchRequest() request.add(ArrayKeys.RAW, input_size) request.add(ArrayKeys.GT_SYN_LABELS, output_size) request.add(ArrayKeys.GT_LABELS, output_size) request.add(ArrayKeys.GT_BDY_DIST, output_size) request.add(ArrayKeys.GT_MASK, output_size) request.add(ArrayKeys.TRAINING_MASK, output_size) request.add(ArrayKeys.GT_SYN_SCALE, output_size) request.add(ArrayKeys.GT_SYN_DIST, output_size) # create a tuple of data sources, one for each HDF file data_sources = tuple( provider + Normalize(ArrayKeys.RAW) + # ensures RAW is in float in [0, 1] # zero-pad provided RAW and GT_MASK to be able to draw batches close to # the boundary of the available data # size more or less irrelevant as followed by Reject Node Pad(ArrayKeys.RAW, None) + Pad(ArrayKeys.GT_MASK, None) + Pad(ArrayKeys.TRAINING_MASK, None) + RandomLocation() + Reject( # chose a random location inside the provided arrays ArrayKeys.GT_MASK ) + Reject( # reject batches wich do contain less than 50% labelled data ArrayKeys.TRAINING_MASK, min_masked=0.99 ) + Reject(ArrayKeys.GT_LABELS, min_masked=0.0, reject_probability=0.95) for provider in data_providers ) snapshot_request = BatchRequest( { ArrayKeys.LOSS_GRADIENT: request[ArrayKeys.GT_SYN_LABELS], ArrayKeys.PREDICTED_SYN_DIST: request[ArrayKeys.GT_SYN_LABELS], ArrayKeys.PREDICTED_BDY_DIST: request[ArrayKeys.GT_LABELS], ArrayKeys.LOSS_GRADIENT: request[ArrayKeys.GT_SYN_DIST], } ) artifact_source = ( Hdf5Source( os.path.join(cremi_dir, "sample_ABC_padded_20160501.defects.hdf"), datasets={ ArrayKeys.RAW: "defect_sections/raw", ArrayKeys.ALPHA_MASK: "defect_sections/mask", }, array_specs={ ArrayKeys.RAW: ArraySpec(voxel_size=(40, 4, 4)), ArrayKeys.ALPHA_MASK: ArraySpec(voxel_size=(40, 4, 4)), }, ) + RandomLocation(min_masked=0.05, mask=ArrayKeys.ALPHA_MASK) + Normalize(ArrayKeys.RAW) + IntensityAugment(ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) + ElasticAugment((4, 40, 40), (0, 2, 2), (0, math.pi / 2.0), subsample=8) + SimpleAugment(transpose_only=[1, 2]) ) train_pipeline = ( data_sources + RandomProvider() + ElasticAugment( (4, 40, 40), (0.0, 2.0, 2.0), (0, math.pi / 2.0), prob_slip=0.05, prob_shift=0.05, max_misalign=10, subsample=8, ) + SimpleAugment(transpose_only=[1, 2]) + IntensityAugment(ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) + DefectAugment( ArrayKeys.RAW, prob_missing=0.03, prob_low_contrast=0.01, prob_artifact=0.03, artifact_source=artifact_source, artifacts=ArrayKeys.RAW, artifacts_mask=ArrayKeys.ALPHA_MASK, contrast_scale=0.5, ) + IntensityScaleShift(ArrayKeys.RAW, 2, -1) + ZeroOutConstSections(ArrayKeys.RAW) + GrowBoundary(ArrayKeys.GT_LABELS, ArrayKeys.GT_MASK, steps=1, only_xy=True) + AddBoundaryDistance( label_array_key=ArrayKeys.GT_LABELS, distance_array_key=ArrayKeys.GT_BDY_DIST, normalize="tanh", normalize_args=100, ) + AddDistance( label_array_key=ArrayKeys.GT_SYN_LABELS, distance_array_key=ArrayKeys.GT_SYN_DIST, normalize="tanh", normalize_args=dt_scaling_factor, ) + BalanceLabels( ArrayKeys.GT_SYN_LABELS, ArrayKeys.GT_SYN_SCALE, ArrayKeys.GT_MASK ) + PreCache(cache_size=cache_size, num_workers=num_workers) + Train( "unet", optimizer=net_io_names["optimizer"], loss=net_io_names[loss_name], inputs={ net_io_names["raw"]: ArrayKeys.RAW, net_io_names["gt_syn_dist"]: ArrayKeys.GT_SYN_DIST, net_io_names["gt_bdy_dist"]: ArrayKeys.GT_BDY_DIST, net_io_names["loss_weights"]: ArrayKeys.GT_SYN_SCALE, net_io_names["mask"]: ArrayKeys.GT_MASK, }, summary=net_io_names["summary"], log_dir="log", outputs={ net_io_names["syn_dist"]: ArrayKeys.PREDICTED_SYN_DIST, net_io_names["bdy_dist"]: ArrayKeys.PREDICTED_BDY_DIST, }, gradients={net_io_names["syn_dist"]: ArrayKeys.LOSS_GRADIENT}, ) + Snapshot( { ArrayKeys.RAW: "volumes/raw", ArrayKeys.GT_SYN_LABELS: "volumes/labels/gt_clefts", ArrayKeys.GT_SYN_DIST: "volumes/labels/gt_clefts_dist", ArrayKeys.PREDICTED_SYN_DIST: "volumes/labels/pred_clefts_dist", ArrayKeys.LOSS_GRADIENT: "volumes/loss_gradient", ArrayKeys.GT_LABELS: "volumes/labels/neuron_ids", ArrayKeys.PREDICTED_BDY_DIST: "volumes/labels/pred_bdy_dist", ArrayKeys.GT_BDY_DIST: "volumes/labels/gt_bdy_dist", }, every=500, output_filename="batch_{iteration}.hdf", output_dir="snapshots/", additional_request=snapshot_request, ) + PrintProfilingStats(every=50) ) print("Starting training...") with build(train_pipeline) as b: for i in range(max_iteration): b.request_batch(request) print("Training finished")
def train_until(max_iteration, data_sources, input_shape, output_shape, dt_scaling_factor, loss_name): ArrayKey('RAW') ArrayKey('ALPHA_MASK') ArrayKey('GT_SYN_LABELS') ArrayKey('GT_LABELS') ArrayKey('GT_MASK') ArrayKey('TRAINING_MASK') ArrayKey('GT_SYN_SCALE') ArrayKey('LOSS_GRADIENT') ArrayKey('GT_SYN_DIST') ArrayKey('PREDICTED_SYN_DIST') ArrayKey('GT_BDY_DIST') ArrayKey('PREDICTED_BDY_DIST') data_providers = [] cremi_dir = "/groups/saalfeld/saalfeldlab/larissa/data/cremi-2017/" if tf.train.latest_checkpoint('.'): trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1]) print('Resuming training from', trained_until) else: trained_until = 0 print('Starting fresh training') for sample in data_sources: print(sample) h5_source = Hdf5Source( os.path.join(cremi_dir, 'sample_'+sample+'_cleftsorig.hdf'), datasets={ ArrayKeys.RAW: 'volumes/raw', ArrayKeys.GT_SYN_LABELS: 'volumes/labels/clefts', ArrayKeys.GT_MASK: 'volumes/masks/groundtruth', ArrayKeys.TRAINING_MASK: 'volumes/masks/validation', ArrayKeys.GT_LABELS: 'volumes/labels/neuron_ids' }, array_specs={ ArrayKeys.GT_MASK: ArraySpec(interpolatable=False) } ) data_providers.append(h5_source) #todo: dvid source with open('net_io_names.json', 'r') as f: net_io_names = json.load(f) voxel_size = Coordinate((40, 4, 4)) input_size = Coordinate(input_shape) * voxel_size output_size = Coordinate(output_shape) * voxel_size # input_size = Coordinate((132,)*3) * voxel_size # output_size = Coordinate((44,)*3) * voxel_size # specifiy which volumes should be requested for each batch request = BatchRequest() request.add(ArrayKeys.RAW, input_size) request.add(ArrayKeys.GT_SYN_LABELS, output_size) request.add(ArrayKeys.GT_LABELS, output_size) request.add(ArrayKeys.GT_BDY_DIST, output_size) request.add(ArrayKeys.GT_MASK, output_size) request.add(ArrayKeys.TRAINING_MASK, output_size) request.add(ArrayKeys.GT_SYN_SCALE, output_size) request.add(ArrayKeys.GT_SYN_DIST, output_size) # create a tuple of data sources, one for each HDF file data_sources = tuple( provider + Normalize(ArrayKeys.RAW) + # ensures RAW is in float in [0, 1] # zero-pad provided RAW and GT_MASK to be able to draw batches close to # the boundary of the available data # size more or less irrelevant as followed by Reject Node Pad(ArrayKeys.RAW, None) + Pad(ArrayKeys.GT_MASK, None) + Pad(ArrayKeys.TRAINING_MASK, None) + RandomLocation() + # chose a random location inside the provided arrays Reject(ArrayKeys.GT_MASK) + # reject batches wich do contain less than 50% labelled data Reject(ArrayKeys.TRAINING_MASK, min_masked=0.99) + Reject(ArrayKeys.GT_LABELS, min_masked=0.0, reject_probability=0.95) for provider in data_providers) snapshot_request = BatchRequest({ ArrayKeys.LOSS_GRADIENT: request[ArrayKeys.GT_SYN_LABELS], ArrayKeys.PREDICTED_SYN_DIST: request[ArrayKeys.GT_SYN_LABELS], ArrayKeys.PREDICTED_BDY_DIST: request[ArrayKeys.GT_LABELS], ArrayKeys.LOSS_GRADIENT: request[ArrayKeys.GT_SYN_DIST], }) artifact_source = ( Hdf5Source( os.path.join(cremi_dir, 'sample_ABC_padded_20160501.defects.hdf'), datasets={ ArrayKeys.RAW: 'defect_sections/raw', ArrayKeys.ALPHA_MASK: 'defect_sections/mask', }, array_specs={ ArrayKeys.RAW: ArraySpec(voxel_size=(40, 4, 4)), ArrayKeys.ALPHA_MASK: ArraySpec(voxel_size=(40, 4, 4)), } ) + RandomLocation(min_masked=0.05, mask=ArrayKeys.ALPHA_MASK) + Normalize(ArrayKeys.RAW) + IntensityAugment(ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) + ElasticAugment((4, 40, 40), (0, 2, 2), (0, math.pi/2.0), subsample=8) + SimpleAugment(transpose_only=[1,2]) ) train_pipeline = ( data_sources + RandomProvider() + ElasticAugment((4, 40, 40), (0., 2., 2.), (0, math.pi/2.0), prob_slip=0.05, prob_shift=0.05, max_misalign=10, subsample=8) + SimpleAugment(transpose_only=[1,2]) + IntensityAugment(ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) + DefectAugment(ArrayKeys.RAW, prob_missing=0.03, prob_low_contrast=0.01, prob_artifact=0.03, artifact_source=artifact_source, artifacts=ArrayKeys.RAW, artifacts_mask=ArrayKeys.ALPHA_MASK, contrast_scale=0.5) + IntensityScaleShift(ArrayKeys.RAW, 2, -1) + ZeroOutConstSections(ArrayKeys.RAW) + GrowBoundary(ArrayKeys.GT_LABELS, ArrayKeys.GT_MASK, steps=1, only_xy=True) + AddBoundaryDistance(label_array_key=ArrayKeys.GT_LABELS, distance_array_key=ArrayKeys.GT_BDY_DIST, normalize='tanh', normalize_args=100) + AddDistance(label_array_key=ArrayKeys.GT_SYN_LABELS, distance_array_key=ArrayKeys.GT_SYN_DIST, normalize='tanh', normalize_args=dt_scaling_factor ) + BalanceLabels(ArrayKeys.GT_SYN_LABELS, ArrayKeys.GT_SYN_SCALE, ArrayKeys.GT_MASK) + PreCache( cache_size=40, num_workers=10)+ Train( 'unet', optimizer=net_io_names['optimizer'], loss=net_io_names[loss_name], inputs={ net_io_names['raw']: ArrayKeys.RAW, net_io_names['gt_syn_dist']: ArrayKeys.GT_SYN_DIST, net_io_names['gt_bdy_dist']: ArrayKeys.GT_BDY_DIST, net_io_names['loss_weights']: ArrayKeys.GT_SYN_SCALE, net_io_names['mask']: ArrayKeys.GT_MASK, }, summary=net_io_names['summary'], log_dir='log', outputs={ net_io_names['syn_dist']: ArrayKeys.PREDICTED_SYN_DIST, net_io_names['bdy_dist']: ArrayKeys.PREDICTED_BDY_DIST }, gradients={ net_io_names['syn_dist']: ArrayKeys.LOSS_GRADIENT }) + Snapshot({ ArrayKeys.RAW: 'volumes/raw', ArrayKeys.GT_SYN_LABELS: 'volumes/labels/gt_clefts', ArrayKeys.GT_SYN_DIST: 'volumes/labels/gt_clefts_dist', ArrayKeys.PREDICTED_SYN_DIST: 'volumes/labels/pred_clefts_dist', ArrayKeys.LOSS_GRADIENT: 'volumes/loss_gradient', ArrayKeys.GT_LABELS: 'volumes/labels/neuron_ids', ArrayKeys.PREDICTED_BDY_DIST: 'volumes/labels/pred_bdy_dist', ArrayKeys.GT_BDY_DIST: 'volumes/labels/gt_bdy_dist' }, every=500, output_filename='batch_{iteration}.hdf', output_dir='snapshots/', additional_request=snapshot_request) + PrintProfilingStats(every=50)) print("Starting training...") with build(train_pipeline) as b: for i in range(max_iteration): b.request_batch(request) print("Training finished")
def train_until( max_iteration, data_sources, input_shape, output_shape, dt_scaling_factor, loss_name ): raw = ArrayKey("RAW") # ArrayKey('ALPHA_MASK') clefts = ArrayKey("GT_LABELS") mask = ArrayKey("GT_MASK") scale = ArrayKey("GT_SCALE") # grad = ArrayKey('LOSS_GRADIENT') gt_dist = ArrayKey("GT_DIST") pred_dist = ArrayKey("PREDICTED_DIST") data_providers = [] if tf.train.latest_checkpoint("."): trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1]) print("Resuming training from", trained_until) else: trained_until = 0 print("Starting fresh training") if trained_until >= max_iteration: return data_dir = "/groups/saalfeld/saalfeldlab/larissa/data/fib19/mine/" for sample in data_sources: print(sample) h5_source = Hdf5Source( os.path.join(data_dir, "cube{0:}.hdf".format(sample)), datasets={ raw: "volumes/raw", clefts: "volumes/labels/clefts", mask: "/volumes/masks/groundtruth", }, array_specs={mask: ArraySpec(interpolatable=False)}, ) data_providers.append(h5_source) with open("net_io_names.json", "r") as f: net_io_names = json.load(f) voxel_size = Coordinate((8, 8, 8)) input_size = Coordinate(input_shape) * voxel_size output_size = Coordinate(output_shape) * voxel_size # input_size = Coordinate((132,)*3) * voxel_size # output_size = Coordinate((44,)*3) * voxel_size # specifiy which Arrays should be requested for each batch request = BatchRequest() request.add(raw, input_size) request.add(clefts, output_size) request.add(mask, output_size) # request.add(ArrayKeys.TRAINING_MASK, output_size) request.add(scale, output_size) request.add(gt_dist, output_size) # create a tuple of data sources, one for each HDF file data_sources = tuple( provider + Normalize(ArrayKeys.RAW) + # ensures RAW is in float in [0, 1] # zero-pad provided RAW and GT_MASK to be able to draw batches close to # the boundary of the available data # size more or less irrelevant as followed by Reject Node Pad(raw, None) + RandomLocation() + # chose a random location inside the provided arrays # Reject(ArrayKeys.GT_MASK) + # reject batches wich do contain less than 50% labelled data # Reject(ArrayKeys.TRAINING_MASK, min_masked=0.99) + Reject(mask=mask) + Reject(clefts, min_masked=0.0, reject_probability=0.95) for provider in data_providers ) snapshot_request = BatchRequest({pred_dist: request[clefts]}) train_pipeline = ( data_sources + RandomProvider() + ElasticAugment( (40, 40, 40), (2.0, 2.0, 2.0), (0, math.pi / 2.0), prob_slip=0.01, prob_shift=0.01, max_misalign=1, subsample=8, ) + SimpleAugment() + IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1) + IntensityScaleShift(raw, 2, -1) + ZeroOutConstSections(raw) + # GrowBoundary(steps=1) + # SplitAndRenumberSegmentationLabels() + # AddGtAffinities(malis.mknhood3d()) + AddDistance( label_array_key=clefts, distance_array_key=gt_dist, normalize="tanh", normalize_args=dt_scaling_factor, ) + # BalanceLabels(clefts, scale, mask) + BalanceByThreshold(labels=ArrayKeys.GT_DIST, scales=ArrayKeys.GT_SCALE) + # { # ArrayKeys.GT_AFFINITIES: ArrayKeys.GT_SCALE # }, # { # ArrayKeys.GT_AFFINITIES: ArrayKeys.GT_MASK # }) + PreCache(cache_size=40, num_workers=10) + Train( "unet", optimizer=net_io_names["optimizer"], loss=net_io_names[loss_name], inputs={ net_io_names["raw"]: raw, net_io_names["gt_dist"]: gt_dist, net_io_names["loss_weights"]: scale, }, summary=net_io_names["summary"], log_dir="log", outputs={net_io_names["dist"]: pred_dist}, gradients={}, ) + Snapshot( { raw: "volumes/raw", clefts: "volumes/labels/gt_clefts", gt_dist: "volumes/labels/gt_clefts_dist", pred_dist: "volumes/labels/pred_clefts_dist", }, dataset_dtypes={clefts: np.uint64}, every=500, output_filename="batch_{iteration}.hdf", output_dir="snapshots/", additional_request=snapshot_request, ) + PrintProfilingStats(every=50) ) print("Starting training...") with build(train_pipeline) as b: for i in range(max_iteration - trained_until): b.request_batch(request) print("Training finished")
def train_until(max_iteration, data_sources, ribo_sources, input_shape, output_shape, \ dt_scaling_factor, loss_name, labels, net_name, min_masked_voxels=17561., mask_ds_name='volumes/masks/training'): with open('net_io_names.json', 'r') as f: net_io_names = json.load(f) ArrayKey('RAW') ArrayKey('ALPHA_MASK') ArrayKey('GT_LABELS') ArrayKey('MASK') ArrayKey('RIBO_GT') voxel_size_up = Coordinate((2, 2, 2)) voxel_size_input = Coordinate((8, 8, 8)) voxel_size_output = Coordinate((4, 4, 4)) input_size = Coordinate(input_shape) * voxel_size_input output_size = Coordinate(output_shape) * voxel_size_output # context = input_size-output_size keep_thr = float(min_masked_voxels) / np.prod(output_shape) data_providers = [] inputs = dict() outputs = dict() snapshot = dict() request = BatchRequest() snapshot_request = BatchRequest() datasets_ribo = { ArrayKeys.RAW: None, ArrayKeys.GT_LABELS: 'volumes/labels/all', ArrayKeys.MASK: mask_ds_name, ArrayKeys.RIBO_GT: 'volumes/labels/ribosomes', } # for datasets without ribosome annotations volumes/labels/ribosomes doesn't exist, so use volumes/labels/all # instead (only one with the right resolution) datasets_no_ribo = { ArrayKeys.RAW: None, ArrayKeys.GT_LABELS: 'volumes/labels/all', ArrayKeys.MASK: mask_ds_name, ArrayKeys.RIBO_GT: 'volumes/labels/all', } array_specs = { ArrayKeys.MASK: ArraySpec(interpolatable=False), ArrayKeys.RAW: ArraySpec(voxel_size=Coordinate(voxel_size_input)) } array_specs_pred = {} inputs[net_io_names['raw']] = ArrayKeys.RAW snapshot[ArrayKeys.RAW] = 'volumes/raw' snapshot[ArrayKeys.GT_LABELS] = 'volumes/labels/gt_labels' request.add(ArrayKeys.GT_LABELS, output_size, voxel_size=voxel_size_up) request.add(ArrayKeys.MASK, output_size, voxel_size=voxel_size_output) request.add(ArrayKeys.RIBO_GT, output_size, voxel_size=voxel_size_up) request.add(ArrayKeys.RAW, input_size, voxel_size=voxel_size_input) for label in labels: datasets_no_ribo[label.mask_key] = 'volumes/masks/' + label.labelname datasets_ribo[label.mask_key] = 'volumes/masks/' + label.labelname array_specs[label.mask_key] = ArraySpec(interpolatable=False) array_specs_pred[label.pred_dist_key] = ArraySpec( voxel_size=voxel_size_output, interpolatable=True) inputs[net_io_names['mask_' + label.labelname]] = label.mask_key inputs[net_io_names['gt_' + label.labelname]] = label.gt_dist_key if label.scale_loss or label.scale_key is not None: inputs[net_io_names['w_' + label.labelname]] = label.scale_key outputs[net_io_names[label.labelname]] = label.pred_dist_key snapshot[ label.gt_dist_key] = 'volumes/labels/gt_dist_' + label.labelname snapshot[label. pred_dist_key] = 'volumes/labels/pred_dist_' + label.labelname request.add(label.gt_dist_key, output_size, voxel_size=voxel_size_output) request.add(label.pred_dist_key, output_size, voxel_size=voxel_size_output) request.add(label.mask_key, output_size, voxel_size=voxel_size_output) if label.scale_loss: request.add(label.scale_key, output_size, voxel_size=voxel_size_output) snapshot_request.add(label.pred_dist_key, output_size, voxel_size=voxel_size_output) if tf.train.latest_checkpoint('.'): trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1]) print('Resuming training from', trained_until) else: trained_until = 0 print('Starting fresh training') for src in data_sources: for subsample_variant in range(8): dnr = datasets_no_ribo.copy() dr = datasets_ribo.copy() dnr[ArrayKeys.RAW] = 'volumes/subsampled/raw{0:}/'.format( subsample_variant) dr[ArrayKeys.RAW] = 'volumes/subsampled/raw{0:}/'.format( subsample_variant) if src not in ribo_sources: n5_source = N5Source(src.full_path, datasets=dnr, array_specs=array_specs) else: n5_source = N5Source(src.full_path, datasets=dr, array_specs=array_specs) data_providers.append(n5_source) # create a tuple of data sources, one for each HDF file data_stream = tuple( provider + Normalize(ArrayKeys.RAW) + # ensures RAW is in float in [0, 1] # zero-pad provided RAW and MASK to be able to draw batches close to # the boundary of the available data # size more or less irrelevant as followed by Reject Node # Pad(ArrayKeys.RAW, context) + RandomLocation() + # chose a random location inside the provided arrays Reject(ArrayKeys.MASK, min_masked=keep_thr) for provider in data_providers) train_pipeline = ( data_stream + RandomProvider( tuple(np.repeat([ds.labeled_voxels for ds in data_sources], 8))) + gpn.SimpleAugment() + gpn.ElasticAugment(voxel_size_output, (100, 100, 100), (10., 10., 10.), (0, math.pi / 2.0), spatial_dims=3, subsample=8) + gpn.IntensityAugment(ArrayKeys.RAW, 0.25, 1.75, -0.5, 0.35) + GammaAugment(ArrayKeys.RAW, 0.5, 2.0) + IntensityScaleShift(ArrayKeys.RAW, 2, -1)) for label in labels: if label.labelname != 'ribosomes': train_pipeline += AddDistance(label_array_key=ArrayKeys.GT_LABELS, distance_array_key=label.gt_dist_key, normalize='tanh', normalize_args=dt_scaling_factor, label_id=label.labelid, factor=2) else: train_pipeline += AddDistance(label_array_key=ArrayKeys.RIBO_GT, distance_array_key=label.gt_dist_key, normalize='tanh+', normalize_args=(dt_scaling_factor, 8), label_id=label.labelid, factor=2) for label in labels: if label.scale_loss: train_pipeline += BalanceByThreshold(label.gt_dist_key, label.scale_key, mask=label.mask_key) train_pipeline = (train_pipeline + PreCache(cache_size=10, num_workers=20) + Train(net_name, optimizer=net_io_names['optimizer'], loss=net_io_names[loss_name], inputs=inputs, summary=net_io_names['summary'], log_dir='log', outputs=outputs, gradients={}, log_every=5, save_every=500, array_specs=array_specs_pred) + Snapshot(snapshot, every=500, output_filename='batch_{iteration}.hdf', output_dir='snapshots/', additional_request=snapshot_request) + PrintProfilingStats(every=500)) print("Starting training...") with build(train_pipeline) as b: for i in range(max_iteration): start_it = time.time() b.request_batch(request) time_it = time.time() - start_it logging.info('it {0:}: {1:}'.format(i + 1, time_it)) print("Training finished")
def train_until(max_iteration, data_dir, data_sources, input_shape, output_shape, loss_name): ArrayKey('RAW') ArrayKey('ALPHA_MASK') ArrayKey('GT_LABELS') ArrayKey('GT_DIST_SCALE') # ArrayKey('GT_AFF_SCALE') ArrayKey('LOSS_GRADIENT') ArrayKey('GT_DIST') ArrayKey('PREDICTED_DIST') # ArrayKey('GT_AFF') # ArrayKey('PREDICTED_AFF1') # ArrayKey('PREDICTED_AFF3') # ArrayKey('PREDICTED_AFF9') data_providers = [] if tf.train.latest_checkpoint('.'): trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1]) print('Resuming training from', trained_until) else: trained_until = 0 print('Starting fresh training') for sample in data_sources: h5_source = Hdf5Source(data_dir, datasets={ ArrayKeys.RAW: sample + '/image', ArrayKeys.GT_LABELS: sample + '/mask', }, array_specs={ ArrayKeys.RAW: ArraySpec(voxel_size=Coordinate((1, 1))), ArrayKeys.GT_LABELS: ArraySpec(voxel_size=Coordinate((1, 1))) }) data_providers.append(h5_source) #todo: dvid source with open('net_io_names.json', 'r') as f: net_io_names = json.load(f) voxel_size = Coordinate((1, 1)) input_size = Coordinate(input_shape) * voxel_size output_size = Coordinate(output_shape) * voxel_size # specifiy which volumes should be requested for each batch request = BatchRequest() request.add(ArrayKeys.RAW, input_size) request.add(ArrayKeys.GT_LABELS, output_size) # request.add(ArrayKeys.GT_AFF, output_size) request.add(ArrayKeys.GT_DIST, output_size) request.add(ArrayKeys.GT_DIST_SCALE, output_size) # request.add(ArrayKeys.GT_AFF_SCALE, output_size) # create a tuple of data sources, one for each HDF file data_sources = tuple( provider + Normalize(ArrayKeys.RAW) + # ensures RAW is in float in [0, 1] # zero-pad provided RAW and GT_MASK to be able to draw batches close to # the boundary of the available data # size more or less irrelevant as followed by Reject Node Pad(ArrayKeys.RAW, None) + RandomLocation() + # chose a random location inside the provided arrays Reject(ArrayKeys.GT_LABELS, min_masked=0.0, reject_probability=0.95) for provider in data_providers) snapshot_request = BatchRequest({ ArrayKeys.LOSS_GRADIENT: request[ArrayKeys.GT_LABELS], ArrayKeys.PREDICTED_DIST: request[ArrayKeys.GT_DIST], # ArrayKeys.PREDICTED_AFF1: request[ArrayKeys.GT_AFF], # ArrayKeys.PREDICTED_AFF3: request[ArrayKeys.GT_AFF], # ArrayKeys.PREDICTED_AFF9: request[ArrayKeys.GT_AFF], ArrayKeys.LOSS_GRADIENT: request[ArrayKeys.GT_DIST], }) train_pipeline = ( data_sources + RandomProvider() + #ElasticAugment((40, 40), (2., 2.), (0, math.pi/2.0), # subsample=4, spatial_dims=2) + #SimpleAugment() + IntensityAugment(ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1) + #DefectAugment(ArrayKeys.RAW, prob_low_contrast=0.01, contrast_scale=0.5) + IntensityScaleShift(ArrayKeys.RAW, 2, -1) + # AddAffinities([[-1, 0], [0, -1], # [-3, 0], [0, -3], # [-9, 0], [0, -9]], # ArrayKeys.GT_LABELS, # ArrayKeys.GT_AFF) + AddDistance(label_array_key=ArrayKeys.GT_LABELS, distance_array_key=ArrayKeys.GT_DIST, normalize='tanh', normalize_args=150) + # BalanceLabels(ArrayKeys.GT_AFF, ArrayKeys.GT_AFF_SCALE) + BalanceByThreshold(labels=ArrayKeys.GT_LABELS, scales=ArrayKeys.GT_DIST_SCALE) + # PreCache( # cache_size=40, # num_workers=10) + Train( 'unet', optimizer=net_io_names['optimizer'], loss=net_io_names[loss_name], inputs={ net_io_names['raw']: ArrayKeys.RAW, net_io_names['gt_dist']: ArrayKeys.GT_DIST, # net_io_names['gt_aff']: ArrayKeys.GT_AFF, net_io_names['loss_weights_dist']: ArrayKeys.GT_DIST_SCALE, # net_io_names['loss_weights_aff']: ArrayKeys.GT_AFF_SCALE }, summary=net_io_names['summary'], log_dir='log', outputs={ net_io_names['dist']: ArrayKeys.PREDICTED_DIST, # net_io_names['aff1']: ArrayKeys.PREDICTED_AFF1, # net_io_names['aff3']: ArrayKeys.PREDICTED_AFF3, # net_io_names['aff9']: ArrayKeys.PREDICTED_AFF9 }, gradients={net_io_names['dist']: ArrayKeys.LOSS_GRADIENT}) + Snapshot( { ArrayKeys.RAW: 'volumes/raw', ArrayKeys.GT_DIST: 'volumes/labels/dist', # ArrayKeys.GT_AFF: 'volumes/labels/aff', ArrayKeys.GT_LABELS: 'volumes/labels/nuclei', ArrayKeys.PREDICTED_DIST: 'volumes/predictions/dist', # ArrayKeys.PREDICTED_AFF1: 'volumes/predictions/aff1', # ArrayKeys.PREDICTED_AFF3: 'volumes/predictions/aff3', # ArrayKeys.PREDICTED_AFF9: 'volumes/predictions/aff9', ArrayKeys.LOSS_GRADIENT: 'volumes/loss_gradient', }, every=500, output_filename='batch_{iteration}.hdf', output_dir='snapshots/', additional_request=snapshot_request) + PrintProfilingStats(every=50)) print("Starting training...") with build(train_pipeline) as b: for i in range(max_iteration): b.request_batch(request) print("Training finished")
def train_until(max_iteration, data_sources, input_shape, output_shape, dt_scaling_factor, loss_name): ArrayKey("RAW") ArrayKey("RAW_UP") ArrayKey("ALPHA_MASK") ArrayKey("GT_LABELS") ArrayKey("MASK") ArrayKey("MASK_UP") ArrayKey("GT_DIST_CENTROSOME") ArrayKey("GT_DIST_GOLGI") ArrayKey("GT_DIST_GOLGI_MEM") ArrayKey("GT_DIST_ER") ArrayKey("GT_DIST_ER_MEM") ArrayKey("GT_DIST_MVB") ArrayKey("GT_DIST_MVB_MEM") ArrayKey("GT_DIST_MITO") ArrayKey("GT_DIST_MITO_MEM") ArrayKey("GT_DIST_LYSOSOME") ArrayKey("GT_DIST_LYSOSOME_MEM") ArrayKey("PRED_DIST_CENTROSOME") ArrayKey("PRED_DIST_GOLGI") ArrayKey("PRED_DIST_GOLGI_MEM") ArrayKey("PRED_DIST_ER") ArrayKey("PRED_DIST_ER_MEM") ArrayKey("PRED_DIST_MVB") ArrayKey("PRED_DIST_MVB_MEM") ArrayKey("PRED_DIST_MITO") ArrayKey("PRED_DIST_MITO_MEM") ArrayKey("PRED_DIST_LYSOSOME") ArrayKey("PRED_DIST_LYSOSOME_MEM") ArrayKey("SCALE_CENTROSOME") ArrayKey("SCALE_GOLGI") ArrayKey("SCALE_GOLGI_MEM") ArrayKey("SCALE_ER") ArrayKey("SCALE_ER_MEM") ArrayKey("SCALE_MVB") ArrayKey("SCALE_MVB_MEM") ArrayKey("SCALE_MITO") ArrayKey("SCALE_MITO_MEM") ArrayKey("SCALE_LYSOSOME") ArrayKey("SCALE_LYSOSOME_MEM") data_providers = [] data_dir = "/groups/saalfeld/saalfeldlab/larissa/data/cell/{0:}.n5" voxel_size_up = Coordinate((4, 4, 4)) voxel_size_orig = Coordinate((8, 8, 8)) input_size = Coordinate(input_shape) * voxel_size_orig output_size = Coordinate(output_shape) * voxel_size_orig if tf.train.latest_checkpoint("."): trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1]) print("Resuming training from", trained_until) else: trained_until = 0 print("Starting fresh training") for src in data_sources: n5_source = N5Source( os.path.join(data_dir.format(src)), datasets={ ArrayKeys.RAW_UP: "volumes/raw", ArrayKeys.GT_LABELS: "volumes/labels/all", ArrayKeys.MASK_UP: "volumes/mask", }, array_specs={ArrayKeys.MASK_UP: ArraySpec(interpolatable=False)}, ) data_providers.append(n5_source) with open("net_io_names.json", "r") as f: net_io_names = json.load(f) # specifiy which Arrays should be requested for each batch request = BatchRequest() request.add(ArrayKeys.RAW, input_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.RAW_UP, input_size, voxel_size=voxel_size_up) request.add(ArrayKeys.GT_LABELS, output_size, voxel_size=voxel_size_up) request.add(ArrayKeys.MASK_UP, output_size, voxel_size=voxel_size_up) request.add(ArrayKeys.MASK, output_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.GT_DIST_CENTROSOME, output_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.GT_DIST_GOLGI, output_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.GT_DIST_GOLGI_MEM, output_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.GT_DIST_ER, output_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.GT_DIST_ER_MEM, output_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.GT_DIST_MVB, output_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.GT_DIST_MVB_MEM, output_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.GT_DIST_MITO, output_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.GT_DIST_MITO_MEM, output_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.GT_DIST_LYSOSOME, output_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.GT_DIST_LYSOSOME_MEM, output_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.SCALE_CENTROSOME, output_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.SCALE_GOLGI, output_size, voxel_size=voxel_size_orig) # request.add(ArrayKeys.SCALE_GOLGI_MEM, output_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.SCALE_ER, output_size, voxel_size=voxel_size_orig) # request.add(ArrayKeys.SCALE_ER_MEM, output_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.SCALE_MVB, output_size, voxel_size=voxel_size_orig) # request.add(ArrayKeys.SCALE_MVB_MEM, output_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.SCALE_MITO, output_size, voxel_size=voxel_size_orig) # request.add(ArrayKeys.SCALE_MITO_MEM, output_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.SCALE_LYSOSOME, output_size, voxel_size=voxel_size_orig) # request.add(ArrayKeys.SCALE_LYSOSOME_MEM, output_size, voxel_size=voxel_size_orig) # create a tuple of data sources, one for each HDF file data_sources = tuple( provider + Normalize(ArrayKeys.RAW_UP) + # ensures RAW is in float in [0, 1] # zero-pad provided RAW and MASK to be able to draw batches close to # the boundary of the available data # size more or less irrelevant as followed by Reject Node Pad(ArrayKeys.RAW_UP, None) + RandomLocation(min_masked=0.5, mask=ArrayKeys.MASK_UP ) # chose a random location inside the provided arrays # Reject(ArrayKeys.MASK) # reject batches wich do contain less than 50% labelled data for provider in data_providers) snapshot_request = BatchRequest() snapshot_request.add(ArrayKeys.PRED_DIST_CENTROSOME, output_size) snapshot_request.add(ArrayKeys.PRED_DIST_GOLGI, output_size) snapshot_request.add(ArrayKeys.PRED_DIST_GOLGI_MEM, output_size) snapshot_request.add(ArrayKeys.PRED_DIST_ER, output_size) snapshot_request.add(ArrayKeys.PRED_DIST_ER_MEM, output_size) snapshot_request.add(ArrayKeys.PRED_DIST_MVB, output_size) snapshot_request.add(ArrayKeys.PRED_DIST_MVB_MEM, output_size) snapshot_request.add(ArrayKeys.PRED_DIST_MITO, output_size) snapshot_request.add(ArrayKeys.PRED_DIST_MITO_MEM, output_size) snapshot_request.add(ArrayKeys.PRED_DIST_LYSOSOME, output_size) snapshot_request.add(ArrayKeys.PRED_DIST_LYSOSOME_MEM, output_size) train_pipeline = ( data_sources + RandomProvider() + ElasticAugment( (100, 100, 100), (10.0, 10.0, 10.0), (0, math.pi / 2.0), prob_slip=0, prob_shift=0, max_misalign=0, subsample=8, ) + SimpleAugment() + ElasticAugment( (40, 1000, 1000), (10.0, 0.0, 0.0), (0, 0), subsample=8) + IntensityAugment(ArrayKeys.RAW_UP, 0.9, 1.1, -0.1, 0.1) + IntensityScaleShift(ArrayKeys.RAW_UP, 2, -1) + ZeroOutConstSections(ArrayKeys.RAW_UP) + # GrowBoundary(steps=1) + # SplitAndRenumberSegmentationLabels() + # AddGtAffinities(malis.mknhood3d()) + AddDistance( label_array_key=ArrayKeys.GT_LABELS, distance_array_key=ArrayKeys.GT_DIST_CENTROSOME, normalize="tanh", normalize_args=dt_scaling_factor, label_id=1, factor=2, ) + AddDistance( label_array_key=ArrayKeys.GT_LABELS, distance_array_key=ArrayKeys.GT_DIST_GOLGI, normalize="tanh", normalize_args=dt_scaling_factor, label_id=(2, 11), factor=2, ) + AddDistance( label_array_key=ArrayKeys.GT_LABELS, distance_array_key=ArrayKeys.GT_DIST_GOLGI_MEM, normalize="tanh", normalize_args=dt_scaling_factor, label_id=11, factor=2, ) + AddDistance( label_array_key=ArrayKeys.GT_LABELS, distance_array_key=ArrayKeys.GT_DIST_ER, normalize="tanh", normalize_args=dt_scaling_factor, label_id=(3, 10), factor=2, ) + AddDistance( label_array_key=ArrayKeys.GT_LABELS, distance_array_key=ArrayKeys.GT_DIST_ER_MEM, normalize="tanh", normalize_args=dt_scaling_factor, label_id=10, factor=2, ) + AddDistance( label_array_key=ArrayKeys.GT_LABELS, distance_array_key=ArrayKeys.GT_DIST_MVB, normalize="tanh", normalize_args=dt_scaling_factor, label_id=(4, 9), factor=2, ) + AddDistance( label_array_key=ArrayKeys.GT_LABELS, distance_array_key=ArrayKeys.GT_DIST_MVB_MEM, normalize="tanh", normalize_args=dt_scaling_factor, label_id=9, factor=2, ) + AddDistance( label_array_key=ArrayKeys.GT_LABELS, distance_array_key=ArrayKeys.GT_DIST_MITO, normalize="tanh", normalize_args=dt_scaling_factor, label_id=(5, 8), factor=2, ) + AddDistance( label_array_key=ArrayKeys.GT_LABELS, distance_array_key=ArrayKeys.GT_DIST_MITO_MEM, normalize="tanh", normalize_args=dt_scaling_factor, label_id=8, factor=2, ) + AddDistance( label_array_key=ArrayKeys.GT_LABELS, distance_array_key=ArrayKeys.GT_DIST_LYSOSOME, normalize="tanh", normalize_args=dt_scaling_factor, label_id=(6, 7), factor=2, ) + AddDistance( label_array_key=ArrayKeys.GT_LABELS, distance_array_key=ArrayKeys.GT_DIST_LYSOSOME_MEM, normalize="tanh", normalize_args=dt_scaling_factor, label_id=7, factor=2, ) + DownSample(ArrayKeys.MASK_UP, 2, ArrayKeys.MASK) + BalanceByThreshold( ArrayKeys.GT_DIST_CENTROSOME, ArrayKeys.SCALE_CENTROSOME, mask=ArrayKeys.MASK, ) + BalanceByThreshold(ArrayKeys.GT_DIST_GOLGI, ArrayKeys.SCALE_GOLGI, mask=ArrayKeys.MASK) + # BalanceByThreshold(ArrayKeys.GT_DIST_GOLGI_MEM, ArrayKeys.SCALE_GOLGI_MEM, mask=ArrayKeys.MASK) + BalanceByThreshold( ArrayKeys.GT_DIST_ER, ArrayKeys.SCALE_ER, mask=ArrayKeys.MASK) + # BalanceByThreshold(ArrayKeys.GT_DIST_ER_MEM, ArrayKeys.SCALE_ER_MEM, mask=ArrayKeys.MASK) + BalanceByThreshold( ArrayKeys.GT_DIST_MVB, ArrayKeys.SCALE_MVB, mask=ArrayKeys.MASK) + # BalanceByThreshold(ArrayKeys.GT_DIST_MVB_MEM, ArrayKeys.SCALE_MVB_MEM, mask=ArrayKeys.MASK) + BalanceByThreshold(ArrayKeys.GT_DIST_MITO, ArrayKeys.SCALE_MITO, mask=ArrayKeys.MASK) + # BalanceByThreshold(ArrayKeys.GT_DIST_MITO_MEM, ArrayKeys.SCALE_MITO_MEM, mask=ArrayKeys.MASK) + BalanceByThreshold(ArrayKeys.GT_DIST_LYSOSOME, ArrayKeys.SCALE_LYSOSOME, mask=ArrayKeys.MASK) + # BalanceByThreshold(ArrayKeys.GT_DIST_LYSOSOME_MEM, ArrayKeys.SCALE_LYSOSOME_MEM, mask=ArrayKeys.MASK) + # BalanceByThreshold( # labels=ArrayKeys.GT_DIST, # scales= ArrayKeys.GT_SCALE) + # { # ArrayKeys.GT_AFFINITIES: ArrayKeys.GT_SCALE # }, # { # ArrayKeys.GT_AFFINITIES: ArrayKeys.MASK # }) + DownSample(ArrayKeys.RAW_UP, 2, ArrayKeys.RAW) + PreCache(cache_size=40, num_workers=10) + Train( "build", optimizer=net_io_names["optimizer"], loss=net_io_names[loss_name], inputs={ net_io_names["raw"]: ArrayKeys.RAW, net_io_names["gt_centrosome"]: ArrayKeys.GT_DIST_CENTROSOME, net_io_names["gt_golgi"]: ArrayKeys.GT_DIST_GOLGI, net_io_names["gt_golgi_mem"]: ArrayKeys.GT_DIST_GOLGI_MEM, net_io_names["gt_er"]: ArrayKeys.GT_DIST_ER, net_io_names["gt_er_mem"]: ArrayKeys.GT_DIST_ER_MEM, net_io_names["gt_mvb"]: ArrayKeys.GT_DIST_MVB, net_io_names["gt_mvb_mem"]: ArrayKeys.GT_DIST_MVB_MEM, net_io_names["gt_mito"]: ArrayKeys.GT_DIST_MITO, net_io_names["gt_mito_mem"]: ArrayKeys.GT_DIST_MITO_MEM, net_io_names["gt_lysosome"]: ArrayKeys.GT_DIST_LYSOSOME, net_io_names["gt_lysosome_mem"]: ArrayKeys.GT_DIST_LYSOSOME_MEM, net_io_names["w_centrosome"]: ArrayKeys.SCALE_CENTROSOME, net_io_names["w_golgi"]: ArrayKeys.SCALE_GOLGI, net_io_names["w_golgi_mem"]: ArrayKeys.SCALE_GOLGI, net_io_names["w_er"]: ArrayKeys.SCALE_ER, net_io_names["w_er_mem"]: ArrayKeys.SCALE_ER, net_io_names["w_mvb"]: ArrayKeys.SCALE_MVB, net_io_names["w_mvb_mem"]: ArrayKeys.SCALE_MVB, net_io_names["w_mito"]: ArrayKeys.SCALE_MITO, net_io_names["w_mito_mem"]: ArrayKeys.SCALE_MITO, net_io_names["w_lysosome"]: ArrayKeys.SCALE_LYSOSOME, net_io_names["w_lysosome_mem"]: ArrayKeys.SCALE_LYSOSOME, }, summary=net_io_names["summary"], log_dir="log", outputs={ net_io_names["centrosome"]: ArrayKeys.PRED_DIST_CENTROSOME, net_io_names["golgi"]: ArrayKeys.PRED_DIST_GOLGI, net_io_names["golgi_mem"]: ArrayKeys.PRED_DIST_GOLGI_MEM, net_io_names["er"]: ArrayKeys.PRED_DIST_ER, net_io_names["er_mem"]: ArrayKeys.PRED_DIST_ER_MEM, net_io_names["mvb"]: ArrayKeys.PRED_DIST_MVB, net_io_names["mvb_mem"]: ArrayKeys.PRED_DIST_MVB_MEM, net_io_names["mito"]: ArrayKeys.PRED_DIST_MITO, net_io_names["mito_mem"]: ArrayKeys.PRED_DIST_MITO_MEM, net_io_names["lysosome"]: ArrayKeys.PRED_DIST_LYSOSOME, net_io_names["lysosome_mem"]: ArrayKeys.PRED_DIST_LYSOSOME_MEM, }, gradients={}, ) + Snapshot( { ArrayKeys.RAW: "volumes/raw", ArrayKeys.GT_LABELS: "volumes/labels/gt_labels", ArrayKeys.GT_DIST_CENTROSOME: "volumes/labels/gt_dist_centrosome", ArrayKeys.PRED_DIST_CENTROSOME: "volumes/labels/pred_dist_centrosome", ArrayKeys.GT_DIST_GOLGI: "volumes/labels/gt_dist_golgi", ArrayKeys.PRED_DIST_GOLGI: "volumes/labels/pred_dist_golgi", ArrayKeys.GT_DIST_GOLGI_MEM: "volumes/labels/gt_dist_golgi_mem", ArrayKeys.PRED_DIST_GOLGI_MEM: "volumes/labels/pred_dist_golgi_mem", ArrayKeys.GT_DIST_ER: "volumes/labels/gt_dist_er", ArrayKeys.PRED_DIST_ER: "volumes/labels/pred_dist_er", ArrayKeys.GT_DIST_ER_MEM: "volumes/labels/gt_dist_er_mem", ArrayKeys.PRED_DIST_ER_MEM: "volumes/labels/pred_dist_er_mem", ArrayKeys.GT_DIST_MVB: "volumes/labels/gt_dist_mvb", ArrayKeys.PRED_DIST_MVB: "volumes/labels/pred_dist_mvb", ArrayKeys.GT_DIST_MVB_MEM: "volumes/labels/gt_dist_mvb_mem", ArrayKeys.PRED_DIST_MVB_MEM: "volumes/labels/pred_dist_mvb_mem", ArrayKeys.GT_DIST_MITO: "volumes/labels/gt_dist_mito", ArrayKeys.PRED_DIST_MITO: "volumes/labels/pred_dist_mito", ArrayKeys.GT_DIST_MITO_MEM: "volumes/labels/gt_dist_mito_mem", ArrayKeys.PRED_DIST_MITO_MEM: "volumes/labels/pred_dist_mito_mem", ArrayKeys.GT_DIST_LYSOSOME: "volumes/labels/gt_dist_lysosome", ArrayKeys.PRED_DIST_LYSOSOME: "volumes/labels/pred_dist_lysosome", ArrayKeys.GT_DIST_LYSOSOME_MEM: "volumes/labels/gt_dist_lysosome_mem", ArrayKeys.PRED_DIST_LYSOSOME_MEM: "volumes/labels/pred_dist_lysosome_mem", }, every=500, output_filename="batch_{iteration}.hdf", output_dir="snapshots/", additional_request=snapshot_request, ) + PrintProfilingStats(every=50)) print("Starting training...") with build(train_pipeline) as b: for i in range(max_iteration): b.request_batch(request) print("Training finished")
def train_until(max_iteration, data_sources, input_shape, output_shape, dt_scaling_factor, loss_name, cremi_version, aligned): ArrayKey('RAW') ArrayKey('ALPHA_MASK') ArrayKey('GT_LABELS') ArrayKey('GT_CLEFTS') ArrayKey('GT_MASK') ArrayKey('TRAINING_MASK') ArrayKey('CLEFT_SCALE') ArrayKey('PRE_SCALE') ArrayKey('POST_SCALE') ArrayKey('LOSS_GRADIENT') ArrayKey('GT_CLEFT_DIST') ArrayKey('PRED_CLEFT_DIST') ArrayKey('GT_PRE_DIST') ArrayKey('PRED_PRE_DIST') ArrayKey('GT_POST_DIST') ArrayKey('PRED_POST_DIST') ArrayKey('GT_POST_DIST') data_providers = [] if cremi_version == '2016': cremi_dir = "/groups/saalfeld/saalfeldlab/larissa/data/cremi-2016/" filename = 'sample_{0:}_padded_20160501.' elif cremi_version == '2017': cremi_dir = "/groups/saalfeld/saalfeldlab/larissa/data/cremi-2017/" filename = 'sample_{0:}_padded_20170424.' if aligned: filename += 'aligned.' filename += '0bg.hdf' if tf.train.latest_checkpoint('.'): trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1]) print('Resuming training from', trained_until) else: trained_until = 0 print('Starting fresh training') for sample in data_sources: print(sample) h5_source = Hdf5Source( os.path.join(cremi_dir, filename.format(sample)), datasets={ ArrayKeys.RAW: 'volumes/raw', ArrayKeys.GT_CLEFTS: 'volumes/labels/clefts', ArrayKeys.GT_MASK: 'volumes/masks/groundtruth', ArrayKeys.TRAINING_MASK: 'volumes/masks/validation', ArrayKeys.GT_LABELS: 'volumes/labels/neuron_ids' }, array_specs={ ArrayKeys.GT_MASK: ArraySpec(interpolatable=False), ArrayKeys.GT_CLEFTS: ArraySpec(interpolatable=False), ArrayKeys.TRAINING_MASK: ArraySpec(interpolatable=False) }) data_providers.append(h5_source) if cremi_version == '2017': csv_files = [ os.path.join(cremi_dir, 'cleft-partners_' + sample + '_2017.csv') for sample in data_sources ] elif cremi_version == '2016': csv_files = [ os.path.join( cremi_dir, 'cleft-partners-' + sample + '-20160501.aligned.corrected.csv') for sample in data_sources ] cleft_to_pre, cleft_to_post = make_cleft_to_prepostsyn_neuron_id_dict( csv_files) print(cleft_to_pre, cleft_to_post) with open('net_io_names.json', 'r') as f: net_io_names = json.load(f) voxel_size = Coordinate((40, 4, 4)) input_size = Coordinate(input_shape) * voxel_size output_size = Coordinate(output_shape) * voxel_size context = input_size - output_size # specifiy which Arrays should be requested for each batch request = BatchRequest() request.add(ArrayKeys.RAW, input_size) request.add(ArrayKeys.GT_LABELS, output_size) request.add(ArrayKeys.GT_CLEFTS, output_size) request.add(ArrayKeys.GT_MASK, output_size) request.add(ArrayKeys.TRAINING_MASK, output_size) request.add(ArrayKeys.CLEFT_SCALE, output_size) request.add(ArrayKeys.GT_CLEFT_DIST, output_size) request.add(ArrayKeys.GT_PRE_DIST, output_size) request.add(ArrayKeys.GT_POST_DIST, output_size) # create a tuple of data sources, one for each HDF file data_sources = tuple( provider + Normalize(ArrayKeys.RAW) + # ensures RAW is in float in [0, 1] IntensityScaleShift(ArrayKeys.TRAINING_MASK, -1, 1) + # zero-pad provided RAW and GT_MASK to be able to draw batches close to # the boundary of the available data # size more or less irrelevant as followed by Reject Node Pad(ArrayKeys.RAW, None) + Pad(ArrayKeys.GT_MASK, None) + Pad(ArrayKeys.TRAINING_MASK, context) + RandomLocation(min_masked=0.99, mask=ArrayKeys.TRAINING_MASK) + # chose a random location inside the provided arrays Reject(ArrayKeys.GT_MASK) + # reject batches which do contain less than 50% labelled data Reject(ArrayKeys.GT_CLEFTS, min_masked=0.0, reject_probability=0.95) for provider in data_providers) snapshot_request = BatchRequest({ ArrayKeys.LOSS_GRADIENT: request[ArrayKeys.GT_CLEFTS], ArrayKeys.PRED_CLEFT_DIST: request[ArrayKeys.GT_CLEFT_DIST], ArrayKeys.PRED_PRE_DIST: request[ArrayKeys.GT_PRE_DIST], ArrayKeys.PRED_POST_DIST: request[ArrayKeys.GT_POST_DIST], }) train_pipeline = ( data_sources + RandomProvider() + ElasticAugment( (4, 40, 40), (0., 0., 0.), (0, math.pi / 2.0), prob_slip=0.0, prob_shift=0.0, max_misalign=0, subsample=8) + SimpleAugment(transpose_only=[1, 2], mirror_only=[1, 2]) + IntensityAugment( ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1, z_section_wise=False) + IntensityScaleShift(ArrayKeys.RAW, 2, -1) + ZeroOutConstSections(ArrayKeys.RAW) + AddDistance(label_array_key=ArrayKeys.GT_CLEFTS, distance_array_key=ArrayKeys.GT_CLEFT_DIST, normalize='tanh', normalize_args=dt_scaling_factor) + AddPrePostCleftDistance(ArrayKeys.GT_CLEFTS, ArrayKeys.GT_LABELS, ArrayKeys.GT_PRE_DIST, ArrayKeys.GT_POST_DIST, cleft_to_pre, cleft_to_post, normalize='tanh', normalize_args=dt_scaling_factor, include_cleft=False) + BalanceByThreshold(labels=ArrayKeys.GT_CLEFT_DIST, scales=ArrayKeys.CLEFT_SCALE, mask=ArrayKeys.GT_MASK) + BalanceByThreshold(labels=ArrayKeys.GT_PRE_DIST, scales=ArrayKeys.PRE_SCALE, mask=ArrayKeys.GT_MASK, threshold=-0.5) + BalanceByThreshold(labels=ArrayKeys.GT_POST_DIST, scales=ArrayKeys.POST_SCALE, mask=ArrayKeys.GT_MASK, threshold=-0.5) + PreCache(cache_size=40, num_workers=10) + Train( 'unet', optimizer=net_io_names['optimizer'], loss=net_io_names[loss_name], inputs={ net_io_names['raw']: ArrayKeys.RAW, net_io_names['gt_cleft_dist']: ArrayKeys.GT_CLEFT_DIST, net_io_names['gt_pre_dist']: ArrayKeys.GT_PRE_DIST, net_io_names['gt_post_dist']: ArrayKeys.GT_POST_DIST, net_io_names['loss_weights_cleft']: ArrayKeys.CLEFT_SCALE, net_io_names['loss_weights_pre']: ArrayKeys.CLEFT_SCALE, net_io_names['loss_weights_post']: ArrayKeys.CLEFT_SCALE, net_io_names['mask']: ArrayKeys.GT_MASK }, summary=net_io_names['summary'], log_dir='log', outputs={ net_io_names['cleft_dist']: ArrayKeys.PRED_CLEFT_DIST, net_io_names['pre_dist']: ArrayKeys.PRED_PRE_DIST, net_io_names['post_dist']: ArrayKeys.PRED_POST_DIST }, gradients={net_io_names['cleft_dist']: ArrayKeys.LOSS_GRADIENT}) + Snapshot( { ArrayKeys.RAW: 'volumes/raw', ArrayKeys.GT_CLEFTS: 'volumes/labels/gt_clefts', ArrayKeys.GT_CLEFT_DIST: 'volumes/labels/gt_clefts_dist', ArrayKeys.PRED_CLEFT_DIST: 'volumes/labels/pred_clefts_dist', ArrayKeys.LOSS_GRADIENT: 'volumes/loss_gradient', ArrayKeys.PRED_PRE_DIST: 'volumes/labels/pred_pre_dist', ArrayKeys.PRED_POST_DIST: 'volumes/labels/pred_post_dist', ArrayKeys.GT_PRE_DIST: 'volumes/labels/gt_pre_dist', ArrayKeys.GT_POST_DIST: 'volumes/labels/gt_post_dist' }, every=500, output_filename='batch_{iteration}.hdf', output_dir='snapshots/', additional_request=snapshot_request) + PrintProfilingStats(every=50)) print("Starting training...") with build(train_pipeline) as b: for i in range(max_iteration): b.request_batch(request) print("Training finished")
def train_until( max_iteration, data_sources, ribo_sources, dt_scaling_factor, loss_name, labels, scnet, raw_name="raw", min_masked_voxels=17561.0, ): with open("net_io_names.json", "r") as f: net_io_names = json.load(f) ArrayKey("ALPHA_MASK") ArrayKey("GT_LABELS") ArrayKey("MASK") ArrayKey("RIBO_GT") datasets_ribo = { ArrayKeys.GT_LABELS: "volumes/labels/all", ArrayKeys.MASK: "volumes/masks/training_cropped", ArrayKeys.RIBO_GT: "volumes/labels/ribosomes", } # for datasets without ribosome annotations volumes/labels/ribosomes doesn't exist, so use volumes/labels/all # instead (only one with the right resolution) datasets_no_ribo = { ArrayKeys.GT_LABELS: "volumes/labels/all", ArrayKeys.MASK: "volumes/masks/training_cropped", ArrayKeys.RIBO_GT: "volumes/labels/all", } array_specs = {ArrayKeys.MASK: ArraySpec(interpolatable=False)} array_specs_pred = {} # individual mask per label for label in labels: datasets_no_ribo[label.mask_key] = "volumes/masks/" + label.labelname datasets_ribo[label.mask_key] = "volumes/masks/" + label.labelname array_specs[label.mask_key] = ArraySpec(interpolatable=False) # inputs = {net_io_names['mask']: ArrayKeys.MASK} snapshot = {ArrayKeys.GT_LABELS: "volumes/labels/all"} request = BatchRequest() snapshot_request = BatchRequest() raw_array_keys = [] contexts = [] # input and output sizes in world coordinates input_sizes_wc = [ Coordinate(inp_sh) * Coordinate(vs) for inp_sh, vs in zip(scnet.input_shapes, scnet.voxel_sizes) ] output_size_wc = Coordinate(scnet.output_shapes[0]) * Coordinate( scnet.voxel_sizes[0] ) keep_thr = float(min_masked_voxels) / np.prod(scnet.output_shapes[0]) voxel_size_up = Coordinate((2, 2, 2)) voxel_size_orig = Coordinate((4, 4, 4)) assert voxel_size_orig == Coordinate( scnet.voxel_sizes[0] ) # make sure that scnet has the same base voxel size inputs = {} # add multiscale raw data as inputs for k, (inp_sh_wc, vs) in enumerate(zip(input_sizes_wc, scnet.voxel_sizes)): ak = ArrayKey("RAW_S{0:}".format(k)) raw_array_keys.append(ak) datasets_ribo[ak] = "volumes/{0:}/data/s{1:}".format(raw_name, k) datasets_no_ribo[ak] = "volumes/{0:}/data/s{1:}".format(raw_name, k) inputs[net_io_names["raw_{0:}".format(vs[0])]] = ak snapshot[ak] = "volumes/raw_s{0:}".format(k) array_specs[ak] = ArraySpec(voxel_size=Coordinate(vs)) request.add(ak, inp_sh_wc, voxel_size=Coordinate(vs)) contexts.append(inp_sh_wc - output_size_wc) outputs = dict() for label in labels: inputs[net_io_names["gt_" + label.labelname]] = label.gt_dist_key if label.scale_loss or label.scale_key is not None: inputs[net_io_names["w_" + label.labelname]] = label.scale_key inputs[net_io_names["mask_" + label.labelname]] = label.mask_key outputs[net_io_names[label.labelname]] = label.pred_dist_key snapshot[label.gt_dist_key] = "volumes/labels/gt_dist_" + label.labelname snapshot[label.pred_dist_key] = "volumes/labels/pred_dist_" + label.labelname array_specs_pred[label.pred_dist_key] = ArraySpec( voxel_size=voxel_size_orig, interpolatable=True ) request.add(ArrayKeys.GT_LABELS, output_size_wc, voxel_size=voxel_size_up) request.add(ArrayKeys.MASK, output_size_wc, voxel_size=voxel_size_orig) request.add(ArrayKeys.RIBO_GT, output_size_wc, voxel_size=voxel_size_up) for label in labels: request.add(label.gt_dist_key, output_size_wc, voxel_size=voxel_size_orig) snapshot_request.add( label.pred_dist_key, output_size_wc, voxel_size=voxel_size_orig ) request.add(label.pred_dist_key, output_size_wc, voxel_size=voxel_size_orig) request.add(label.mask_key, output_size_wc, voxel_size=voxel_size_orig) if label.scale_loss: request.add(label.scale_key, output_size_wc, voxel_size=voxel_size_orig) data_providers = [] if tf.train.latest_checkpoint("."): trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1]) print("Resuming training from", trained_until) else: trained_until = 0 print("Starting fresh training") for src in data_sources: if src not in ribo_sources: n5_source = N5Source( src.full_path, datasets=datasets_no_ribo, array_specs=array_specs ) else: n5_source = N5Source( src.full_path, datasets=datasets_ribo, array_specs=array_specs ) data_providers.append(n5_source) data_stream = [] for provider in data_providers: data_stream.append(provider) for ak, context in zip(raw_array_keys, contexts): data_stream[-1] += Normalize(ak) # data_stream[-1] += Pad(ak, context) # this shouldn't be necessary as I cropped the input data to have # sufficient padding data_stream[-1] += RandomLocation() data_stream[-1] += Reject(ArrayKeys.MASK, min_masked=keep_thr) data_stream = tuple(data_stream) train_pipeline = ( data_stream + RandomProvider(tuple([ds.labeled_voxels for ds in data_sources])) + gpn.SimpleAugment() + gpn.ElasticAugment( tuple(scnet.voxel_sizes[0]), (100, 100, 100), (10.0, 10.0, 10.0), (0, math.pi / 2.0), spatial_dims=3, subsample=8, ) + gpn.IntensityAugment(raw_array_keys, 0.25, 1.75, -0.5, 0.35) + GammaAugment(raw_array_keys, 0.5, 2.0) ) for ak in raw_array_keys: train_pipeline += IntensityScaleShift(ak, 2, -1) # train_pipeline += ZeroOutConstSections(ak) for label in labels: if label.labelname != "ribosomes": train_pipeline += AddDistance( label_array_key=ArrayKeys.GT_LABELS, distance_array_key=label.gt_dist_key, normalize="tanh", normalize_args=dt_scaling_factor, label_id=label.labelid, factor=2, ) else: train_pipeline += AddDistance( label_array_key=ArrayKeys.RIBO_GT, distance_array_key=label.gt_dist_key, normalize="tanh+", normalize_args=(dt_scaling_factor, 8), label_id=label.labelid, factor=2, ) for label in labels: if label.scale_loss: train_pipeline += BalanceByThreshold( label.gt_dist_key, label.scale_key, mask=label.mask_key ) train_pipeline = ( train_pipeline + PreCache(cache_size=10, num_workers=40) + Train( scnet.name, optimizer=net_io_names["optimizer"], loss=net_io_names[loss_name], inputs=inputs, summary=net_io_names["summary"], log_dir="log", outputs=outputs, gradients={}, log_every=5, save_every=500, array_specs=array_specs_pred, ) + Snapshot( snapshot, every=500, output_filename="batch_{iteration}.hdf", output_dir="snapshots/", additional_request=snapshot_request, ) + PrintProfilingStats(every=10) ) print("Starting training...") with build(train_pipeline) as b: for i in range(max_iteration): start_it = time.time() b.request_batch(request) time_it = time.time() - start_it logging.info("it {0:}: {1:}".format(i + 1, time_it)) print("Training finished")
def train_until( max_iteration: int, gt_version: str, labels: List[CNNectome.utils.label.Label], net_name: str, input_shape: Union[np.ndarray, List[int]], output_shape: Union[np.ndarray, List[int]], loss_name: str, balance_global: bool = False, data_dir: Optional[str] = None, prioritized_label: Optional[CNNectome.utils.label.Label] = None, dataset: Optional[str] = None, prob_prioritized: float = 0.5, completion_min: int = 6, dt_scaling_factor: int = 50, cache_size: int = 5, num_workers: int = 10, min_masked_voxels: Union[float, int] = 17561., voxel_size_labels: Coordinate = Coordinate((2, 2, 2)), voxel_size: Coordinate = Coordinate((4, 4, 4)), voxel_size_input: Coordinate = Coordinate((4, 4, 4)) ): """ Training a tensorflow network to learn signed distance transforms of specified labels (organelles) using gunpowder. Training data is read from crops whose metadata are organized in a database. Args: max_iteration: Total number of iterations that network should be trained for. gt_version: Version of groundtruth annotations, e.g. "v0003". labels: List of labels that the network needs to be trained for. net_name: Filename of tensorflow meta graph definition. input_shape: Input shape of network. output_shape: Output shape of network. loss_name: Name of loss used as stored in net io names json file. balance_global: If Ture, use globabl balancing, i.e. weigh loss for each label using its `frac_pos` and `frac_neg` attributes. data_dir: Path to directory where data is stored. If None, read from config file. prioritized_label: Label to use for prioritizing sampling from crops that contain examples of it. If None (default), sample from each crop equally. dataset: Only consider crops that come from the specified dataset. If None (default), use all othwerwise eligible training data. prob_prioritized: If `prioritized_label` is not None, this is the probability with which to sample from the crops containing the label. Default is .5, which implies sampling equally from crops containing the labels and all others. completion_min: Minimal completion status for a crop from the database to be added to the training. dt_scaling_factor: Scaling factor to divide distance transform by before applying nonlinearity tanh. cache_size: Cache size for queue grabbing batches. num_workers: Number of workers grabbing batches. min_masked_voxels: Minimum number of voxels in a batch that need to be part of the groundtruth annotation. voxel_size_labels: Voxel size of the annotated labels. voxel_size: Voxel size of the desired output. voxel_size_input: Voxel size of the raw input data. """ keep_thr = float(min_masked_voxels) / np.prod(output_shape) one_vx_thr = 1. / np.prod(output_shape) max_distance = 2.76 * dt_scaling_factor ak_raw = ArrayKey("RAW") ak_labels = ArrayKey("GT_LABELS") ak_labels_downsampled = ArrayKey("GT_LABELS_DOWNSAMPLED") ak_mask = ArrayKey("MASK") ak_labelmasks_comb = ArrayKey("LABELMASKS_COMBINED") input_size = Coordinate(input_shape) * voxel_size_input output_size = Coordinate(output_shape) * voxel_size crop_width = Coordinate((max_distance,) * len(voxel_size_labels)) crop_width = crop_width//voxel_size if crop_width == 0: crop_width *= voxel_size else: crop_width = (crop_width+(1,)*len(crop_width)) * voxel_size # crop_width = crop_width # (Coordinate((max_distance,) * len(voxel_size_labels))/2 ) db = CNNectome.utils.cosem_db.MongoCosemDB(gt_version=gt_version) collection = db.access("crops", db.gt_version) db_filter = {"completion": {"$gte": completion_min}} if dataset is not None: db_filter['dataset_id'] = dataset skip = {"_id": 0, "number": 1, "labels": 1, "dataset_id": 1, "parent":1, "dimensions": 1} net_io_names, start_iteration, inputs, outputs = _network_setup(max_iteration, ak_raw, ak_mask, labels) # construct batch request request = BatchRequest() request.add(ak_labels, output_size, voxel_size=voxel_size_labels) request.add(ak_labels_downsampled, output_size, voxel_size=voxel_size) request.add(ak_mask, output_size, voxel_size=voxel_size) request.add(ak_labelmasks_comb, output_size, voxel_size=voxel_size) request.add(ak_raw, input_size, voxel_size=voxel_size_input) for label in labels: if label.separate_labelset: request.add(label.gt_key, output_size, voxel_size=voxel_size_labels) request.add(label.gt_dist_key, output_size, voxel_size=voxel_size) request.add(label.pred_dist_key, output_size, voxel_size=voxel_size) request.add(label.mask_key, output_size, voxel_size=voxel_size) if label.scale_loss: request.add(label.scale_key, output_size, voxel_size=voxel_size) # specify specs for output array_specs_pred = dict() for label in labels: array_specs_pred[label.pred_dist_key] = ArraySpec(voxel_size=voxel_size, interpolatable=True) # specify snapshot data layout snapshot_data = dict() snapshot_data[ak_raw] = "volumes/raw" snapshot_data[ak_mask] = "volumes/masks/all" if len(_label_filter(lambda l: not l.separate_labelset, labels)) > 0: snapshot_data[ak_labels] = "volumes/labels/gt_labels" for label in _label_filter(lambda l: l.separate_labelset, labels): snapshot_data[label.gt_key] = "volumes/labels/gt_"+label.labelname for label in labels: snapshot_data[label.gt_dist_key] = "volumes/labels/gt_dist_" + label.labelname snapshot_data[label.pred_dist_key] = "volumes/labels/pred_dist_" + label.labelname snapshot_data[label.mask_key] = "volumes/masks/" + label.labelname # specify snapshot request snapshot_request = BatchRequest() crop_srcs = [] crop_sizes = [] if prioritized_label is not None: crop_prioritized_label_indicator = [] for crop in collection.find(db_filter, skip): if len(set(get_all_annotated_label_ids(crop)).intersection(set(get_all_labelids(labels)))) > 0: logging.info("Adding crop number {0:}".format(crop["number"])) if voxel_size_input != voxel_size: for subsample_variant in range(int(np.prod(voxel_size_input/voxel_size))): crop_srcs.append( _make_crop_source(crop, data_dir, subsample_variant, gt_version, labels, ak_raw, ak_labels, ak_labels_downsampled, ak_mask, input_size, output_size, voxel_size_input, voxel_size, crop_width, keep_thr)) crop_sizes.append(get_crop_size(crop)) if prioritized_label is not None: crop_prioritized = is_prioritized(crop, prioritized_label) logging.info(f"Crop {crop['number']} is {'not ' if not crop_prioritized else ''}prioritized") crop_prioritized_label_indicator.extend( [crop_prioritized] * int(np.prod(voxel_size_input/voxel_size)) ) else: crop_srcs.append(_make_crop_source(crop, data_dir, None, gt_version, labels, ak_raw, ak_labels, ak_labels_downsampled, ak_mask, input_size, output_size, voxel_size_input, voxel_size, crop_width, keep_thr)) crop_sizes.append(get_crop_size(crop)) if prioritized_label is not None: crop_prioritized = is_prioritized(crop, prioritized_label) logging.info(f"Crop {crop['number']} is {'not ' if not crop_prioritized else ''}prioritized") crop_prioritized_label_indicator.append(crop_prioritized) if prioritized_label is not None: sampling_probs = prioritized_sampling_probabilities( crop_sizes, crop_prioritized_label_indicator, prob_prioritized ) else: sampling_probs = crop_sizes print(sampling_probs) pipeline = (tuple(crop_srcs) + RandomProvider(sampling_probs) ) pipeline += Normalize(ak_raw, 1.0/255) pipeline += IntensityCrop(ak_raw, 0., 1.) # augmentations pipeline = (pipeline + fuse.SimpleAugment() + fuse.ElasticAugment(voxel_size, (100, 100, 100), (10., 10., 10.), (0, math.pi / 2.), spatial_dims=3, subsample=8 ) + fuse.IntensityAugment(ak_raw, 0.25, 1.75, -0.5, 0.35) + GammaAugment(ak_raw, 0.5, 2.) ) pipeline += IntensityScaleShift(ak_raw, 2, -1) # label generation for label in labels: pipeline += AddDistance( label_array_key=label.gt_key, distance_array_key=label.gt_dist_key, mask_array_key=label.mask_key, add_constant=label.add_constant, label_id=label.labelid, factor=2, max_distance=max_distance, ) # combine distances for centrosomes centrosome = _get_label("centrosome", labels) microtubules = _get_label("microtubules", labels) microtubules_out = _get_label("microtubules_out", labels) subdistal_app = _get_label("subdistal_app", labels) distal_app = _get_label("distal_app", labels) # add the centrosomes to the microtubules if microtubules_out is not None and centrosome is not None: pipeline += CombineDistances( (microtubules_out.gt_dist_key, centrosome.gt_dist_key), microtubules_out.gt_dist_key, (microtubules_out.mask_key, centrosome.mask_key), microtubules_out.mask_key ) if microtubules is not None and centrosome is not None: pipeline += CombineDistances( (microtubules.gt_dist_key, centrosome.gt_dist_key), microtubules.gt_dist_key, (microtubules.mask_key, centrosome.mask_key), microtubules.mask_key ) # add the distal_app and subdistal_app to the centrosomes if centrosome is not None and distal_app is not None and subdistal_app is not None: pipeline += CombineDistances( (distal_app.gt_dist_key, subdistal_app.gt_dist_key, centrosome.gt_dist_key), centrosome.gt_dist_key, (distal_app.mask_key, subdistal_app.mask_key, centrosome.mask_key), centrosome.mask_key ) arrays_that_need_to_be_cropped = [] for label in labels: arrays_that_need_to_be_cropped.append(label.gt_key) arrays_that_need_to_be_cropped.append(label.gt_dist_key) arrays_that_need_to_be_cropped.append(label.mask_key) arrays_that_need_to_be_cropped.append(ak_labels) arrays_that_need_to_be_cropped.append(ak_labels_downsampled) arrays_that_need_to_be_cropped = list(set(arrays_that_need_to_be_cropped)) for ak in arrays_that_need_to_be_cropped: pipeline += CropArray(ak, crop_width, crop_width) for label in labels: pipeline += TanhSaturate(label.gt_dist_key, dt_scaling_factor) for label in _label_filter(lambda l: l.scale_loss, labels): if balance_global: pipeline += BalanceGlobalByThreshold( label.gt_dist_key, label.scale_key, label.frac_pos, label.frac_neg ) else: pipeline += BalanceByThreshold( label.gt_dist_key, label.scale_key, mask=(label.mask_key, ak_mask) ) pipeline += Sum([l.mask_key for l in labels], ak_labelmasks_comb, sum_array_spec=ArraySpec( dtype=np.uint8, interpolatable=False)) pipeline += Reject(ak_labelmasks_comb, min_masked=one_vx_thr) pipeline = (pipeline + PreCache(cache_size=cache_size, num_workers=num_workers) + Train(net_name, optimizer=net_io_names["optimizer"], loss=net_io_names[loss_name], inputs=inputs, summary=net_io_names["summary"], log_dir="log", outputs=outputs, gradients={}, log_every=10, save_every=500, array_specs=array_specs_pred, ) + Snapshot(snapshot_data, every=500, output_filename="batch_{iteration}.hdf", output_dir="snapshots/", additional_request=snapshot_request, ) + PrintProfilingStats(every=50) ) logging.info("Starting training...") with build(pipeline) as pp: for i in range(start_iteration, max_iteration+1): start_it = time.time() pp.request_batch(request) time_it = time.time() - start_it logging.info("it{0:}: {1:}".format(i+1, time_it)) logging.info("Training finished")