def train_simple_pipeline(n_iterations, setup_config, mknet_tensor_names, loss_tensor_names): input_shape = gp.Coordinate(setup_config["INPUT_SHAPE"]) output_shape = gp.Coordinate(setup_config["OUTPUT_SHAPE"]) voxel_size = gp.Coordinate(setup_config["VOXEL_SIZE"]) num_iterations = setup_config["NUM_ITERATIONS"] cache_size = setup_config["CACHE_SIZE"] num_workers = setup_config["NUM_WORKERS"] snapshot_every = setup_config["SNAPSHOT_EVERY"] checkpoint_every = setup_config["CHECKPOINT_EVERY"] profile_every = setup_config["PROFILE_EVERY"] seperate_by = setup_config["SEPERATE_BY"] gap_crossing_dist = setup_config["GAP_CROSSING_DIST"] match_distance_threshold = setup_config["MATCH_DISTANCE_THRESHOLD"] point_balance_radius = setup_config["POINT_BALANCE_RADIUS"] neuron_radius = setup_config["NEURON_RADIUS"] samples_path = Path(setup_config["SAMPLES_PATH"]) mongo_url = setup_config["MONGO_URL"] input_size = input_shape * voxel_size output_size = output_shape * voxel_size # voxels have size ~= 1 micron on z axis # use this value to scale anything that depends on world unit distance micron_scale = voxel_size[0] seperate_distance = (np.array(seperate_by)).tolist() # array keys for data sources raw = gp.ArrayKey("RAW") consensus = gp.PointsKey("CONSENSUS") skeletonization = gp.PointsKey("SKELETONIZATION") matched = gp.PointsKey("MATCHED") labels = gp.ArrayKey("LABELS") labels_fg = gp.ArrayKey("LABELS_FG") labels_fg_bin = gp.ArrayKey("LABELS_FG_BIN") loss_weights = gp.ArrayKey("LOSS_WEIGHTS") # tensorflow tensors gt_fg = gp.ArrayKey("GT_FG") fg_pred = gp.ArrayKey("FG_PRED") embedding = gp.ArrayKey("EMBEDDING") fg = gp.ArrayKey("FG") maxima = gp.ArrayKey("MAXIMA") gradient_embedding = gp.ArrayKey("GRADIENT_EMBEDDING") gradient_fg = gp.ArrayKey("GRADIENT_FG") emst = gp.ArrayKey("EMST") edges_u = gp.ArrayKey("EDGES_U") edges_v = gp.ArrayKey("EDGES_V") ratio_pos = gp.ArrayKey("RATIO_POS") ratio_neg = gp.ArrayKey("RATIO_NEG") dist = gp.ArrayKey("DIST") num_pos_pairs = gp.ArrayKey("NUM_POS") num_neg_pairs = gp.ArrayKey("NUM_NEG") # add request request = gp.BatchRequest() request.add(labels_fg, output_size) request.add(labels_fg_bin, output_size) request.add(loss_weights, output_size) request.add(raw, input_size) request.add(labels, input_size) request.add(matched, input_size) request.add(skeletonization, input_size) request.add(consensus, input_size) # add snapshot request snapshot_request = gp.BatchRequest() request.add(labels_fg, output_size) # tensorflow requests # snapshot_request.add(raw, input_size) # input_size request for positioning # snapshot_request.add(embedding, output_size, voxel_size=voxel_size) # snapshot_request.add(fg, output_size, voxel_size=voxel_size) # snapshot_request.add(gt_fg, output_size, voxel_size=voxel_size) # snapshot_request.add(fg_pred, output_size, voxel_size=voxel_size) # snapshot_request.add(maxima, output_size, voxel_size=voxel_size) # snapshot_request.add(gradient_embedding, output_size, voxel_size=voxel_size) # snapshot_request.add(gradient_fg, output_size, voxel_size=voxel_size) # snapshot_request[emst] = gp.ArraySpec() # snapshot_request[edges_u] = gp.ArraySpec() # snapshot_request[edges_v] = gp.ArraySpec() # snapshot_request[ratio_pos] = gp.ArraySpec() # snapshot_request[ratio_neg] = gp.ArraySpec() # snapshot_request[dist] = gp.ArraySpec() # snapshot_request[num_pos_pairs] = gp.ArraySpec() # snapshot_request[num_neg_pairs] = gp.ArraySpec() data_sources = tuple( ( gp.N5Source( filename=str((sample / "fluorescence-near-consensus.n5").absolute()), datasets={raw: "volume"}, array_specs={ raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size, dtype=np.uint16) }, ), gp.DaisyGraphProvider( f"mouselight-{sample.name}-consensus", mongo_url, points=[consensus], directed=True, node_attrs=[], edge_attrs=[], ), gp.DaisyGraphProvider( f"mouselight-{sample.name}-skeletonization", mongo_url, points=[skeletonization], directed=False, node_attrs=[], edge_attrs=[], ), ) + gp.MergeProvider() + gp.RandomLocation( ensure_nonempty=consensus, ensure_centered=True, point_balance_radius=point_balance_radius * micron_scale, ) + TopologicalMatcher( skeletonization, consensus, matched, failures=Path("matching_failures_slow"), match_distance_threshold=match_distance_threshold * micron_scale, max_gap_crossing=gap_crossing_dist * micron_scale, try_complete=False, use_gurobi=True, ) + RejectIfEmpty(matched) + RasterizeSkeleton( points=matched, array=labels, array_spec=gp.ArraySpec( interpolatable=False, voxel_size=voxel_size, dtype=np.uint32), ) + GrowLabels(labels, radii=[neuron_radius * micron_scale]) # TODO: Do these need to be scaled by world units? + gp.ElasticAugment( [40, 10, 10], [0.25, 1, 1], [0, math.pi / 2.0], subsample=4, use_fast_points_transform=True, recompute_missing_points=False, ) # + gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2]) + gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001) for sample in samples_path.iterdir() if sample.name in ("2018-07-02", "2018-08-01")) pipeline = ( data_sources + gp.RandomProvider() + Crop(labels, labels_fg) + BinarizeGt(labels_fg, labels_fg_bin) + gp.BalanceLabels(labels_fg_bin, loss_weights) + gp.PreCache(cache_size=cache_size, num_workers=num_workers) + gp.tensorflow.Train( "train_net", optimizer=create_custom_loss(mknet_tensor_names, setup_config), loss=None, inputs={ mknet_tensor_names["loss_weights"]: loss_weights, mknet_tensor_names["raw"]: raw, mknet_tensor_names["gt_labels"]: labels_fg, }, outputs={ mknet_tensor_names["embedding"]: embedding, mknet_tensor_names["fg"]: fg, loss_tensor_names["fg_pred"]: fg_pred, loss_tensor_names["maxima"]: maxima, loss_tensor_names["gt_fg"]: gt_fg, loss_tensor_names["emst"]: emst, loss_tensor_names["edges_u"]: edges_u, loss_tensor_names["edges_v"]: edges_v, loss_tensor_names["ratio_pos"]: ratio_pos, loss_tensor_names["ratio_neg"]: ratio_neg, loss_tensor_names["dist"]: dist, loss_tensor_names["num_pos_pairs"]: num_pos_pairs, loss_tensor_names["num_neg_pairs"]: num_neg_pairs, }, gradients={ mknet_tensor_names["embedding"]: gradient_embedding, mknet_tensor_names["fg"]: gradient_fg, }, save_every=checkpoint_every, summary="Merge/MergeSummary:0", log_dir="tensorflow_logs", ) + gp.PrintProfilingStats(every=profile_every) + gp.Snapshot( additional_request=snapshot_request, output_filename="snapshot_{}_{}.hdf".format( int(np.min(seperate_distance)), "{id}"), dataset_names={ # raw data raw: "volumes/raw", # labeled data labels: "volumes/labels", # trees skeletonization: "points/skeletonization", consensus: "points/consensus", matched: "points/matched", # output volumes embedding: "volumes/embedding", fg: "volumes/fg", maxima: "volumes/maxima", gt_fg: "volumes/gt_fg", fg_pred: "volumes/fg_pred", gradient_embedding: "volumes/gradient_embedding", gradient_fg: "volumes/gradient_fg", # output trees emst: "emst", edges_u: "edges_u", edges_v: "edges_v", # output debug data ratio_pos: "ratio_pos", ratio_neg: "ratio_neg", dist: "dist", num_pos_pairs: "num_pos_pairs", num_neg_pairs: "num_neg_pairs", loss_weights: "volumes/loss_weights", }, every=snapshot_every, )) with gp.build(pipeline): for _ in range(num_iterations): pipeline.request_batch(request)
request.add(masked_base, input_size) request.add(masked_add, input_size) request.add(softmask, input_size) request.add(mask_maxed, input_size) request.add(masked_base_b, input_size) request.add(masked_add_b, input_size) request.add(softmask_b, input_size) request.add(mask_maxed_b, input_size) data_sources = tuple(( gp.N5Source( filename=str(( filename / "consensus-neurons-with-machine-centerpoints-labelled-as-swcs-carved.n5" ).absolute()), datasets={raw: "volume"}, array_specs={ raw: gp.ArraySpec( interpolatable=True, voxel_size=voxel_size, dtype=np.uint16) }, ), MouselightSwcFileSource( filename=str((filename / "high-res-swcs/G-002.swc").absolute()), points=(swcs, ), scale=voxel_size, transpose=(2, 1, 0), transform_file=str((filename / "transform.txt").absolute()), ignore_human_nodes=False, ), ) + gp.MergeProvider() + gp.RandomLocation( ensure_nonempty=swcs, ensure_centered=True) + RasterizeSkeleton(
def train_until(max_iteration): # get the latest checkpoint if tf.train.latest_checkpoint("."): trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1]) else: trained_until = 0 if trained_until >= max_iteration: return # array keys for data sources raw = gp.ArrayKey("RAW") swcs = gp.PointsKey("SWCS") labels = gp.ArrayKey("LABELS") # array keys for base volume raw_base = gp.ArrayKey("RAW_BASE") labels_base = gp.ArrayKey("LABELS_BASE") swc_base = gp.PointsKey("SWC_BASE") # array keys for add volume raw_add = gp.ArrayKey("RAW_ADD") labels_add = gp.ArrayKey("LABELS_ADD") swc_add = gp.PointsKey("SWC_ADD") # array keys for fused volume raw_fused = gp.ArrayKey("RAW_FUSED") labels_fused = gp.ArrayKey("LABELS_FUSED") swc_fused = gp.PointsKey("SWC_FUSED") # output data fg = gp.ArrayKey("FG") labels_fg = gp.ArrayKey("LABELS_FG") labels_fg_bin = gp.ArrayKey("LABELS_FG_BIN") gradient_fg = gp.ArrayKey("GRADIENT_FG") loss_weights = gp.ArrayKey("LOSS_WEIGHTS") voxel_size = gp.Coordinate((10, 3, 3)) input_size = gp.Coordinate(net_config["input_shape"]) * voxel_size output_size = gp.Coordinate(net_config["output_shape"]) * voxel_size # add request request = gp.BatchRequest() request.add(raw_fused, input_size) request.add(labels_fused, input_size) request.add(swc_fused, input_size) request.add(labels_fg, output_size) request.add(labels_fg_bin, output_size) request.add(loss_weights, output_size) # add snapshot request # request.add(fg, output_size) # request.add(labels_fg, output_size) request.add(gradient_fg, output_size) request.add(raw_base, input_size) request.add(raw_add, input_size) request.add(labels_base, input_size) request.add(labels_add, input_size) request.add(swc_base, input_size) request.add(swc_add, input_size) data_sources = tuple( ( gp.N5Source( filename=str( ( filename / "consensus-neurons-with-machine-centerpoints-labelled-as-swcs-carved.n5" ).absolute() ), datasets={raw: "volume"}, array_specs={ raw: gp.ArraySpec( interpolatable=True, voxel_size=voxel_size, dtype=np.uint16 ) }, ), MouselightSwcFileSource( filename=str( ( filename / "consensus-neurons-with-machine-centerpoints-labelled-as-swcs" ).absolute() ), points=(swcs,), scale=voxel_size, transpose=(2, 1, 0), transform_file=str((filename / "transform.txt").absolute()), ignore_human_nodes=True ), ) + gp.MergeProvider() + gp.RandomLocation( ensure_nonempty=swcs, ensure_centered=True ) + RasterizeSkeleton( points=swcs, array=labels, array_spec=gp.ArraySpec( interpolatable=False, voxel_size=voxel_size, dtype=np.uint32 ), ) + GrowLabels(labels, radius=10) # augment + gp.ElasticAugment( [40, 10, 10], [0.25, 1, 1], [0, math.pi / 2.0], subsample=4, ) + gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2]) + gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001) for filename in Path(sample_dir).iterdir() if "2018-08-01" in filename.name # or "2018-07-02" in filename.name ) pipeline = ( data_sources + gp.RandomProvider() + GetNeuronPair( swcs, raw, labels, (swc_base, swc_add), (raw_base, raw_add), (labels_base, labels_add), seperate_by=150, shift_attempts=50, request_attempts=10, ) + FusionAugment( raw_base, raw_add, labels_base, labels_add, swc_base, swc_add, raw_fused, labels_fused, swc_fused, blend_mode="labels_mask", blend_smoothness=10, num_blended_objects=0, ) + Crop(labels_fused, labels_fg) + BinarizeGt(labels_fg, labels_fg_bin) + gp.BalanceLabels(labels_fg_bin, loss_weights) # train + gp.PreCache(cache_size=40, num_workers=10) + gp.tensorflow.Train( "./train_net", optimizer=net_names["optimizer"], loss=net_names["loss"], inputs={ net_names["raw"]: raw_fused, net_names["labels_fg"]: labels_fg_bin, net_names["loss_weights"]: loss_weights, }, outputs={net_names["fg"]: fg}, gradients={net_names["fg"]: gradient_fg}, save_every=100000, ) + gp.Snapshot( output_filename="snapshot_{iteration}.hdf", dataset_names={ raw_fused: "volumes/raw_fused", raw_base: "volumes/raw_base", raw_add: "volumes/raw_add", labels_fused: "volumes/labels_fused", labels_base: "volumes/labels_base", labels_add: "volumes/labels_add", labels_fg_bin: "volumes/labels_fg_bin", fg: "volumes/pred_fg", gradient_fg: "volumes/gradient_fg", }, every=100, ) + gp.PrintProfilingStats(every=10) ) with gp.build(pipeline): logging.info("Starting training...") for i in range(max_iteration - trained_until): logging.info("requesting batch {}".format(i)) batch = pipeline.request_batch(request) """
def train_distance_pipeline(n_iterations, setup_config, mknet_tensor_names, loss_tensor_names): input_shape = gp.Coordinate(setup_config["INPUT_SHAPE"]) output_shape = gp.Coordinate(setup_config["OUTPUT_SHAPE"]) voxel_size = gp.Coordinate(setup_config["VOXEL_SIZE"]) num_iterations = setup_config["NUM_ITERATIONS"] cache_size = setup_config["CACHE_SIZE"] num_workers = setup_config["NUM_WORKERS"] snapshot_every = setup_config["SNAPSHOT_EVERY"] checkpoint_every = setup_config["CHECKPOINT_EVERY"] profile_every = setup_config["PROFILE_EVERY"] seperate_by = setup_config["SEPERATE_BY"] gap_crossing_dist = setup_config["GAP_CROSSING_DIST"] match_distance_threshold = setup_config["MATCH_DISTANCE_THRESHOLD"] point_balance_radius = setup_config["POINT_BALANCE_RADIUS"] max_label_dist = setup_config["MAX_LABEL_DIST"] samples_path = Path(setup_config["SAMPLES_PATH"]) mongo_url = setup_config["MONGO_URL"] input_size = input_shape * voxel_size output_size = output_shape * voxel_size # voxels have size ~= 1 micron on z axis # use this value to scale anything that depends on world unit distance micron_scale = voxel_size[0] seperate_distance = (np.array(seperate_by)).tolist() # array keys for data sources raw = gp.ArrayKey("RAW") consensus = gp.PointsKey("CONSENSUS") skeletonization = gp.PointsKey("SKELETONIZATION") matched = gp.PointsKey("MATCHED") labels = gp.ArrayKey("LABELS") dist = gp.ArrayKey("DIST") dist_mask = gp.ArrayKey("DIST_MASK") dist_cropped = gp.ArrayKey("DIST_CROPPED") loss_weights = gp.ArrayKey("LOSS_WEIGHTS") # tensorflow tensors fg_dist = gp.ArrayKey("FG_DIST") gradient_fg = gp.ArrayKey("GRADIENT_FG") # add request request = gp.BatchRequest() request.add(dist_mask, output_size) request.add(dist_cropped, output_size) request.add(raw, input_size) request.add(labels, input_size) request.add(dist, input_size) request.add(matched, input_size) request.add(skeletonization, input_size) request.add(consensus, input_size) request.add(loss_weights, output_size) # add snapshot request snapshot_request = gp.BatchRequest() # tensorflow requests snapshot_request.add(raw, input_size) # input_size request for positioning snapshot_request.add(gradient_fg, output_size, voxel_size=voxel_size) snapshot_request.add(fg_dist, output_size, voxel_size=voxel_size) data_sources = tuple( ( gp.N5Source( filename=str((sample / "fluorescence-near-consensus.n5").absolute()), datasets={raw: "volume"}, array_specs={ raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size, dtype=np.uint16) }, ), gp.DaisyGraphProvider( f"mouselight-{sample.name}-consensus", mongo_url, points=[consensus], directed=True, node_attrs=[], edge_attrs=[], ), gp.DaisyGraphProvider( f"mouselight-{sample.name}-skeletonization", mongo_url, points=[skeletonization], directed=False, node_attrs=[], edge_attrs=[], ), ) + gp.MergeProvider() + gp.RandomLocation( ensure_nonempty=consensus, ensure_centered=True, point_balance_radius=point_balance_radius * micron_scale, ) + TopologicalMatcher( skeletonization, consensus, matched, failures=Path("matching_failures_slow"), match_distance_threshold=match_distance_threshold * micron_scale, max_gap_crossing=gap_crossing_dist * micron_scale, try_complete=False, use_gurobi=True, ) + RejectIfEmpty(matched, center_size=output_size) + RasterizeSkeleton( points=matched, array=labels, array_spec=gp.ArraySpec( interpolatable=False, voxel_size=voxel_size, dtype=np.uint32), ) + gp.contrib.nodes.add_distance.AddDistance( labels, dist, dist_mask, max_distance=max_label_dist * micron_scale) + gp.contrib.nodes. tanh_saturate.TanhSaturate(dist, scale=micron_scale, offset=1) + ThresholdMask(dist, loss_weights, 1e-4) # TODO: Do these need to be scaled by world units? + gp.ElasticAugment( [40, 10, 10], [0.25, 1, 1], [0, math.pi / 2.0], subsample=4, use_fast_points_transform=True, recompute_missing_points=False, ) # + gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2]) + gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001) for sample in samples_path.iterdir() if sample.name in ("2018-07-02", "2018-08-01")) pipeline = ( data_sources + gp.RandomProvider() + Crop(dist, dist_cropped) # + gp.PreCache(cache_size=cache_size, num_workers=num_workers) + gp.tensorflow.Train( "train_net_foreground", optimizer=mknet_tensor_names["optimizer"], loss=mknet_tensor_names["fg_loss"], inputs={ mknet_tensor_names["raw"]: raw, mknet_tensor_names["gt_distances"]: dist_cropped, mknet_tensor_names["loss_weights"]: loss_weights, }, outputs={mknet_tensor_names["fg_pred"]: fg_dist}, gradients={mknet_tensor_names["fg_pred"]: gradient_fg}, save_every=checkpoint_every, # summary=mknet_tensor_names["summaries"], log_dir="tensorflow_logs", ) + gp.PrintProfilingStats(every=profile_every) + gp.Snapshot( additional_request=snapshot_request, output_filename="snapshot_{}_{}.hdf".format( int(np.min(seperate_distance)), "{id}"), dataset_names={ # raw data raw: "volumes/raw", labels: "volumes/labels", # labeled data dist_cropped: "volumes/dist", # trees skeletonization: "points/skeletonization", consensus: "points/consensus", matched: "points/matched", # output volumes fg_dist: "volumes/fg_dist", gradient_fg: "volumes/gradient_fg", # output debug data dist_mask: "volumes/dist_mask", loss_weights: "volumes/loss_weights" }, every=snapshot_every, )) with gp.build(pipeline): for _ in range(num_iterations): pipeline.request_batch(request)
def train_until(max_iteration): # get the latest checkpoint if tf.train.latest_checkpoint("."): trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1]) else: trained_until = 0 if trained_until >= max_iteration: return # array keys for data sources raw = gp.ArrayKey("RAW") swcs = gp.PointsKey("SWCS") voxel_size = gp.Coordinate((10, 3, 3)) input_size = gp.Coordinate(net_config["input_shape"]) * voxel_size * 2 # add request request = gp.BatchRequest() request.add(raw, input_size) request.add(swcs, input_size) data_sources = tuple(( gp.N5Source( filename=str(( filename / "consensus-neurons-with-machine-centerpoints-labelled-as-swcs-carved.n5" ).absolute()), datasets={raw: "volume"}, array_specs={ raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size, dtype=np.uint16) }, ), MouselightSwcFileSource( filename=str(( filename / "consensus-neurons-with-machine-centerpoints-labelled-as-swcs/G-002.swc" ).absolute()), points=(swcs, ), scale=voxel_size, transpose=(2, 1, 0), transform_file=str((filename / "transform.txt").absolute()), ), ) + gp.MergeProvider() + gp.RandomLocation(ensure_nonempty=swcs, ensure_centered=True) for filename in Path(sample_dir).iterdir() if "2018-08-01" in filename.name) pipeline = data_sources + gp.RandomProvider() with gp.build(pipeline): print("Starting training...") for i in range(max_iteration - trained_until): batch = pipeline.request_batch(request) vis_points_with_array(batch[raw].data, points_to_graph(batch[swcs].data), np.array(voxel_size))