def train_until( max_iteration, data_sources, ribo_sources, input_shape, output_shape, dt_scaling_factor, loss_name, labels, net_name, min_masked_voxels=17561.0, mask_ds_name="volumes/masks/training_cropped", ): with open("net_io_names.json", "r") as f: net_io_names = json.load(f) ArrayKey("RAW") ArrayKey("ALPHA_MASK") ArrayKey("GT_LABELS") ArrayKey("MASK") ArrayKey("RIBO_GT") voxel_size_up = Coordinate((2, 2, 2)) voxel_size_orig = Coordinate((4, 4, 4)) input_size = Coordinate(input_shape) * voxel_size_orig output_size = Coordinate(output_shape) * voxel_size_orig # context = input_size-output_size keep_thr = float(min_masked_voxels) / np.prod(output_shape) data_providers = [] inputs = dict() outputs = dict() snapshot = dict() request = BatchRequest() snapshot_request = BatchRequest() datasets_ribo = { ArrayKeys.RAW: "volumes/raw/data/s0", ArrayKeys.GT_LABELS: "volumes/labels/all", ArrayKeys.MASK: mask_ds_name, ArrayKeys.RIBO_GT: "volumes/labels/ribosomes", } # for datasets without ribosome annotations volumes/labels/ribosomes doesn't exist, so use volumes/labels/all # instead (only one with the right resolution) datasets_no_ribo = { ArrayKeys.RAW: "volumes/raw/data/s0", ArrayKeys.GT_LABELS: "volumes/labels/all", ArrayKeys.MASK: mask_ds_name, ArrayKeys.RIBO_GT: "volumes/labels/all", } array_specs = { ArrayKeys.MASK: ArraySpec(interpolatable=False), ArrayKeys.RAW: ArraySpec(voxel_size=Coordinate(voxel_size_orig)), } array_specs_pred = {} inputs[net_io_names["raw"]] = ArrayKeys.RAW snapshot[ArrayKeys.RAW] = "volumes/raw" snapshot[ArrayKeys.GT_LABELS] = "volumes/labels/gt_labels" request.add(ArrayKeys.GT_LABELS, output_size, voxel_size=voxel_size_up) request.add(ArrayKeys.MASK, output_size, voxel_size=voxel_size_orig) request.add(ArrayKeys.RIBO_GT, output_size, voxel_size=voxel_size_up) request.add(ArrayKeys.RAW, input_size, voxel_size=voxel_size_orig) for label in labels: datasets_no_ribo[label.mask_key] = "volumes/masks/" + label.labelname datasets_ribo[label.mask_key] = "volumes/masks/" + label.labelname array_specs[label.mask_key] = ArraySpec(interpolatable=False) array_specs_pred[label.pred_dist_key] = ArraySpec( voxel_size=voxel_size_orig, interpolatable=True) inputs[net_io_names["mask_" + label.labelname]] = label.mask_key inputs[net_io_names["gt_" + label.labelname]] = label.gt_dist_key if label.scale_loss or label.scale_key is not None: inputs[net_io_names["w_" + label.labelname]] = label.scale_key outputs[net_io_names[label.labelname]] = label.pred_dist_key snapshot[ label.gt_dist_key] = "volumes/labels/gt_dist_" + label.labelname snapshot[label. pred_dist_key] = "volumes/labels/pred_dist_" + label.labelname request.add(label.gt_dist_key, output_size, voxel_size=voxel_size_orig) request.add(label.pred_dist_key, output_size, voxel_size=voxel_size_orig) request.add(label.mask_key, output_size, voxel_size=voxel_size_orig) if label.scale_loss: request.add(label.scale_key, output_size, voxel_size=voxel_size_orig) snapshot_request.add(label.pred_dist_key, output_size, voxel_size=voxel_size_orig) if tf.train.latest_checkpoint("."): trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1]) print("Resuming training from", trained_until) else: trained_until = 0 print("Starting fresh training") for src in data_sources: if src not in ribo_sources: n5_source = N5Source(src.full_path, datasets=datasets_no_ribo, array_specs=array_specs) else: n5_source = N5Source(src.full_path, datasets=datasets_ribo, array_specs=array_specs) data_providers.append(n5_source) # create a tuple of data sources, one for each HDF file data_stream = tuple( provider + Normalize(ArrayKeys.RAW) + # ensures RAW is in float in [0, 1] # zero-pad provided RAW and MASK to be able to draw batches close to # the boundary of the available data # size more or less irrelevant as followed by Reject Node # Pad(ArrayKeys.RAW, context) + RandomLocation() + RejectEfficiently( # chose a random location inside the provided arrays ArrayKeys.MASK, min_masked=keep_thr) # Reject(ArrayKeys.MASK) # reject batches wich do contain less than 50% labelled data for provider in data_providers) train_pipeline = ( data_stream + RandomProvider(tuple([ds.labeled_voxels for ds in data_sources])) + gpn.SimpleAugment() + gpn.ElasticAugment( voxel_size_orig, (100, 100, 100), (10.0, 10.0, 10.0), (0, math.pi / 2.0), spatial_dims=3, subsample=8, ) + # ElasticAugment((40, 1000, 1000), (10., 0., 0.), (0, 0), subsample=8) + gpn.IntensityAugment(ArrayKeys.RAW, 0.25, 1.75, -0.5, 0.35) + GammaAugment(ArrayKeys.RAW, 0.5, 2.0) + IntensityScaleShift(ArrayKeys.RAW, 2, -1)) # ZeroOutConstSections(ArrayKeys.RAW)) for label in labels: if label.labelname != "ribosomes": train_pipeline += AddDistance( label_array_key=ArrayKeys.GT_LABELS, distance_array_key=label.gt_dist_key, normalize="tanh", normalize_args=dt_scaling_factor, label_id=label.labelid, factor=2, ) else: train_pipeline += AddDistance( label_array_key=ArrayKeys.RIBO_GT, distance_array_key=label.gt_dist_key, normalize="tanh+", normalize_args=(dt_scaling_factor, 8), label_id=label.labelid, factor=2, ) for label in labels: if label.scale_loss: train_pipeline += BalanceByThreshold(label.gt_dist_key, label.scale_key, mask=label.mask_key) train_pipeline = (train_pipeline + PreCache(cache_size=30, num_workers=30) + Train( net_name, optimizer=net_io_names["optimizer"], loss=net_io_names[loss_name], inputs=inputs, summary=net_io_names["summary"], log_dir="log", outputs=outputs, gradients={}, log_every=5, save_every=500, array_specs=array_specs_pred, ) + Snapshot( snapshot, every=500, output_filename="batch_{iteration}.hdf", output_dir="snapshots/", additional_request=snapshot_request, ) + PrintProfilingStats(every=500)) print("Starting training...") with build(train_pipeline) as b: for i in range(max_iteration): start_it = time.time() b.request_batch(request) time_it = time.time() - start_it logging.info("it {0:}: {1:}".format(i + 1, time_it)) print("Training finished")
def train_until( max_iteration, data_sources, input_shape, output_shape, dt_scaling_factor, loss_name, cremi_version, aligned, ): ArrayKey("RAW") ArrayKey("ALPHA_MASK") ArrayKey("GT_LABELS") ArrayKey("GT_CLEFTS") ArrayKey("GT_MASK") ArrayKey("TRAINING_MASK") ArrayKey("CLEFT_SCALE") ArrayKey("PRE_SCALE") ArrayKey("POST_SCALE") ArrayKey("LOSS_GRADIENT") ArrayKey("GT_CLEFT_DIST") ArrayKey("PRED_CLEFT_DIST") ArrayKey("GT_PRE_DIST") ArrayKey("PRED_PRE_DIST") ArrayKey("GT_POST_DIST") ArrayKey("PRED_POST_DIST") ArrayKey("GT_POST_DIST") data_providers = [] if cremi_version == "2016": cremi_dir = "/groups/saalfeld/saalfeldlab/larissa/data/cremi-2016/" filename = "sample_{0:}_padded_20160501." elif cremi_version == "2017": cremi_dir = "/groups/saalfeld/saalfeldlab/larissa/data/cremi-2017/" filename = "sample_{0:}_padded_20170424." if aligned: filename += "aligned." filename += "0bg.hdf" if tf.train.latest_checkpoint("."): trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1]) print("Resuming training from", trained_until) else: trained_until = 0 print("Starting fresh training") for sample in data_sources: print(sample) h5_source = Hdf5Source( os.path.join(cremi_dir, filename.format(sample)), datasets={ ArrayKeys.RAW: "volumes/raw", ArrayKeys.GT_CLEFTS: "volumes/labels/clefts", ArrayKeys.GT_MASK: "volumes/masks/groundtruth", ArrayKeys.TRAINING_MASK: "volumes/masks/validation", ArrayKeys.GT_LABELS: "volumes/labels/neuron_ids", }, array_specs={ ArrayKeys.GT_MASK: ArraySpec(interpolatable=False), ArrayKeys.GT_CLEFTS: ArraySpec(interpolatable=False), ArrayKeys.TRAINING_MASK: ArraySpec(interpolatable=False), }, ) data_providers.append(h5_source) if cremi_version == "2017": csv_files = [ os.path.join(cremi_dir, "cleft-partners_" + sample + "_2017.csv") for sample in data_sources ] elif cremi_version == "2016": csv_files = [ os.path.join( cremi_dir, "cleft-partners-" + sample + "-20160501.aligned.corrected.csv", ) for sample in data_sources ] cleft_to_pre, cleft_to_post = make_cleft_to_prepostsyn_neuron_id_dict( csv_files) print(cleft_to_pre, cleft_to_post) with open("net_io_names.json", "r") as f: net_io_names = json.load(f) voxel_size = Coordinate((40, 4, 4)) input_size = Coordinate(input_shape) * voxel_size output_size = Coordinate(output_shape) * voxel_size context = input_size - output_size # specifiy which Arrays should be requested for each batch request = BatchRequest() request.add(ArrayKeys.RAW, input_size) request.add(ArrayKeys.GT_LABELS, output_size) request.add(ArrayKeys.GT_CLEFTS, output_size) request.add(ArrayKeys.GT_MASK, output_size) request.add(ArrayKeys.TRAINING_MASK, output_size) request.add(ArrayKeys.CLEFT_SCALE, output_size) request.add(ArrayKeys.GT_CLEFT_DIST, output_size) request.add(ArrayKeys.GT_PRE_DIST, output_size) request.add(ArrayKeys.GT_POST_DIST, output_size) # create a tuple of data sources, one for each HDF file data_sources = tuple( provider + Normalize(ArrayKeys.RAW) + IntensityScaleShift( # ensures RAW is in float in [0, 1] ArrayKeys.TRAINING_MASK, -1, 1) + # zero-pad provided RAW and GT_MASK to be able to draw batches close to # the boundary of the available data # size more or less irrelevant as followed by Reject Node Pad(ArrayKeys.RAW, None) + Pad(ArrayKeys.GT_MASK, None) + Pad(ArrayKeys.TRAINING_MASK, context) + RandomLocation(min_masked=0.99, mask=ArrayKeys.TRAINING_MASK) + Reject( # chose a random location inside the provided arrays ArrayKeys.GT_MASK) + Reject( # reject batches which do contain less than 50% labelled data ArrayKeys.GT_CLEFTS, min_masked=0.0, reject_probability=0.95) for provider in data_providers) snapshot_request = BatchRequest({ ArrayKeys.LOSS_GRADIENT: request[ArrayKeys.GT_CLEFTS], ArrayKeys.PRED_CLEFT_DIST: request[ArrayKeys.GT_CLEFT_DIST], ArrayKeys.PRED_PRE_DIST: request[ArrayKeys.GT_PRE_DIST], ArrayKeys.PRED_POST_DIST: request[ArrayKeys.GT_POST_DIST], }) artifact_source = ( Hdf5Source( os.path.join(cremi_dir, "sample_ABC_padded_20160501.defects.hdf"), datasets={ ArrayKeys.RAW: "defect_sections/raw", ArrayKeys.ALPHA_MASK: "defect_sections/mask", }, array_specs={ ArrayKeys.RAW: ArraySpec(voxel_size=(40, 4, 4)), ArrayKeys.ALPHA_MASK: ArraySpec(voxel_size=(40, 4, 4)), }, ) + RandomLocation(min_masked=0.05, mask=ArrayKeys.ALPHA_MASK) + Normalize(ArrayKeys.RAW) + IntensityAugment( ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) + ElasticAugment((4, 40, 40), (0, 2, 2), (0, math.pi / 2.0), subsample=8) + SimpleAugment(transpose_only=[1, 2], mirror_only=[1, 2])) train_pipeline = ( data_sources + RandomProvider() + SimpleAugment( transpose_only=[1, 2], mirror_only=[1, 2]) + gpn.ElasticAugment( (40, 4, 4), (4, 40, 40), (0.0, 2.0, 2.0), (0, math.pi / 2.0), spatial_dims=3, subsample=8, ) + gpn.Misalign( 40, prob_slip=0.05, prob_shift=0.05, max_misalign=10, ignore_keys_for_slip=( ArrayKeys.GT_CLEFTS, ArrayKeys.GT_MASK, ArrayKeys.TRAINING_MASK, ArrayKeys.GT_LABELS, ), ) + IntensityAugment( ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) + DefectAugment( ArrayKeys.RAW, prob_missing=0.03, prob_low_contrast=0.01, prob_artifact=0.03, artifact_source=artifact_source, artifacts=ArrayKeys.RAW, artifacts_mask=ArrayKeys.ALPHA_MASK, contrast_scale=0.5, ) + IntensityScaleShift(ArrayKeys.RAW, 2, -1) + ZeroOutConstSections(ArrayKeys.RAW) + AddDistance( label_array_key=ArrayKeys.GT_CLEFTS, distance_array_key=ArrayKeys.GT_CLEFT_DIST, normalize="tanh", normalize_args=dt_scaling_factor, ) + AddPrePostCleftDistance( ArrayKeys.GT_CLEFTS, ArrayKeys.GT_LABELS, ArrayKeys.GT_PRE_DIST, ArrayKeys.GT_POST_DIST, cleft_to_pre, cleft_to_post, normalize="tanh", normalize_args=dt_scaling_factor, include_cleft=False, ) + BalanceByThreshold( labels=ArrayKeys.GT_CLEFT_DIST, scales=ArrayKeys.CLEFT_SCALE, mask=ArrayKeys.GT_MASK, ) + BalanceByThreshold( labels=ArrayKeys.GT_PRE_DIST, scales=ArrayKeys.PRE_SCALE, mask=ArrayKeys.GT_MASK, threshold=-0.5, ) + BalanceByThreshold( labels=ArrayKeys.GT_POST_DIST, scales=ArrayKeys.POST_SCALE, mask=ArrayKeys.GT_MASK, threshold=-0.5, ) + PreCache(cache_size=40, num_workers=10) + Train( "unet", optimizer=net_io_names["optimizer"], loss=net_io_names[loss_name], inputs={ net_io_names["raw"]: ArrayKeys.RAW, net_io_names["gt_cleft_dist"]: ArrayKeys.GT_CLEFT_DIST, net_io_names["gt_pre_dist"]: ArrayKeys.GT_PRE_DIST, net_io_names["gt_post_dist"]: ArrayKeys.GT_POST_DIST, net_io_names["loss_weights_cleft"]: ArrayKeys.CLEFT_SCALE, net_io_names["loss_weights_pre"]: ArrayKeys.CLEFT_SCALE, net_io_names["loss_weights_post"]: ArrayKeys.CLEFT_SCALE, net_io_names["mask"]: ArrayKeys.GT_MASK, }, summary=net_io_names["summary"], log_dir="log", outputs={ net_io_names["cleft_dist"]: ArrayKeys.PRED_CLEFT_DIST, net_io_names["pre_dist"]: ArrayKeys.PRED_PRE_DIST, net_io_names["post_dist"]: ArrayKeys.PRED_POST_DIST, }, gradients={net_io_names["cleft_dist"]: ArrayKeys.LOSS_GRADIENT}, ) + Snapshot( { ArrayKeys.RAW: "volumes/raw", ArrayKeys.GT_CLEFTS: "volumes/labels/gt_clefts", ArrayKeys.GT_CLEFT_DIST: "volumes/labels/gt_clefts_dist", ArrayKeys.PRED_CLEFT_DIST: "volumes/labels/pred_clefts_dist", ArrayKeys.LOSS_GRADIENT: "volumes/loss_gradient", ArrayKeys.PRED_PRE_DIST: "volumes/labels/pred_pre_dist", ArrayKeys.PRED_POST_DIST: "volumes/labels/pred_post_dist", ArrayKeys.GT_PRE_DIST: "volumes/labels/gt_pre_dist", ArrayKeys.GT_POST_DIST: "volumes/labels/gt_post_dist", }, every=500, output_filename="batch_{iteration}.hdf", output_dir="snapshots/", additional_request=snapshot_request, ) + PrintProfilingStats(every=50)) print("Starting training...") with build(train_pipeline) as b: for i in range(max_iteration): b.request_batch(request) print("Training finished")
def train_until(max_iteration, data_sources, ribo_sources, input_shape, output_shape, \ dt_scaling_factor, loss_name, labels, net_name, min_masked_voxels=17561., mask_ds_name='volumes/masks/training'): with open('net_io_names.json', 'r') as f: net_io_names = json.load(f) ArrayKey('RAW') ArrayKey('ALPHA_MASK') ArrayKey('GT_LABELS') ArrayKey('MASK') ArrayKey('RIBO_GT') voxel_size_up = Coordinate((2, 2, 2)) voxel_size_input = Coordinate((8, 8, 8)) voxel_size_output = Coordinate((4, 4, 4)) input_size = Coordinate(input_shape) * voxel_size_input output_size = Coordinate(output_shape) * voxel_size_output # context = input_size-output_size keep_thr = float(min_masked_voxels) / np.prod(output_shape) data_providers = [] inputs = dict() outputs = dict() snapshot = dict() request = BatchRequest() snapshot_request = BatchRequest() datasets_ribo = { ArrayKeys.RAW: None, ArrayKeys.GT_LABELS: 'volumes/labels/all', ArrayKeys.MASK: mask_ds_name, ArrayKeys.RIBO_GT: 'volumes/labels/ribosomes', } # for datasets without ribosome annotations volumes/labels/ribosomes doesn't exist, so use volumes/labels/all # instead (only one with the right resolution) datasets_no_ribo = { ArrayKeys.RAW: None, ArrayKeys.GT_LABELS: 'volumes/labels/all', ArrayKeys.MASK: mask_ds_name, ArrayKeys.RIBO_GT: 'volumes/labels/all', } array_specs = { ArrayKeys.MASK: ArraySpec(interpolatable=False), ArrayKeys.RAW: ArraySpec(voxel_size=Coordinate(voxel_size_input)) } array_specs_pred = {} inputs[net_io_names['raw']] = ArrayKeys.RAW snapshot[ArrayKeys.RAW] = 'volumes/raw' snapshot[ArrayKeys.GT_LABELS] = 'volumes/labels/gt_labels' request.add(ArrayKeys.GT_LABELS, output_size, voxel_size=voxel_size_up) request.add(ArrayKeys.MASK, output_size, voxel_size=voxel_size_output) request.add(ArrayKeys.RIBO_GT, output_size, voxel_size=voxel_size_up) request.add(ArrayKeys.RAW, input_size, voxel_size=voxel_size_input) for label in labels: datasets_no_ribo[label.mask_key] = 'volumes/masks/' + label.labelname datasets_ribo[label.mask_key] = 'volumes/masks/' + label.labelname array_specs[label.mask_key] = ArraySpec(interpolatable=False) array_specs_pred[label.pred_dist_key] = ArraySpec( voxel_size=voxel_size_output, interpolatable=True) inputs[net_io_names['mask_' + label.labelname]] = label.mask_key inputs[net_io_names['gt_' + label.labelname]] = label.gt_dist_key if label.scale_loss or label.scale_key is not None: inputs[net_io_names['w_' + label.labelname]] = label.scale_key outputs[net_io_names[label.labelname]] = label.pred_dist_key snapshot[ label.gt_dist_key] = 'volumes/labels/gt_dist_' + label.labelname snapshot[label. pred_dist_key] = 'volumes/labels/pred_dist_' + label.labelname request.add(label.gt_dist_key, output_size, voxel_size=voxel_size_output) request.add(label.pred_dist_key, output_size, voxel_size=voxel_size_output) request.add(label.mask_key, output_size, voxel_size=voxel_size_output) if label.scale_loss: request.add(label.scale_key, output_size, voxel_size=voxel_size_output) snapshot_request.add(label.pred_dist_key, output_size, voxel_size=voxel_size_output) if tf.train.latest_checkpoint('.'): trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1]) print('Resuming training from', trained_until) else: trained_until = 0 print('Starting fresh training') for src in data_sources: for subsample_variant in range(8): dnr = datasets_no_ribo.copy() dr = datasets_ribo.copy() dnr[ArrayKeys.RAW] = 'volumes/subsampled/raw{0:}/'.format( subsample_variant) dr[ArrayKeys.RAW] = 'volumes/subsampled/raw{0:}/'.format( subsample_variant) if src not in ribo_sources: n5_source = N5Source(src.full_path, datasets=dnr, array_specs=array_specs) else: n5_source = N5Source(src.full_path, datasets=dr, array_specs=array_specs) data_providers.append(n5_source) # create a tuple of data sources, one for each HDF file data_stream = tuple( provider + Normalize(ArrayKeys.RAW) + # ensures RAW is in float in [0, 1] # zero-pad provided RAW and MASK to be able to draw batches close to # the boundary of the available data # size more or less irrelevant as followed by Reject Node # Pad(ArrayKeys.RAW, context) + RandomLocation() + # chose a random location inside the provided arrays Reject(ArrayKeys.MASK, min_masked=keep_thr) for provider in data_providers) train_pipeline = ( data_stream + RandomProvider( tuple(np.repeat([ds.labeled_voxels for ds in data_sources], 8))) + gpn.SimpleAugment() + gpn.ElasticAugment(voxel_size_output, (100, 100, 100), (10., 10., 10.), (0, math.pi / 2.0), spatial_dims=3, subsample=8) + gpn.IntensityAugment(ArrayKeys.RAW, 0.25, 1.75, -0.5, 0.35) + GammaAugment(ArrayKeys.RAW, 0.5, 2.0) + IntensityScaleShift(ArrayKeys.RAW, 2, -1)) for label in labels: if label.labelname != 'ribosomes': train_pipeline += AddDistance(label_array_key=ArrayKeys.GT_LABELS, distance_array_key=label.gt_dist_key, normalize='tanh', normalize_args=dt_scaling_factor, label_id=label.labelid, factor=2) else: train_pipeline += AddDistance(label_array_key=ArrayKeys.RIBO_GT, distance_array_key=label.gt_dist_key, normalize='tanh+', normalize_args=(dt_scaling_factor, 8), label_id=label.labelid, factor=2) for label in labels: if label.scale_loss: train_pipeline += BalanceByThreshold(label.gt_dist_key, label.scale_key, mask=label.mask_key) train_pipeline = (train_pipeline + PreCache(cache_size=10, num_workers=20) + Train(net_name, optimizer=net_io_names['optimizer'], loss=net_io_names[loss_name], inputs=inputs, summary=net_io_names['summary'], log_dir='log', outputs=outputs, gradients={}, log_every=5, save_every=500, array_specs=array_specs_pred) + Snapshot(snapshot, every=500, output_filename='batch_{iteration}.hdf', output_dir='snapshots/', additional_request=snapshot_request) + PrintProfilingStats(every=500)) print("Starting training...") with build(train_pipeline) as b: for i in range(max_iteration): start_it = time.time() b.request_batch(request) time_it = time.time() - start_it logging.info('it {0:}: {1:}'.format(i + 1, time_it)) print("Training finished")
def train_until( max_iteration, data_sources, ribo_sources, dt_scaling_factor, loss_name, labels, scnet, raw_name="raw", min_masked_voxels=17561.0, ): with open("net_io_names.json", "r") as f: net_io_names = json.load(f) ArrayKey("ALPHA_MASK") ArrayKey("GT_LABELS") ArrayKey("MASK") ArrayKey("RIBO_GT") datasets_ribo = { ArrayKeys.GT_LABELS: "volumes/labels/all", ArrayKeys.MASK: "volumes/masks/training_cropped", ArrayKeys.RIBO_GT: "volumes/labels/ribosomes", } # for datasets without ribosome annotations volumes/labels/ribosomes doesn't exist, so use volumes/labels/all # instead (only one with the right resolution) datasets_no_ribo = { ArrayKeys.GT_LABELS: "volumes/labels/all", ArrayKeys.MASK: "volumes/masks/training_cropped", ArrayKeys.RIBO_GT: "volumes/labels/all", } array_specs = {ArrayKeys.MASK: ArraySpec(interpolatable=False)} array_specs_pred = {} # individual mask per label for label in labels: datasets_no_ribo[label.mask_key] = "volumes/masks/" + label.labelname datasets_ribo[label.mask_key] = "volumes/masks/" + label.labelname array_specs[label.mask_key] = ArraySpec(interpolatable=False) # inputs = {net_io_names['mask']: ArrayKeys.MASK} snapshot = {ArrayKeys.GT_LABELS: "volumes/labels/all"} request = BatchRequest() snapshot_request = BatchRequest() raw_array_keys = [] contexts = [] # input and output sizes in world coordinates input_sizes_wc = [ Coordinate(inp_sh) * Coordinate(vs) for inp_sh, vs in zip(scnet.input_shapes, scnet.voxel_sizes) ] output_size_wc = Coordinate(scnet.output_shapes[0]) * Coordinate( scnet.voxel_sizes[0] ) keep_thr = float(min_masked_voxels) / np.prod(scnet.output_shapes[0]) voxel_size_up = Coordinate((2, 2, 2)) voxel_size_orig = Coordinate((4, 4, 4)) assert voxel_size_orig == Coordinate( scnet.voxel_sizes[0] ) # make sure that scnet has the same base voxel size inputs = {} # add multiscale raw data as inputs for k, (inp_sh_wc, vs) in enumerate(zip(input_sizes_wc, scnet.voxel_sizes)): ak = ArrayKey("RAW_S{0:}".format(k)) raw_array_keys.append(ak) datasets_ribo[ak] = "volumes/{0:}/data/s{1:}".format(raw_name, k) datasets_no_ribo[ak] = "volumes/{0:}/data/s{1:}".format(raw_name, k) inputs[net_io_names["raw_{0:}".format(vs[0])]] = ak snapshot[ak] = "volumes/raw_s{0:}".format(k) array_specs[ak] = ArraySpec(voxel_size=Coordinate(vs)) request.add(ak, inp_sh_wc, voxel_size=Coordinate(vs)) contexts.append(inp_sh_wc - output_size_wc) outputs = dict() for label in labels: inputs[net_io_names["gt_" + label.labelname]] = label.gt_dist_key if label.scale_loss or label.scale_key is not None: inputs[net_io_names["w_" + label.labelname]] = label.scale_key inputs[net_io_names["mask_" + label.labelname]] = label.mask_key outputs[net_io_names[label.labelname]] = label.pred_dist_key snapshot[label.gt_dist_key] = "volumes/labels/gt_dist_" + label.labelname snapshot[label.pred_dist_key] = "volumes/labels/pred_dist_" + label.labelname array_specs_pred[label.pred_dist_key] = ArraySpec( voxel_size=voxel_size_orig, interpolatable=True ) request.add(ArrayKeys.GT_LABELS, output_size_wc, voxel_size=voxel_size_up) request.add(ArrayKeys.MASK, output_size_wc, voxel_size=voxel_size_orig) request.add(ArrayKeys.RIBO_GT, output_size_wc, voxel_size=voxel_size_up) for label in labels: request.add(label.gt_dist_key, output_size_wc, voxel_size=voxel_size_orig) snapshot_request.add( label.pred_dist_key, output_size_wc, voxel_size=voxel_size_orig ) request.add(label.pred_dist_key, output_size_wc, voxel_size=voxel_size_orig) request.add(label.mask_key, output_size_wc, voxel_size=voxel_size_orig) if label.scale_loss: request.add(label.scale_key, output_size_wc, voxel_size=voxel_size_orig) data_providers = [] if tf.train.latest_checkpoint("."): trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1]) print("Resuming training from", trained_until) else: trained_until = 0 print("Starting fresh training") for src in data_sources: if src not in ribo_sources: n5_source = N5Source( src.full_path, datasets=datasets_no_ribo, array_specs=array_specs ) else: n5_source = N5Source( src.full_path, datasets=datasets_ribo, array_specs=array_specs ) data_providers.append(n5_source) data_stream = [] for provider in data_providers: data_stream.append(provider) for ak, context in zip(raw_array_keys, contexts): data_stream[-1] += Normalize(ak) # data_stream[-1] += Pad(ak, context) # this shouldn't be necessary as I cropped the input data to have # sufficient padding data_stream[-1] += RandomLocation() data_stream[-1] += Reject(ArrayKeys.MASK, min_masked=keep_thr) data_stream = tuple(data_stream) train_pipeline = ( data_stream + RandomProvider(tuple([ds.labeled_voxels for ds in data_sources])) + gpn.SimpleAugment() + gpn.ElasticAugment( tuple(scnet.voxel_sizes[0]), (100, 100, 100), (10.0, 10.0, 10.0), (0, math.pi / 2.0), spatial_dims=3, subsample=8, ) + gpn.IntensityAugment(raw_array_keys, 0.25, 1.75, -0.5, 0.35) + GammaAugment(raw_array_keys, 0.5, 2.0) ) for ak in raw_array_keys: train_pipeline += IntensityScaleShift(ak, 2, -1) # train_pipeline += ZeroOutConstSections(ak) for label in labels: if label.labelname != "ribosomes": train_pipeline += AddDistance( label_array_key=ArrayKeys.GT_LABELS, distance_array_key=label.gt_dist_key, normalize="tanh", normalize_args=dt_scaling_factor, label_id=label.labelid, factor=2, ) else: train_pipeline += AddDistance( label_array_key=ArrayKeys.RIBO_GT, distance_array_key=label.gt_dist_key, normalize="tanh+", normalize_args=(dt_scaling_factor, 8), label_id=label.labelid, factor=2, ) for label in labels: if label.scale_loss: train_pipeline += BalanceByThreshold( label.gt_dist_key, label.scale_key, mask=label.mask_key ) train_pipeline = ( train_pipeline + PreCache(cache_size=10, num_workers=40) + Train( scnet.name, optimizer=net_io_names["optimizer"], loss=net_io_names[loss_name], inputs=inputs, summary=net_io_names["summary"], log_dir="log", outputs=outputs, gradients={}, log_every=5, save_every=500, array_specs=array_specs_pred, ) + Snapshot( snapshot, every=500, output_filename="batch_{iteration}.hdf", output_dir="snapshots/", additional_request=snapshot_request, ) + PrintProfilingStats(every=10) ) print("Starting training...") with build(train_pipeline) as b: for i in range(max_iteration): start_it = time.time() b.request_batch(request) time_it = time.time() - start_it logging.info("it {0:}: {1:}".format(i + 1, time_it)) print("Training finished")