def add_snapshot( pipeline, setup_config, datasets: List[Tuple[Union[ArrayKey, PointsKey], gp.Coordinate, str]], ): # Config options snapshot_every = setup_config["SNAPSHOT_EVERY"] snapshot_file_name = setup_config["SNAPSHOT_FILE_NAME"] snapshot_dir = setup_config["SNAPSHOT_DIRECTORY"] # Snapshot request: snapshot_request = gp.BatchRequest() for key, size, *_ in datasets: snapshot_request.add(key, size) pipeline = pipeline + gp.Snapshot( additional_request=snapshot_request, output_dir=snapshot_dir, output_filename=snapshot_file_name, dataset_names={key: location for key, _, location, *_ in datasets}, every=snapshot_every, ) return pipeline
def pre_computed_fg_validation_pipeline(config, snapshot_file, raw_path, gt_path, fg_path): blocks = config["BLOCKS"] benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"]) sample = config["VALIDATION_SAMPLES"][0] transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt" voxel_size = gp.Coordinate(config["VOXEL_SIZE"]) input_shape = gp.Coordinate(config["INPUT_SHAPE"]) output_shape = gp.Coordinate(config["OUTPUT_SHAPE"]) input_size = voxel_size * input_shape output_size = voxel_size * output_shape candidate_spacing = config["CANDIDATE_SPACING"] candidate_threshold = config["CANDIDATE_THRESHOLD"] distance_attr = config["DISTANCE_ATTR"] num_thresholds = config["NUM_EVAL_THRESHOLDS"] threshold_range = config["EVAL_THRESHOLD_RANGE"] component_threshold = config["COMPONENT_THRESHOLD_1"] validation_pipelines = [] specs = {} for block in blocks: validation_dir = get_validation_dir(benchmark_datasets_path, block) trees = [] cube = None for gt_file in validation_dir.iterdir(): if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc": trees.append(gt_file) if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc": cube = gt_file assert cube.exists() cube_roi = get_roi_from_swc( cube, Path(transform_template.format(sample=sample)), np.array(voxel_size[::-1]), ) candidates = gp.ArrayKey(f"CANDIDATES_{block}") raw = gp.ArrayKey(f"RAW_{block}") mst = gp.GraphKey(f"MST_{block}") gt = gp.GraphKey(f"GT_{block}") fg = gp.ArrayKey(f"FG_{block}") score = gp.ArrayKey(f"SCORE_{block}") details = gp.GraphKey(f"DETAILS_{block}") raw_source = SnapshotSource( snapshot_file, datasets={ raw: raw_path.format(block=block), fg: fg_path.format(block=block), }, ) gt_source = SnapshotSource( snapshot_file, datasets={gt: gt_path.format(block=block)}, directed={gt: False}, ) input_roi = cube_roi.grow((input_size - output_size) // 2, (input_size - output_size) // 2) cube_roi_shifted = gp.Roi((0, ) * len(cube_roi.get_shape()), cube_roi.get_shape()) input_roi = cube_roi_shifted.grow((input_size - output_size) // 2, (input_size - output_size) // 2) block_spec = specs.setdefault(block, {}) block_spec[raw] = gp.ArraySpec(input_roi) block_spec[candidates] = gp.ArraySpec(cube_roi_shifted) block_spec[fg] = gp.ArraySpec(cube_roi_shifted) block_spec[gt] = gp.GraphSpec(cube_roi_shifted, directed=False) block_spec[mst] = gp.GraphSpec(cube_roi_shifted, directed=False) block_spec[score] = gp.ArraySpec(nonspatial=True) additional_request = BatchRequest() additional_request[raw] = gp.ArraySpec(input_roi) additional_request[candidates] = gp.ArraySpec(cube_roi_shifted) additional_request[fg] = gp.ArraySpec(cube_roi_shifted) additional_request[gt] = gp.GraphSpec(cube_roi_shifted, directed=False) additional_request[mst] = gp.GraphSpec(cube_roi_shifted, directed=False) additional_request[details] = gp.GraphSpec(cube_roi_shifted, directed=False) pipeline = ((raw_source, gt_source) + gp.MergeProvider() + Skeletonize( fg, candidates, candidate_spacing, candidate_threshold) + MiniMax(fg, candidates, mst, distance_attr=distance_attr)) pipeline += Evaluate( gt, mst, score, roi=cube_roi_shifted, details=details, edge_threshold_attr=distance_attr, num_thresholds=num_thresholds, threshold_range=threshold_range, small_component_threshold=component_threshold, ) if config["EVAL_SNAPSHOT"]: pipeline += gp.Snapshot( { raw: f"volumes/raw", fg: f"volumes/foreground", candidates: f"volumes/candidates", mst: f"points/mst", gt: f"points/gt", details: f"points/details", }, output_dir="eval_results", output_filename=config["EVAL_SNAPSHOT_NAME"].format( block=block), edge_attrs={ mst: [distance_attr], details: ["details", "label_pair"] }, node_attrs={details: ["details", "label_pair"]}, additional_request=additional_request, ) validation_pipelines.append(pipeline) final_score = gp.ArrayKey("SCORE") validation_pipeline = (tuple(pipeline for pipeline in validation_pipelines) + gp.MergeProvider() + MergeScores(final_score, specs) + gp.PrintProfilingStats()) return validation_pipeline, final_score
def train_simple_pipeline(n_iterations, setup_config, mknet_tensor_names, loss_tensor_names): input_shape = gp.Coordinate(setup_config["INPUT_SHAPE"]) output_shape = gp.Coordinate(setup_config["OUTPUT_SHAPE"]) voxel_size = gp.Coordinate(setup_config["VOXEL_SIZE"]) num_iterations = setup_config["NUM_ITERATIONS"] cache_size = setup_config["CACHE_SIZE"] num_workers = setup_config["NUM_WORKERS"] snapshot_every = setup_config["SNAPSHOT_EVERY"] checkpoint_every = setup_config["CHECKPOINT_EVERY"] profile_every = setup_config["PROFILE_EVERY"] seperate_by = setup_config["SEPERATE_BY"] gap_crossing_dist = setup_config["GAP_CROSSING_DIST"] match_distance_threshold = setup_config["MATCH_DISTANCE_THRESHOLD"] point_balance_radius = setup_config["POINT_BALANCE_RADIUS"] neuron_radius = setup_config["NEURON_RADIUS"] samples_path = Path(setup_config["SAMPLES_PATH"]) mongo_url = setup_config["MONGO_URL"] input_size = input_shape * voxel_size output_size = output_shape * voxel_size # voxels have size ~= 1 micron on z axis # use this value to scale anything that depends on world unit distance micron_scale = voxel_size[0] seperate_distance = (np.array(seperate_by)).tolist() # array keys for data sources raw = gp.ArrayKey("RAW") consensus = gp.PointsKey("CONSENSUS") skeletonization = gp.PointsKey("SKELETONIZATION") matched = gp.PointsKey("MATCHED") labels = gp.ArrayKey("LABELS") labels_fg = gp.ArrayKey("LABELS_FG") labels_fg_bin = gp.ArrayKey("LABELS_FG_BIN") loss_weights = gp.ArrayKey("LOSS_WEIGHTS") # tensorflow tensors gt_fg = gp.ArrayKey("GT_FG") fg_pred = gp.ArrayKey("FG_PRED") embedding = gp.ArrayKey("EMBEDDING") fg = gp.ArrayKey("FG") maxima = gp.ArrayKey("MAXIMA") gradient_embedding = gp.ArrayKey("GRADIENT_EMBEDDING") gradient_fg = gp.ArrayKey("GRADIENT_FG") emst = gp.ArrayKey("EMST") edges_u = gp.ArrayKey("EDGES_U") edges_v = gp.ArrayKey("EDGES_V") ratio_pos = gp.ArrayKey("RATIO_POS") ratio_neg = gp.ArrayKey("RATIO_NEG") dist = gp.ArrayKey("DIST") num_pos_pairs = gp.ArrayKey("NUM_POS") num_neg_pairs = gp.ArrayKey("NUM_NEG") # add request request = gp.BatchRequest() request.add(labels_fg, output_size) request.add(labels_fg_bin, output_size) request.add(loss_weights, output_size) request.add(raw, input_size) request.add(labels, input_size) request.add(matched, input_size) request.add(skeletonization, input_size) request.add(consensus, input_size) # add snapshot request snapshot_request = gp.BatchRequest() request.add(labels_fg, output_size) # tensorflow requests # snapshot_request.add(raw, input_size) # input_size request for positioning # snapshot_request.add(embedding, output_size, voxel_size=voxel_size) # snapshot_request.add(fg, output_size, voxel_size=voxel_size) # snapshot_request.add(gt_fg, output_size, voxel_size=voxel_size) # snapshot_request.add(fg_pred, output_size, voxel_size=voxel_size) # snapshot_request.add(maxima, output_size, voxel_size=voxel_size) # snapshot_request.add(gradient_embedding, output_size, voxel_size=voxel_size) # snapshot_request.add(gradient_fg, output_size, voxel_size=voxel_size) # snapshot_request[emst] = gp.ArraySpec() # snapshot_request[edges_u] = gp.ArraySpec() # snapshot_request[edges_v] = gp.ArraySpec() # snapshot_request[ratio_pos] = gp.ArraySpec() # snapshot_request[ratio_neg] = gp.ArraySpec() # snapshot_request[dist] = gp.ArraySpec() # snapshot_request[num_pos_pairs] = gp.ArraySpec() # snapshot_request[num_neg_pairs] = gp.ArraySpec() data_sources = tuple( ( gp.N5Source( filename=str((sample / "fluorescence-near-consensus.n5").absolute()), datasets={raw: "volume"}, array_specs={ raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size, dtype=np.uint16) }, ), gp.DaisyGraphProvider( f"mouselight-{sample.name}-consensus", mongo_url, points=[consensus], directed=True, node_attrs=[], edge_attrs=[], ), gp.DaisyGraphProvider( f"mouselight-{sample.name}-skeletonization", mongo_url, points=[skeletonization], directed=False, node_attrs=[], edge_attrs=[], ), ) + gp.MergeProvider() + gp.RandomLocation( ensure_nonempty=consensus, ensure_centered=True, point_balance_radius=point_balance_radius * micron_scale, ) + TopologicalMatcher( skeletonization, consensus, matched, failures=Path("matching_failures_slow"), match_distance_threshold=match_distance_threshold * micron_scale, max_gap_crossing=gap_crossing_dist * micron_scale, try_complete=False, use_gurobi=True, ) + RejectIfEmpty(matched) + RasterizeSkeleton( points=matched, array=labels, array_spec=gp.ArraySpec( interpolatable=False, voxel_size=voxel_size, dtype=np.uint32), ) + GrowLabels(labels, radii=[neuron_radius * micron_scale]) # TODO: Do these need to be scaled by world units? + gp.ElasticAugment( [40, 10, 10], [0.25, 1, 1], [0, math.pi / 2.0], subsample=4, use_fast_points_transform=True, recompute_missing_points=False, ) # + gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2]) + gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001) for sample in samples_path.iterdir() if sample.name in ("2018-07-02", "2018-08-01")) pipeline = ( data_sources + gp.RandomProvider() + Crop(labels, labels_fg) + BinarizeGt(labels_fg, labels_fg_bin) + gp.BalanceLabels(labels_fg_bin, loss_weights) + gp.PreCache(cache_size=cache_size, num_workers=num_workers) + gp.tensorflow.Train( "train_net", optimizer=create_custom_loss(mknet_tensor_names, setup_config), loss=None, inputs={ mknet_tensor_names["loss_weights"]: loss_weights, mknet_tensor_names["raw"]: raw, mknet_tensor_names["gt_labels"]: labels_fg, }, outputs={ mknet_tensor_names["embedding"]: embedding, mknet_tensor_names["fg"]: fg, loss_tensor_names["fg_pred"]: fg_pred, loss_tensor_names["maxima"]: maxima, loss_tensor_names["gt_fg"]: gt_fg, loss_tensor_names["emst"]: emst, loss_tensor_names["edges_u"]: edges_u, loss_tensor_names["edges_v"]: edges_v, loss_tensor_names["ratio_pos"]: ratio_pos, loss_tensor_names["ratio_neg"]: ratio_neg, loss_tensor_names["dist"]: dist, loss_tensor_names["num_pos_pairs"]: num_pos_pairs, loss_tensor_names["num_neg_pairs"]: num_neg_pairs, }, gradients={ mknet_tensor_names["embedding"]: gradient_embedding, mknet_tensor_names["fg"]: gradient_fg, }, save_every=checkpoint_every, summary="Merge/MergeSummary:0", log_dir="tensorflow_logs", ) + gp.PrintProfilingStats(every=profile_every) + gp.Snapshot( additional_request=snapshot_request, output_filename="snapshot_{}_{}.hdf".format( int(np.min(seperate_distance)), "{id}"), dataset_names={ # raw data raw: "volumes/raw", # labeled data labels: "volumes/labels", # trees skeletonization: "points/skeletonization", consensus: "points/consensus", matched: "points/matched", # output volumes embedding: "volumes/embedding", fg: "volumes/fg", maxima: "volumes/maxima", gt_fg: "volumes/gt_fg", fg_pred: "volumes/fg_pred", gradient_embedding: "volumes/gradient_embedding", gradient_fg: "volumes/gradient_fg", # output trees emst: "emst", edges_u: "edges_u", edges_v: "edges_v", # output debug data ratio_pos: "ratio_pos", ratio_neg: "ratio_neg", dist: "dist", num_pos_pairs: "num_pos_pairs", num_neg_pairs: "num_neg_pairs", loss_weights: "volumes/loss_weights", }, every=snapshot_every, )) with gp.build(pipeline): for _ in range(num_iterations): pipeline.request_batch(request)
def build_pipeline(parameter, augment=True): voxel_size = gp.Coordinate(parameter['voxel_size']) # Array Specifications. raw = gp.ArrayKey('RAW') gt_neurons = gp.ArrayKey('GT_NEURONS') gt_postpre_vectors = gp.ArrayKey('GT_POSTPRE_VECTORS') gt_post_indicator = gp.ArrayKey('GT_POST_INDICATOR') post_loss_weight = gp.ArrayKey('POST_LOSS_WEIGHT') vectors_mask = gp.ArrayKey('VECTORS_MASK') pred_postpre_vectors = gp.ArrayKey('PRED_POSTPRE_VECTORS') pred_post_indicator = gp.ArrayKey('PRED_POST_INDICATOR') grad_syn_indicator = gp.ArrayKey('GRAD_SYN_INDICATOR') grad_partner_vectors = gp.ArrayKey('GRAD_PARTNER_VECTORS') # Points specifications dummypostsyn = gp.PointsKey('DUMMYPOSTSYN') postsyn = gp.PointsKey('POSTSYN') presyn = gp.PointsKey('PRESYN') trg_context = 140 # AddPartnerVectorMap context in nm - pre-post distance with open('train_net_config.json', 'r') as f: net_config = json.load(f) input_size = gp.Coordinate(net_config['input_shape']) * voxel_size output_size = gp.Coordinate(net_config['output_shape']) * voxel_size request = gp.BatchRequest() request.add(raw, input_size) request.add(gt_neurons, output_size) request.add(gt_postpre_vectors, output_size) request.add(gt_post_indicator, output_size) request.add(post_loss_weight, output_size) request.add(vectors_mask, output_size) request.add(dummypostsyn, output_size) for (key, request_spec) in request.items(): print(key) print(request_spec.roi) request_spec.roi.contains(request_spec.roi) # slkfdms snapshot_request = gp.BatchRequest({ pred_post_indicator: request[gt_postpre_vectors], pred_postpre_vectors: request[gt_postpre_vectors], grad_syn_indicator: request[gt_postpre_vectors], grad_partner_vectors: request[gt_postpre_vectors], vectors_mask: request[gt_postpre_vectors] }) postsyn_rastersetting = gp.RasterizationSettings( radius=parameter['blob_radius'], mask=gt_neurons, mode=parameter['blob_mode']) pipeline = tuple([ create_source(sample, raw, presyn, postsyn, dummypostsyn, parameter, gt_neurons) for sample in samples ]) pipeline += gp.RandomProvider() if augment: pipeline += gp.ElasticAugment([4, 40, 40], [0, 2, 2], [0, math.pi / 2.0], prob_slip=0.05, prob_shift=0.05, max_misalign=10, subsample=8) pipeline += gp.SimpleAugment(transpose_only=[1, 2], mirror_only=[1, 2]) pipeline += gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) pipeline += gp.IntensityScaleShift(raw, 2, -1) pipeline += gp.RasterizePoints( postsyn, gt_post_indicator, gp.ArraySpec(voxel_size=voxel_size, dtype=np.int32), postsyn_rastersetting) spec = gp.ArraySpec(voxel_size=voxel_size) pipeline += AddPartnerVectorMap( src_points=postsyn, trg_points=presyn, array=gt_postpre_vectors, radius=parameter['d_blob_radius'], trg_context=trg_context, # enlarge array_spec=spec, mask=gt_neurons, pointmask=vectors_mask) pipeline += gp.BalanceLabels(labels=gt_post_indicator, scales=post_loss_weight, slab=(-1, -1, -1), clipmin=parameter['cliprange'][0], clipmax=parameter['cliprange'][1]) if parameter['d_scale'] != 1: pipeline += gp.IntensityScaleShift(gt_postpre_vectors, scale=parameter['d_scale'], shift=0) pipeline += gp.PreCache(cache_size=40, num_workers=10) pipeline += gp.tensorflow.Train( './train_net', optimizer=net_config['optimizer'], loss=net_config['loss'], summary=net_config['summary'], log_dir='./tensorboard/', save_every=30000, # 10000 log_every=100, inputs={ net_config['raw']: raw, net_config['gt_partner_vectors']: gt_postpre_vectors, net_config['gt_syn_indicator']: gt_post_indicator, net_config['vectors_mask']: vectors_mask, # Loss weights --> mask net_config['indicator_weight']: post_loss_weight, # Loss weights }, outputs={ net_config['pred_partner_vectors']: pred_postpre_vectors, net_config['pred_syn_indicator']: pred_post_indicator, }, gradients={ net_config['pred_partner_vectors']: grad_partner_vectors, net_config['pred_syn_indicator']: grad_syn_indicator, }, ) # Visualize. pipeline += gp.IntensityScaleShift(raw, 0.5, 0.5) pipeline += gp.Snapshot( { raw: 'volumes/raw', gt_neurons: 'volumes/labels/neuron_ids', gt_post_indicator: 'volumes/gt_post_indicator', gt_postpre_vectors: 'volumes/gt_postpre_vectors', pred_postpre_vectors: 'volumes/pred_postpre_vectors', pred_post_indicator: 'volumes/pred_post_indicator', post_loss_weight: 'volumes/post_loss_weight', grad_syn_indicator: 'volumes/post_indicator_gradients', grad_partner_vectors: 'volumes/partner_vectors_gradients', vectors_mask: 'volumes/vectors_mask' }, every=1000, output_filename='batch_{iteration}.hdf', compression_type='gzip', additional_request=snapshot_request) pipeline += gp.PrintProfilingStats(every=100) print("Starting training...") max_iteration = parameter['max_iteration'] with gp.build(pipeline) as b: for i in range(max_iteration): b.request_batch(request)
masked_base=masked_base_b, masked_add=masked_add_b, soft_mask=softmask_b, mask_maxed=mask_maxed_b, ) + gp.PrintProfilingStats(every=1) + gp.Snapshot( output_filename="snapshot_debug_scale_comparison_{}_{}.hdf".format( int(np.mean(SEPERATE_DISTANCE)), "{iteration}"), dataset_names={ raw_fused: "volumes/raw_fused", raw_base: "volumes/raw_base", raw_add: "volumes/raw_add", labels_fused: "volumes/labels_fused", labels_base: "volumes/labels_base", labels_add: "volumes/labels_add", raw_fused_b: "volumes/raw_fused_b", labels_fused_b: "volumes/labels_fused_b", masked_base: "volumes/masked_base", masked_base_b: "volumes/masked_base_b", masked_add: "volumes/masked_add", masked_add_b: "volumes/masked_add_b", softmask: "volumes/softmask", softmask_b: "volumes/softmask_b", mask_maxed: "volumes/mask_maxed", mask_maxed_b: "volumes/mask_maxed_b", }, every=1, )) with build(pipeline): for i in range(1): request = BatchRequest(random_seed=i)
def train_until(**kwargs): if tf.train.latest_checkpoint(kwargs['output_folder']): trained_until = int( tf.train.latest_checkpoint(kwargs['output_folder']).split('_')[-1]) else: trained_until = 0 if trained_until >= kwargs['max_iteration']: return anchor = gp.ArrayKey('ANCHOR') raw = gp.ArrayKey('RAW') raw_cropped = gp.ArrayKey('RAW_CROPPED') gt_threeclass = gp.ArrayKey('GT_THREECLASS') loss_weights_threeclass = gp.ArrayKey('LOSS_WEIGHTS_THREECLASS') pred_threeclass = gp.ArrayKey('PRED_THREECLASS') pred_threeclass_gradients = gp.ArrayKey('PRED_THREECLASS_GRADIENTS') with open( os.path.join(kwargs['output_folder'], kwargs['name'] + '_config.json'), 'r') as f: net_config = json.load(f) with open( os.path.join(kwargs['output_folder'], kwargs['name'] + '_names.json'), 'r') as f: net_names = json.load(f) voxel_size = gp.Coordinate(kwargs['voxel_size']) input_shape_world = gp.Coordinate(net_config['input_shape']) * voxel_size output_shape_world = gp.Coordinate(net_config['output_shape']) * voxel_size # formulate the request for what a batch should (at least) contain request = gp.BatchRequest() request.add(raw, input_shape_world) request.add(raw_cropped, output_shape_world) request.add(gt_threeclass, output_shape_world) request.add(anchor, output_shape_world) request.add(loss_weights_threeclass, output_shape_world) # when we make a snapshot for inspection (see below), we also want to # request the predicted affinities and gradients of the loss wrt the # affinities snapshot_request = gp.BatchRequest() snapshot_request.add(raw_cropped, output_shape_world) snapshot_request.add(gt_threeclass, output_shape_world) snapshot_request.add(pred_threeclass, output_shape_world) # snapshot_request.add(pred_threeclass_gradients, output_shape_world) if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr": raise NotImplementedError("train node for {} not implemented".format( kwargs['input_format'])) fls = [] shapes = [] for f in kwargs['data_files']: fls.append(os.path.splitext(f)[0]) if kwargs['input_format'] == "hdf": vol = h5py.File(f, 'r')['volumes/raw'] elif kwargs['input_format'] == "zarr": vol = zarr.open(f, 'r')['volumes/raw'] print(f, vol.shape, vol.dtype) shapes.append(vol.shape) if vol.dtype != np.float32: print("please convert to float32") ln = len(fls) print("first 5 files: ", fls[0:4]) # padR = 46 # padGT = 32 if kwargs['input_format'] == "hdf": sourceNode = gp.Hdf5Source elif kwargs['input_format'] == "zarr": sourceNode = gp.ZarrSource augmentation = kwargs['augmentation'] pipeline = ( tuple( # read batches from the HDF5 file sourceNode( fls[t] + "." + kwargs['input_format'], datasets={ raw: 'volumes/raw', gt_threeclass: 'volumes/gt_threeclass', anchor: 'volumes/gt_threeclass', }, array_specs={ raw: gp.ArraySpec(interpolatable=True), gt_threeclass: gp.ArraySpec(interpolatable=False), anchor: gp.ArraySpec(interpolatable=False) } ) + gp.MergeProvider() + gp.Pad(raw, None) + gp.Pad(gt_threeclass, None) + gp.Pad(anchor, gp.Coordinate((2,2,2))) # chose a random location for each requested batch + gp.RandomLocation() for t in range(ln) ) + # chose a random source (i.e., sample) from the above gp.RandomProvider() + # elastically deform the batch (gp.ElasticAugment( augmentation['elastic']['control_point_spacing'], augmentation['elastic']['jitter_sigma'], [augmentation['elastic']['rotation_min']*np.pi/180.0, augmentation['elastic']['rotation_max']*np.pi/180.0], subsample=augmentation['elastic'].get('subsample', 1)) \ if augmentation.get('elastic') is not None else NoOp()) + # apply transpose and mirror augmentations gp.SimpleAugment(mirror_only=augmentation['simple'].get("mirror"), transpose_only=augmentation['simple'].get("transpose")) + # # scale and shift the intensity of the raw array gp.IntensityAugment( raw, scale_min=augmentation['intensity']['scale'][0], scale_max=augmentation['intensity']['scale'][1], shift_min=augmentation['intensity']['shift'][0], shift_max=augmentation['intensity']['shift'][1], z_section_wise=False) + # grow a boundary between labels # TODO: check # gp.GrowBoundary( # gt_threeclass, # steps=1, # only_xy=False) + gp.BalanceLabels( gt_threeclass, loss_weights_threeclass, num_classes=3) + # pre-cache batches from the point upstream gp.PreCache( cache_size=kwargs['cache_size'], num_workers=kwargs['num_workers']) + # perform one training iteration for each passing batch (here we use # the tensor names earlier stored in train_net.config) gp.tensorflow.Train( os.path.join(kwargs['output_folder'], kwargs['name']), optimizer=net_names['optimizer'], summary=net_names['summaries'], log_dir=kwargs['output_folder'], loss=net_names['loss'], inputs={ net_names['raw']: raw, net_names['anchor']: anchor, net_names['gt_threeclass']: gt_threeclass, net_names['loss_weights_threeclass']: loss_weights_threeclass }, outputs={ net_names['pred_threeclass']: pred_threeclass, net_names['raw_cropped']: raw_cropped, }, gradients={ net_names['pred_threeclass']: pred_threeclass_gradients, }, save_every=kwargs['checkpoints']) + # save the passing batch as an HDF5 file for inspection gp.Snapshot( { raw: '/volumes/raw', raw_cropped: 'volumes/raw_cropped', gt_threeclass: '/volumes/gt_threeclass', pred_threeclass: '/volumes/pred_threeclass', }, output_dir=os.path.join(kwargs['output_folder'], 'snapshots'), output_filename='batch_{iteration}.hdf', every=kwargs['snapshots'], additional_request=snapshot_request, compression_type='gzip') + # show a summary of time spend in each node every 10 iterations gp.PrintProfilingStats(every=kwargs['profiling']) ) ######### # TRAIN # ######### print("Starting training...") with gp.build(pipeline): print(pipeline) for i in range(trained_until, kwargs['max_iteration']): # print("request", request) start = time.time() pipeline.request_batch(request) time_of_iteration = time.time() - start logger.info("Batch: iteration=%d, time=%f", i, time_of_iteration) # exit() print("Training finished")
def create_pipeline_2d(task, predictor, optimizer, batch_size, outdir, snapshot_every): raw_channels = task.data.raw.num_channels filename = task.data.raw.train.filename input_shape = predictor.input_shape output_shape = predictor.output_shape dataset_shape = task.data.raw.train.shape dataset_roi = task.data.raw.train.roi voxel_size = task.data.raw.train.voxel_size # switch to world units input_size = voxel_size * input_shape output_size = voxel_size * output_shape raw = gp.ArrayKey('RAW') gt = gp.ArrayKey('GT') target = gp.ArrayKey('TARGET') weights = gp.ArrayKey('WEIGHTS') prediction = gp.ArrayKey('PREDICTION') channel_dims = 0 if raw_channels == 1 else 1 data_dims = len(dataset_shape) - channel_dims if data_dims == 3: num_samples = dataset_shape[0] sample_shape = dataset_shape[channel_dims + 1:] else: raise RuntimeError("For 2D training, please provide a 3D array where " "the first dimension indexes the samples.") sample_shape = gp.Coordinate(sample_shape) sample_size = sample_shape * voxel_size # overwrite source ROI to treat samples as z dimension spec = gp.ArraySpec(roi=gp.Roi((0, ) + dataset_roi.get_begin(), (num_samples, ) + sample_size), voxel_size=(1, ) + voxel_size) sources = (task.data.raw.train.get_source(raw, overwrite_spec=spec), task.data.gt.train.get_source(gt, overwrite_spec=spec)) pipeline = sources + gp.MergeProvider() pipeline += gp.Pad(raw, None) pipeline += gp.Normalize(raw) # raw: ([c,] d=1, h, w) # gt: ([c,] d=1, h, w) pipeline += gp.RandomLocation() # raw: ([c,] d=1, h, w) # gt: ([c,] d=1, h, w) for augmentation in eval(task.augmentations): pipeline += augmentation pipeline += predictor.add_target(gt, target) # (don't care about gt anymore) # raw: ([c,] d=1, h, w) # target: ([c,] d=1, h, w) weights_node = task.loss.add_weights(target, weights) if weights_node: pipeline += weights_node loss_inputs = {0: prediction, 1: target, 2: weights} else: loss_inputs = {0: prediction, 1: target} # raw: ([c,] d=1, h, w) # target: ([c,] d=1, h, w) # [weights: ([c,] d=1, h, w)] # get rid of z dim: pipeline += Squash(dim=-3) # raw: ([c,] h, w) # target: ([c,] h, w) # [weights: ([c,] h, w)] if channel_dims == 0: pipeline += AddChannelDim(raw) # raw: (c, h, w) # target: ([c,] h, w) # [weights: ([c,] h, w)] pipeline += gp.PreCache() pipeline += gp.Stack(batch_size) # raw: (b, c, h, w) # target: (b, [c,] h, w) # [weights: (b, [c,] h, w)] pipeline += gp_torch.Train(model=predictor, loss=task.loss, optimizer=optimizer, inputs={'x': raw}, loss_inputs=loss_inputs, outputs={0: prediction}, save_every=1e6) # raw: (b, c, h, w) # target: (b, [c,] h, w) # [weights: (b, [c,] h, w)] # prediction: (b, [c,] h, w) if snapshot_every > 0: # get channels first pipeline += TransposeDims(raw, (1, 0, 2, 3)) if predictor.target_channels > 0: pipeline += TransposeDims(target, (1, 0, 2, 3)) if weights_node: pipeline += TransposeDims(weights, (1, 0, 2, 3)) if predictor.prediction_channels > 0: pipeline += TransposeDims(prediction, (1, 0, 2, 3)) # raw: (c, b, h, w) # target: ([c,] b, h, w) # [weights: ([c,] b, h, w)] # prediction: ([c,] b, h, w) if channel_dims == 0: pipeline += RemoveChannelDim(raw) # raw: ([c,] b, h, w) # target: ([c,] b, h, w) # [weights: ([c,] b, h, w)] # prediction: ([c,] b, h, w) pipeline += gp.Snapshot(dataset_names={ raw: 'raw', target: 'target', prediction: 'prediction', weights: 'weights' }, every=snapshot_every, output_dir=os.path.join(outdir, 'snapshots'), output_filename="{iteration}.hdf") pipeline += gp.PrintProfilingStats(every=100) request = gp.BatchRequest() request.add(raw, input_size) request.add(gt, output_size) request.add(target, output_size) if weights_node: request.add(weights, output_size) request.add(prediction, output_size) return pipeline, request
def train(n_iterations): raw = gp.ArrayKey("RAW") gt = gp.ArrayKey("GT") gt_fg = gp.ArrayKey("GT_FP") embedding = gp.ArrayKey("EMBEDDING") fg = gp.ArrayKey("FG") maxima = gp.ArrayKey("MAXIMA") gradient_embedding = gp.ArrayKey("GRADIENT_EMBEDDING") gradient_fg = gp.ArrayKey("GRADIENT_FG") emst = gp.ArrayKey("EMST") edges_u = gp.ArrayKey("EDGES_U") edges_v = gp.ArrayKey("EDGES_V") request = gp.BatchRequest() request.add(raw, (200, 200)) request.add(gt, (160, 160)) snapshot_request = gp.BatchRequest() snapshot_request[embedding] = request[gt] snapshot_request[fg] = request[gt] snapshot_request[gt_fg] = request[gt] snapshot_request[maxima] = request[gt] snapshot_request[gradient_embedding] = request[gt] snapshot_request[gradient_fg] = request[gt] snapshot_request[emst] = gp.ArraySpec() snapshot_request[edges_u] = gp.ArraySpec() snapshot_request[edges_v] = gp.ArraySpec() pipeline = (Synthetic2DSource(raw, gt) + gp.Normalize(raw) + gp.tensorflow.Train( "train_net", optimizer=add_loss, loss=None, inputs={ tensor_names["raw"]: raw, tensor_names["gt_labels"]: gt }, outputs={ tensor_names["embedding"]: embedding, tensor_names["fg"]: fg, "maxima:0": maxima, "gt_fg:0": gt_fg, emst_name: emst, edges_u_name: edges_u, edges_v_name: edges_v, }, gradients={ tensor_names["embedding"]: gradient_embedding, tensor_names["fg"]: gradient_fg, }, ) + gp.Snapshot( output_filename="{iteration}.hdf", dataset_names={ raw: "volumes/raw", gt: "volumes/gt", embedding: "volumes/embedding", fg: "volumes/fg", maxima: "volumes/maxima", gt_fg: "volumes/gt_fg", gradient_embedding: "volumes/gradient_embedding", gradient_fg: "volumes/gradient_fg", emst: "emst", edges_u: "edges_u", edges_v: "edges_v", }, dataset_dtypes={ maxima: np.float32, gt_fg: np.float32 }, every=100, additional_request=snapshot_request, )) with gp.build(pipeline): for i in range(n_iterations): pipeline.request_batch(request)
def train_until(max_iteration): in_channels = 1 num_fmaps = 12 fmap_inc_factors = 6 downsample_factors = [(1, 3, 3), (1, 3, 3), (3, 3, 3)] unet = UNet(in_channels, num_fmaps, fmap_inc_factors, downsample_factors, constant_upsample=True) model = Convolve(unet, 12, 1) loss = torch.nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=1e-6) # start of gunpowder part: raw = gp.ArrayKey('RAW') points = gp.GraphKey('POINTS') groundtruth = gp.ArrayKey('RASTER') prediction = gp.ArrayKey('PRED_POINT') grad = gp.ArrayKey('GRADIENT') voxel_size = gp.Coordinate((40, 4, 4)) input_shape = (96, 430, 430) output_shape = (60, 162, 162) input_size = gp.Coordinate(input_shape) * voxel_size output_size = gp.Coordinate(output_shape) * voxel_size request = gp.BatchRequest() request.add(raw, input_size) request.add(points, output_size) request.add(groundtruth, output_size) request.add(prediction, output_size) request.add(grad, output_size) pos_sources = tuple( gp.ZarrSource(filename, {raw: 'volumes/raw'}, {raw: gp.ArraySpec(interpolatable=True)}) + AddCenterPoint(points, raw) + gp.Pad(raw, None) + gp.RandomLocation(ensure_nonempty=points) for filename in pos_samples) + gp.RandomProvider() neg_sources = tuple( gp.ZarrSource(filename, {raw: 'volumes/raw'}, {raw: gp.ArraySpec(interpolatable=True)}) + AddNoPoint(points, raw) + gp.RandomLocation() for filename in neg_samples) + gp.RandomProvider() data_sources = (pos_sources, neg_sources) data_sources += gp.RandomProvider(probabilities=[0.9, 0.1]) data_sources += gp.Normalize(raw) train_pipeline = data_sources train_pipeline += gp.ElasticAugment(control_point_spacing=[4, 40, 40], jitter_sigma=[0, 2, 2], rotation_interval=[0, math.pi / 2.0], prob_slip=0.05, prob_shift=0.05, max_misalign=10, subsample=8) train_pipeline += gp.SimpleAugment(transpose_only=[1, 2]) train_pipeline += gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1, \ z_section_wise=True) train_pipeline += gp.RasterizePoints( points, groundtruth, array_spec=gp.ArraySpec(voxel_size=voxel_size), settings=gp.RasterizationSettings(radius=(100, 100, 100), mode='peak')) train_pipeline += gp.PreCache(cache_size=40, num_workers=10) train_pipeline += Reshape(raw, (1, 1) + input_shape) train_pipeline += Reshape(groundtruth, (1, 1) + output_shape) train_pipeline += gp_torch.Train(model=model, loss=loss, optimizer=optimizer, inputs={'x': raw}, outputs={0: prediction}, loss_inputs={ 0: prediction, 1: groundtruth }, gradients={0: grad}, save_every=1000, log_dir='log') train_pipeline += Reshape(raw, input_shape) train_pipeline += Reshape(groundtruth, output_shape) train_pipeline += Reshape(prediction, output_shape) train_pipeline += Reshape(grad, output_shape) train_pipeline += gp.Snapshot( { raw: 'volumes/raw', groundtruth: 'volumes/groundtruth', prediction: 'volumes/prediction', grad: 'volumes/gradient' }, every=500, output_filename='test_{iteration}.hdf') train_pipeline += gp.PrintProfilingStats(every=10) with gp.build(train_pipeline): for i in range(max_iteration): train_pipeline.request_batch(request)
def train_until(max_iteration): # get the latest checkpoint if tf.train.latest_checkpoint("."): trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1]) else: trained_until = 0 if trained_until >= max_iteration: return # array keys for data sources raw = gp.ArrayKey("RAW") swcs = gp.PointsKey("SWCS") labels = gp.ArrayKey("LABELS") # array keys for base volume raw_base = gp.ArrayKey("RAW_BASE") labels_base = gp.ArrayKey("LABELS_BASE") swc_base = gp.PointsKey("SWC_BASE") # array keys for add volume raw_add = gp.ArrayKey("RAW_ADD") labels_add = gp.ArrayKey("LABELS_ADD") swc_add = gp.PointsKey("SWC_ADD") # array keys for fused volume raw_fused = gp.ArrayKey("RAW_FUSED") labels_fused = gp.ArrayKey("LABELS_FUSED") swc_fused = gp.PointsKey("SWC_FUSED") # output data fg = gp.ArrayKey("FG") labels_fg = gp.ArrayKey("LABELS_FG") labels_fg_bin = gp.ArrayKey("LABELS_FG_BIN") gradient_fg = gp.ArrayKey("GRADIENT_FG") loss_weights = gp.ArrayKey("LOSS_WEIGHTS") voxel_size = gp.Coordinate((10, 3, 3)) input_size = gp.Coordinate(net_config["input_shape"]) * voxel_size output_size = gp.Coordinate(net_config["output_shape"]) * voxel_size # add request request = gp.BatchRequest() request.add(raw_fused, input_size) request.add(labels_fused, input_size) request.add(swc_fused, input_size) request.add(labels_fg, output_size) request.add(labels_fg_bin, output_size) request.add(loss_weights, output_size) # add snapshot request # request.add(fg, output_size) # request.add(labels_fg, output_size) request.add(gradient_fg, output_size) request.add(raw_base, input_size) request.add(raw_add, input_size) request.add(labels_base, input_size) request.add(labels_add, input_size) request.add(swc_base, input_size) request.add(swc_add, input_size) data_sources = tuple( ( gp.N5Source( filename=str( ( filename / "consensus-neurons-with-machine-centerpoints-labelled-as-swcs-carved.n5" ).absolute() ), datasets={raw: "volume"}, array_specs={ raw: gp.ArraySpec( interpolatable=True, voxel_size=voxel_size, dtype=np.uint16 ) }, ), MouselightSwcFileSource( filename=str( ( filename / "consensus-neurons-with-machine-centerpoints-labelled-as-swcs" ).absolute() ), points=(swcs,), scale=voxel_size, transpose=(2, 1, 0), transform_file=str((filename / "transform.txt").absolute()), ignore_human_nodes=True ), ) + gp.MergeProvider() + gp.RandomLocation( ensure_nonempty=swcs, ensure_centered=True ) + RasterizeSkeleton( points=swcs, array=labels, array_spec=gp.ArraySpec( interpolatable=False, voxel_size=voxel_size, dtype=np.uint32 ), ) + GrowLabels(labels, radius=10) # augment + gp.ElasticAugment( [40, 10, 10], [0.25, 1, 1], [0, math.pi / 2.0], subsample=4, ) + gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2]) + gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001) for filename in Path(sample_dir).iterdir() if "2018-08-01" in filename.name # or "2018-07-02" in filename.name ) pipeline = ( data_sources + gp.RandomProvider() + GetNeuronPair( swcs, raw, labels, (swc_base, swc_add), (raw_base, raw_add), (labels_base, labels_add), seperate_by=150, shift_attempts=50, request_attempts=10, ) + FusionAugment( raw_base, raw_add, labels_base, labels_add, swc_base, swc_add, raw_fused, labels_fused, swc_fused, blend_mode="labels_mask", blend_smoothness=10, num_blended_objects=0, ) + Crop(labels_fused, labels_fg) + BinarizeGt(labels_fg, labels_fg_bin) + gp.BalanceLabels(labels_fg_bin, loss_weights) # train + gp.PreCache(cache_size=40, num_workers=10) + gp.tensorflow.Train( "./train_net", optimizer=net_names["optimizer"], loss=net_names["loss"], inputs={ net_names["raw"]: raw_fused, net_names["labels_fg"]: labels_fg_bin, net_names["loss_weights"]: loss_weights, }, outputs={net_names["fg"]: fg}, gradients={net_names["fg"]: gradient_fg}, save_every=100000, ) + gp.Snapshot( output_filename="snapshot_{iteration}.hdf", dataset_names={ raw_fused: "volumes/raw_fused", raw_base: "volumes/raw_base", raw_add: "volumes/raw_add", labels_fused: "volumes/labels_fused", labels_base: "volumes/labels_base", labels_add: "volumes/labels_add", labels_fg_bin: "volumes/labels_fg_bin", fg: "volumes/pred_fg", gradient_fg: "volumes/gradient_fg", }, every=100, ) + gp.PrintProfilingStats(every=10) ) with gp.build(pipeline): logging.info("Starting training...") for i in range(max_iteration - trained_until): logging.info("requesting batch {}".format(i)) batch = pipeline.request_batch(request) """
swc_add, raw_fused_b, labels_fused_b, swc_fused_b, blend_mode="labels_mask", blend_smoothness=10, num_blended_objects=0, scale_add_volume=False, ) + gp.Snapshot( output_filename="snapshot_scale_comparison_{}_{}.hdf".format( int(np.mean(SEPERATE_DISTANCE)), "{iteration}"), dataset_names={ raw_fused: "volumes/raw_fused", raw_base: "volumes/raw_base", raw_add: "volumes/raw_add", labels_fused: "volumes/labels_fused", labels_base: "volumes/labels_base", labels_add: "volumes/labels_add", raw_fused_b: "volumes/raw_fused_b", labels_fused_b: "volumes/labels_fused_b", }, every=1, )) with build(pipeline): for i in range(1): request = BatchRequest(random_seed=i) # add request request = gp.BatchRequest() request.add(raw_fused, input_size)
def create_pipeline_3d( task, data, predictor, optimizer, batch_size, outdir, snapshot_every ): raw_channels = max(1, data.raw.num_channels) input_shape = gp.Coordinate(task.model.input_shape) output_shape = gp.Coordinate(task.model.output_shape) voxel_size = data.raw.train.voxel_size task.predictor = task.predictor.to("cuda") # switch to world units input_size = voxel_size * input_shape output_size = voxel_size * output_shape raw = gp.ArrayKey("RAW") gt = gp.ArrayKey("GT") mask = gp.ArrayKey("MASK") target = gp.ArrayKey("TARGET") weights = gp.ArrayKey("WEIGHTS") model_outputs = gp.ArrayKey("MODEL_OUTPUTS") model_output_grads = gp.ArrayKey("MODEL_OUT_GRAD") prediction = gp.ArrayKey("PREDICTION") pred_gradients = gp.ArrayKey("PRED_GRADIENTS") snapshot_dataset_names = { raw: "raw", model_outputs: "model_outputs", model_output_grads: "model_out_grad", target: "target", prediction: "prediction", pred_gradients: "pred_gradients", weights: "weights", } aux_keys = {} aux_grad_keys = {} for name, _, _ in task.aux_tasks: aux_keys[name] = ( gp.ArrayKey(f"{name.upper()}_PREDICTION"), gp.ArrayKey(f"{name.upper()}_TARGET"), None, ) aux_grad_keys[name] = gp.ArrayKey(f"{name.upper()}_PRED_GRAD") aux_pred, aux_target, _ = aux_keys[name] snapshot_dataset_names[aux_pred] = f"{name}_pred" snapshot_dataset_names[aux_target] = f"{name}_target" aux_grad = aux_grad_keys[name] snapshot_dataset_names[aux_grad] = f"{name}_aux_grad" channel_dims = 0 if raw_channels == 1 else 1 num_samples = data.raw.train.num_samples assert num_samples == 0, "Multiple samples for 3D training not yet implemented" sources = (data.raw.train.get_source(raw), data.gt.train.get_source(gt)) pipeline = sources + gp.MergeProvider() pipeline += gp.Pad(raw, input_shape / 2 * voxel_size) # pipeline += gp.Pad(gt, input_shape / 2 * voxel_size) # raw: ([c,] d, h, w) # gt: ([c,] d, h, w) pipeline += gp.Normalize(raw) mask_node = task.loss.add_mask(gt, mask) if mask_node is not None: pipeline += mask_node pipeline += gp.RandomLocation(min_masked=1e-6, mask=mask) else: # raw: ([c,] d, h, w) # gt: ([c,] d, h, w) pipeline += gp.RandomLocation() # raw: ([c,] d, h, w) # gt: ([c,] d, h, w) for augmentation in eval(task.augmentations): pipeline += augmentation pipeline += predictor.add_target(gt, target) # (don't care about gt anymore) # raw: ([c,] d, h, w) # target: ([c,] d, h, w) weights_node = task.loss.add_weights(target, weights) loss_inputs = [] if weights_node: pipeline += weights_node loss_inputs.append({0: prediction, 1: target, 2: weights}) else: loss_inputs.append({0: prediction, 1: target}) head_outputs = [] head_gradients = [] for name, aux_predictor, aux_loss in task.aux_tasks: aux_prediction, aux_target, aux_weights = aux_keys[name] pipeline += aux_predictor.add_target(gt, aux_target) aux_weights_node = aux_loss.add_weights(aux_target, aux_weights) if aux_weights_node: aux_weights = gp.ArrayKey(f"{name.upper()}_WEIGHTS") aux_keys[name] = ( aux_prediction, aux_target, aux_weights, ) pipeline += aux_weights_node loss_inputs.append({0: aux_prediction, 1: aux_target, 2: aux_weights}) snapshot_dataset_names[aux_weights] = f"{name}_weights" else: loss_inputs.append({0: aux_prediction, 1: aux_target}) head_outputs.append({0: aux_prediction}) aux_pred_gradient = aux_grad_keys[name] head_gradients.append({0: aux_pred_gradient}) # raw: ([c,] d, h, w) # target: ([c,] d, h, w) # [weights: ([c,] d, h, w)] if channel_dims == 0: pipeline += AddChannelDim(raw) # raw: (c, d, h, w) # target: ([c,] d, h, w) # [weights: ([c,] d, h, w)] pipeline += gp.PreCache() pipeline += gp.Stack(batch_size) # raw: (b, c, d, h, w) # target: (b, [c,] d, h, w) # [weights: (b, [c,] d, h, w)] pipeline += Train( model=task.model, heads=[("opt", predictor)] + [(name, aux_pred) for name, aux_pred, _ in task.aux_tasks], losses=[task.loss] + [loss for _, _, loss in task.aux_tasks], optimizer=optimizer, inputs={"x": raw}, outputs={0: model_outputs}, head_outputs=[{0: prediction}] + head_outputs, loss_inputs=loss_inputs, gradients=[{0: model_output_grads}, {0: pred_gradients}] + head_gradients, save_every=1e6, ) # raw: (b, c, d, h, w) # target: (b, [c,] d, h, w) # [weights: (b, [c,] d, h, w)] # prediction: (b, [c,] d, h, w) if snapshot_every > 0: # get channels first pipeline += TransposeDims(raw, (1, 0, 2, 3, 4)) if predictor.target_channels > 0: pipeline += TransposeDims(target, (1, 0, 2, 3, 4)) if weights_node: pipeline += TransposeDims(weights, (1, 0, 2, 3, 4)) if predictor.prediction_channels > 0: pipeline += TransposeDims(prediction, (1, 0, 2, 3, 4)) # raw: (c, b, d, h, w) # target: ([c,] b, d, h, w) # [weights: ([c,] b, d, h, w)] # prediction: ([c,] b, d, h, w) if channel_dims == 0: pipeline += RemoveChannelDim(raw) # raw: ([c,] b, d, h, w) # target: (c, b, d, h, w) # [weights: ([c,] b, d, h, w)] # prediction: (c, b, d, h, w) pipeline += gp.Snapshot( dataset_names=snapshot_dataset_names, every=snapshot_every, output_dir=os.path.join(outdir, "snapshots"), output_filename="{iteration}.hdf", ) pipeline += gp.PrintProfilingStats(every=10) request = gp.BatchRequest() request.add(raw, input_size) request.add(gt, output_size) if mask_node is not None: request.add(mask, output_size) request.add(target, output_size) for name, _, _ in task.aux_tasks: aux_pred, aux_target, aux_weight = aux_keys[name] request.add(aux_pred, output_size) request.add(aux_target, output_size) if aux_weight is not None: request.add(aux_weight, output_size) aux_pred_grad = aux_grad_keys[name] request.add(aux_pred_grad, output_size) if weights_node: request.add(weights, output_size) request.add(prediction, output_size) request.add(pred_gradients, output_size) return pipeline, request
def validation_pipeline(config): """ Per block { Raw -> predict -> scan gt -> rasterize -> merge -> candidates -> trees } -> merge -> comatch + evaluate """ blocks = config["BLOCKS"] benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"]) sample = config["VALIDATION_SAMPLES"][0] sample_dir = Path(config["SAMPLES_PATH"]) raw_n5 = config["RAW_N5"] transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt" neuron_width = int(config["NEURON_RADIUS"]) voxel_size = gp.Coordinate(config["VOXEL_SIZE"]) micron_scale = max(voxel_size) input_shape = gp.Coordinate(config["INPUT_SHAPE"]) output_shape = gp.Coordinate(config["OUTPUT_SHAPE"]) input_size = voxel_size * input_shape output_size = voxel_size * output_shape distance_attr = config["DISTANCE_ATTR"] candidate_threshold = config["NMS_THRESHOLD"] candidate_spacing = min(config["NMS_WINDOW_SIZE"]) * micron_scale coordinate_scale = config["COORDINATE_SCALE"] * np.array( voxel_size) / micron_scale emb_model = get_emb_model(config) fg_model = get_fg_model(config) validation_pipelines = [] specs = {} for block in blocks: validation_dir = get_validation_dir(benchmark_datasets_path, block) trees = [] cube = None for gt_file in validation_dir.iterdir(): if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc": trees.append(gt_file) if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc": cube = gt_file assert cube.exists() cube_roi = get_roi_from_swc( cube, Path(transform_template.format(sample=sample)), np.array([300, 300, 1000]), ) raw = gp.ArrayKey(f"RAW_{block}") ground_truth = gp.GraphKey(f"GROUND_TRUTH_{block}") labels = gp.ArrayKey(f"LABELS_{block}") candidates = gp.ArrayKey(f"CANDIDATES_{block}") mst = gp.GraphKey(f"MST_{block}") raw_source = (gp.ZarrSource( filename=str(Path(sample_dir, sample, raw_n5).absolute()), datasets={raw: "volume-rechunked"}, array_specs={ raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size) }, ) + gp.Normalize(raw, dtype=np.float32) + mCLAHE([raw], [20, 64, 64])) emb_source, emb = add_emb_pred(config, raw_source, raw, block, emb_model) pred_source, fg = add_fg_pred(config, emb_source, raw, block, fg_model) pred_source = add_scan(pred_source, { raw: input_size, emb: output_size, fg: output_size }) swc_source = nl.gunpowder.nodes.MouselightSwcFileSource( validation_dir, [ground_truth], transform_file=transform_template.format(sample=sample), ignore_human_nodes=False, scale=voxel_size, transpose=[2, 1, 0], points_spec=[ gp.PointsSpec(roi=gp.Roi( gp.Coordinate([None, None, None]), gp.Coordinate([None, None, None]), )) ], ) additional_request = BatchRequest() input_roi = cube_roi.grow((input_size - output_size) // 2, (input_size - output_size) // 2) block_spec = specs.setdefault(block, {}) block_spec["raw"] = (raw, gp.ArraySpec(input_roi)) additional_request[raw] = gp.ArraySpec(roi=input_roi) block_spec["ground_truth"] = (ground_truth, gp.GraphSpec(cube_roi)) additional_request[ground_truth] = gp.GraphSpec(roi=cube_roi) block_spec["labels"] = (labels, gp.ArraySpec(cube_roi)) additional_request[labels] = gp.ArraySpec(roi=cube_roi) block_spec["fg_pred"] = (fg, gp.ArraySpec(cube_roi)) additional_request[fg] = gp.ArraySpec(roi=cube_roi) block_spec["emb_pred"] = (emb, gp.ArraySpec(cube_roi)) additional_request[emb] = gp.ArraySpec(roi=cube_roi) block_spec["candidates"] = (candidates, gp.ArraySpec(cube_roi)) additional_request[candidates] = gp.ArraySpec(roi=cube_roi) block_spec["mst_pred"] = (mst, gp.GraphSpec(cube_roi)) additional_request[mst] = gp.GraphSpec(roi=cube_roi) pipeline = ((swc_source, pred_source) + gp.nodes.MergeProvider() + nl.gunpowder.RasterizeSkeleton( ground_truth, labels, connected_component_labeling=True, array_spec=gp.ArraySpec( voxel_size=voxel_size, dtype=np.int64, roi=gp.Roi( gp.Coordinate([None, None, None]), gp.Coordinate([None, None, None]), ), ), ) + nl.gunpowder.GrowLabels( labels, radii=[neuron_width * micron_scale]) + Skeletonize(fg, candidates, candidate_spacing, candidate_threshold) + EMST( emb, candidates, mst, distance_attr=distance_attr, coordinate_scale=coordinate_scale, ) + gp.Snapshot( { raw: f"volumes/{raw}", ground_truth: f"points/{ground_truth}", labels: f"volumes/{labels}", fg: f"volumes/{fg}", emb: f"volumes/{emb}", candidates: f"volumes/{candidates}", mst: f"points/{mst}", }, additional_request=additional_request, output_dir="snapshots", output_filename="{id}.hdf", edge_attrs={mst: [distance_attr]}, )) validation_pipelines.append(pipeline) full_gt = gp.GraphKey("FULL_GT") full_mst = gp.GraphKey("FULL_MST") score = gp.ArrayKey("SCORE") validation_pipeline = ( tuple(pipeline for pipeline in validation_pipelines) + gp.MergeProvider() + MergeGraphs(specs, full_gt, full_mst) + Evaluate(full_gt, full_mst, score, edge_threshold_attr=distance_attr) + gp.PrintProfilingStats()) return validation_pipeline, score
def emb_validation_pipeline( config, snapshot_file, candidates_path, raw_path, gt_path, candidates_mst_path=None, candidates_mst_dense_path=None, path_stat="max", ): checkpoint = config["EMB_EVAL_CHECKPOINT"] blocks = config["BLOCKS"] benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"]) sample = config["VALIDATION_SAMPLES"][0] transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt" voxel_size = gp.Coordinate(config["VOXEL_SIZE"]) micron_scale = max(voxel_size) input_shape = gp.Coordinate(config["INPUT_SHAPE"]) output_shape = gp.Coordinate(config["OUTPUT_SHAPE"]) input_size = voxel_size * input_shape output_size = voxel_size * output_shape distance_attr = config["DISTANCE_ATTR"] coordinate_scale = config["COORDINATE_SCALE"] * np.array( voxel_size) / micron_scale num_thresholds = config["NUM_EVAL_THRESHOLDS"] threshold_range = config["EVAL_THRESHOLD_RANGE"] edge_threshold_0 = config["EVAL_EDGE_THRESHOLD_0"] component_threshold_0 = config["COMPONENT_THRESHOLD_0"] component_threshold_1 = config["COMPONENT_THRESHOLD_1"] clip_limit = config["CLAHE_CLIP_LIMIT"] normalize = config["CLAHE_NORMALIZE"] validation_pipelines = [] specs = {} emb_model = get_emb_model(config) emb_model.eval() for block in blocks: validation_dir = get_validation_dir(benchmark_datasets_path, block) trees = [] cube = None for gt_file in validation_dir.iterdir(): if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc": trees.append(gt_file) if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc": cube = gt_file assert cube.exists() cube_roi = get_roi_from_swc( cube, Path(transform_template.format(sample=sample)), np.array(voxel_size[::-1]), ) candidates_1 = gp.ArrayKey(f"CANDIDATES_1_{block}") raw = gp.ArrayKey(f"RAW_{block}") mst_0 = gp.GraphKey(f"MST_0_{block}") mst_dense_0 = gp.GraphKey(f"MST_DENSE_0_{block}") mst_1 = gp.GraphKey(f"MST_1_{block}") mst_dense_1 = gp.GraphKey(f"MST_DENSE_1_{block}") mst_2 = gp.GraphKey(f"MST_2_{block}") mst_dense_2 = gp.GraphKey(f"MST_DENSE_2_{block}") gt = gp.GraphKey(f"GT_{block}") score = gp.ArrayKey(f"SCORE_{block}") details = gp.GraphKey(f"DETAILS_{block}") optimal_mst = gp.GraphKey(f"OPTIMAL_MST_{block}") # Volume Source raw_source = SnapshotSource( snapshot_file, datasets={ raw: raw_path.format(block=block), candidates_1: candidates_path.format(block=block), }, ) # Graph Source graph_datasets = {gt: gt_path.format(block=block)} graph_directionality = {gt: False} edge_attrs = {} if candidates_mst_path is not None: graph_datasets[mst_0] = candidates_mst_path.format(block=block) graph_directionality[mst_0] = False edge_attrs[mst_0] = [distance_attr] if candidates_mst_dense_path is not None: graph_datasets[mst_dense_0] = candidates_mst_dense_path.format( block=block) graph_directionality[mst_dense_0] = False edge_attrs[mst_dense_0] = [distance_attr] gt_source = SnapshotSource( snapshot_file, datasets=graph_datasets, directed=graph_directionality, edge_attrs=edge_attrs, ) if config["EVAL_CLAHE"]: raw_source = raw_source + scipyCLAHE( [raw], gp.Coordinate([20, 64, 64]) * voxel_size, clip_limit=clip_limit, normalize=normalize, ) else: pass emb_source, emb, neighborhood = add_emb_pred(config, raw_source, raw, block, emb_model) reference_sizes = { raw: input_size, emb: output_size, candidates_1: output_size } if neighborhood is not None: reference_sizes[neighborhood] = output_size emb_source = add_scan(emb_source, reference_sizes) input_roi = cube_roi.grow((input_size - output_size) // 2, (input_size - output_size) // 2) cube_roi_shifted = gp.Roi((0, ) * len(cube_roi.get_shape()), cube_roi.get_shape()) input_roi = cube_roi_shifted.grow((input_size - output_size) // 2, (input_size - output_size) // 2) block_spec = specs.setdefault(block, {}) block_spec[raw] = gp.ArraySpec(input_roi) block_spec[candidates_1] = gp.ArraySpec(cube_roi_shifted) block_spec[emb] = gp.ArraySpec(cube_roi_shifted) if neighborhood is not None: block_spec[neighborhood] = gp.ArraySpec(cube_roi_shifted) block_spec[gt] = gp.GraphSpec(cube_roi_shifted, directed=False) block_spec[mst_0] = gp.GraphSpec(cube_roi_shifted, directed=False) block_spec[mst_dense_0] = gp.GraphSpec(cube_roi_shifted, directed=False) block_spec[mst_1] = gp.GraphSpec(cube_roi_shifted, directed=False) block_spec[mst_dense_1] = gp.GraphSpec(cube_roi_shifted, directed=False) block_spec[mst_2] = gp.GraphSpec(cube_roi_shifted, directed=False) # block_spec[mst_dense_2] = gp.GraphSpec(cube_roi_shifted, directed=False) block_spec[score] = gp.ArraySpec(nonspatial=True) block_spec[optimal_mst] = gp.GraphSpec(cube_roi_shifted, directed=False) additional_request = BatchRequest() additional_request[raw] = gp.ArraySpec(input_roi) additional_request[candidates_1] = gp.ArraySpec(cube_roi_shifted) additional_request[emb] = gp.ArraySpec(cube_roi_shifted) if neighborhood is not None: additional_request[neighborhood] = gp.ArraySpec(cube_roi_shifted) additional_request[gt] = gp.GraphSpec(cube_roi_shifted, directed=False) additional_request[mst_0] = gp.GraphSpec(cube_roi_shifted, directed=False) additional_request[mst_dense_0] = gp.GraphSpec(cube_roi_shifted, directed=False) additional_request[mst_1] = gp.GraphSpec(cube_roi_shifted, directed=False) additional_request[mst_dense_1] = gp.GraphSpec(cube_roi_shifted, directed=False) additional_request[mst_2] = gp.GraphSpec(cube_roi_shifted, directed=False) # additional_request[mst_dense_2] = gp.GraphSpec(cube_roi_shifted, directed=False) additional_request[details] = gp.GraphSpec(cube_roi_shifted, directed=False) additional_request[optimal_mst] = gp.GraphSpec(cube_roi_shifted, directed=False) pipeline = (emb_source, gt_source) + gp.MergeProvider() if candidates_mst_path is not None and candidates_mst_dense_path is not None: # mst_0 provided, just need to calculate distances. pass elif config["EVAL_MINIMAX_EMBEDDING_DIST"]: # No mst_0 provided, must first calculate mst_0 and dense mst_0 pipeline += MiniMaxEmbeddings( emb, candidates_1, decimated=mst_0, dense=mst_dense_0, distance_attr=distance_attr, ) else: # mst/mst_dense not provided. Simply use euclidean distance on candidates pipeline += EMST( emb, candidates_1, mst_0, distance_attr=distance_attr, coordinate_scale=coordinate_scale, ) pipeline += EMST( emb, candidates_1, mst_dense_0, distance_attr=distance_attr, coordinate_scale=coordinate_scale, ) pipeline += ThresholdEdges( (mst_0, mst_1), edge_threshold_0, component_threshold_0, msts_dense=(mst_dense_0, mst_dense_1), distance_attr=distance_attr, ) pipeline += ComponentWiseEMST( emb, mst_1, mst_2, distance_attr=distance_attr, coordinate_scale=coordinate_scale, ) # pipeline += ScoreEdges( # mst, mst_dense, emb, distance_attr=distance_attr, path_stat=path_stat # ) pipeline += Evaluate( gt, mst_2, score, roi=cube_roi_shifted, details=details, edge_threshold_attr=distance_attr, num_thresholds=num_thresholds, threshold_range=threshold_range, small_component_threshold=component_threshold_1, # connectivity=mst_1, output_graph=optimal_mst, ) if config["EVAL_SNAPSHOT"]: snapshot_datasets = { raw: f"volumes/raw", emb: f"volumes/embeddings", candidates_1: f"volumes/candidates_1", mst_0: f"points/mst_0", mst_dense_0: f"points/mst_dense_0", mst_1: f"points/mst_1", mst_dense_1: f"points/mst_dense_1", # mst_2: f"points/mst_2", gt: f"points/gt", details: f"points/details", optimal_mst: f"points/optimal_mst", } if neighborhood is not None: snapshot_datasets[neighborhood] = f"volumes/neighborhood" pipeline += gp.Snapshot( snapshot_datasets, output_dir=config["EVAL_SNAPSHOT_DIR"], output_filename=config["EVAL_SNAPSHOT_NAME"].format( checkpoint=checkpoint, block=block, coordinate_scale=",".join( [str(x) for x in coordinate_scale]), ), edge_attrs={ mst_0: [distance_attr], mst_dense_0: [distance_attr], mst_1: [distance_attr], mst_dense_1: [distance_attr], # mst_2: [distance_attr], # optimal_mst: [distance_attr], # it is unclear how to add distances if using connectivity graph # mst_dense_2: [distance_attr], details: ["details", "label_pair"], }, node_attrs={details: ["details", "label_pair"]}, additional_request=additional_request, ) validation_pipelines.append(pipeline) final_score = gp.ArrayKey("SCORE") validation_pipeline = (tuple(pipeline for pipeline in validation_pipelines) + gp.MergeProvider() + MergeScores(final_score, specs) + gp.PrintProfilingStats()) return validation_pipeline, final_score
def build_pipeline( data_dir, model, save_every, batch_size, input_size, output_size, raw, labels, affs, affs_predicted, lr=1e-5): dataset_shape = zarr.open(str(data_dir))['train/raw'].shape num_samples = dataset_shape[0] sample_size = dataset_shape[1:] loss = torch.nn.MSELoss() optimizer = RAdam(model.parameters(), lr=lr) pipeline = ( gp.ZarrSource( data_dir, { raw: 'train/raw', labels: 'train/gt' }, array_specs={ raw: gp.ArraySpec( roi=gp.Roi((0, 0, 0), (num_samples,) + sample_size), voxel_size=(1, 1, 1)), labels: gp.ArraySpec( roi=gp.Roi((0, 0, 0), (num_samples,) + sample_size), voxel_size=(1, 1, 1)) }) + # raw: (d=1, h, w) # labels: (d=1, fmap_inc_factors=5h, w) gp.RandomLocation() + # raw: (d=1, h, w) # labels: (d=1, h, w) gp.AddAffinities( affinity_neighborhood=[(0, 1, 0), (0, 0, 1)], labels=labels, affinities=affs) + gp.Normalize(affs, factor=1.0) + # raw: (d=1, h, w) # affs: (c=2, d=1, h, w) Squash(dim=-3) + # get rid of z dim # raw: (h, w) # affs: (c=2, h, w) AddChannelDim(raw) + # raw: (c=1, h, w) # affs: (c=2, h, w) gp.PreCache() + gp.Stack(batch_size) + # raw: (b=10, c=1, h, w) # affs: (b=10, c=2, h, w) Train( model=model, loss=loss, optimizer=optimizer, inputs={'x': raw}, target=affs, output=affs_predicted, save_every=save_every, log_dir='log') + # raw: (b=10, c=1, h, w) # affs: (b=10, c=2, h, w) # affs_predicted: (b=10, c=2, h, w) TransposeDims(raw,(1, 0, 2, 3)) + TransposeDims(affs,(1, 0, 2, 3)) + TransposeDims(affs_predicted,(1, 0, 2, 3)) + # raw: (c=1, b=10, h, w) # affs: (c=2, b=10, h, w) # affs_predicted: (c=2, b=10, h, w) RemoveChannelDim(raw) + # raw: (b=10, h, w) # affs: (c=2, b=10, h, w) # affs_predicted: (c=2, b=10, h, w) gp.Snapshot( dataset_names={ raw: 'raw', labels: 'labels', affs: 'affs', affs_predicted: 'affs_predicted' }, every=100) + gp.PrintProfilingStats(every=100) ) return pipeline
def train_until(max_iteration): # get the latest checkpoint if tf.train.latest_checkpoint('.'): trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1]) else: trained_until = 0 if trained_until >= max_iteration: return # array keys for fused volume raw = gp.ArrayKey('RAW') labels = gp.ArrayKey('LABELS') labels_fg = gp.ArrayKey('LABELS_FG') # array keys for base volume raw_base = gp.ArrayKey('RAW_BASE') labels_base = gp.ArrayKey('LABELS_BASE') swc_base = gp.PointsKey('SWC_BASE') swc_center_base = gp.PointsKey('SWC_CENTER_BASE') # array keys for add volume raw_add = gp.ArrayKey('RAW_ADD') labels_add = gp.ArrayKey('LABELS_ADD') swc_add = gp.PointsKey('SWC_ADD') swc_center_add = gp.PointsKey('SWC_CENTER_ADD') # output data fg = gp.ArrayKey('FG') gradient_fg = gp.ArrayKey('GRADIENT_FG') loss_weights = gp.ArrayKey('LOSS_WEIGHTS') voxel_size = gp.Coordinate((3, 3, 3)) input_size = gp.Coordinate(net_config['input_shape']) * voxel_size output_size = gp.Coordinate(net_config['output_shape']) * voxel_size # add request request = gp.BatchRequest() request.add(raw, input_size) request.add(labels, output_size) request.add(labels_fg, output_size) request.add(loss_weights, output_size) request.add(swc_center_base, output_size) request.add(swc_base, input_size) request.add(swc_center_add, output_size) request.add(swc_add, input_size) # add snapshot request snapshot_request = gp.BatchRequest() snapshot_request.add(fg, output_size) snapshot_request.add(labels_fg, output_size) snapshot_request.add(gradient_fg, output_size) snapshot_request.add(raw_base, input_size) snapshot_request.add(raw_add, input_size) snapshot_request.add(labels_base, input_size) snapshot_request.add(labels_add, input_size) # data source for "base" volume data_sources_base = tuple() data_sources_base += tuple( (gp.Hdf5Source(file, datasets={ raw_base: '/volume', }, array_specs={ raw_base: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size, dtype=np.uint16), }, channels_first=False), SwcSource(filename=file, dataset='/reconstruction', points=(swc_center_base, swc_base), scale=voxel_size)) + gp.MergeProvider() + gp.RandomLocation(ensure_nonempty=swc_center_base) + RasterizeSkeleton( points=swc_base, array=labels_base, array_spec=gp.ArraySpec( interpolatable=False, voxel_size=voxel_size, dtype=np.uint32), iteration=10) for file in files) data_sources_base += gp.RandomProvider() # data source for "add" volume data_sources_add = tuple() data_sources_add += tuple( (gp.Hdf5Source(file, datasets={ raw_add: '/volume', }, array_specs={ raw_add: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size, dtype=np.uint16), }, channels_first=False), SwcSource(filename=file, dataset='/reconstruction', points=(swc_center_add, swc_add), scale=voxel_size)) + gp.MergeProvider() + gp.RandomLocation(ensure_nonempty=swc_center_add) + RasterizeSkeleton( points=swc_add, array=labels_add, array_spec=gp.ArraySpec( interpolatable=False, voxel_size=voxel_size, dtype=np.uint32), iteration=1) for file in files) data_sources_add += gp.RandomProvider() data_sources = tuple([data_sources_base, data_sources_add ]) + gp.MergeProvider() pipeline = ( data_sources + FusionAugment(raw_base, raw_add, labels_base, labels_add, raw, labels, blend_mode='labels_mask', blend_smoothness=10, num_blended_objects=0) + # augment gp.ElasticAugment([10, 10, 10], [1, 1, 1], [0, math.pi / 2.0], subsample=8) + gp.SimpleAugment(mirror_only=[2], transpose_only=[]) + gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001) + BinarizeGt(labels, labels_fg) + gp.BalanceLabels(labels_fg, loss_weights) + # train gp.PreCache(cache_size=40, num_workers=10) + gp.tensorflow.Train('./train_net', optimizer=net_names['optimizer'], loss=net_names['loss'], inputs={ net_names['raw']: raw, net_names['labels_fg']: labels_fg, net_names['loss_weights']: loss_weights, }, outputs={ net_names['fg']: fg, }, gradients={ net_names['fg']: gradient_fg, }, save_every=100) + # visualize gp.Snapshot(output_filename='snapshot_{iteration}.hdf', dataset_names={ raw: 'volumes/raw', raw_base: 'volumes/raw_base', raw_add: 'volumes/raw_add', labels: 'volumes/labels', labels_base: 'volumes/labels_base', labels_add: 'volumes/labels_add', fg: 'volumes/fg', labels_fg: 'volumes/labels_fg', gradient_fg: 'volumes/gradient_fg', }, additional_request=snapshot_request, every=10) + gp.PrintProfilingStats(every=100)) with gp.build(pipeline): print("Starting training...") for i in range(max_iteration - trained_until): pipeline.request_batch(request)
def validation_pipeline(config): """ Per block { Raw -> predict -> scan gt -> rasterize -> merge -> candidates -> trees } -> merge -> comatch + evaluate """ blocks = config["BLOCKS"] benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"]) sample = config["VALIDATION_SAMPLES"][0] sample_dir = Path(config["SAMPLES_PATH"]) raw_n5 = config["RAW_N5"] transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt" neuron_width = int(config["NEURON_RADIUS"]) voxel_size = gp.Coordinate(config["VOXEL_SIZE"]) micron_scale = max(voxel_size) input_shape = gp.Coordinate(config["INPUT_SHAPE"]) output_shape = gp.Coordinate(config["OUTPUT_SHAPE"]) input_size = voxel_size * input_shape output_size = voxel_size * output_shape distance_attr = config["DISTANCE_ATTR"] validation_pipelines = [] specs = {} for block in blocks: validation_dir = get_validation_dir(benchmark_datasets_path, block) trees = [] cube = None for gt_file in validation_dir.iterdir(): if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc": trees.append(gt_file) if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc": cube = gt_file assert cube.exists() cube_roi = get_roi_from_swc( cube, Path(transform_template.format(sample=sample)), np.array([300, 300, 1000]), ) raw = gp.ArrayKey(f"RAW_{block}") raw_clahed = gp.ArrayKey(f"RAW_CLAHED_{block}") ground_truth = gp.GraphKey(f"GROUND_TRUTH_{block}") labels = gp.ArrayKey(f"LABELS_{block}") raw_source = (gp.ZarrSource( filename=str(Path(sample_dir, sample, raw_n5).absolute()), datasets={ raw: "volume-rechunked", raw_clahed: "volume-rechunked" }, array_specs={ raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size), raw_clahed: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size), }, ) + gp.Normalize(raw, dtype=np.float32) + gp.Normalize(raw_clahed, dtype=np.float32) + scipyCLAHE([raw_clahed], [20, 64, 64])) swc_source = nl.gunpowder.nodes.MouselightSwcFileSource( validation_dir, [ground_truth], transform_file=transform_template.format(sample=sample), ignore_human_nodes=False, scale=voxel_size, transpose=[2, 1, 0], points_spec=[ gp.PointsSpec(roi=gp.Roi( gp.Coordinate([None, None, None]), gp.Coordinate([None, None, None]), )) ], ) additional_request = BatchRequest() input_roi = cube_roi.grow((input_size - output_size) // 2, (input_size - output_size) // 2) cube_roi_shifted = gp.Roi((0, ) * len(cube_roi.get_shape()), cube_roi.get_shape()) input_roi = cube_roi_shifted.grow((input_size - output_size) // 2, (input_size - output_size) // 2) block_spec = specs.setdefault(block, {}) block_spec[raw] = gp.ArraySpec(input_roi) additional_request[raw] = gp.ArraySpec(roi=input_roi) block_spec[raw_clahed] = gp.ArraySpec(input_roi) additional_request[raw_clahed] = gp.ArraySpec(roi=input_roi) block_spec[ground_truth] = gp.GraphSpec(cube_roi_shifted) additional_request[ground_truth] = gp.GraphSpec(roi=cube_roi_shifted) block_spec[labels] = gp.ArraySpec(cube_roi_shifted) additional_request[labels] = gp.ArraySpec(roi=cube_roi_shifted) pipeline = ((swc_source, raw_source) + gp.nodes.MergeProvider() + gp.SpecifiedLocation(locations=[cube_roi.get_center()]) + gp.Crop(raw, roi=input_roi) + gp.Crop(raw_clahed, roi=input_roi) + gp.Crop(ground_truth, roi=cube_roi_shifted) + nl.gunpowder.RasterizeSkeleton( ground_truth, labels, connected_component_labeling=True, array_spec=gp.ArraySpec( voxel_size=voxel_size, dtype=np.int64, roi=gp.Roi( gp.Coordinate([None, None, None]), gp.Coordinate([None, None, None]), ), ), ) + nl.gunpowder.GrowLabels( labels, radii=[neuron_width * micron_scale]) + gp.Crop(labels, roi=cube_roi_shifted) + gp.Snapshot( { raw: f"volumes/{block}/raw", raw_clahed: f"volumes/{block}/raw_clahe", ground_truth: f"points/{block}/ground_truth", labels: f"volumes/{block}/labels", }, additional_request=additional_request, output_dir="validations", output_filename="validations.hdf", )) validation_pipelines.append(pipeline) validation_pipeline = (tuple(pipeline for pipeline in validation_pipelines) + gp.MergeProvider() + gp.PrintProfilingStats()) return validation_pipeline, specs
def train(n_iterations): point_trees = gp.PointsKey("POINT_TREES") labels = gp.ArrayKey("LABELS") raw = gp.ArrayKey("RAW") # gt_fg = gp.ArrayKey("GT_FG") # embedding = gp.ArrayKey("EMBEDDING") # fg = gp.ArrayKey("FG") # maxima = gp.ArrayKey("MAXIMA") # gradient_embedding = gp.ArrayKey("GRADIENT_EMBEDDING") # gradient_fg = gp.ArrayKey("GRADIENT_FG") # emst = gp.ArrayKey("EMST") # edges_u = gp.ArrayKey("EDGES_U") # edges_v = gp.ArrayKey("EDGES_V") request = gp.BatchRequest() request.add(raw, INPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) request.add(labels, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) request.add(point_trees, INPUT_SHAPE) snapshot_request = gp.BatchRequest() snapshot_request.add(raw, INPUT_SHAPE) # snapshot_request.add(embedding, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) # snapshot_request.add(fg, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) # snapshot_request.add(gt_fg, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) # snapshot_request.add(maxima, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) # snapshot_request.add( # gradient_embedding, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1)) # ) # snapshot_request.add(gradient_fg, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) # snapshot_request[emst] = gp.ArraySpec() # snapshot_request[edges_u] = gp.ArraySpec() # snapshot_request[edges_v] = gp.ArraySpec() pipeline = ( nl.SyntheticLightLike( point_trees, dims=2, r=SKEL_GEN_RADIUS, n_obj=N_OBJS, thetas=THETAS, split_ps=SPLIT_PS, ) # + gp.SimpleAugment() # + gp.ElasticAugment([10, 10], [0.1, 0.1], [0, 2.0 * math.pi], spatial_dims=2) + nl.RasterizeSkeleton( point_trees, labels, gp.ArraySpec( roi=gp.Roi((None,) * 2, (None,) * 2), voxel_size=gp.Coordinate((1, 1)), dtype=np.uint64, ), ) + gp.Copy(labels, raw) + nl.GrowLabels(labels, radii=LABEL_RADII) + nl.GrowLabels(raw, radii=RAW_RADII) + LabelToFloat32(raw, intensities=RAW_INTENSITIES) + gp.NoiseAugment(raw, var=NOISE_VAR) # + gp.PreCache(cache_size=40, num_workers=10) # + gp.tensorflow.Train( # "train_net", # optimizer=add_loss, # loss=None, # inputs={tensor_names["raw"]: raw, tensor_names["gt_labels"]: labels}, # outputs={ # tensor_names["embedding"]: embedding, # tensor_names["fg"]: fg, # "maxima:0": maxima, # "gt_fg:0": gt_fg, # emst_name: emst, # edges_u_name: edges_u, # edges_v_name: edges_v, # }, # gradients={ # tensor_names["embedding"]: gradient_embedding, # tensor_names["fg"]: gradient_fg, # }, # ) + gp.Snapshot( output_filename="{iteration}.hdf", dataset_names={ raw: "volumes/raw", labels: "volumes/labels", point_trees: "point_trees", # embedding: "volumes/embedding", # fg: "volumes/fg", # maxima: "volumes/maxima", # gt_fg: "volumes/gt_fg", # gradient_embedding: "volumes/gradient_embedding", # gradient_fg: "volumes/gradient_fg", # emst: "emst", # edges_u: "edges_u", # edges_v: "edges_v", }, # dataset_dtypes={maxima: np.float32, gt_fg: np.float32}, every=100, additional_request=snapshot_request, ) + gp.PrintProfilingStats(every=10) ) with gp.build(pipeline): for i in range(n_iterations): pipeline.request_batch(request)
def create_pipeline_3d(task, predictor, optimizer, batch_size, outdir, snapshot_every): raw_channels = max(1, task.data.raw.num_channels) input_shape = predictor.input_shape output_shape = predictor.output_shape voxel_size = task.data.raw.train.voxel_size # switch to world units input_size = voxel_size * input_shape output_size = voxel_size * output_shape raw = gp.ArrayKey('RAW') gt = gp.ArrayKey('GT') target = gp.ArrayKey('TARGET') weights = gp.ArrayKey('WEIGHTS') prediction = gp.ArrayKey('PREDICTION') channel_dims = 0 if raw_channels == 1 else 1 num_samples = task.data.raw.train.num_samples assert num_samples == 0, ( "Multiple samples for 3D training not yet implemented") sources = (task.data.raw.train.get_source(raw), task.data.gt.train.get_source(gt)) pipeline = sources + gp.MergeProvider() pipeline += gp.Pad(raw, None) # raw: ([c,] d, h, w) # gt: ([c,] d, h, w) pipeline += gp.Normalize(raw) # raw: ([c,] d, h, w) # gt: ([c,] d, h, w) pipeline += gp.RandomLocation() # raw: ([c,] d, h, w) # gt: ([c,] d, h, w) for augmentation in eval(task.augmentations): pipeline += augmentation pipeline += predictor.add_target(gt, target) # (don't care about gt anymore) # raw: ([c,] d, h, w) # target: ([c,] d, h, w) weights_node = task.loss.add_weights(target, weights) if weights_node: pipeline += weights_node loss_inputs = {0: prediction, 1: target, 2: weights} else: loss_inputs = {0: prediction, 1: target} # raw: ([c,] d, h, w) # target: ([c,] d, h, w) # [weights: ([c,] d, h, w)] if channel_dims == 0: pipeline += AddChannelDim(raw) # raw: (c, d, h, w) # target: ([c,] d, h, w) # [weights: ([c,] d, h, w)] pipeline += gp.PreCache() pipeline += gp.Stack(batch_size) # raw: (b, c, d, h, w) # target: (b, [c,] d, h, w) # [weights: (b, [c,] d, h, w)] pipeline += gp_torch.Train(model=predictor, loss=task.loss, optimizer=optimizer, inputs={'x': raw}, loss_inputs=loss_inputs, outputs={0: prediction}, save_every=1e6) # raw: (b, c, d, h, w) # target: (b, [c,] d, h, w) # [weights: (b, [c,] d, h, w)] # prediction: (b, [c,] d, h, w) if snapshot_every > 0: # get channels first pipeline += TransposeDims(raw, (1, 0, 2, 3, 4)) if predictor.target_channels > 0: pipeline += TransposeDims(target, (1, 0, 2, 3, 4)) if weights_node: pipeline += TransposeDims(weights, (1, 0, 2, 3, 4)) if predictor.prediction_channels > 0: pipeline += TransposeDims(prediction, (1, 0, 2, 3, 4)) # raw: (c, b, d, h, w) # target: ([c,] b, d, h, w) # [weights: ([c,] b, d, h, w)] # prediction: ([c,] b, d, h, w) if channel_dims == 0: pipeline += RemoveChannelDim(raw) # raw: ([c,] b, d, h, w) # target: (c, b, d, h, w) # [weights: ([c,] b, d, h, w)] # prediction: (c, b, d, h, w) pipeline += gp.Snapshot(dataset_names={ raw: 'raw', target: 'target', prediction: 'prediction', weights: 'weights' }, every=snapshot_every, output_dir=os.path.join(outdir, 'snapshots'), output_filename="{iteration}.hdf") pipeline += gp.PrintProfilingStats(every=100) request = gp.BatchRequest() request.add(raw, input_size) request.add(gt, output_size) request.add(target, output_size) if weights_node: request.add(weights, output_size) request.add(prediction, output_size) return pipeline, request
def train_distance_pipeline(n_iterations, setup_config, mknet_tensor_names, loss_tensor_names): input_shape = gp.Coordinate(setup_config["INPUT_SHAPE"]) output_shape = gp.Coordinate(setup_config["OUTPUT_SHAPE"]) voxel_size = gp.Coordinate(setup_config["VOXEL_SIZE"]) num_iterations = setup_config["NUM_ITERATIONS"] cache_size = setup_config["CACHE_SIZE"] num_workers = setup_config["NUM_WORKERS"] snapshot_every = setup_config["SNAPSHOT_EVERY"] checkpoint_every = setup_config["CHECKPOINT_EVERY"] profile_every = setup_config["PROFILE_EVERY"] seperate_by = setup_config["SEPERATE_BY"] gap_crossing_dist = setup_config["GAP_CROSSING_DIST"] match_distance_threshold = setup_config["MATCH_DISTANCE_THRESHOLD"] point_balance_radius = setup_config["POINT_BALANCE_RADIUS"] max_label_dist = setup_config["MAX_LABEL_DIST"] samples_path = Path(setup_config["SAMPLES_PATH"]) mongo_url = setup_config["MONGO_URL"] input_size = input_shape * voxel_size output_size = output_shape * voxel_size # voxels have size ~= 1 micron on z axis # use this value to scale anything that depends on world unit distance micron_scale = voxel_size[0] seperate_distance = (np.array(seperate_by)).tolist() # array keys for data sources raw = gp.ArrayKey("RAW") consensus = gp.PointsKey("CONSENSUS") skeletonization = gp.PointsKey("SKELETONIZATION") matched = gp.PointsKey("MATCHED") labels = gp.ArrayKey("LABELS") dist = gp.ArrayKey("DIST") dist_mask = gp.ArrayKey("DIST_MASK") dist_cropped = gp.ArrayKey("DIST_CROPPED") loss_weights = gp.ArrayKey("LOSS_WEIGHTS") # tensorflow tensors fg_dist = gp.ArrayKey("FG_DIST") gradient_fg = gp.ArrayKey("GRADIENT_FG") # add request request = gp.BatchRequest() request.add(dist_mask, output_size) request.add(dist_cropped, output_size) request.add(raw, input_size) request.add(labels, input_size) request.add(dist, input_size) request.add(matched, input_size) request.add(skeletonization, input_size) request.add(consensus, input_size) request.add(loss_weights, output_size) # add snapshot request snapshot_request = gp.BatchRequest() # tensorflow requests snapshot_request.add(raw, input_size) # input_size request for positioning snapshot_request.add(gradient_fg, output_size, voxel_size=voxel_size) snapshot_request.add(fg_dist, output_size, voxel_size=voxel_size) data_sources = tuple( ( gp.N5Source( filename=str((sample / "fluorescence-near-consensus.n5").absolute()), datasets={raw: "volume"}, array_specs={ raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size, dtype=np.uint16) }, ), gp.DaisyGraphProvider( f"mouselight-{sample.name}-consensus", mongo_url, points=[consensus], directed=True, node_attrs=[], edge_attrs=[], ), gp.DaisyGraphProvider( f"mouselight-{sample.name}-skeletonization", mongo_url, points=[skeletonization], directed=False, node_attrs=[], edge_attrs=[], ), ) + gp.MergeProvider() + gp.RandomLocation( ensure_nonempty=consensus, ensure_centered=True, point_balance_radius=point_balance_radius * micron_scale, ) + TopologicalMatcher( skeletonization, consensus, matched, failures=Path("matching_failures_slow"), match_distance_threshold=match_distance_threshold * micron_scale, max_gap_crossing=gap_crossing_dist * micron_scale, try_complete=False, use_gurobi=True, ) + RejectIfEmpty(matched, center_size=output_size) + RasterizeSkeleton( points=matched, array=labels, array_spec=gp.ArraySpec( interpolatable=False, voxel_size=voxel_size, dtype=np.uint32), ) + gp.contrib.nodes.add_distance.AddDistance( labels, dist, dist_mask, max_distance=max_label_dist * micron_scale) + gp.contrib.nodes. tanh_saturate.TanhSaturate(dist, scale=micron_scale, offset=1) + ThresholdMask(dist, loss_weights, 1e-4) # TODO: Do these need to be scaled by world units? + gp.ElasticAugment( [40, 10, 10], [0.25, 1, 1], [0, math.pi / 2.0], subsample=4, use_fast_points_transform=True, recompute_missing_points=False, ) # + gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2]) + gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001) for sample in samples_path.iterdir() if sample.name in ("2018-07-02", "2018-08-01")) pipeline = ( data_sources + gp.RandomProvider() + Crop(dist, dist_cropped) # + gp.PreCache(cache_size=cache_size, num_workers=num_workers) + gp.tensorflow.Train( "train_net_foreground", optimizer=mknet_tensor_names["optimizer"], loss=mknet_tensor_names["fg_loss"], inputs={ mknet_tensor_names["raw"]: raw, mknet_tensor_names["gt_distances"]: dist_cropped, mknet_tensor_names["loss_weights"]: loss_weights, }, outputs={mknet_tensor_names["fg_pred"]: fg_dist}, gradients={mknet_tensor_names["fg_pred"]: gradient_fg}, save_every=checkpoint_every, # summary=mknet_tensor_names["summaries"], log_dir="tensorflow_logs", ) + gp.PrintProfilingStats(every=profile_every) + gp.Snapshot( additional_request=snapshot_request, output_filename="snapshot_{}_{}.hdf".format( int(np.min(seperate_distance)), "{id}"), dataset_names={ # raw data raw: "volumes/raw", labels: "volumes/labels", # labeled data dist_cropped: "volumes/dist", # trees skeletonization: "points/skeletonization", consensus: "points/consensus", matched: "points/matched", # output volumes fg_dist: "volumes/fg_dist", gradient_fg: "volumes/gradient_fg", # output debug data dist_mask: "volumes/dist_mask", loss_weights: "volumes/loss_weights" }, every=snapshot_every, )) with gp.build(pipeline): for _ in range(num_iterations): pipeline.request_batch(request)
def train(until): model = SpineUNet() loss = torch.nn.BCELoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.001) input_size = (8, 96, 96) raw = gp.ArrayKey('RAW') labels = gp.ArrayKey('LABELS') affs = gp.ArrayKey('AFFS') affs_predicted = gp.ArrayKey('AFFS_PREDICTED') pipeline = ( ( gp.ZarrSource( 'data/20200201.zarr', { raw: 'train/sample1/raw', labels: 'train/sample1/labels' }), gp.ZarrSource( 'data/20200201.zarr', { raw: 'train/sample2/raw', labels: 'train/sample2/labels' }), gp.ZarrSource( 'data/20200201.zarr', { raw: 'train/sample3/raw', labels: 'train/sample3/labels' }) ) + gp.RandomProvider() + gp.Normalize(raw) + gp.RandomLocation() + gp.SimpleAugment(transpose_only=(1, 2)) + gp.ElasticAugment((2, 10, 10), (0.0, 0.5, 0.5), [0, math.pi]) + gp.AddAffinities( [(1, 0, 0), (0, 1, 0), (0, 0, 1)], labels, affs) + gp.Normalize(affs, factor=1.0) + #gp.PreCache(num_workers=1) + # raw: (d, h, w) # affs: (3, d, h, w) gp.Stack(1) + # raw: (1, d, h, w) # affs: (1, 3, d, h, w) AddChannelDim(raw) + # raw: (1, 1, d, h, w) # affs: (1, 3, d, h, w) gp_torch.Train( model, loss, optimizer, inputs={'x': raw}, outputs={0: affs_predicted}, loss_inputs={0: affs_predicted, 1: affs}, save_every=10000) + RemoveChannelDim(raw) + RemoveChannelDim(raw) + RemoveChannelDim(affs) + RemoveChannelDim(affs_predicted) + # raw: (d, h, w) # affs: (3, d, h, w) # affs_predicted: (3, d, h, w) gp.Snapshot( { raw: 'raw', labels: 'labels', affs: 'affs', affs_predicted: 'affs_predicted' }, every=500, output_filename='iteration_{iteration}.hdf') ) request = gp.BatchRequest() request.add(raw, input_size) request.add(labels, input_size) request.add(affs, input_size) request.add(affs_predicted, input_size) with gp.build(pipeline): for i in range(until): pipeline.request_batch(request)
def random_point_pairs_pipeline(model, loss, optimizer, dataset, augmentation_parameters, point_density, out_dir, normalize_factor=None, checkpoint_interval=5000, snapshot_interval=5000): raw_0 = gp.ArrayKey('RAW_0') points_0 = gp.GraphKey('POINTS_0') locations_0 = gp.ArrayKey('LOCATIONS_0') emb_0 = gp.ArrayKey('EMBEDDING_0') raw_1 = gp.ArrayKey('RAW_1') points_1 = gp.GraphKey('POINTS_1') locations_1 = gp.ArrayKey('LOCATIONS_1') emb_1 = gp.ArrayKey('EMBEDDING_1') # TODO parse this key from somewhere key = 'train/raw/0' data = daisy.open_ds(dataset.filename, key) source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape()) voxel_size = gp.Coordinate(data.voxel_size) emb_voxel_size = voxel_size # Get in and out shape in_shape = gp.Coordinate(model.in_shape) out_shape = gp.Coordinate(model.out_shape) logger.info(f"source roi: {source_roi}") logger.info(f"in_shape: {in_shape}") logger.info(f"out_shape: {out_shape}") logger.info(f"voxel_size: {voxel_size}") request = gp.BatchRequest() request.add(raw_0, in_shape) request.add(raw_1, in_shape) request.add(points_0, out_shape) request.add(points_1, out_shape) request[locations_0] = gp.ArraySpec(nonspatial=True) request[locations_1] = gp.ArraySpec(nonspatial=True) snapshot_request = gp.BatchRequest() snapshot_request[emb_0] = gp.ArraySpec(roi=request[points_0].roi) snapshot_request[emb_1] = gp.ArraySpec(roi=request[points_1].roi) # Let's hardcode this for now # TODO read actual number from zarr file keys n_samples = 447 batch_size = 1 dim = 2 padding = (100, 100) sources = [] for i in range(n_samples): ds_key = f'train/raw/{i}' image_sources = tuple( gp.ZarrSource( dataset.filename, {raw: ds_key}, {raw: gp.ArraySpec(interpolatable=True, voxel_size=(1, 1))}) + gp.Pad(raw, None) for raw in [raw_0, raw_1]) random_point_generator = RandomPointGenerator(density=point_density, repetitions=2) point_sources = tuple( (RandomPointSource(points_0, dim, random_point_generator=random_point_generator), RandomPointSource(points_1, dim, random_point_generator=random_point_generator))) # TODO: get augmentation parameters from some config file! points_and_image_sources = tuple( (img_source, point_source) + gp.MergeProvider() + \ gp.SimpleAugment() + \ gp.ElasticAugment( spatial_dims=2, control_point_spacing=(10, 10), jitter_sigma=(0.0, 0.0), rotation_interval=(0, math.pi/2)) + \ gp.IntensityAugment(r, scale_min=0.8, scale_max=1.2, shift_min=-0.2, shift_max=0.2, clip=False) + \ gp.NoiseAugment(r, var=0.01, clip=False) for r, img_source, point_source in zip([raw_0, raw_1], image_sources, point_sources)) sample_source = points_and_image_sources + gp.MergeProvider() data = daisy.open_ds(dataset.filename, ds_key) source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape()) sample_source += gp.Crop(raw_0, source_roi) sample_source += gp.Crop(raw_1, source_roi) sample_source += gp.Pad(raw_0, padding) sample_source += gp.Pad(raw_1, padding) sample_source += gp.RandomLocation() sources.append(sample_source) sources = tuple(sources) pipeline = sources + gp.RandomProvider() pipeline += gp.Unsqueeze([raw_0, raw_1]) pipeline += PrepareBatch(raw_0, raw_1, points_0, points_1, locations_0, locations_1) # How does prepare batch relate to Stack????? pipeline += RejectArray(ensure_nonempty=locations_1) pipeline += RejectArray(ensure_nonempty=locations_0) # batch content # raw_0: (1, h, w) # raw_1: (1, h, w) # locations_0: (n, 2) # locations_1: (n, 2) pipeline += gp.Stack(batch_size) # batch content # raw_0: (b, 1, h, w) # raw_1: (b, 1, h, w) # locations_0: (b, n, 2) # locations_1: (b, n, 2) pipeline += gp.PreCache(num_workers=10) pipeline += gp.torch.Train( model, loss, optimizer, inputs={ 'raw_0': raw_0, 'raw_1': raw_1 }, loss_inputs={ 'emb_0': emb_0, 'emb_1': emb_1, 'locations_0': locations_0, 'locations_1': locations_1 }, outputs={ 2: emb_0, 3: emb_1 }, array_specs={ emb_0: gp.ArraySpec(voxel_size=emb_voxel_size), emb_1: gp.ArraySpec(voxel_size=emb_voxel_size) }, checkpoint_basename=os.path.join(out_dir, 'model'), save_every=checkpoint_interval) pipeline += gp.Snapshot( { raw_0: 'raw_0', raw_1: 'raw_1', emb_0: 'emb_0', emb_1: 'emb_1', # locations_0 : 'locations_0', # locations_1 : 'locations_1', }, every=snapshot_interval, additional_request=snapshot_request) return pipeline, request
def create_train_pipeline(self, model): print( f"Creating training pipeline with batch size {self.params['batch_size']}" ) filename = self.params['data_file'] raw_dataset = self.params['dataset']['train']['raw'] gt_dataset = self.params['dataset']['train']['gt'] optimizer = self.params['optimizer'](model.parameters(), **self.params['optimizer_kwargs']) raw = gp.ArrayKey('RAW') gt_labels = gp.ArrayKey('LABELS') gt_aff = gp.ArrayKey('AFFINITIES') predictions = gp.ArrayKey('PREDICTIONS') emb = gp.ArrayKey('EMBEDDING') raw_data = daisy.open_ds(filename, raw_dataset) source_roi = gp.Roi(raw_data.roi.get_offset(), raw_data.roi.get_shape()) source_voxel_size = gp.Coordinate(raw_data.voxel_size) out_voxel_size = gp.Coordinate(raw_data.voxel_size) # Get in and out shape in_shape = gp.Coordinate(model.in_shape) out_shape = gp.Coordinate(model.out_shape[2:]) is_2d = in_shape.dims() == 2 in_shape = in_shape * out_voxel_size out_shape = out_shape * out_voxel_size context = (in_shape - out_shape) / 2 gt_labels_out_shape = out_shape # Add fake 3rd dim if is_2d: source_voxel_size = gp.Coordinate((1, *source_voxel_size)) source_roi = gp.Roi((0, *source_roi.get_offset()), (raw_data.shape[0], *source_roi.get_shape())) context = gp.Coordinate((0, *context)) aff_neighborhood = [[0, -1, 0], [0, 0, -1]] gt_labels_out_shape = (1, *gt_labels_out_shape) else: aff_neighborhood = [[-1, 0, 0], [0, -1, 0], [0, 0, -1]] logger.info(f"source roi: {source_roi}") logger.info(f"in_shape: {in_shape}") logger.info(f"out_shape: {out_shape}") logger.info(f"voxel_size: {out_voxel_size}") logger.info(f"context: {context}") request = gp.BatchRequest() request.add(raw, in_shape) request.add(gt_aff, out_shape) request.add(predictions, out_shape) snapshot_request = gp.BatchRequest() snapshot_request[emb] = gp.ArraySpec( roi=gp.Roi((0, ) * in_shape.dims(), gp.Coordinate((*model.base_encoder.out_shape[2:], )) * out_voxel_size)) snapshot_request[gt_labels] = gp.ArraySpec( roi=gp.Roi(context, gt_labels_out_shape)) source = ( gp.ZarrSource(filename, { raw: raw_dataset, gt_labels: gt_dataset }, array_specs={ raw: gp.ArraySpec(roi=source_roi, voxel_size=source_voxel_size, interpolatable=True), gt_labels: gp.ArraySpec(roi=source_roi, voxel_size=source_voxel_size) }) + gp.Normalize(raw, self.params['norm_factor']) + gp.Pad(raw, context) + gp.Pad(gt_labels, context) + gp.RandomLocation() # raw : (l=1, h, w) # gt_labels: (l=1, h, w) ) source = self._augmentation_pipeline(raw, source) pipeline = ( source + # raw : (l=1, h, w) # gt_labels: (l=1, h, w) gp.AddAffinities(aff_neighborhood, gt_labels, gt_aff) + SetDtype(gt_aff, np.float32) + # raw : (l=1, h, w) # gt_aff : (c=2, l=1, h, w) AddChannelDim(raw) # raw : (c=1, l=1, h, w) # gt_aff : (c=2, l=1, h, w) ) if is_2d: pipeline = ( pipeline + RemoveSpatialDim(raw) + RemoveSpatialDim(gt_aff) # raw : (c=1, h, w) # gt_aff : (c=2, h, w) ) pipeline = ( pipeline + gp.Stack(self.params['batch_size']) + gp.PreCache() + # raw : (b, c=1, h, w) # gt_aff : (b, c=2, h, w) # (which is what train requires) gp.torch.Train( model, self.loss, optimizer, inputs={'raw': raw}, loss_inputs={ 0: predictions, 1: gt_aff }, outputs={ 0: predictions, 1: emb }, array_specs={ predictions: gp.ArraySpec(voxel_size=out_voxel_size), }, checkpoint_basename=self.logdir + '/checkpoints/model', save_every=self.params['save_every'], log_dir=self.logdir, log_every=self.log_every) + # everything is 2D at this point, plus extra dimensions for # channels and batch # raw : (b, c=1, h, w) # gt_aff : (b, c=2, h, w) # predictions: (b, c=2, h, w) # Crop GT to look at labels gp.Crop(gt_labels, gp.Roi(context, gt_labels_out_shape)) + gp.Snapshot(output_dir=self.logdir + '/snapshots', output_filename='it{iteration}.hdf', dataset_names={ raw: 'raw', gt_labels: 'gt_labels', predictions: 'predictions', gt_aff: 'gt_aff', emb: 'emb' }, additional_request=snapshot_request, every=self.params['save_every']) + gp.PrintProfilingStats(every=500)) return pipeline, request
def train(n_iterations, setup_config, mknet_tensor_names, loss_tensor_names): # Network hyperparams INPUT_SHAPE = setup_config["INPUT_SHAPE"] OUTPUT_SHAPE = setup_config["OUTPUT_SHAPE"] # Skeleton generation hyperparams SKEL_GEN_RADIUS = setup_config["SKEL_GEN_RADIUS"] THETAS = np.array(setup_config["THETAS"]) * math.pi SPLIT_PS = setup_config["SPLIT_PS"] NOISE_VAR = setup_config["NOISE_VAR"] N_OBJS = setup_config["N_OBJS"] # Skeleton variation hyperparams LABEL_RADII = setup_config["LABEL_RADII"] RAW_RADII = setup_config["RAW_RADII"] RAW_INTENSITIES = setup_config["RAW_INTENSITIES"] # Training hyperparams CACHE_SIZE = setup_config["CACHE_SIZE"] NUM_WORKERS = setup_config["NUM_WORKERS"] SNAPSHOT_EVERY = setup_config["SNAPSHOT_EVERY"] CHECKPOINT_EVERY = setup_config["CHECKPOINT_EVERY"] point_trees = gp.PointsKey("POINT_TREES") labels = gp.ArrayKey("LABELS") raw = gp.ArrayKey("RAW") gt_fg = gp.ArrayKey("GT_FG") embedding = gp.ArrayKey("EMBEDDING") fg = gp.ArrayKey("FG") maxima = gp.ArrayKey("MAXIMA") gradient_embedding = gp.ArrayKey("GRADIENT_EMBEDDING") gradient_fg = gp.ArrayKey("GRADIENT_FG") # tensorflow tensors emst = gp.ArrayKey("EMST") edges_u = gp.ArrayKey("EDGES_U") edges_v = gp.ArrayKey("EDGES_V") ratio_pos = gp.ArrayKey("RATIO_POS") ratio_neg = gp.ArrayKey("RATIO_NEG") dist = gp.ArrayKey("DIST") num_pos_pairs = gp.ArrayKey("NUM_POS") num_neg_pairs = gp.ArrayKey("NUM_NEG") request = gp.BatchRequest() request.add(raw, INPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) request.add(labels, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) request.add(point_trees, INPUT_SHAPE) snapshot_request = gp.BatchRequest() snapshot_request.add(raw, INPUT_SHAPE) snapshot_request.add(embedding, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) snapshot_request.add(fg, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) snapshot_request.add(gt_fg, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) snapshot_request.add(maxima, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) snapshot_request.add(gradient_embedding, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) snapshot_request.add(gradient_fg, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) snapshot_request[emst] = gp.ArraySpec() snapshot_request[edges_u] = gp.ArraySpec() snapshot_request[edges_v] = gp.ArraySpec() snapshot_request[ratio_pos] = gp.ArraySpec() snapshot_request[ratio_neg] = gp.ArraySpec() snapshot_request[dist] = gp.ArraySpec() snapshot_request[num_pos_pairs] = gp.ArraySpec() snapshot_request[num_neg_pairs] = gp.ArraySpec() pipeline = ( nl.SyntheticLightLike( point_trees, dims=2, r=SKEL_GEN_RADIUS, n_obj=N_OBJS, thetas=THETAS, split_ps=SPLIT_PS, ) # + gp.SimpleAugment() # + gp.ElasticAugment([10, 10], [0.1, 0.1], [0, 2.0 * math.pi], spatial_dims=2) + nl.RasterizeSkeleton( point_trees, raw, gp.ArraySpec( roi=gp.Roi((None, ) * 2, (None, ) * 2), voxel_size=gp.Coordinate((1, 1)), dtype=np.uint64, ), ) + nl.RasterizeSkeleton( point_trees, labels, gp.ArraySpec( roi=gp.Roi((None, ) * 2, (None, ) * 2), voxel_size=gp.Coordinate((1, 1)), dtype=np.uint64, ), use_component=True, n_objs=int(setup_config["HIDE_SIGNAL"]), ) + nl.GrowLabels(labels, radii=LABEL_RADII) + nl.GrowLabels(raw, radii=RAW_RADII) + LabelToFloat32(raw, intensities=RAW_INTENSITIES) + gp.NoiseAugment(raw, var=NOISE_VAR) + gp.PreCache(cache_size=CACHE_SIZE, num_workers=NUM_WORKERS) + gp.tensorflow.Train( "train_net", optimizer=create_custom_loss(mknet_tensor_names, setup_config), loss=None, inputs={ mknet_tensor_names["raw"]: raw, mknet_tensor_names["gt_labels"]: labels }, outputs={ mknet_tensor_names["embedding"]: embedding, mknet_tensor_names["fg"]: fg, "strided_slice_1:0": maxima, "gt_fg:0": gt_fg, loss_tensor_names["emst"]: emst, loss_tensor_names["edges_u"]: edges_u, loss_tensor_names["edges_v"]: edges_v, loss_tensor_names["ratio_pos"]: ratio_pos, loss_tensor_names["ratio_neg"]: ratio_neg, loss_tensor_names["dist"]: dist, loss_tensor_names["num_pos_pairs"]: num_pos_pairs, loss_tensor_names["num_neg_pairs"]: num_neg_pairs, }, gradients={ mknet_tensor_names["embedding"]: gradient_embedding, mknet_tensor_names["fg"]: gradient_fg, }, save_every=CHECKPOINT_EVERY, summary="Merge/MergeSummary:0", log_dir="tensorflow_logs", ) + gp.Snapshot( output_filename="{iteration}.hdf", dataset_names={ raw: "volumes/raw", labels: "volumes/labels", point_trees: "point_trees", embedding: "volumes/embedding", fg: "volumes/fg", maxima: "volumes/maxima", gt_fg: "volumes/gt_fg", gradient_embedding: "volumes/gradient_embedding", gradient_fg: "volumes/gradient_fg", emst: "emst", edges_u: "edges_u", edges_v: "edges_v", ratio_pos: "ratio_pos", ratio_neg: "ratio_neg", dist: "dist", num_pos_pairs: "num_pos_pairs", num_neg_pairs: "num_neg_pairs", }, dataset_dtypes={ maxima: np.float32, gt_fg: np.float32 }, every=SNAPSHOT_EVERY, additional_request=snapshot_request, ) # + gp.PrintProfilingStats(every=100) ) with gp.build(pipeline): for i in range(n_iterations + 1): pipeline.request_batch(request) request._update_random_seed()
def train(iterations): ################## # DECLARE ARRAYS # ################## # raw intensities raw = gp.ArrayKey('RAW') # objects labelled with unique IDs gt_labels = gp.ArrayKey('LABELS') # array of per-voxel affinities to direct neighbors gt_affs = gp.ArrayKey('AFFINITIES') # weights to use to balance the loss loss_weights = gp.ArrayKey('LOSS_WEIGHTS') # the predicted affinities pred_affs = gp.ArrayKey('PRED_AFFS') # the gredient of the loss wrt to the predicted affinities pred_affs_gradients = gp.ArrayKey('PRED_AFFS_GRADIENTS') #################### # DECLARE REQUESTS # #################### with open('train_net_config.json', 'r') as f: net_config = json.load(f) # get the input and output size in world units (nm, in this case) voxel_size = gp.Coordinate((8, 8, 8)) input_size = gp.Coordinate(net_config['input_shape']) * voxel_size output_size = gp.Coordinate(net_config['output_shape']) * voxel_size # formulate the request for what a batch should (at least) contain request = gp.BatchRequest() request.add(raw, input_size) request.add(gt_affs, output_size) request.add(loss_weights, output_size) # when we make a snapshot for inspection (see below), we also want to # request the predicted affinities and gradients of the loss wrt the # affinities snapshot_request = gp.BatchRequest() snapshot_request[pred_affs] = request[gt_affs] snapshot_request[pred_affs_gradients] = request[gt_affs] ############################## # ASSEMBLE TRAINING PIPELINE # ############################## pipeline = ( # a tuple of sources, one for each sample (A, B, and C) provided by the # CREMI challenge tuple( # read batches from the HDF5 file gp.Hdf5Source(os.path.join(data_dir, 'fib.hdf'), datasets={ raw: 'volumes/raw', gt_labels: 'volumes/labels/neuron_ids' }) + # convert raw to float in [0, 1] gp.Normalize(raw) + # chose a random location for each requested batch gp.RandomLocation()) + # chose a random source (i.e., sample) from the above gp.RandomProvider() + # elastically deform the batch gp.ElasticAugment([8, 8, 8], [0, 2, 2], [0, math.pi / 2.0], prob_slip=0.05, prob_shift=0.05, max_misalign=25) + # apply transpose and mirror augmentations gp.SimpleAugment(transpose_only=[1, 2]) + # scale and shift the intensity of the raw array gp.IntensityAugment(raw, scale_min=0.9, scale_max=1.1, shift_min=-0.1, shift_max=0.1, z_section_wise=True) + # grow a boundary between labels gp.GrowBoundary(gt_labels, steps=3, only_xy=True) + # convert labels into affinities between voxels gp.AddAffinities([[-1, 0, 0], [0, -1, 0], [0, 0, -1]], gt_labels, gt_affs) + # create a weight array that balances positive and negative samples in # the affinity array gp.BalanceLabels(gt_affs, loss_weights) + # pre-cache batches from the point upstream gp.PreCache(cache_size=10, num_workers=5) + # perform one training iteration for each passing batch (here we use # the tensor names earlier stored in train_net.config) gp.tensorflow.Train( 'train_net', net_config['optimizer'], net_config['loss'], inputs={ net_config['raw']: raw, net_config['gt_affs']: gt_affs, net_config['loss_weights']: loss_weights }, outputs={net_config['pred_affs']: pred_affs}, gradients={net_config['pred_affs']: pred_affs_gradients}, save_every=10000) + # save the passing batch as an HDF5 file for inspection gp.Snapshot( { raw: '/volumes/raw', gt_labels: '/volumes/labels/neuron_ids', gt_affs: '/volumes/labels/affs', pred_affs: '/volumes/pred_affs', pred_affs_gradients: '/volumes/pred_affs_gradients' }, output_dir='snapshots', output_filename='batch_{iteration}.hdf', every=1000, additional_request=snapshot_request, compression_type='gzip') + # show a summary of time spend in each node every 10 iterations gp.PrintProfilingStats(every=1000)) ######### # TRAIN # ######### print("Training for", iterations, "iterations") with gp.build(pipeline): for i in range(iterations): pipeline.request_batch(request) print("Finished")
def train_until(**kwargs): print("cuda visibile devices", os.environ["CUDA_VISIBLE_DEVICES"]) if tf.train.latest_checkpoint(kwargs['output_folder']): trained_until = int( tf.train.latest_checkpoint(kwargs['output_folder']).split('_')[-1]) else: trained_until = 0 if trained_until >= kwargs['max_iteration']: return raw = gp.ArrayKey('RAW') raw_cropped = gp.ArrayKey('RAW_CROPPED') gt_labels = gp.ArrayKey('GT_LABELS') gt_instances = gp.ArrayKey('GT_INSTANCES') gt_affs = gp.ArrayKey('GT_AFFS') gt_numinst = gp.ArrayKey('GT_NUMINST') gt_sample_mask = gp.ArrayKey('GT_SAMPLE_MASK') pred_affs = gp.ArrayKey('PRED_AFFS') pred_affs_gradients = gp.ArrayKey('PRED_AFFS_GRADIENTS') pred_numinst = gp.ArrayKey('PRED_NUMINST') with open(os.path.join(kwargs['output_folder'], kwargs['name'] + '_config.json'), 'r') as f: net_config = json.load(f) with open(os.path.join(kwargs['output_folder'], kwargs['name'] + '_names.json'), 'r') as f: net_names = json.load(f) voxel_size = gp.Coordinate(kwargs['voxel_size']) input_shape_world = gp.Coordinate(net_config['input_shape'])*voxel_size output_shape_world = gp.Coordinate(net_config['output_shape'])*voxel_size context = gp.Coordinate(input_shape_world - output_shape_world) / 2 # formulate the request for what a batch should (at least) contain request = gp.BatchRequest() request.add(raw, input_shape_world) request.add(raw_cropped, output_shape_world) request.add(gt_labels, output_shape_world) request.add(gt_instances, output_shape_world) request.add(gt_sample_mask, output_shape_world) request.add(gt_affs, output_shape_world) if kwargs['overlapping_inst']: request.add(gt_numinst, output_shape_world) # request.add(loss_weights_affs, output_shape_world) # when we make a snapshot for inspection (see below), we also want to # request the predicted affinities and gradients of the loss wrt the # affinities snapshot_request = gp.BatchRequest() snapshot_request.add(raw_cropped, output_shape_world) snapshot_request.add(pred_affs, output_shape_world) if kwargs['overlapping_inst']: snapshot_request.add(pred_numinst, output_shape_world) # snapshot_request.add(pred_affs_gradients, output_shape_world) if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr": raise NotImplementedError("train node for %s not implemented yet", kwargs['input_format']) raw_key = kwargs.get('raw_key', 'volumes/raw') print('raw key: ', raw_key) fls = [] shapes = [] for f in kwargs['data_files']: fls.append(os.path.splitext(f)[0]) if kwargs['input_format'] == "hdf": vol = h5py.File(f, 'r')[raw_key] elif kwargs['input_format'] == "zarr": vol = zarr.open(f, 'r')[raw_key] # print(f, vol.shape, vol.dtype) shapes.append(vol.shape) if vol.dtype != np.float32: print("please convert to float32") ln = len(fls) print("first 5 files: ", fls[0:4]) if kwargs['input_format'] == "hdf": sourceNode = gp.Hdf5Source elif kwargs['input_format'] == "zarr": sourceNode = gp.ZarrSource neighborhood = [] psH = np.array(kwargs['patchshape'])//2 for i in range(-psH[1], psH[1]+1, kwargs['patchstride'][1]): for j in range(-psH[2], psH[2]+1, kwargs['patchstride'][2]): neighborhood.append([i,j]) datasets = { raw: raw_key, gt_labels: 'volumes/gt_labels', gt_instances: 'volumes/gt_instances' } array_specs = { raw: gp.ArraySpec(interpolatable=True), gt_labels: gp.ArraySpec(interpolatable=False), gt_instances: gp.ArraySpec(interpolatable=False) } inputs = { net_names['raw']: raw, net_names['gt_affs']: gt_affs, # net_names['loss_weights_affs']: loss_weights_affs, } outputs = { net_names['pred_affs']: pred_affs, net_names['raw_cropped']: raw_cropped, } snapshot = { raw: '/volumes/raw', raw_cropped: 'volumes/raw_cropped', gt_affs: '/volumes/gt_affs', pred_affs: '/volumes/pred_affs', pred_affs_gradients: '/volumes/pred_affs_gradients', } if kwargs['overlapping_inst']: datasets[gt_numinst] = 'volumes/gt_numinst' array_specs[gt_numinst] = gp.ArraySpec(interpolatable=False) inputs[net_names['gt_numinst']] = gt_numinst outputs[net_names['pred_numinst']] = pred_numinst snapshot[gt_numinst] = '/volumes/gt_numinst' snapshot[pred_numinst] = '/volumes/pred_numinst' augmentation = kwargs['augmentation'] sampling = kwargs['sampling'] source_fg = tuple( sourceNode( fls[t] + "." + kwargs['input_format'], datasets=datasets, array_specs=array_specs ) + gp.Pad(raw, context) + # chose a random location for each requested batch nl.CountOverlap(gt_labels, gt_sample_mask, maxnuminst=1) + gp.RandomLocation( min_masked=sampling['min_masked'], mask=gt_sample_mask ) for t in range(ln) ) source_fg += gp.RandomProvider() source_overlap = tuple( sourceNode( fls[t] + "." + kwargs['input_format'], datasets=datasets, array_specs=array_specs ) + gp.Pad(raw, context) + # chose a random location for each requested batch nl.MaskCloseDistanceToOverlap( gt_labels, gt_sample_mask, sampling['overlap_min_dist'], sampling['overlap_max_dist'] ) + gp.RandomLocation( min_masked=sampling['min_masked_overlap'], mask=gt_sample_mask ) for t in range(ln) ) source_overlap += gp.RandomProvider() pipeline = ( (source_fg, source_overlap) + # chose a random source (i.e., sample) from the above gp.RandomProvider(probabilities=[sampling['probability_fg'], sampling['probability_overlap']]) + # elastically deform the batch gp.ElasticAugment( augmentation['elastic']['control_point_spacing'], augmentation['elastic']['jitter_sigma'], [augmentation['elastic']['rotation_min']*np.pi/180.0, augmentation['elastic']['rotation_max']*np.pi/180.0]) + # apply transpose and mirror augmentations gp.SimpleAugment( mirror_only=augmentation['simple'].get("mirror"), transpose_only=augmentation['simple'].get("transpose")) + # # scale and shift the intensity of the raw array gp.IntensityAugment( raw, scale_min=augmentation['intensity']['scale'][0], scale_max=augmentation['intensity']['scale'][1], shift_min=augmentation['intensity']['shift'][0], shift_max=augmentation['intensity']['shift'][1], z_section_wise=False) + gp.IntensityScaleShift(raw, 2, -1) + # convert labels into affinities between voxels nl.AddAffinities( neighborhood, gt_labels, gt_affs, multiple_labels=kwargs['overlapping_inst']) + # pre-cache batches from the point upstream gp.PreCache( cache_size=kwargs['cache_size'], num_workers=kwargs['num_workers']) + # perform one training iteration for each passing batch (here we use # the tensor names earlier stored in train_net.config) gp.tensorflow.Train( os.path.join(kwargs['output_folder'], kwargs['name']), optimizer=net_names['optimizer'], summary=net_names['summaries'], log_dir=kwargs['output_folder'], loss=net_names['loss'], inputs=inputs, outputs=outputs, gradients={ net_names['pred_affs']: pred_affs_gradients, }, save_every=kwargs['checkpoints']) + # save the passing batch as an HDF5 file for inspection gp.Snapshot( snapshot, output_dir=os.path.join(kwargs['output_folder'], 'snapshots'), output_filename='batch_{iteration}.hdf', every=kwargs['snapshots'], additional_request=snapshot_request, compression_type='gzip') + # show a summary of time spend in each node every 10 iterations gp.PrintProfilingStats(every=kwargs['profiling']) ) ######### # TRAIN # ######### print("Starting training...") with gp.build(pipeline): print(pipeline) for i in range(trained_until, kwargs['max_iteration']): # print("request", request) start = time.time() pipeline.request_batch(request) time_of_iteration = time.time() - start logger.info( "Batch: iteration=%d, time=%f", i, time_of_iteration) # exit() print("Training finished")
def create_train_pipeline(self, model): print(f"Creating training pipeline with batch size \ {self.params['batch_size']}") filename = self.params['data_file'] raw_dataset = self.params['dataset']['train']['raw'] gt_dataset = self.params['dataset']['train']['gt'] optimizer = self.params['optimizer'](model.parameters(), **self.params['optimizer_kwargs']) raw = gp.ArrayKey('RAW') gt_labels = gp.ArrayKey('LABELS') points = gp.GraphKey("POINTS") locations = gp.ArrayKey("LOCATIONS") predictions = gp.ArrayKey('PREDICTIONS') emb = gp.ArrayKey('EMBEDDING') raw_data = daisy.open_ds(filename, raw_dataset) source_roi = gp.Roi(raw_data.roi.get_offset(), raw_data.roi.get_shape()) source_voxel_size = gp.Coordinate(raw_data.voxel_size) out_voxel_size = gp.Coordinate(raw_data.voxel_size) # Get in and out shape in_shape = gp.Coordinate(model.in_shape) out_roi = gp.Coordinate(model.base_encoder.out_shape[2:]) is_2d = in_shape.dims() == 2 in_shape = in_shape * out_voxel_size out_roi = out_roi * out_voxel_size out_shape = gp.Coordinate( (self.params["num_points"], *model.out_shape[2:])) context = (in_shape - out_roi) / 2 gt_labels_out_shape = out_roi # Add fake 3rd dim if is_2d: source_voxel_size = gp.Coordinate((1, *source_voxel_size)) source_roi = gp.Roi((0, *source_roi.get_offset()), (raw_data.shape[0], *source_roi.get_shape())) context = gp.Coordinate((0, *context)) gt_labels_out_shape = (1, *gt_labels_out_shape) points_roi = out_voxel_size * tuple((*self.params["point_roi"], )) points_pad = (0, *points_roi) context = gp.Coordinate((0, None, None)) else: points_roi = source_voxel_size * tuple(self.params["point_roi"]) points_pad = points_roi context = gp.Coordinate((None, None, None)) logger.info(f"source roi: {source_roi}") logger.info(f"in_shape: {in_shape}") logger.info(f"out_shape: {out_shape}") logger.info(f"voxel_size: {out_voxel_size}") logger.info(f"context: {context}") logger.info(f"out_voxel_size: {out_voxel_size}") request = gp.BatchRequest() request.add(raw, in_shape) request.add(points, points_roi) request.add(gt_labels, out_roi) request[locations] = gp.ArraySpec(nonspatial=True) request[predictions] = gp.ArraySpec(nonspatial=True) snapshot_request = gp.BatchRequest() snapshot_request[emb] = gp.ArraySpec( roi=gp.Roi((0, ) * in_shape.dims(), gp.Coordinate((*model.base_encoder.out_shape[2:], )) * out_voxel_size)) source = ( (gp.ZarrSource(filename, { raw: raw_dataset, gt_labels: gt_dataset }, array_specs={ raw: gp.ArraySpec(roi=source_roi, voxel_size=source_voxel_size, interpolatable=True), gt_labels: gp.ArraySpec(roi=source_roi, voxel_size=source_voxel_size) }), PointsLabelsSource(points, self.data, scale=source_voxel_size)) + gp.MergeProvider() + gp.Pad(raw, context) + gp.Pad(gt_labels, context) + gp.Pad(points, points_pad) + gp.RandomLocation(ensure_nonempty=points) + gp.Normalize(raw, self.params['norm_factor']) # raw : (source_roi) # gt_labels: (source_roi) # points : (c=1, source_locations_shape) # If 2d then source_roi = (1, input_shape) in order to select a RL ) source = self._augmentation_pipeline(raw, source) pipeline = ( source + # Batches seem to be rejected because points are chosen near the # edge of the points ROI and the augmentations remove them. # TODO: Figure out if this is an actual issue, and if anything can # be done. gp.Reject(ensure_nonempty=points) + SetDtype(gt_labels, np.int64) + # raw : (source_roi) # gt_labels: (source_roi) # points : (c=1, source_locations_shape) AddChannelDim(raw) + AddChannelDim(gt_labels) # raw : (c=1, source_roi) # gt_labels: (c=2, source_roi) # points : (c=1, source_locations_shape) ) if is_2d: pipeline = ( # Remove extra dim the 2d roi had pipeline + RemoveSpatialDim(raw) + RemoveSpatialDim(gt_labels) + RemoveSpatialDim(points) # raw : (c=1, roi) # gt_labels: (c=1, roi) # points : (c=1, locations_shape) ) pipeline = ( pipeline + FillLocations(raw, points, locations, is_2d=False, max_points=1) + gp.Stack(self.params['batch_size']) + gp.PreCache() + # raw : (b, c=1, roi) # gt_labels: (b, c=1, roi) # locations: (b, c=1, locations_shape) # (which is what train requires) gp.torch.Train( model, self.loss, optimizer, inputs={ 'raw': raw, 'points': locations }, loss_inputs={ 0: predictions, 1: gt_labels, 2: locations }, outputs={ 0: predictions, 1: emb }, array_specs={ predictions: gp.ArraySpec(nonspatial=True), emb: gp.ArraySpec(voxel_size=out_voxel_size) }, checkpoint_basename=self.logdir + '/checkpoints/model', save_every=self.params['save_every'], log_dir=self.logdir, log_every=self.log_every) + # everything is 2D at this point, plus extra dimensions for # channels and batch # raw : (b, c=1, roi) # gt_labels : (b, c=1, roi) # predictions: (b, num_points) gp.Snapshot(output_dir=self.logdir + '/snapshots', output_filename='it{iteration}.hdf', dataset_names={ raw: 'raw', gt_labels: 'gt_labels', predictions: 'predictions', emb: 'emb' }, additional_request=snapshot_request, every=self.params['save_every']) + InspectBatch('END') + gp.PrintProfilingStats(every=500)) return pipeline, request
def train_until(max_iteration, return_intermediates=False): # get the latest checkpoint if tf.train.latest_checkpoint('.'): trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1]) else: trained_until = 0 if trained_until >= max_iteration: return # input data ch1 = gp.ArrayKey('CH1') ch2 = gp.ArrayKey('CH2') swc = gp.PointsKey('SWC') swc_env = gp.PointsKey('SWC_ENV') swc_center = gp.PointsKey('SWC_CENTER') gt = gp.ArrayKey('GT') gt_fg = gp.ArrayKey('GT_FG') # show fusion augment batches if return_intermediates: a_ch1 = gp.ArrayKey('A_CH1') a_ch2 = gp.ArrayKey('A_CH2') b_ch1 = gp.ArrayKey('B_CH1') b_ch2 = gp.ArrayKey('B_CH2') soft_mask = gp.ArrayKey('SOFT_MASK') # output data fg = gp.ArrayKey('FG') gradient_fg = gp.ArrayKey('GRADIENT_FG') loss_weights = gp.ArrayKey('LOSS_WEIGHTS') voxel_size = gp.Coordinate((4, 1, 1)) input_size = gp.Coordinate(net_config['input_shape']) * voxel_size output_size = gp.Coordinate(net_config['output_shape']) * voxel_size # add request request = gp.BatchRequest() request.add(ch1, input_size) request.add(ch2, input_size) request.add(swc, input_size) request.add(swc_center, output_size) request.add(gt, output_size) request.add(gt_fg, output_size) # request.add(loss_weights, output_size) if return_intermediates: request.add(a_ch1, input_size) request.add(a_ch2, input_size) request.add(b_ch1, input_size) request.add(b_ch2, input_size) request.add(soft_mask, input_size) # add snapshot request snapshot_request = gp.BatchRequest() # snapshot_request[fg] = request[gt] # snapshot_request[gt_fg] = request[gt] # snapshot_request[gradient_fg] = request[gt] data_sources = tuple() data_sources += tuple( (Hdf5ChannelSource(file, datasets={ ch1: '/volume', ch2: '/volume', }, channel_ids={ ch1: 0, ch2: 1, }, data_format='channels_last', array_specs={ ch1: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size, dtype=np.uint16), ch2: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size, dtype=np.uint16), }), SwcSource(filename=file, dataset='/reconstruction', points=(swc_center, swc), return_env=True, scale=voxel_size)) + gp.MergeProvider() + gp.RandomLocation(ensure_nonempty=swc_center) + RasterizeSkeleton( points=swc, array=gt, array_spec=gp.ArraySpec( interpolatable=False, voxel_size=voxel_size, dtype=np.uint32), points_env=swc_env, iteration=10) for file in files) snapshot_datasets = {} if return_intermediates: snapshot_datasets = { ch1: 'volumes/ch1', ch2: 'volumes/ch2', a_ch1: 'volumes/a_ch1', a_ch2: 'volumes/a_ch2', b_ch1: 'volumes/b_ch1', b_ch2: 'volumes/b_ch2', soft_mask: 'volumes/soft_mask', gt: 'volumes/gt', fg: 'volumes/fg', gt_fg: 'volumes/gt_fg', gradient_fg: 'volumes/gradient_fg', } else: snapshot_datasets = { ch1: 'volumes/ch1', ch2: 'volumes/ch2', gt: 'volumes/gt', fg: 'volumes/fg', gt_fg: 'volumes/gt_fg', gradient_fg: 'volumes/gradient_fg', } pipeline = ( data_sources + #gp.RandomProvider() + FusionAugment(ch1, ch2, gt, smoothness=1, return_intermediate=return_intermediates) + # augment #gp.ElasticAugment(...) + #gp.SimpleAugment() + gp.Normalize(ch1) + gp.Normalize(ch2) + gp.Normalize(a_ch1) + gp.Normalize(a_ch2) + gp.Normalize(b_ch1) + gp.Normalize(b_ch2) + gp.IntensityAugment(ch1, 0.9, 1.1, -0.001, 0.001) + gp.IntensityAugment(ch2, 0.9, 1.1, -0.001, 0.001) + BinarizeGt(gt, gt_fg) + # visualize gp.Snapshot(output_filename='snapshot_{iteration}.hdf', dataset_names=snapshot_datasets, additional_request=snapshot_request, every=20) + gp.PrintProfilingStats(every=1000)) with gp.build(pipeline): print("Starting training...") for i in range(max_iteration - trained_until): pipeline.request_batch(request)
def train_until(max_iteration, name='train_net', output_folder='.', clip_max=2000): # get the latest checkpoint if tf.train.latest_checkpoint(output_folder): trained_until = int(tf.train.latest_checkpoint(output_folder).split('_')[-1]) else: trained_until = 0 if trained_until >= max_iteration: return with open(os.path.join(output_folder, name + '_config.json'), 'r') as f: net_config = json.load(f) with open(os.path.join(output_folder, name + '_names.json'), 'r') as f: net_names = json.load(f) # array keys raw = gp.ArrayKey('RAW') gt_mask = gp.ArrayKey('GT_MASK') gt_dt = gp.ArrayKey('GT_DT') pred_dt = gp.ArrayKey('PRED_DT') loss_gradient = gp.ArrayKey('LOSS_GRADIENT') voxel_size = gp.Coordinate((1, 1, 1)) input_shape = gp.Coordinate(net_config['input_shape']) output_shape = gp.Coordinate(net_config['output_shape']) context = gp.Coordinate(input_shape - output_shape) / 2 request = gp.BatchRequest() request.add(raw, input_shape) request.add(gt_mask, output_shape) request.add(gt_dt, output_shape) snapshot_request = gp.BatchRequest() snapshot_request.add(raw, input_shape) snapshot_request.add(gt_mask, output_shape) snapshot_request.add(gt_dt, output_shape) snapshot_request.add(pred_dt, output_shape) snapshot_request.add(loss_gradient, output_shape) # specify data source data_sources = tuple() for data_file in data_files: current_path = os.path.join(data_dir, data_file) with h5py.File(current_path, 'r') as f: data_sources += tuple( gp.Hdf5Source( current_path, datasets={ raw: sample + '/raw', gt_mask: sample + '/fg' }, array_specs={ raw: gp.ArraySpec(interpolatable=True, dtype=np.uint16, voxel_size=voxel_size), gt_mask: gp.ArraySpec(interpolatable=False, dtype=np.bool, voxel_size=voxel_size), } ) + Convert(gt_mask, np.uint8) + gp.Pad(raw, context) + gp.Pad(gt_mask, context) + gp.RandomLocation() for sample in f) pipeline = ( data_sources + gp.RandomProvider() + gp.Reject(gt_mask, min_masked=0.005, reject_probability=1.) + DistanceTransform(gt_mask, gt_dt, 3) + nl.Clip(raw, 0, clip_max) + gp.Normalize(raw, factor=1.0/clip_max) + gp.ElasticAugment( control_point_spacing=[20, 20, 20], jitter_sigma=[1, 1, 1], rotation_interval=[0, math.pi/2.0], subsample=4) + gp.SimpleAugment(mirror_only=[1,2], transpose_only=[1,2]) + gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1) + gp.IntensityScaleShift(raw, 2,-1) + # train gp.PreCache( cache_size=40, num_workers=5) + gp.tensorflow.Train( os.path.join(output_folder, name), optimizer=net_names['optimizer'], loss=net_names['loss'], inputs={ net_names['raw']: raw, net_names['gt_dt']: gt_dt, }, outputs={ net_names['pred_dt']: pred_dt, }, gradients={ net_names['pred_dt']: loss_gradient, }, save_every=5000) + # visualize gp.Snapshot({ raw: 'volumes/raw', gt_mask: 'volumes/gt_mask', gt_dt: 'volumes/gt_dt', pred_dt: 'volumes/pred_dt', loss_gradient: 'volumes/gradient', }, output_filename=os.path.join(output_folder, 'snapshots', 'batch_{iteration}.hdf'), additional_request=snapshot_request, every=2000) + gp.PrintProfilingStats(every=500) ) with gp.build(pipeline): print("Starting training...") for i in range(max_iteration - trained_until): pipeline.request_batch(request)
def train_until(max_iteration): # get the latest checkpoint if tf.train.latest_checkpoint("."): trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1]) else: trained_until = 0 if trained_until >= max_iteration: return # array keys for fused volume raw = gp.ArrayKey("RAW") labels = gp.ArrayKey("LABELS") labels_fg = gp.ArrayKey("LABELS_FG") # array keys for base volume raw_base = gp.ArrayKey("RAW_BASE") labels_base = gp.ArrayKey("LABELS_BASE") swc_base = gp.PointsKey("SWC_BASE") swc_center_base = gp.PointsKey("SWC_CENTER_BASE") # array keys for add volume raw_add = gp.ArrayKey("RAW_ADD") labels_add = gp.ArrayKey("LABELS_ADD") swc_add = gp.PointsKey("SWC_ADD") swc_center_add = gp.PointsKey("SWC_CENTER_ADD") # output data fg = gp.ArrayKey("FG") gradient_fg = gp.ArrayKey("GRADIENT_FG") loss_weights = gp.ArrayKey("LOSS_WEIGHTS") voxel_size = gp.Coordinate((4, 1, 1)) input_size = gp.Coordinate(net_config["input_shape"]) * voxel_size output_size = gp.Coordinate(net_config["output_shape"]) * voxel_size # add request request = gp.BatchRequest() request.add(raw, input_size) request.add(labels, output_size) request.add(labels_fg, output_size) request.add(loss_weights, output_size) request.add(swc_center_base, output_size) request.add(swc_center_add, output_size) # add snapshot request snapshot_request = gp.BatchRequest() snapshot_request.add(fg, output_size) snapshot_request.add(labels_fg, output_size) snapshot_request.add(gradient_fg, output_size) snapshot_request.add(raw_base, input_size) snapshot_request.add(raw_add, input_size) snapshot_request.add(labels_base, input_size) snapshot_request.add(labels_add, input_size) # data source for "base" volume data_sources_base = tuple( ( gp.Hdf5Source( filename, datasets={raw_base: "/volume"}, array_specs={ raw_base: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size, dtype=np.uint16) }, channels_first=False, ), SwcSource( filename=filename, dataset="/reconstruction", points=(swc_center_base, swc_base), scale=voxel_size, ), ) + gp.MergeProvider() + gp.RandomLocation(ensure_nonempty=swc_center_base) + RasterizeSkeleton( points=swc_base, array=labels_base, array_spec=gp.ArraySpec( interpolatable=False, voxel_size=voxel_size, dtype=np.uint32), radius=5.0, ) for filename in files) # data source for "add" volume data_sources_add = tuple( ( gp.Hdf5Source( file, datasets={raw_add: "/volume"}, array_specs={ raw_add: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size, dtype=np.uint16) }, channels_first=False, ), SwcSource( filename=file, dataset="/reconstruction", points=(swc_center_add, swc_add), scale=voxel_size, ), ) + gp.MergeProvider() + gp.RandomLocation(ensure_nonempty=swc_center_add) + RasterizeSkeleton( points=swc_add, array=labels_add, array_spec=gp.ArraySpec( interpolatable=False, voxel_size=voxel_size, dtype=np.uint32), radius=5.0, ) for file in files) data_sources = ( (data_sources_base + gp.RandomProvider()), (data_sources_add + gp.RandomProvider()), ) + gp.MergeProvider() pipeline = ( data_sources + FusionAugment( raw_base, raw_add, labels_base, labels_add, raw, labels, blend_mode="labels_mask", blend_smoothness=10, num_blended_objects=0, ) + # augment gp.ElasticAugment([40, 10, 10], [0.25, 1, 1], [0, math.pi / 2.0], subsample=4) + gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2]) + gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001) + BinarizeGt(labels, labels_fg) + gp.BalanceLabels(labels_fg, loss_weights) + # train gp.PreCache(cache_size=40, num_workers=10) + gp.tensorflow.Train( "./train_net", optimizer=net_names["optimizer"], loss=net_names["loss"], inputs={ net_names["raw"]: raw, net_names["labels_fg"]: labels_fg, net_names["loss_weights"]: loss_weights, }, outputs={net_names["fg"]: fg}, gradients={net_names["fg"]: gradient_fg}, save_every=100000, ) + # visualize gp.Snapshot( output_filename="snapshot_{iteration}.hdf", dataset_names={ raw: "volumes/raw", raw_base: "volumes/raw_base", raw_add: "volumes/raw_add", labels: "volumes/labels", labels_base: "volumes/labels_base", labels_add: "volumes/labels_add", fg: "volumes/fg", labels_fg: "volumes/labels_fg", gradient_fg: "volumes/gradient_fg", }, additional_request=snapshot_request, every=100, ) + gp.PrintProfilingStats(every=100)) with gp.build(pipeline): print("Starting training...") for i in range(max_iteration - trained_until): pipeline.request_batch(request)