def get_labels_snapshot_source(config, blocks): validation_blocks = Path(config["VALIDATION_BLOCKS"]) labels = gp.ArrayKey("LABELS") gt = gp.GraphKey("GT") block_pipelines = [] for block in blocks: pipeline = SnapshotSource( validation_blocks / f"block_{block}.hdf", { labels: "volumes/labels", gt: "points/gt" }, directed={gt: True}, ) block_pipelines.append(pipeline) return block_pipelines, (labels, gt)
def validation_data_sources_from_snapshots(config, blocks): validation_blocks = Path(config["VALIDATION_BLOCKS"]) raw = gp.ArrayKey("RAW") ground_truth = gp.GraphKey("GROUND_TRUTH") labels = gp.ArrayKey("LABELS") voxel_size = gp.Coordinate(config["VOXEL_SIZE"]) input_shape = gp.Coordinate(config["INPUT_SHAPE"]) output_shape = gp.Coordinate(config["OUTPUT_SHAPE"]) input_size = voxel_size * input_shape output_size = voxel_size * output_shape block_pipelines = [] for block in blocks: pipelines = ( SnapshotSource( validation_blocks / f"block_{block}.hdf", { labels: "volumes/labels", ground_truth: "points/gt" }, directed={ground_truth: True}, ), SnapshotSource(validation_blocks / f"block_{block}.hdf", {raw: "volumes/raw"}), ) cube_roi = get_cube_roi(config, block) request = gp.BatchRequest() input_roi = cube_roi.grow((input_size - output_size) // 2, (input_size - output_size) // 2) request[raw] = gp.ArraySpec(input_roi) request[ground_truth] = gp.GraphSpec(cube_roi) request[labels] = gp.ArraySpec(cube_roi) block_pipelines.append((pipelines, request)) return block_pipelines, (raw, labels, ground_truth)
def create_train_pipeline(self, model): print(f"Creating training pipeline with batch size \ {self.params['batch_size']}") filename = self.params['data_file'] raw_dataset = self.params['dataset']['train']['raw'] gt_dataset = self.params['dataset']['train']['gt'] optimizer = self.params['optimizer'](model.parameters(), **self.params['optimizer_kwargs']) raw = gp.ArrayKey('RAW') gt_labels = gp.ArrayKey('LABELS') points = gp.GraphKey("POINTS") locations = gp.ArrayKey("LOCATIONS") predictions = gp.ArrayKey('PREDICTIONS') emb = gp.ArrayKey('EMBEDDING') raw_data = daisy.open_ds(filename, raw_dataset) source_roi = gp.Roi(raw_data.roi.get_offset(), raw_data.roi.get_shape()) source_voxel_size = gp.Coordinate(raw_data.voxel_size) out_voxel_size = gp.Coordinate(raw_data.voxel_size) # Get in and out shape in_shape = gp.Coordinate(model.in_shape) out_roi = gp.Coordinate(model.base_encoder.out_shape[2:]) is_2d = in_shape.dims() == 2 in_shape = in_shape * out_voxel_size out_roi = out_roi * out_voxel_size out_shape = gp.Coordinate( (self.params["num_points"], *model.out_shape[2:])) context = (in_shape - out_roi) / 2 gt_labels_out_shape = out_roi # Add fake 3rd dim if is_2d: source_voxel_size = gp.Coordinate((1, *source_voxel_size)) source_roi = gp.Roi((0, *source_roi.get_offset()), (raw_data.shape[0], *source_roi.get_shape())) context = gp.Coordinate((0, *context)) gt_labels_out_shape = (1, *gt_labels_out_shape) points_roi = out_voxel_size * tuple((*self.params["point_roi"], )) points_pad = (0, *points_roi) context = gp.Coordinate((0, None, None)) else: points_roi = source_voxel_size * tuple(self.params["point_roi"]) points_pad = points_roi context = gp.Coordinate((None, None, None)) logger.info(f"source roi: {source_roi}") logger.info(f"in_shape: {in_shape}") logger.info(f"out_shape: {out_shape}") logger.info(f"voxel_size: {out_voxel_size}") logger.info(f"context: {context}") logger.info(f"out_voxel_size: {out_voxel_size}") request = gp.BatchRequest() request.add(raw, in_shape) request.add(points, points_roi) request.add(gt_labels, out_roi) request[locations] = gp.ArraySpec(nonspatial=True) request[predictions] = gp.ArraySpec(nonspatial=True) snapshot_request = gp.BatchRequest() snapshot_request[emb] = gp.ArraySpec( roi=gp.Roi((0, ) * in_shape.dims(), gp.Coordinate((*model.base_encoder.out_shape[2:], )) * out_voxel_size)) source = ( (gp.ZarrSource(filename, { raw: raw_dataset, gt_labels: gt_dataset }, array_specs={ raw: gp.ArraySpec(roi=source_roi, voxel_size=source_voxel_size, interpolatable=True), gt_labels: gp.ArraySpec(roi=source_roi, voxel_size=source_voxel_size) }), PointsLabelsSource(points, self.data, scale=source_voxel_size)) + gp.MergeProvider() + gp.Pad(raw, context) + gp.Pad(gt_labels, context) + gp.Pad(points, points_pad) + gp.RandomLocation(ensure_nonempty=points) + gp.Normalize(raw, self.params['norm_factor']) # raw : (source_roi) # gt_labels: (source_roi) # points : (c=1, source_locations_shape) # If 2d then source_roi = (1, input_shape) in order to select a RL ) source = self._augmentation_pipeline(raw, source) pipeline = ( source + # Batches seem to be rejected because points are chosen near the # edge of the points ROI and the augmentations remove them. # TODO: Figure out if this is an actual issue, and if anything can # be done. gp.Reject(ensure_nonempty=points) + SetDtype(gt_labels, np.int64) + # raw : (source_roi) # gt_labels: (source_roi) # points : (c=1, source_locations_shape) AddChannelDim(raw) + AddChannelDim(gt_labels) # raw : (c=1, source_roi) # gt_labels: (c=2, source_roi) # points : (c=1, source_locations_shape) ) if is_2d: pipeline = ( # Remove extra dim the 2d roi had pipeline + RemoveSpatialDim(raw) + RemoveSpatialDim(gt_labels) + RemoveSpatialDim(points) # raw : (c=1, roi) # gt_labels: (c=1, roi) # points : (c=1, locations_shape) ) pipeline = ( pipeline + FillLocations(raw, points, locations, is_2d=False, max_points=1) + gp.Stack(self.params['batch_size']) + gp.PreCache() + # raw : (b, c=1, roi) # gt_labels: (b, c=1, roi) # locations: (b, c=1, locations_shape) # (which is what train requires) gp.torch.Train( model, self.loss, optimizer, inputs={ 'raw': raw, 'points': locations }, loss_inputs={ 0: predictions, 1: gt_labels, 2: locations }, outputs={ 0: predictions, 1: emb }, array_specs={ predictions: gp.ArraySpec(nonspatial=True), emb: gp.ArraySpec(voxel_size=out_voxel_size) }, checkpoint_basename=self.logdir + '/checkpoints/model', save_every=self.params['save_every'], log_dir=self.logdir, log_every=self.log_every) + # everything is 2D at this point, plus extra dimensions for # channels and batch # raw : (b, c=1, roi) # gt_labels : (b, c=1, roi) # predictions: (b, num_points) gp.Snapshot(output_dir=self.logdir + '/snapshots', output_filename='it{iteration}.hdf', dataset_names={ raw: 'raw', gt_labels: 'gt_labels', predictions: 'predictions', emb: 'emb' }, additional_request=snapshot_request, every=self.params['save_every']) + InspectBatch('END') + gp.PrintProfilingStats(every=500)) return pipeline, request
def validation_data_sources_recomputed(config, blocks): benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"]) sample = config["VALIDATION_SAMPLES"][0] sample_dir = Path(config["SAMPLES_PATH"]) raw_n5 = config["RAW_N5"] transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt" neuron_width = int(config["NEURON_RADIUS"]) voxel_size = gp.Coordinate(config["VOXEL_SIZE"]) input_shape = gp.Coordinate(config["INPUT_SHAPE"]) output_shape = gp.Coordinate(config["OUTPUT_SHAPE"]) input_size = voxel_size * input_shape output_size = voxel_size * output_shape validation_dirs = {} for group in benchmark_datasets_path.iterdir(): if "validation" in group.name and group.is_dir(): for validation_dir in group.iterdir(): validation_num = int(validation_dir.name.split("_")[-1]) if validation_num in blocks: validation_dirs[validation_num] = validation_dir validation_dirs = [validation_dirs[block] for block in blocks] raw = gp.ArrayKey("RAW") ground_truth = gp.GraphKey("GROUND_TRUTH") labels = gp.ArrayKey("LABELS") validation_pipelines = [] for validation_dir in validation_dirs: trees = [] cube = None for gt_file in validation_dir.iterdir(): if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc": trees.append(gt_file) if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc": cube = gt_file assert cube.exists() cube_roi = get_roi_from_swc( cube, Path(transform_template.format(sample=sample)), np.array([300, 300, 1000]), ) pipeline = (( gp.ZarrSource( filename=str(Path(sample_dir, sample, raw_n5).absolute()), datasets={raw: "volume-rechunked"}, array_specs={ raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size) }, ), nl.gunpowder.nodes.MouselightSwcFileSource( validation_dir, [ground_truth], transform_file=transform_template.format(sample=sample), ignore_human_nodes=False, scale=voxel_size, transpose=[2, 1, 0], points_spec=[ gp.PointsSpec(roi=gp.Roi( gp.Coordinate([None, None, None]), gp.Coordinate([None, None, None]), )) ], ), ) + gp.nodes.MergeProvider() + gp.Normalize( raw, dtype=np.float32) + nl.gunpowder.RasterizeSkeleton( ground_truth, labels, connected_component_labeling=True, array_spec=gp.ArraySpec( voxel_size=voxel_size, dtype=np.int64, roi=gp.Roi( gp.Coordinate([None, None, None]), gp.Coordinate([None, None, None]), ), ), ) + nl.gunpowder.GrowLabels(labels, radii=[neuron_width * 1000])) request = gp.BatchRequest() input_roi = cube_roi.grow((input_size - output_size) // 2, (input_size - output_size) // 2) print(f"input_roi has shape: {input_roi.get_shape()}") print(f"cube_roi has shape: {cube_roi.get_shape()}") request[raw] = gp.ArraySpec(input_roi) request[ground_truth] = gp.GraphSpec(cube_roi) request[labels] = gp.ArraySpec(cube_roi) validation_pipelines.append((pipeline, request)) return validation_pipelines, (raw, labels, ground_truth)
def random_point_pairs_pipeline(model, loss, optimizer, dataset, augmentation_parameters, point_density, out_dir, normalize_factor=None, checkpoint_interval=5000, snapshot_interval=5000): raw_0 = gp.ArrayKey('RAW_0') points_0 = gp.GraphKey('POINTS_0') locations_0 = gp.ArrayKey('LOCATIONS_0') emb_0 = gp.ArrayKey('EMBEDDING_0') raw_1 = gp.ArrayKey('RAW_1') points_1 = gp.GraphKey('POINTS_1') locations_1 = gp.ArrayKey('LOCATIONS_1') emb_1 = gp.ArrayKey('EMBEDDING_1') # TODO parse this key from somewhere key = 'train/raw/0' data = daisy.open_ds(dataset.filename, key) source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape()) voxel_size = gp.Coordinate(data.voxel_size) emb_voxel_size = voxel_size # Get in and out shape in_shape = gp.Coordinate(model.in_shape) out_shape = gp.Coordinate(model.out_shape) logger.info(f"source roi: {source_roi}") logger.info(f"in_shape: {in_shape}") logger.info(f"out_shape: {out_shape}") logger.info(f"voxel_size: {voxel_size}") request = gp.BatchRequest() request.add(raw_0, in_shape) request.add(raw_1, in_shape) request.add(points_0, out_shape) request.add(points_1, out_shape) request[locations_0] = gp.ArraySpec(nonspatial=True) request[locations_1] = gp.ArraySpec(nonspatial=True) snapshot_request = gp.BatchRequest() snapshot_request[emb_0] = gp.ArraySpec(roi=request[points_0].roi) snapshot_request[emb_1] = gp.ArraySpec(roi=request[points_1].roi) # Let's hardcode this for now # TODO read actual number from zarr file keys n_samples = 447 batch_size = 1 dim = 2 padding = (100, 100) sources = [] for i in range(n_samples): ds_key = f'train/raw/{i}' image_sources = tuple( gp.ZarrSource( dataset.filename, {raw: ds_key}, {raw: gp.ArraySpec(interpolatable=True, voxel_size=(1, 1))}) + gp.Pad(raw, None) for raw in [raw_0, raw_1]) random_point_generator = RandomPointGenerator(density=point_density, repetitions=2) point_sources = tuple( (RandomPointSource(points_0, dim, random_point_generator=random_point_generator), RandomPointSource(points_1, dim, random_point_generator=random_point_generator))) # TODO: get augmentation parameters from some config file! points_and_image_sources = tuple( (img_source, point_source) + gp.MergeProvider() + \ gp.SimpleAugment() + \ gp.ElasticAugment( spatial_dims=2, control_point_spacing=(10, 10), jitter_sigma=(0.0, 0.0), rotation_interval=(0, math.pi/2)) + \ gp.IntensityAugment(r, scale_min=0.8, scale_max=1.2, shift_min=-0.2, shift_max=0.2, clip=False) + \ gp.NoiseAugment(r, var=0.01, clip=False) for r, img_source, point_source in zip([raw_0, raw_1], image_sources, point_sources)) sample_source = points_and_image_sources + gp.MergeProvider() data = daisy.open_ds(dataset.filename, ds_key) source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape()) sample_source += gp.Crop(raw_0, source_roi) sample_source += gp.Crop(raw_1, source_roi) sample_source += gp.Pad(raw_0, padding) sample_source += gp.Pad(raw_1, padding) sample_source += gp.RandomLocation() sources.append(sample_source) sources = tuple(sources) pipeline = sources + gp.RandomProvider() pipeline += gp.Unsqueeze([raw_0, raw_1]) pipeline += PrepareBatch(raw_0, raw_1, points_0, points_1, locations_0, locations_1) # How does prepare batch relate to Stack????? pipeline += RejectArray(ensure_nonempty=locations_1) pipeline += RejectArray(ensure_nonempty=locations_0) # batch content # raw_0: (1, h, w) # raw_1: (1, h, w) # locations_0: (n, 2) # locations_1: (n, 2) pipeline += gp.Stack(batch_size) # batch content # raw_0: (b, 1, h, w) # raw_1: (b, 1, h, w) # locations_0: (b, n, 2) # locations_1: (b, n, 2) pipeline += gp.PreCache(num_workers=10) pipeline += gp.torch.Train( model, loss, optimizer, inputs={ 'raw_0': raw_0, 'raw_1': raw_1 }, loss_inputs={ 'emb_0': emb_0, 'emb_1': emb_1, 'locations_0': locations_0, 'locations_1': locations_1 }, outputs={ 2: emb_0, 3: emb_1 }, array_specs={ emb_0: gp.ArraySpec(voxel_size=emb_voxel_size), emb_1: gp.ArraySpec(voxel_size=emb_voxel_size) }, checkpoint_basename=os.path.join(out_dir, 'model'), save_every=checkpoint_interval) pipeline += gp.Snapshot( { raw_0: 'raw_0', raw_1: 'raw_1', emb_0: 'emb_0', emb_1: 'emb_1', # locations_0 : 'locations_0', # locations_1 : 'locations_1', }, every=snapshot_interval, additional_request=snapshot_request) return pipeline, request
def validation_pipeline(config): """ Per block { Raw -> predict -> scan gt -> rasterize -> merge -> candidates -> trees } -> merge -> comatch + evaluate """ blocks = config["BLOCKS"] benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"]) sample = config["VALIDATION_SAMPLES"][0] sample_dir = Path(config["SAMPLES_PATH"]) raw_n5 = config["RAW_N5"] transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt" neuron_width = int(config["NEURON_RADIUS"]) voxel_size = gp.Coordinate(config["VOXEL_SIZE"]) micron_scale = max(voxel_size) input_shape = gp.Coordinate(config["INPUT_SHAPE"]) output_shape = gp.Coordinate(config["OUTPUT_SHAPE"]) input_size = voxel_size * input_shape output_size = voxel_size * output_shape distance_attr = config["DISTANCE_ATTR"] validation_pipelines = [] specs = {} for block in blocks: validation_dir = get_validation_dir(benchmark_datasets_path, block) trees = [] cube = None for gt_file in validation_dir.iterdir(): if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc": trees.append(gt_file) if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc": cube = gt_file assert cube.exists() cube_roi = get_roi_from_swc( cube, Path(transform_template.format(sample=sample)), np.array([300, 300, 1000]), ) raw = gp.ArrayKey(f"RAW_{block}") raw_clahed = gp.ArrayKey(f"RAW_CLAHED_{block}") ground_truth = gp.GraphKey(f"GROUND_TRUTH_{block}") labels = gp.ArrayKey(f"LABELS_{block}") raw_source = (gp.ZarrSource( filename=str(Path(sample_dir, sample, raw_n5).absolute()), datasets={ raw: "volume-rechunked", raw_clahed: "volume-rechunked" }, array_specs={ raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size), raw_clahed: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size), }, ) + gp.Normalize(raw, dtype=np.float32) + gp.Normalize(raw_clahed, dtype=np.float32) + scipyCLAHE([raw_clahed], [20, 64, 64])) swc_source = nl.gunpowder.nodes.MouselightSwcFileSource( validation_dir, [ground_truth], transform_file=transform_template.format(sample=sample), ignore_human_nodes=False, scale=voxel_size, transpose=[2, 1, 0], points_spec=[ gp.PointsSpec(roi=gp.Roi( gp.Coordinate([None, None, None]), gp.Coordinate([None, None, None]), )) ], ) additional_request = BatchRequest() input_roi = cube_roi.grow((input_size - output_size) // 2, (input_size - output_size) // 2) cube_roi_shifted = gp.Roi((0, ) * len(cube_roi.get_shape()), cube_roi.get_shape()) input_roi = cube_roi_shifted.grow((input_size - output_size) // 2, (input_size - output_size) // 2) block_spec = specs.setdefault(block, {}) block_spec[raw] = gp.ArraySpec(input_roi) additional_request[raw] = gp.ArraySpec(roi=input_roi) block_spec[raw_clahed] = gp.ArraySpec(input_roi) additional_request[raw_clahed] = gp.ArraySpec(roi=input_roi) block_spec[ground_truth] = gp.GraphSpec(cube_roi_shifted) additional_request[ground_truth] = gp.GraphSpec(roi=cube_roi_shifted) block_spec[labels] = gp.ArraySpec(cube_roi_shifted) additional_request[labels] = gp.ArraySpec(roi=cube_roi_shifted) pipeline = ((swc_source, raw_source) + gp.nodes.MergeProvider() + gp.SpecifiedLocation(locations=[cube_roi.get_center()]) + gp.Crop(raw, roi=input_roi) + gp.Crop(raw_clahed, roi=input_roi) + gp.Crop(ground_truth, roi=cube_roi_shifted) + nl.gunpowder.RasterizeSkeleton( ground_truth, labels, connected_component_labeling=True, array_spec=gp.ArraySpec( voxel_size=voxel_size, dtype=np.int64, roi=gp.Roi( gp.Coordinate([None, None, None]), gp.Coordinate([None, None, None]), ), ), ) + nl.gunpowder.GrowLabels( labels, radii=[neuron_width * micron_scale]) + gp.Crop(labels, roi=cube_roi_shifted) + gp.Snapshot( { raw: f"volumes/{block}/raw", raw_clahed: f"volumes/{block}/raw_clahe", ground_truth: f"points/{block}/ground_truth", labels: f"volumes/{block}/labels", }, additional_request=additional_request, output_dir="validations", output_filename="validations.hdf", )) validation_pipelines.append(pipeline) validation_pipeline = (tuple(pipeline for pipeline in validation_pipelines) + gp.MergeProvider() + gp.PrintProfilingStats()) return validation_pipeline, specs
def train_until(max_iteration): in_channels = 1 num_fmaps = 12 fmap_inc_factors = 6 downsample_factors = [(1, 3, 3), (1, 3, 3), (3, 3, 3)] unet = UNet(in_channels, num_fmaps, fmap_inc_factors, downsample_factors, constant_upsample=True) model = Convolve(unet, 12, 1) loss = torch.nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=1e-6) # start of gunpowder part: raw = gp.ArrayKey('RAW') points = gp.GraphKey('POINTS') groundtruth = gp.ArrayKey('RASTER') prediction = gp.ArrayKey('PRED_POINT') grad = gp.ArrayKey('GRADIENT') voxel_size = gp.Coordinate((40, 4, 4)) input_shape = (96, 430, 430) output_shape = (60, 162, 162) input_size = gp.Coordinate(input_shape) * voxel_size output_size = gp.Coordinate(output_shape) * voxel_size request = gp.BatchRequest() request.add(raw, input_size) request.add(points, output_size) request.add(groundtruth, output_size) request.add(prediction, output_size) request.add(grad, output_size) pos_sources = tuple( gp.ZarrSource(filename, {raw: 'volumes/raw'}, {raw: gp.ArraySpec(interpolatable=True)}) + AddCenterPoint(points, raw) + gp.Pad(raw, None) + gp.RandomLocation(ensure_nonempty=points) for filename in pos_samples) + gp.RandomProvider() neg_sources = tuple( gp.ZarrSource(filename, {raw: 'volumes/raw'}, {raw: gp.ArraySpec(interpolatable=True)}) + AddNoPoint(points, raw) + gp.RandomLocation() for filename in neg_samples) + gp.RandomProvider() data_sources = (pos_sources, neg_sources) data_sources += gp.RandomProvider(probabilities=[0.9, 0.1]) data_sources += gp.Normalize(raw) train_pipeline = data_sources train_pipeline += gp.ElasticAugment(control_point_spacing=[4, 40, 40], jitter_sigma=[0, 2, 2], rotation_interval=[0, math.pi / 2.0], prob_slip=0.05, prob_shift=0.05, max_misalign=10, subsample=8) train_pipeline += gp.SimpleAugment(transpose_only=[1, 2]) train_pipeline += gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1, \ z_section_wise=True) train_pipeline += gp.RasterizePoints( points, groundtruth, array_spec=gp.ArraySpec(voxel_size=voxel_size), settings=gp.RasterizationSettings(radius=(100, 100, 100), mode='peak')) train_pipeline += gp.PreCache(cache_size=40, num_workers=10) train_pipeline += Reshape(raw, (1, 1) + input_shape) train_pipeline += Reshape(groundtruth, (1, 1) + output_shape) train_pipeline += gp_torch.Train(model=model, loss=loss, optimizer=optimizer, inputs={'x': raw}, outputs={0: prediction}, loss_inputs={ 0: prediction, 1: groundtruth }, gradients={0: grad}, save_every=1000, log_dir='log') train_pipeline += Reshape(raw, input_shape) train_pipeline += Reshape(groundtruth, output_shape) train_pipeline += Reshape(prediction, output_shape) train_pipeline += Reshape(grad, output_shape) train_pipeline += gp.Snapshot( { raw: 'volumes/raw', groundtruth: 'volumes/groundtruth', prediction: 'volumes/prediction', grad: 'volumes/gradient' }, every=500, output_filename='test_{iteration}.hdf') train_pipeline += gp.PrintProfilingStats(every=10) with gp.build(train_pipeline): for i in range(max_iteration): train_pipeline.request_batch(request)
def validation_pipeline(config): """ Per block { Raw -> predict -> scan gt -> rasterize -> merge -> candidates -> trees } -> merge -> comatch + evaluate """ blocks = config["BLOCKS"] benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"]) sample = config["VALIDATION_SAMPLES"][0] sample_dir = Path(config["SAMPLES_PATH"]) raw_n5 = config["RAW_N5"] transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt" neuron_width = int(config["NEURON_RADIUS"]) voxel_size = gp.Coordinate(config["VOXEL_SIZE"]) micron_scale = max(voxel_size) input_shape = gp.Coordinate(config["INPUT_SHAPE"]) output_shape = gp.Coordinate(config["OUTPUT_SHAPE"]) input_size = voxel_size * input_shape output_size = voxel_size * output_shape distance_attr = config["DISTANCE_ATTR"] candidate_threshold = config["NMS_THRESHOLD"] candidate_spacing = min(config["NMS_WINDOW_SIZE"]) * micron_scale coordinate_scale = config["COORDINATE_SCALE"] * np.array( voxel_size) / micron_scale emb_model = get_emb_model(config) fg_model = get_fg_model(config) validation_pipelines = [] specs = {} for block in blocks: validation_dir = get_validation_dir(benchmark_datasets_path, block) trees = [] cube = None for gt_file in validation_dir.iterdir(): if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc": trees.append(gt_file) if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc": cube = gt_file assert cube.exists() cube_roi = get_roi_from_swc( cube, Path(transform_template.format(sample=sample)), np.array([300, 300, 1000]), ) raw = gp.ArrayKey(f"RAW_{block}") ground_truth = gp.GraphKey(f"GROUND_TRUTH_{block}") labels = gp.ArrayKey(f"LABELS_{block}") candidates = gp.ArrayKey(f"CANDIDATES_{block}") mst = gp.GraphKey(f"MST_{block}") raw_source = (gp.ZarrSource( filename=str(Path(sample_dir, sample, raw_n5).absolute()), datasets={raw: "volume-rechunked"}, array_specs={ raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size) }, ) + gp.Normalize(raw, dtype=np.float32) + mCLAHE([raw], [20, 64, 64])) emb_source, emb = add_emb_pred(config, raw_source, raw, block, emb_model) pred_source, fg = add_fg_pred(config, emb_source, raw, block, fg_model) pred_source = add_scan(pred_source, { raw: input_size, emb: output_size, fg: output_size }) swc_source = nl.gunpowder.nodes.MouselightSwcFileSource( validation_dir, [ground_truth], transform_file=transform_template.format(sample=sample), ignore_human_nodes=False, scale=voxel_size, transpose=[2, 1, 0], points_spec=[ gp.PointsSpec(roi=gp.Roi( gp.Coordinate([None, None, None]), gp.Coordinate([None, None, None]), )) ], ) additional_request = BatchRequest() input_roi = cube_roi.grow((input_size - output_size) // 2, (input_size - output_size) // 2) block_spec = specs.setdefault(block, {}) block_spec["raw"] = (raw, gp.ArraySpec(input_roi)) additional_request[raw] = gp.ArraySpec(roi=input_roi) block_spec["ground_truth"] = (ground_truth, gp.GraphSpec(cube_roi)) additional_request[ground_truth] = gp.GraphSpec(roi=cube_roi) block_spec["labels"] = (labels, gp.ArraySpec(cube_roi)) additional_request[labels] = gp.ArraySpec(roi=cube_roi) block_spec["fg_pred"] = (fg, gp.ArraySpec(cube_roi)) additional_request[fg] = gp.ArraySpec(roi=cube_roi) block_spec["emb_pred"] = (emb, gp.ArraySpec(cube_roi)) additional_request[emb] = gp.ArraySpec(roi=cube_roi) block_spec["candidates"] = (candidates, gp.ArraySpec(cube_roi)) additional_request[candidates] = gp.ArraySpec(roi=cube_roi) block_spec["mst_pred"] = (mst, gp.GraphSpec(cube_roi)) additional_request[mst] = gp.GraphSpec(roi=cube_roi) pipeline = ((swc_source, pred_source) + gp.nodes.MergeProvider() + nl.gunpowder.RasterizeSkeleton( ground_truth, labels, connected_component_labeling=True, array_spec=gp.ArraySpec( voxel_size=voxel_size, dtype=np.int64, roi=gp.Roi( gp.Coordinate([None, None, None]), gp.Coordinate([None, None, None]), ), ), ) + nl.gunpowder.GrowLabels( labels, radii=[neuron_width * micron_scale]) + Skeletonize(fg, candidates, candidate_spacing, candidate_threshold) + EMST( emb, candidates, mst, distance_attr=distance_attr, coordinate_scale=coordinate_scale, ) + gp.Snapshot( { raw: f"volumes/{raw}", ground_truth: f"points/{ground_truth}", labels: f"volumes/{labels}", fg: f"volumes/{fg}", emb: f"volumes/{emb}", candidates: f"volumes/{candidates}", mst: f"points/{mst}", }, additional_request=additional_request, output_dir="snapshots", output_filename="{id}.hdf", edge_attrs={mst: [distance_attr]}, )) validation_pipelines.append(pipeline) full_gt = gp.GraphKey("FULL_GT") full_mst = gp.GraphKey("FULL_MST") score = gp.ArrayKey("SCORE") validation_pipeline = ( tuple(pipeline for pipeline in validation_pipelines) + gp.MergeProvider() + MergeGraphs(specs, full_gt, full_mst) + Evaluate(full_gt, full_mst, score, edge_threshold_attr=distance_attr) + gp.PrintProfilingStats()) return validation_pipeline, score
def emb_validation_pipeline( config, snapshot_file, candidates_path, raw_path, gt_path, candidates_mst_path=None, candidates_mst_dense_path=None, path_stat="max", ): checkpoint = config["EMB_EVAL_CHECKPOINT"] blocks = config["BLOCKS"] benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"]) sample = config["VALIDATION_SAMPLES"][0] transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt" voxel_size = gp.Coordinate(config["VOXEL_SIZE"]) micron_scale = max(voxel_size) input_shape = gp.Coordinate(config["INPUT_SHAPE"]) output_shape = gp.Coordinate(config["OUTPUT_SHAPE"]) input_size = voxel_size * input_shape output_size = voxel_size * output_shape distance_attr = config["DISTANCE_ATTR"] coordinate_scale = config["COORDINATE_SCALE"] * np.array( voxel_size) / micron_scale num_thresholds = config["NUM_EVAL_THRESHOLDS"] threshold_range = config["EVAL_THRESHOLD_RANGE"] edge_threshold_0 = config["EVAL_EDGE_THRESHOLD_0"] component_threshold_0 = config["COMPONENT_THRESHOLD_0"] component_threshold_1 = config["COMPONENT_THRESHOLD_1"] clip_limit = config["CLAHE_CLIP_LIMIT"] normalize = config["CLAHE_NORMALIZE"] validation_pipelines = [] specs = {} emb_model = get_emb_model(config) emb_model.eval() for block in blocks: validation_dir = get_validation_dir(benchmark_datasets_path, block) trees = [] cube = None for gt_file in validation_dir.iterdir(): if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc": trees.append(gt_file) if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc": cube = gt_file assert cube.exists() cube_roi = get_roi_from_swc( cube, Path(transform_template.format(sample=sample)), np.array(voxel_size[::-1]), ) candidates_1 = gp.ArrayKey(f"CANDIDATES_1_{block}") raw = gp.ArrayKey(f"RAW_{block}") mst_0 = gp.GraphKey(f"MST_0_{block}") mst_dense_0 = gp.GraphKey(f"MST_DENSE_0_{block}") mst_1 = gp.GraphKey(f"MST_1_{block}") mst_dense_1 = gp.GraphKey(f"MST_DENSE_1_{block}") mst_2 = gp.GraphKey(f"MST_2_{block}") mst_dense_2 = gp.GraphKey(f"MST_DENSE_2_{block}") gt = gp.GraphKey(f"GT_{block}") score = gp.ArrayKey(f"SCORE_{block}") details = gp.GraphKey(f"DETAILS_{block}") optimal_mst = gp.GraphKey(f"OPTIMAL_MST_{block}") # Volume Source raw_source = SnapshotSource( snapshot_file, datasets={ raw: raw_path.format(block=block), candidates_1: candidates_path.format(block=block), }, ) # Graph Source graph_datasets = {gt: gt_path.format(block=block)} graph_directionality = {gt: False} edge_attrs = {} if candidates_mst_path is not None: graph_datasets[mst_0] = candidates_mst_path.format(block=block) graph_directionality[mst_0] = False edge_attrs[mst_0] = [distance_attr] if candidates_mst_dense_path is not None: graph_datasets[mst_dense_0] = candidates_mst_dense_path.format( block=block) graph_directionality[mst_dense_0] = False edge_attrs[mst_dense_0] = [distance_attr] gt_source = SnapshotSource( snapshot_file, datasets=graph_datasets, directed=graph_directionality, edge_attrs=edge_attrs, ) if config["EVAL_CLAHE"]: raw_source = raw_source + scipyCLAHE( [raw], gp.Coordinate([20, 64, 64]) * voxel_size, clip_limit=clip_limit, normalize=normalize, ) else: pass emb_source, emb, neighborhood = add_emb_pred(config, raw_source, raw, block, emb_model) reference_sizes = { raw: input_size, emb: output_size, candidates_1: output_size } if neighborhood is not None: reference_sizes[neighborhood] = output_size emb_source = add_scan(emb_source, reference_sizes) input_roi = cube_roi.grow((input_size - output_size) // 2, (input_size - output_size) // 2) cube_roi_shifted = gp.Roi((0, ) * len(cube_roi.get_shape()), cube_roi.get_shape()) input_roi = cube_roi_shifted.grow((input_size - output_size) // 2, (input_size - output_size) // 2) block_spec = specs.setdefault(block, {}) block_spec[raw] = gp.ArraySpec(input_roi) block_spec[candidates_1] = gp.ArraySpec(cube_roi_shifted) block_spec[emb] = gp.ArraySpec(cube_roi_shifted) if neighborhood is not None: block_spec[neighborhood] = gp.ArraySpec(cube_roi_shifted) block_spec[gt] = gp.GraphSpec(cube_roi_shifted, directed=False) block_spec[mst_0] = gp.GraphSpec(cube_roi_shifted, directed=False) block_spec[mst_dense_0] = gp.GraphSpec(cube_roi_shifted, directed=False) block_spec[mst_1] = gp.GraphSpec(cube_roi_shifted, directed=False) block_spec[mst_dense_1] = gp.GraphSpec(cube_roi_shifted, directed=False) block_spec[mst_2] = gp.GraphSpec(cube_roi_shifted, directed=False) # block_spec[mst_dense_2] = gp.GraphSpec(cube_roi_shifted, directed=False) block_spec[score] = gp.ArraySpec(nonspatial=True) block_spec[optimal_mst] = gp.GraphSpec(cube_roi_shifted, directed=False) additional_request = BatchRequest() additional_request[raw] = gp.ArraySpec(input_roi) additional_request[candidates_1] = gp.ArraySpec(cube_roi_shifted) additional_request[emb] = gp.ArraySpec(cube_roi_shifted) if neighborhood is not None: additional_request[neighborhood] = gp.ArraySpec(cube_roi_shifted) additional_request[gt] = gp.GraphSpec(cube_roi_shifted, directed=False) additional_request[mst_0] = gp.GraphSpec(cube_roi_shifted, directed=False) additional_request[mst_dense_0] = gp.GraphSpec(cube_roi_shifted, directed=False) additional_request[mst_1] = gp.GraphSpec(cube_roi_shifted, directed=False) additional_request[mst_dense_1] = gp.GraphSpec(cube_roi_shifted, directed=False) additional_request[mst_2] = gp.GraphSpec(cube_roi_shifted, directed=False) # additional_request[mst_dense_2] = gp.GraphSpec(cube_roi_shifted, directed=False) additional_request[details] = gp.GraphSpec(cube_roi_shifted, directed=False) additional_request[optimal_mst] = gp.GraphSpec(cube_roi_shifted, directed=False) pipeline = (emb_source, gt_source) + gp.MergeProvider() if candidates_mst_path is not None and candidates_mst_dense_path is not None: # mst_0 provided, just need to calculate distances. pass elif config["EVAL_MINIMAX_EMBEDDING_DIST"]: # No mst_0 provided, must first calculate mst_0 and dense mst_0 pipeline += MiniMaxEmbeddings( emb, candidates_1, decimated=mst_0, dense=mst_dense_0, distance_attr=distance_attr, ) else: # mst/mst_dense not provided. Simply use euclidean distance on candidates pipeline += EMST( emb, candidates_1, mst_0, distance_attr=distance_attr, coordinate_scale=coordinate_scale, ) pipeline += EMST( emb, candidates_1, mst_dense_0, distance_attr=distance_attr, coordinate_scale=coordinate_scale, ) pipeline += ThresholdEdges( (mst_0, mst_1), edge_threshold_0, component_threshold_0, msts_dense=(mst_dense_0, mst_dense_1), distance_attr=distance_attr, ) pipeline += ComponentWiseEMST( emb, mst_1, mst_2, distance_attr=distance_attr, coordinate_scale=coordinate_scale, ) # pipeline += ScoreEdges( # mst, mst_dense, emb, distance_attr=distance_attr, path_stat=path_stat # ) pipeline += Evaluate( gt, mst_2, score, roi=cube_roi_shifted, details=details, edge_threshold_attr=distance_attr, num_thresholds=num_thresholds, threshold_range=threshold_range, small_component_threshold=component_threshold_1, # connectivity=mst_1, output_graph=optimal_mst, ) if config["EVAL_SNAPSHOT"]: snapshot_datasets = { raw: f"volumes/raw", emb: f"volumes/embeddings", candidates_1: f"volumes/candidates_1", mst_0: f"points/mst_0", mst_dense_0: f"points/mst_dense_0", mst_1: f"points/mst_1", mst_dense_1: f"points/mst_dense_1", # mst_2: f"points/mst_2", gt: f"points/gt", details: f"points/details", optimal_mst: f"points/optimal_mst", } if neighborhood is not None: snapshot_datasets[neighborhood] = f"volumes/neighborhood" pipeline += gp.Snapshot( snapshot_datasets, output_dir=config["EVAL_SNAPSHOT_DIR"], output_filename=config["EVAL_SNAPSHOT_NAME"].format( checkpoint=checkpoint, block=block, coordinate_scale=",".join( [str(x) for x in coordinate_scale]), ), edge_attrs={ mst_0: [distance_attr], mst_dense_0: [distance_attr], mst_1: [distance_attr], mst_dense_1: [distance_attr], # mst_2: [distance_attr], # optimal_mst: [distance_attr], # it is unclear how to add distances if using connectivity graph # mst_dense_2: [distance_attr], details: ["details", "label_pair"], }, node_attrs={details: ["details", "label_pair"]}, additional_request=additional_request, ) validation_pipelines.append(pipeline) final_score = gp.ArrayKey("SCORE") validation_pipeline = (tuple(pipeline for pipeline in validation_pipelines) + gp.MergeProvider() + MergeScores(final_score, specs) + gp.PrintProfilingStats()) return validation_pipeline, final_score
def pre_computed_fg_validation_pipeline(config, snapshot_file, raw_path, gt_path, fg_path): blocks = config["BLOCKS"] benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"]) sample = config["VALIDATION_SAMPLES"][0] transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt" voxel_size = gp.Coordinate(config["VOXEL_SIZE"]) input_shape = gp.Coordinate(config["INPUT_SHAPE"]) output_shape = gp.Coordinate(config["OUTPUT_SHAPE"]) input_size = voxel_size * input_shape output_size = voxel_size * output_shape candidate_spacing = config["CANDIDATE_SPACING"] candidate_threshold = config["CANDIDATE_THRESHOLD"] distance_attr = config["DISTANCE_ATTR"] num_thresholds = config["NUM_EVAL_THRESHOLDS"] threshold_range = config["EVAL_THRESHOLD_RANGE"] component_threshold = config["COMPONENT_THRESHOLD_1"] validation_pipelines = [] specs = {} for block in blocks: validation_dir = get_validation_dir(benchmark_datasets_path, block) trees = [] cube = None for gt_file in validation_dir.iterdir(): if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc": trees.append(gt_file) if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc": cube = gt_file assert cube.exists() cube_roi = get_roi_from_swc( cube, Path(transform_template.format(sample=sample)), np.array(voxel_size[::-1]), ) candidates = gp.ArrayKey(f"CANDIDATES_{block}") raw = gp.ArrayKey(f"RAW_{block}") mst = gp.GraphKey(f"MST_{block}") gt = gp.GraphKey(f"GT_{block}") fg = gp.ArrayKey(f"FG_{block}") score = gp.ArrayKey(f"SCORE_{block}") details = gp.GraphKey(f"DETAILS_{block}") raw_source = SnapshotSource( snapshot_file, datasets={ raw: raw_path.format(block=block), fg: fg_path.format(block=block), }, ) gt_source = SnapshotSource( snapshot_file, datasets={gt: gt_path.format(block=block)}, directed={gt: False}, ) input_roi = cube_roi.grow((input_size - output_size) // 2, (input_size - output_size) // 2) cube_roi_shifted = gp.Roi((0, ) * len(cube_roi.get_shape()), cube_roi.get_shape()) input_roi = cube_roi_shifted.grow((input_size - output_size) // 2, (input_size - output_size) // 2) block_spec = specs.setdefault(block, {}) block_spec[raw] = gp.ArraySpec(input_roi) block_spec[candidates] = gp.ArraySpec(cube_roi_shifted) block_spec[fg] = gp.ArraySpec(cube_roi_shifted) block_spec[gt] = gp.GraphSpec(cube_roi_shifted, directed=False) block_spec[mst] = gp.GraphSpec(cube_roi_shifted, directed=False) block_spec[score] = gp.ArraySpec(nonspatial=True) additional_request = BatchRequest() additional_request[raw] = gp.ArraySpec(input_roi) additional_request[candidates] = gp.ArraySpec(cube_roi_shifted) additional_request[fg] = gp.ArraySpec(cube_roi_shifted) additional_request[gt] = gp.GraphSpec(cube_roi_shifted, directed=False) additional_request[mst] = gp.GraphSpec(cube_roi_shifted, directed=False) additional_request[details] = gp.GraphSpec(cube_roi_shifted, directed=False) pipeline = ((raw_source, gt_source) + gp.MergeProvider() + Skeletonize( fg, candidates, candidate_spacing, candidate_threshold) + MiniMax(fg, candidates, mst, distance_attr=distance_attr)) pipeline += Evaluate( gt, mst, score, roi=cube_roi_shifted, details=details, edge_threshold_attr=distance_attr, num_thresholds=num_thresholds, threshold_range=threshold_range, small_component_threshold=component_threshold, ) if config["EVAL_SNAPSHOT"]: pipeline += gp.Snapshot( { raw: f"volumes/raw", fg: f"volumes/foreground", candidates: f"volumes/candidates", mst: f"points/mst", gt: f"points/gt", details: f"points/details", }, output_dir="eval_results", output_filename=config["EVAL_SNAPSHOT_NAME"].format( block=block), edge_attrs={ mst: [distance_attr], details: ["details", "label_pair"] }, node_attrs={details: ["details", "label_pair"]}, additional_request=additional_request, ) validation_pipelines.append(pipeline) final_score = gp.ArrayKey("SCORE") validation_pipeline = (tuple(pipeline for pipeline in validation_pipelines) + gp.MergeProvider() + MergeScores(final_score, specs) + gp.PrintProfilingStats()) return validation_pipeline, final_score
def create_train_pipeline(self, model): optimizer = self.params['optimizer'](model.parameters(), **self.params['optimizer_kwargs']) filename = self.params['data_file'] datasets = self.params['dataset'] raw_0 = gp.ArrayKey('RAW_0') points_0 = gp.GraphKey('POINTS_0') locations_0 = gp.ArrayKey('LOCATIONS_0') emb_0 = gp.ArrayKey('EMBEDDING_0') raw_1 = gp.ArrayKey('RAW_1') points_1 = gp.GraphKey('POINTS_1') locations_1 = gp.ArrayKey('LOCATIONS_1') emb_1 = gp.ArrayKey('EMBEDDING_1') data = daisy.open_ds(filename, datasets[0]) source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape()) voxel_size = gp.Coordinate(data.voxel_size) # Get in and out shape in_shape = gp.Coordinate(model.in_shape) out_shape = gp.Coordinate(model.out_shape[2:]) is_2d = in_shape.dims() == 2 emb_voxel_size = voxel_size cv_loss = ContrastiveVolumeLoss(self.params['temperature'], self.params['point_density'], out_shape * voxel_size) # Add fake 3rd dim if is_2d: in_shape = gp.Coordinate((1, *in_shape)) out_shape = gp.Coordinate((1, *out_shape)) voxel_size = gp.Coordinate((1, *voxel_size)) source_roi = gp.Roi((0, *source_roi.get_offset()), (data.shape[0], *source_roi.get_shape())) in_shape = in_shape * voxel_size out_shape = out_shape * voxel_size logger.info(f"source roi: {source_roi}") logger.info(f"in_shape: {in_shape}") logger.info(f"out_shape: {out_shape}") logger.info(f"voxel_size: {voxel_size}") request = gp.BatchRequest() request.add(raw_0, in_shape) request.add(raw_1, in_shape) request.add(points_0, out_shape) request.add(points_1, out_shape) request[locations_0] = gp.ArraySpec(nonspatial=True) request[locations_1] = gp.ArraySpec(nonspatial=True) snapshot_request = gp.BatchRequest() snapshot_request[emb_0] = gp.ArraySpec(roi=request[points_0].roi) snapshot_request[emb_1] = gp.ArraySpec(roi=request[points_1].roi) random_point_generator = RandomPointGenerator( density=self.params['point_density'], repetitions=2) # Use volume to calculate probabilities, RandomSourceGenerator will # normalize volumes to probablilties probabilities = np.array([ np.product(daisy.open_ds(filename, dataset).shape) for dataset in datasets ]) random_source_generator = RandomSourceGenerator( num_sources=len(datasets), probabilities=probabilities, repetitions=2) array_sources = tuple( tuple( gp.ZarrSource( filename, {raw: dataset}, # fake 3D data array_specs={ raw: gp.ArraySpec(roi=source_roi, voxel_size=voxel_size, interpolatable=True) }) for dataset in datasets) for raw in [raw_0, raw_1]) # Choose a random dataset to pull from array_sources = \ tuple(arrays + RandomMultiBranchSource(random_source_generator) + gp.Normalize(raw, self.params['norm_factor']) + gp.Pad(raw, None) for raw, arrays in zip([raw_0, raw_1], array_sources)) point_sources = tuple( (RandomPointSource(points_0, random_point_generator=random_point_generator), RandomPointSource(points_1, random_point_generator=random_point_generator))) # Merge the point and array sources together. # There is one array and point source per branch. sources = tuple((array_source, point_source) + gp.MergeProvider() for array_source, point_source in zip( array_sources, point_sources)) sources = tuple( self._make_train_augmentation_pipeline(raw, source) for raw, source in zip([raw_0, raw_1], sources)) pipeline = (sources + gp.MergeProvider() + gp.Crop(raw_0, source_roi) + gp.Crop(raw_1, source_roi) + gp.RandomLocation() + PrepareBatch(raw_0, raw_1, points_0, points_1, locations_0, locations_1, is_2d) + RejectArray(ensure_nonempty=locations_0) + RejectArray(ensure_nonempty=locations_1)) if not is_2d: pipeline = (pipeline + AddChannelDim(raw_0) + AddChannelDim(raw_1)) pipeline = (pipeline + gp.PreCache() + gp.torch.Train( model, cv_loss, optimizer, inputs={ 'raw_0': raw_0, 'raw_1': raw_1 }, loss_inputs={ 'emb_0': emb_0, 'emb_1': emb_1, 'locations_0': locations_0, 'locations_1': locations_1 }, outputs={ 2: emb_0, 3: emb_1 }, array_specs={ emb_0: gp.ArraySpec(voxel_size=emb_voxel_size), emb_1: gp.ArraySpec(voxel_size=emb_voxel_size) }, checkpoint_basename=self.logdir + '/contrastive/checkpoints/model', save_every=self.params['save_every'], log_dir=self.logdir + "/contrastive", log_every=self.log_every)) if is_2d: pipeline = ( pipeline + # everything is 3D, except emb_0 and emb_1 AddSpatialDim(emb_0) + AddSpatialDim(emb_1)) pipeline = ( pipeline + # now everything is 3D RemoveChannelDim(raw_0) + RemoveChannelDim(raw_1) + RemoveChannelDim(emb_0) + RemoveChannelDim(emb_1) + gp.Snapshot(output_dir=self.logdir + '/contrastive/snapshots', output_filename='it{iteration}.hdf', dataset_names={ raw_0: 'raw_0', raw_1: 'raw_1', locations_0: 'locations_0', locations_1: 'locations_1', emb_0: 'emb_0', emb_1: 'emb_1' }, additional_request=snapshot_request, every=self.params['save_every']) + gp.PrintProfilingStats(every=500)) return pipeline, request