def add_augmentation_pipeline( pipeline, raw, simple=None, elastic=None, blur=None, noise=None): '''Add an augmentation pipeline to an existing pipeline. All optional arguments are kwargs for the corresponding augmentation node. If not given, those augmentations are not added. ''' if simple is not None: pipeline = pipeline + gp.SimpleAugment(**simple) if elastic is not None: pipeline = pipeline + gp.ElasticAugment(**elastic) if blur is not None: pipeline = pipeline + Blur(raw, **blur) if noise is not None: pipeline = pipeline + gp.NoiseAugment(raw, **noise) return pipeline
def _augmentation_pipeline(self, raw, source): if 'elastic' in self.params and self.params['elastic']: source = source + gp.ElasticAugment( **self.params["elastic_params"]) if 'blur' in self.params and self.params['blur']: source = source + Blur(raw, **self.params["blur_params"]) if 'simple' in self.params and self.params['simple']: source = source + gp.SimpleAugment(**self.params["simple_params"]) if 'noise' in self.params and self.params['noise']: source = source + gp.NoiseAugment(raw, ** self.params['noise_params']) return source
def random_point_pairs_pipeline(model, loss, optimizer, dataset, augmentation_parameters, point_density, out_dir, normalize_factor=None, checkpoint_interval=5000, snapshot_interval=5000): raw_0 = gp.ArrayKey('RAW_0') points_0 = gp.GraphKey('POINTS_0') locations_0 = gp.ArrayKey('LOCATIONS_0') emb_0 = gp.ArrayKey('EMBEDDING_0') raw_1 = gp.ArrayKey('RAW_1') points_1 = gp.GraphKey('POINTS_1') locations_1 = gp.ArrayKey('LOCATIONS_1') emb_1 = gp.ArrayKey('EMBEDDING_1') # TODO parse this key from somewhere key = 'train/raw/0' data = daisy.open_ds(dataset.filename, key) source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape()) voxel_size = gp.Coordinate(data.voxel_size) emb_voxel_size = voxel_size # Get in and out shape in_shape = gp.Coordinate(model.in_shape) out_shape = gp.Coordinate(model.out_shape) logger.info(f"source roi: {source_roi}") logger.info(f"in_shape: {in_shape}") logger.info(f"out_shape: {out_shape}") logger.info(f"voxel_size: {voxel_size}") request = gp.BatchRequest() request.add(raw_0, in_shape) request.add(raw_1, in_shape) request.add(points_0, out_shape) request.add(points_1, out_shape) request[locations_0] = gp.ArraySpec(nonspatial=True) request[locations_1] = gp.ArraySpec(nonspatial=True) snapshot_request = gp.BatchRequest() snapshot_request[emb_0] = gp.ArraySpec(roi=request[points_0].roi) snapshot_request[emb_1] = gp.ArraySpec(roi=request[points_1].roi) # Let's hardcode this for now # TODO read actual number from zarr file keys n_samples = 447 batch_size = 1 dim = 2 padding = (100, 100) sources = [] for i in range(n_samples): ds_key = f'train/raw/{i}' image_sources = tuple( gp.ZarrSource( dataset.filename, {raw: ds_key}, {raw: gp.ArraySpec(interpolatable=True, voxel_size=(1, 1))}) + gp.Pad(raw, None) for raw in [raw_0, raw_1]) random_point_generator = RandomPointGenerator(density=point_density, repetitions=2) point_sources = tuple( (RandomPointSource(points_0, dim, random_point_generator=random_point_generator), RandomPointSource(points_1, dim, random_point_generator=random_point_generator))) # TODO: get augmentation parameters from some config file! points_and_image_sources = tuple( (img_source, point_source) + gp.MergeProvider() + \ gp.SimpleAugment() + \ gp.ElasticAugment( spatial_dims=2, control_point_spacing=(10, 10), jitter_sigma=(0.0, 0.0), rotation_interval=(0, math.pi/2)) + \ gp.IntensityAugment(r, scale_min=0.8, scale_max=1.2, shift_min=-0.2, shift_max=0.2, clip=False) + \ gp.NoiseAugment(r, var=0.01, clip=False) for r, img_source, point_source in zip([raw_0, raw_1], image_sources, point_sources)) sample_source = points_and_image_sources + gp.MergeProvider() data = daisy.open_ds(dataset.filename, ds_key) source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape()) sample_source += gp.Crop(raw_0, source_roi) sample_source += gp.Crop(raw_1, source_roi) sample_source += gp.Pad(raw_0, padding) sample_source += gp.Pad(raw_1, padding) sample_source += gp.RandomLocation() sources.append(sample_source) sources = tuple(sources) pipeline = sources + gp.RandomProvider() pipeline += gp.Unsqueeze([raw_0, raw_1]) pipeline += PrepareBatch(raw_0, raw_1, points_0, points_1, locations_0, locations_1) # How does prepare batch relate to Stack????? pipeline += RejectArray(ensure_nonempty=locations_1) pipeline += RejectArray(ensure_nonempty=locations_0) # batch content # raw_0: (1, h, w) # raw_1: (1, h, w) # locations_0: (n, 2) # locations_1: (n, 2) pipeline += gp.Stack(batch_size) # batch content # raw_0: (b, 1, h, w) # raw_1: (b, 1, h, w) # locations_0: (b, n, 2) # locations_1: (b, n, 2) pipeline += gp.PreCache(num_workers=10) pipeline += gp.torch.Train( model, loss, optimizer, inputs={ 'raw_0': raw_0, 'raw_1': raw_1 }, loss_inputs={ 'emb_0': emb_0, 'emb_1': emb_1, 'locations_0': locations_0, 'locations_1': locations_1 }, outputs={ 2: emb_0, 3: emb_1 }, array_specs={ emb_0: gp.ArraySpec(voxel_size=emb_voxel_size), emb_1: gp.ArraySpec(voxel_size=emb_voxel_size) }, checkpoint_basename=os.path.join(out_dir, 'model'), save_every=checkpoint_interval) pipeline += gp.Snapshot( { raw_0: 'raw_0', raw_1: 'raw_1', emb_0: 'emb_0', emb_1: 'emb_1', # locations_0 : 'locations_0', # locations_1 : 'locations_1', }, every=snapshot_interval, additional_request=snapshot_request) return pipeline, request
def train(n_iterations, setup_config, mknet_tensor_names, loss_tensor_names): # Network hyperparams INPUT_SHAPE = setup_config["INPUT_SHAPE"] OUTPUT_SHAPE = setup_config["OUTPUT_SHAPE"] # Skeleton generation hyperparams SKEL_GEN_RADIUS = setup_config["SKEL_GEN_RADIUS"] THETAS = np.array(setup_config["THETAS"]) * math.pi SPLIT_PS = setup_config["SPLIT_PS"] NOISE_VAR = setup_config["NOISE_VAR"] N_OBJS = setup_config["N_OBJS"] # Skeleton variation hyperparams LABEL_RADII = setup_config["LABEL_RADII"] RAW_RADII = setup_config["RAW_RADII"] RAW_INTENSITIES = setup_config["RAW_INTENSITIES"] # Training hyperparams CACHE_SIZE = setup_config["CACHE_SIZE"] NUM_WORKERS = setup_config["NUM_WORKERS"] SNAPSHOT_EVERY = setup_config["SNAPSHOT_EVERY"] CHECKPOINT_EVERY = setup_config["CHECKPOINT_EVERY"] point_trees = gp.PointsKey("POINT_TREES") labels = gp.ArrayKey("LABELS") raw = gp.ArrayKey("RAW") gt_fg = gp.ArrayKey("GT_FG") embedding = gp.ArrayKey("EMBEDDING") fg = gp.ArrayKey("FG") maxima = gp.ArrayKey("MAXIMA") gradient_embedding = gp.ArrayKey("GRADIENT_EMBEDDING") gradient_fg = gp.ArrayKey("GRADIENT_FG") # tensorflow tensors emst = gp.ArrayKey("EMST") edges_u = gp.ArrayKey("EDGES_U") edges_v = gp.ArrayKey("EDGES_V") ratio_pos = gp.ArrayKey("RATIO_POS") ratio_neg = gp.ArrayKey("RATIO_NEG") dist = gp.ArrayKey("DIST") num_pos_pairs = gp.ArrayKey("NUM_POS") num_neg_pairs = gp.ArrayKey("NUM_NEG") request = gp.BatchRequest() request.add(raw, INPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) request.add(labels, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) request.add(point_trees, INPUT_SHAPE) snapshot_request = gp.BatchRequest() snapshot_request.add(raw, INPUT_SHAPE) snapshot_request.add(embedding, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) snapshot_request.add(fg, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) snapshot_request.add(gt_fg, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) snapshot_request.add(maxima, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) snapshot_request.add(gradient_embedding, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) snapshot_request.add(gradient_fg, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) snapshot_request[emst] = gp.ArraySpec() snapshot_request[edges_u] = gp.ArraySpec() snapshot_request[edges_v] = gp.ArraySpec() snapshot_request[ratio_pos] = gp.ArraySpec() snapshot_request[ratio_neg] = gp.ArraySpec() snapshot_request[dist] = gp.ArraySpec() snapshot_request[num_pos_pairs] = gp.ArraySpec() snapshot_request[num_neg_pairs] = gp.ArraySpec() pipeline = ( nl.SyntheticLightLike( point_trees, dims=2, r=SKEL_GEN_RADIUS, n_obj=N_OBJS, thetas=THETAS, split_ps=SPLIT_PS, ) # + gp.SimpleAugment() # + gp.ElasticAugment([10, 10], [0.1, 0.1], [0, 2.0 * math.pi], spatial_dims=2) + nl.RasterizeSkeleton( point_trees, raw, gp.ArraySpec( roi=gp.Roi((None, ) * 2, (None, ) * 2), voxel_size=gp.Coordinate((1, 1)), dtype=np.uint64, ), ) + nl.RasterizeSkeleton( point_trees, labels, gp.ArraySpec( roi=gp.Roi((None, ) * 2, (None, ) * 2), voxel_size=gp.Coordinate((1, 1)), dtype=np.uint64, ), use_component=True, n_objs=int(setup_config["HIDE_SIGNAL"]), ) + nl.GrowLabels(labels, radii=LABEL_RADII) + nl.GrowLabels(raw, radii=RAW_RADII) + LabelToFloat32(raw, intensities=RAW_INTENSITIES) + gp.NoiseAugment(raw, var=NOISE_VAR) + gp.PreCache(cache_size=CACHE_SIZE, num_workers=NUM_WORKERS) + gp.tensorflow.Train( "train_net", optimizer=create_custom_loss(mknet_tensor_names, setup_config), loss=None, inputs={ mknet_tensor_names["raw"]: raw, mknet_tensor_names["gt_labels"]: labels }, outputs={ mknet_tensor_names["embedding"]: embedding, mknet_tensor_names["fg"]: fg, "strided_slice_1:0": maxima, "gt_fg:0": gt_fg, loss_tensor_names["emst"]: emst, loss_tensor_names["edges_u"]: edges_u, loss_tensor_names["edges_v"]: edges_v, loss_tensor_names["ratio_pos"]: ratio_pos, loss_tensor_names["ratio_neg"]: ratio_neg, loss_tensor_names["dist"]: dist, loss_tensor_names["num_pos_pairs"]: num_pos_pairs, loss_tensor_names["num_neg_pairs"]: num_neg_pairs, }, gradients={ mknet_tensor_names["embedding"]: gradient_embedding, mknet_tensor_names["fg"]: gradient_fg, }, save_every=CHECKPOINT_EVERY, summary="Merge/MergeSummary:0", log_dir="tensorflow_logs", ) + gp.Snapshot( output_filename="{iteration}.hdf", dataset_names={ raw: "volumes/raw", labels: "volumes/labels", point_trees: "point_trees", embedding: "volumes/embedding", fg: "volumes/fg", maxima: "volumes/maxima", gt_fg: "volumes/gt_fg", gradient_embedding: "volumes/gradient_embedding", gradient_fg: "volumes/gradient_fg", emst: "emst", edges_u: "edges_u", edges_v: "edges_v", ratio_pos: "ratio_pos", ratio_neg: "ratio_neg", dist: "dist", num_pos_pairs: "num_pos_pairs", num_neg_pairs: "num_neg_pairs", }, dataset_dtypes={ maxima: np.float32, gt_fg: np.float32 }, every=SNAPSHOT_EVERY, additional_request=snapshot_request, ) # + gp.PrintProfilingStats(every=100) ) with gp.build(pipeline): for i in range(n_iterations + 1): pipeline.request_batch(request) request._update_random_seed()
def train(n_iterations): point_trees = gp.PointsKey("POINT_TREES") labels = gp.ArrayKey("LABELS") raw = gp.ArrayKey("RAW") # gt_fg = gp.ArrayKey("GT_FG") # embedding = gp.ArrayKey("EMBEDDING") # fg = gp.ArrayKey("FG") # maxima = gp.ArrayKey("MAXIMA") # gradient_embedding = gp.ArrayKey("GRADIENT_EMBEDDING") # gradient_fg = gp.ArrayKey("GRADIENT_FG") # emst = gp.ArrayKey("EMST") # edges_u = gp.ArrayKey("EDGES_U") # edges_v = gp.ArrayKey("EDGES_V") request = gp.BatchRequest() request.add(raw, INPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) request.add(labels, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) request.add(point_trees, INPUT_SHAPE) snapshot_request = gp.BatchRequest() snapshot_request.add(raw, INPUT_SHAPE) # snapshot_request.add(embedding, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) # snapshot_request.add(fg, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) # snapshot_request.add(gt_fg, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) # snapshot_request.add(maxima, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) # snapshot_request.add( # gradient_embedding, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1)) # ) # snapshot_request.add(gradient_fg, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))) # snapshot_request[emst] = gp.ArraySpec() # snapshot_request[edges_u] = gp.ArraySpec() # snapshot_request[edges_v] = gp.ArraySpec() pipeline = ( nl.SyntheticLightLike( point_trees, dims=2, r=SKEL_GEN_RADIUS, n_obj=N_OBJS, thetas=THETAS, split_ps=SPLIT_PS, ) # + gp.SimpleAugment() # + gp.ElasticAugment([10, 10], [0.1, 0.1], [0, 2.0 * math.pi], spatial_dims=2) + nl.RasterizeSkeleton( point_trees, labels, gp.ArraySpec( roi=gp.Roi((None,) * 2, (None,) * 2), voxel_size=gp.Coordinate((1, 1)), dtype=np.uint64, ), ) + gp.Copy(labels, raw) + nl.GrowLabels(labels, radii=LABEL_RADII) + nl.GrowLabels(raw, radii=RAW_RADII) + LabelToFloat32(raw, intensities=RAW_INTENSITIES) + gp.NoiseAugment(raw, var=NOISE_VAR) # + gp.PreCache(cache_size=40, num_workers=10) # + gp.tensorflow.Train( # "train_net", # optimizer=add_loss, # loss=None, # inputs={tensor_names["raw"]: raw, tensor_names["gt_labels"]: labels}, # outputs={ # tensor_names["embedding"]: embedding, # tensor_names["fg"]: fg, # "maxima:0": maxima, # "gt_fg:0": gt_fg, # emst_name: emst, # edges_u_name: edges_u, # edges_v_name: edges_v, # }, # gradients={ # tensor_names["embedding"]: gradient_embedding, # tensor_names["fg"]: gradient_fg, # }, # ) + gp.Snapshot( output_filename="{iteration}.hdf", dataset_names={ raw: "volumes/raw", labels: "volumes/labels", point_trees: "point_trees", # embedding: "volumes/embedding", # fg: "volumes/fg", # maxima: "volumes/maxima", # gt_fg: "volumes/gt_fg", # gradient_embedding: "volumes/gradient_embedding", # gradient_fg: "volumes/gradient_fg", # emst: "emst", # edges_u: "edges_u", # edges_v: "edges_v", }, # dataset_dtypes={maxima: np.float32, gt_fg: np.float32}, every=100, additional_request=snapshot_request, ) + gp.PrintProfilingStats(every=10) ) with gp.build(pipeline): for i in range(n_iterations): pipeline.request_batch(request)