def get_test_data_sources(setup_config): input_shape = Coordinate(setup_config["INPUT_SHAPE"]) voxel_size = Coordinate(setup_config["VOXEL_SIZE"]) input_size = input_shape * voxel_size micron_scale = voxel_size[0] # New array keys # Note: These are intended to be requested with size input_size raw = ArrayKey("RAW") matched = GraphKey("MATCHED") nonempty_placeholder = GraphKey("NONEMPTY") labels = ArrayKey("LABELS") ensure_nonempty = matched data_sources = (( TestImageSource( array=raw, array_specs={ raw: ArraySpec(interpolatable=True, voxel_size=voxel_size, dtype=np.uint16) }, size=input_size * 3, voxel_size=voxel_size, ), TestPointSource( points=[matched, nonempty_placeholder], directed=False, size=input_size * 3, num_points=333, ), ) + MergeProvider() + RandomLocation( ensure_nonempty=ensure_nonempty, ensure_centered=True, point_balance_radius=10 * micron_scale, ) + RasterizeSkeleton( points=matched, array=labels, array_spec=ArraySpec( interpolatable=False, voxel_size=voxel_size, dtype=np.uint64), ) + Normalize(raw)) return ( data_sources, raw, labels, nonempty_placeholder, matched, )
def train_simple_pipeline(n_iterations, setup_config, mknet_tensor_names, loss_tensor_names): input_shape = gp.Coordinate(setup_config["INPUT_SHAPE"]) output_shape = gp.Coordinate(setup_config["OUTPUT_SHAPE"]) voxel_size = gp.Coordinate(setup_config["VOXEL_SIZE"]) num_iterations = setup_config["NUM_ITERATIONS"] cache_size = setup_config["CACHE_SIZE"] num_workers = setup_config["NUM_WORKERS"] snapshot_every = setup_config["SNAPSHOT_EVERY"] checkpoint_every = setup_config["CHECKPOINT_EVERY"] profile_every = setup_config["PROFILE_EVERY"] seperate_by = setup_config["SEPERATE_BY"] gap_crossing_dist = setup_config["GAP_CROSSING_DIST"] match_distance_threshold = setup_config["MATCH_DISTANCE_THRESHOLD"] point_balance_radius = setup_config["POINT_BALANCE_RADIUS"] neuron_radius = setup_config["NEURON_RADIUS"] samples_path = Path(setup_config["SAMPLES_PATH"]) mongo_url = setup_config["MONGO_URL"] input_size = input_shape * voxel_size output_size = output_shape * voxel_size # voxels have size ~= 1 micron on z axis # use this value to scale anything that depends on world unit distance micron_scale = voxel_size[0] seperate_distance = (np.array(seperate_by)).tolist() # array keys for data sources raw = gp.ArrayKey("RAW") consensus = gp.PointsKey("CONSENSUS") skeletonization = gp.PointsKey("SKELETONIZATION") matched = gp.PointsKey("MATCHED") labels = gp.ArrayKey("LABELS") labels_fg = gp.ArrayKey("LABELS_FG") labels_fg_bin = gp.ArrayKey("LABELS_FG_BIN") loss_weights = gp.ArrayKey("LOSS_WEIGHTS") # tensorflow tensors gt_fg = gp.ArrayKey("GT_FG") fg_pred = gp.ArrayKey("FG_PRED") embedding = gp.ArrayKey("EMBEDDING") fg = gp.ArrayKey("FG") maxima = gp.ArrayKey("MAXIMA") gradient_embedding = gp.ArrayKey("GRADIENT_EMBEDDING") gradient_fg = gp.ArrayKey("GRADIENT_FG") emst = gp.ArrayKey("EMST") edges_u = gp.ArrayKey("EDGES_U") edges_v = gp.ArrayKey("EDGES_V") ratio_pos = gp.ArrayKey("RATIO_POS") ratio_neg = gp.ArrayKey("RATIO_NEG") dist = gp.ArrayKey("DIST") num_pos_pairs = gp.ArrayKey("NUM_POS") num_neg_pairs = gp.ArrayKey("NUM_NEG") # add request request = gp.BatchRequest() request.add(labels_fg, output_size) request.add(labels_fg_bin, output_size) request.add(loss_weights, output_size) request.add(raw, input_size) request.add(labels, input_size) request.add(matched, input_size) request.add(skeletonization, input_size) request.add(consensus, input_size) # add snapshot request snapshot_request = gp.BatchRequest() request.add(labels_fg, output_size) # tensorflow requests # snapshot_request.add(raw, input_size) # input_size request for positioning # snapshot_request.add(embedding, output_size, voxel_size=voxel_size) # snapshot_request.add(fg, output_size, voxel_size=voxel_size) # snapshot_request.add(gt_fg, output_size, voxel_size=voxel_size) # snapshot_request.add(fg_pred, output_size, voxel_size=voxel_size) # snapshot_request.add(maxima, output_size, voxel_size=voxel_size) # snapshot_request.add(gradient_embedding, output_size, voxel_size=voxel_size) # snapshot_request.add(gradient_fg, output_size, voxel_size=voxel_size) # snapshot_request[emst] = gp.ArraySpec() # snapshot_request[edges_u] = gp.ArraySpec() # snapshot_request[edges_v] = gp.ArraySpec() # snapshot_request[ratio_pos] = gp.ArraySpec() # snapshot_request[ratio_neg] = gp.ArraySpec() # snapshot_request[dist] = gp.ArraySpec() # snapshot_request[num_pos_pairs] = gp.ArraySpec() # snapshot_request[num_neg_pairs] = gp.ArraySpec() data_sources = tuple( ( gp.N5Source( filename=str((sample / "fluorescence-near-consensus.n5").absolute()), datasets={raw: "volume"}, array_specs={ raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size, dtype=np.uint16) }, ), gp.DaisyGraphProvider( f"mouselight-{sample.name}-consensus", mongo_url, points=[consensus], directed=True, node_attrs=[], edge_attrs=[], ), gp.DaisyGraphProvider( f"mouselight-{sample.name}-skeletonization", mongo_url, points=[skeletonization], directed=False, node_attrs=[], edge_attrs=[], ), ) + gp.MergeProvider() + gp.RandomLocation( ensure_nonempty=consensus, ensure_centered=True, point_balance_radius=point_balance_radius * micron_scale, ) + TopologicalMatcher( skeletonization, consensus, matched, failures=Path("matching_failures_slow"), match_distance_threshold=match_distance_threshold * micron_scale, max_gap_crossing=gap_crossing_dist * micron_scale, try_complete=False, use_gurobi=True, ) + RejectIfEmpty(matched) + RasterizeSkeleton( points=matched, array=labels, array_spec=gp.ArraySpec( interpolatable=False, voxel_size=voxel_size, dtype=np.uint32), ) + GrowLabels(labels, radii=[neuron_radius * micron_scale]) # TODO: Do these need to be scaled by world units? + gp.ElasticAugment( [40, 10, 10], [0.25, 1, 1], [0, math.pi / 2.0], subsample=4, use_fast_points_transform=True, recompute_missing_points=False, ) # + gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2]) + gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001) for sample in samples_path.iterdir() if sample.name in ("2018-07-02", "2018-08-01")) pipeline = ( data_sources + gp.RandomProvider() + Crop(labels, labels_fg) + BinarizeGt(labels_fg, labels_fg_bin) + gp.BalanceLabels(labels_fg_bin, loss_weights) + gp.PreCache(cache_size=cache_size, num_workers=num_workers) + gp.tensorflow.Train( "train_net", optimizer=create_custom_loss(mknet_tensor_names, setup_config), loss=None, inputs={ mknet_tensor_names["loss_weights"]: loss_weights, mknet_tensor_names["raw"]: raw, mknet_tensor_names["gt_labels"]: labels_fg, }, outputs={ mknet_tensor_names["embedding"]: embedding, mknet_tensor_names["fg"]: fg, loss_tensor_names["fg_pred"]: fg_pred, loss_tensor_names["maxima"]: maxima, loss_tensor_names["gt_fg"]: gt_fg, loss_tensor_names["emst"]: emst, loss_tensor_names["edges_u"]: edges_u, loss_tensor_names["edges_v"]: edges_v, loss_tensor_names["ratio_pos"]: ratio_pos, loss_tensor_names["ratio_neg"]: ratio_neg, loss_tensor_names["dist"]: dist, loss_tensor_names["num_pos_pairs"]: num_pos_pairs, loss_tensor_names["num_neg_pairs"]: num_neg_pairs, }, gradients={ mknet_tensor_names["embedding"]: gradient_embedding, mknet_tensor_names["fg"]: gradient_fg, }, save_every=checkpoint_every, summary="Merge/MergeSummary:0", log_dir="tensorflow_logs", ) + gp.PrintProfilingStats(every=profile_every) + gp.Snapshot( additional_request=snapshot_request, output_filename="snapshot_{}_{}.hdf".format( int(np.min(seperate_distance)), "{id}"), dataset_names={ # raw data raw: "volumes/raw", # labeled data labels: "volumes/labels", # trees skeletonization: "points/skeletonization", consensus: "points/consensus", matched: "points/matched", # output volumes embedding: "volumes/embedding", fg: "volumes/fg", maxima: "volumes/maxima", gt_fg: "volumes/gt_fg", fg_pred: "volumes/fg_pred", gradient_embedding: "volumes/gradient_embedding", gradient_fg: "volumes/gradient_fg", # output trees emst: "emst", edges_u: "edges_u", edges_v: "edges_v", # output debug data ratio_pos: "ratio_pos", ratio_neg: "ratio_neg", dist: "dist", num_pos_pairs: "num_pos_pairs", num_neg_pairs: "num_neg_pairs", loss_weights: "volumes/loss_weights", }, every=snapshot_every, )) with gp.build(pipeline): for _ in range(num_iterations): pipeline.request_batch(request)
def get_test_data_sources(setup_config): input_shape = Coordinate(setup_config["INPUT_SHAPE"]) voxel_size = Coordinate(setup_config["VOXEL_SIZE"]) input_size = input_shape * voxel_size micron_scale = voxel_size[0] # New array keys # Note: These are intended to be requested with size input_size raw = ArrayKey("RAW") consensus = GraphKey("CONSENSUS") skeletonization = GraphKey("SKELETONIZATION") matched = GraphKey("MATCHED") nonempty_placeholder = GraphKey("NONEMPTY") labels = ArrayKey("LABELS") use_gurobi = setup_config["USE_GUROBI"] if setup_config["FUSION_PIPELINE"]: ensure_nonempty = nonempty_placeholder else: ensure_nonempty = consensus data_sources = (( TestImageSource( array=raw, array_specs={ raw: ArraySpec(interpolatable=True, voxel_size=voxel_size, dtype=np.uint16) }, size=input_size * 3, voxel_size=voxel_size, ), TestPointSource( points=[consensus, nonempty_placeholder], directed=True, size=input_size * 3, num_points=30, ), TestPointSource( points=[skeletonization], directed=False, size=input_size * 3, num_points=333, ), ) + MergeProvider() + RandomLocation( ensure_nonempty=ensure_nonempty, ensure_centered=True, point_balance_radius=10 * micron_scale, ) + TopologicalMatcher( skeletonization, consensus, matched, match_distance_threshold=50 * micron_scale, max_gap_crossing=30 * micron_scale, try_complete=False, use_gurobi=use_gurobi, expected_edge_len=1000, ) + RasterizeSkeleton( points=matched, array=labels, array_spec=ArraySpec( interpolatable=False, voxel_size=voxel_size, dtype=np.uint64), ) + Normalize(raw)) return ( data_sources, raw, labels, nonempty_placeholder, matched, )
def train_distance_pipeline(n_iterations, setup_config, mknet_tensor_names, loss_tensor_names): input_shape = gp.Coordinate(setup_config["INPUT_SHAPE"]) output_shape = gp.Coordinate(setup_config["OUTPUT_SHAPE"]) voxel_size = gp.Coordinate(setup_config["VOXEL_SIZE"]) num_iterations = setup_config["NUM_ITERATIONS"] cache_size = setup_config["CACHE_SIZE"] num_workers = setup_config["NUM_WORKERS"] snapshot_every = setup_config["SNAPSHOT_EVERY"] checkpoint_every = setup_config["CHECKPOINT_EVERY"] profile_every = setup_config["PROFILE_EVERY"] seperate_by = setup_config["SEPERATE_BY"] gap_crossing_dist = setup_config["GAP_CROSSING_DIST"] match_distance_threshold = setup_config["MATCH_DISTANCE_THRESHOLD"] point_balance_radius = setup_config["POINT_BALANCE_RADIUS"] max_label_dist = setup_config["MAX_LABEL_DIST"] samples_path = Path(setup_config["SAMPLES_PATH"]) mongo_url = setup_config["MONGO_URL"] input_size = input_shape * voxel_size output_size = output_shape * voxel_size # voxels have size ~= 1 micron on z axis # use this value to scale anything that depends on world unit distance micron_scale = voxel_size[0] seperate_distance = (np.array(seperate_by)).tolist() # array keys for data sources raw = gp.ArrayKey("RAW") consensus = gp.PointsKey("CONSENSUS") skeletonization = gp.PointsKey("SKELETONIZATION") matched = gp.PointsKey("MATCHED") labels = gp.ArrayKey("LABELS") dist = gp.ArrayKey("DIST") dist_mask = gp.ArrayKey("DIST_MASK") dist_cropped = gp.ArrayKey("DIST_CROPPED") loss_weights = gp.ArrayKey("LOSS_WEIGHTS") # tensorflow tensors fg_dist = gp.ArrayKey("FG_DIST") gradient_fg = gp.ArrayKey("GRADIENT_FG") # add request request = gp.BatchRequest() request.add(dist_mask, output_size) request.add(dist_cropped, output_size) request.add(raw, input_size) request.add(labels, input_size) request.add(dist, input_size) request.add(matched, input_size) request.add(skeletonization, input_size) request.add(consensus, input_size) request.add(loss_weights, output_size) # add snapshot request snapshot_request = gp.BatchRequest() # tensorflow requests snapshot_request.add(raw, input_size) # input_size request for positioning snapshot_request.add(gradient_fg, output_size, voxel_size=voxel_size) snapshot_request.add(fg_dist, output_size, voxel_size=voxel_size) data_sources = tuple( ( gp.N5Source( filename=str((sample / "fluorescence-near-consensus.n5").absolute()), datasets={raw: "volume"}, array_specs={ raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size, dtype=np.uint16) }, ), gp.DaisyGraphProvider( f"mouselight-{sample.name}-consensus", mongo_url, points=[consensus], directed=True, node_attrs=[], edge_attrs=[], ), gp.DaisyGraphProvider( f"mouselight-{sample.name}-skeletonization", mongo_url, points=[skeletonization], directed=False, node_attrs=[], edge_attrs=[], ), ) + gp.MergeProvider() + gp.RandomLocation( ensure_nonempty=consensus, ensure_centered=True, point_balance_radius=point_balance_radius * micron_scale, ) + TopologicalMatcher( skeletonization, consensus, matched, failures=Path("matching_failures_slow"), match_distance_threshold=match_distance_threshold * micron_scale, max_gap_crossing=gap_crossing_dist * micron_scale, try_complete=False, use_gurobi=True, ) + RejectIfEmpty(matched, center_size=output_size) + RasterizeSkeleton( points=matched, array=labels, array_spec=gp.ArraySpec( interpolatable=False, voxel_size=voxel_size, dtype=np.uint32), ) + gp.contrib.nodes.add_distance.AddDistance( labels, dist, dist_mask, max_distance=max_label_dist * micron_scale) + gp.contrib.nodes. tanh_saturate.TanhSaturate(dist, scale=micron_scale, offset=1) + ThresholdMask(dist, loss_weights, 1e-4) # TODO: Do these need to be scaled by world units? + gp.ElasticAugment( [40, 10, 10], [0.25, 1, 1], [0, math.pi / 2.0], subsample=4, use_fast_points_transform=True, recompute_missing_points=False, ) # + gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2]) + gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001) for sample in samples_path.iterdir() if sample.name in ("2018-07-02", "2018-08-01")) pipeline = ( data_sources + gp.RandomProvider() + Crop(dist, dist_cropped) # + gp.PreCache(cache_size=cache_size, num_workers=num_workers) + gp.tensorflow.Train( "train_net_foreground", optimizer=mknet_tensor_names["optimizer"], loss=mknet_tensor_names["fg_loss"], inputs={ mknet_tensor_names["raw"]: raw, mknet_tensor_names["gt_distances"]: dist_cropped, mknet_tensor_names["loss_weights"]: loss_weights, }, outputs={mknet_tensor_names["fg_pred"]: fg_dist}, gradients={mknet_tensor_names["fg_pred"]: gradient_fg}, save_every=checkpoint_every, # summary=mknet_tensor_names["summaries"], log_dir="tensorflow_logs", ) + gp.PrintProfilingStats(every=profile_every) + gp.Snapshot( additional_request=snapshot_request, output_filename="snapshot_{}_{}.hdf".format( int(np.min(seperate_distance)), "{id}"), dataset_names={ # raw data raw: "volumes/raw", labels: "volumes/labels", # labeled data dist_cropped: "volumes/dist", # trees skeletonization: "points/skeletonization", consensus: "points/consensus", matched: "points/matched", # output volumes fg_dist: "volumes/fg_dist", gradient_fg: "volumes/gradient_fg", # output debug data dist_mask: "volumes/dist_mask", loss_weights: "volumes/loss_weights" }, every=snapshot_every, )) with gp.build(pipeline): for _ in range(num_iterations): pipeline.request_batch(request)
def get_mouselight_data_sources(setup_config: Dict[str, Any], source_samples: List[str], locations=False): # Source Paths and accessibility raw_n5 = setup_config["RAW_N5"] mongo_url = setup_config["MONGO_URL"] samples_path = Path(setup_config["SAMPLES_PATH"]) # specified_locations = setup_config.get("SPECIFIED_LOCATIONS") # Graph matching parameters point_balance_radius = setup_config["POINT_BALANCE_RADIUS"] matching_failures_dir = setup_config["MATCHING_FAILURES_DIR"] matching_failures_dir = (matching_failures_dir if matching_failures_dir is None else Path(matching_failures_dir)) # Data Properties voxel_size = gp.Coordinate(setup_config["VOXEL_SIZE"]) output_shape = gp.Coordinate(setup_config["OUTPUT_SHAPE"]) output_size = output_shape * voxel_size micron_scale = voxel_size[0] distance_attr = setup_config["DISTANCE_ATTRIBUTE"] target_distance = float(setup_config["MIN_DIST_TO_FALLBACK"]) max_nonempty_points = int(setup_config["MAX_RANDOM_LOCATION_POINTS"]) mongo_db_template = setup_config["MONGO_DB_TEMPLATE"] matched_source = setup_config.get("MATCHED_SOURCE", "matched") # New array keys # Note: These are intended to be requested with size input_size raw = ArrayKey("RAW") matched = gp.PointsKey("MATCHED") nonempty_placeholder = gp.PointsKey("NONEMPTY") labels = ArrayKey("LABELS") ensure_nonempty = nonempty_placeholder node_offset = { sample.name: (daisy.persistence.MongoDbGraphProvider( mongo_db_template.format(sample=sample.name, source="skeletonization"), mongo_url, ).num_nodes(None) + 1) for sample in samples_path.iterdir() if sample.name in source_samples } # if specified_locations is not None: # centers = pickle.load(open(specified_locations, "rb")) # random = gp.SpecifiedLocation # kwargs = {"locations": centers, "choose_randomly": True} # logger.info(f"Using specified locations from {specified_locations}") # elif locations: # random = RandomLocations # kwargs = { # "ensure_nonempty": ensure_nonempty, # "ensure_centered": True, # "point_balance_radius": point_balance_radius * micron_scale, # "loc": gp.ArrayKey("RANDOM_LOCATION"), # } # else: random = RandomLocation kwargs = { "ensure_nonempty": ensure_nonempty, "ensure_centered": True, "point_balance_radius": point_balance_radius * micron_scale, } data_sources = (tuple( ( gp.ZarrSource( filename=str((sample / raw_n5).absolute()), datasets={raw: "volume-rechunked"}, array_specs={ raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size, dtype=np.uint16) }, ), DaisyGraphProvider( mongo_db_template.format(sample=sample.name, source=matched_source), mongo_url, points=[matched], directed=True, node_attrs=[], edge_attrs=[], ), FilteredDaisyGraphProvider( mongo_db_template.format(sample=sample.name, source=matched_source), mongo_url, points=[nonempty_placeholder], directed=True, node_attrs=["distance_to_fallback"], edge_attrs=[], num_nodes=max_nonempty_points, dist_attribute=distance_attr, min_dist=target_distance, ), ) + gp.MergeProvider() + random(**kwargs) + gp.Normalize(raw) + FilterComponents( matched, node_offset[sample.name], centroid_size=output_size) + RasterizeSkeleton( points=matched, array=labels, array_spec=gp.ArraySpec( interpolatable=False, voxel_size=voxel_size, dtype=np.int64), ) for sample in samples_path.iterdir() if sample.name in source_samples) + gp.RandomProvider()) return (data_sources, raw, labels, nonempty_placeholder, matched)