def test_unsqueeze(self): raw = gp.ArrayKey("RAW") labels = gp.ArrayKey("LABELS") voxel_size = gp.Coordinate((50, 5, 5)) input_voxels = gp.Coordinate((10, 10, 10)) input_size = input_voxels * voxel_size request = gp.BatchRequest() request.add(raw, input_size) request.add(labels, input_size) pipeline = ( ExampleSourceUnsqueeze(voxel_size) + gp.Unsqueeze([raw, labels]) + gp.Unsqueeze([raw], axis=1) ) with gp.build(pipeline) as p: batch = p.request_batch(request) assert batch[raw].data.shape == (1,) + (1,) + input_voxels assert batch[labels].data.shape == (1,) + input_voxels
def test_gp_dacapo_array_source(array_config): # Create Array from config array = array_config.array_type(array_config) # Make sure the DaCapoArraySource can properly read # the data in `array` key = gp.ArrayKey("TEST") source_node = DaCapoArraySource(array, key) with gp.build(source_node): request = gp.BatchRequest() request[key] = gp.ArraySpec(roi=array.roi) batch = source_node.request_batch(request) data = batch[key].data assert (data - array[array.roi]).sum() == 0
def test_pipeline3(self): array_key = gp.ArrayKey("TEST_ARRAY") points_key = gp.PointsKey("TEST_POINTS") voxel_size = gp.Coordinate((1, 1)) spec = gp.ArraySpec(voxel_size=voxel_size, interpolatable=True) hdf5_source = gp.Hdf5Source(self.fake_data_file, {array_key: 'testdata'}, array_specs={array_key: spec}) csv_source = gp.CsvPointsSource( self.fake_points_file, points_key, gp.PointsSpec( roi=gp.Roi(shape=gp.Coordinate((100, 100)), offset=(0, 0)))) request = gp.BatchRequest() shape = gp.Coordinate((60, 60)) request.add(array_key, shape, voxel_size=gp.Coordinate((1, 1))) request.add(points_key, shape) shift_node = gp.ShiftAugment(prob_slip=0.2, prob_shift=0.2, sigma=5, shift_axis=0) pipeline = ((hdf5_source, csv_source) + gp.MergeProvider() + gp.RandomLocation(ensure_nonempty=points_key) + shift_node) with gp.build(pipeline) as b: request = b.request_batch(request) # print(request[points_key]) target_vals = [ self.fake_data[point[0]][point[1]] for point in self.fake_points ] result_data = request[array_key].data result_points = request[points_key].data result_vals = [ result_data[int(point.location[0])][int(point.location[1])] for point in result_points.values() ] for result_val in result_vals: self.assertTrue( result_val in target_vals, msg= "result value {} at points {} not in target values {} at points {}" .format(result_val, list(result_points.values()), target_vals, self.fake_points))
def add_fg_pred(config, pipeline, raw, block, model): checkpoint_file = config["FG_MODEL_CHECKPOINT"] device = config.get("DEVICE", "cuda") fg_pred = gp.ArrayKey(f"FG_PRED_{block}") pipeline = (pipeline + nl.gunpowder.nodes.helpers.UnSqueeze(raw) + gp.torch.Predict( model, inputs={"raw": raw}, outputs={0: fg_pred}, checkpoint=checkpoint_file, device=device, ) + nl.gunpowder.nodes.helpers.Squeeze(raw) + nl.gunpowder.nodes.helpers.Squeeze(fg_pred)) return pipeline, fg_pred
def test_prepare1(self): key = gp.ArrayKey("TEST_ARRAY") spec = gp.ArraySpec(voxel_size=gp.Coordinate((1, 1)), interpolatable=True) hdf5_source = gp.Hdf5Source(self.fake_data_file, {key: 'testdata'}, array_specs={key: spec}) request = gp.BatchRequest() shape = gp.Coordinate((3, 3)) request.add(key, shape, voxel_size=gp.Coordinate((1, 1))) shift_node = gp.ShiftAugment(sigma=1, shift_axis=0) with gp.build((hdf5_source + shift_node)): shift_node.prepare(request) self.assertTrue(shift_node.ndim == 2) self.assertTrue(shift_node.shift_sigmas == tuple([0.0, 1.0]))
def test_pipeline2(self): key = gp.ArrayKey("TEST_ARRAY") spec = gp.ArraySpec(voxel_size=gp.Coordinate((3, 1)), interpolatable=True) hdf5_source = gp.Hdf5Source(self.fake_data_file, {key: 'testdata'}, array_specs={key: spec}) request = gp.BatchRequest() shape = gp.Coordinate((3, 3)) request.add(key, shape, voxel_size=gp.Coordinate((3, 1))) shift_node = gp.ShiftAugment(prob_slip=0.2, prob_shift=0.2, sigma=1, shift_axis=0) with gp.build((hdf5_source + shift_node)) as b: b.request_batch(request)
def get_labels_snapshot_source(config, blocks): validation_blocks = Path(config["VALIDATION_BLOCKS"]) labels = gp.ArrayKey("LABELS") gt = gp.GraphKey("GT") block_pipelines = [] for block in blocks: pipeline = SnapshotSource( validation_blocks / f"block_{block}.hdf", { labels: "volumes/labels", gt: "points/gt" }, directed={gt: True}, ) block_pipelines.append(pipeline) return block_pipelines, (labels, gt)
def add_nms(pipeline, config, foreground): model_config = copy.deepcopy(DEFAULT_CONFIG) model_config.update(json.load(open(config["EMB_MODEL_CONFIG"]))) # Data properties voxel_size = gp.Coordinate(model_config["VOXEL_SIZE"]) micron_scale = voxel_size[0] # Config options window_size = gp.Coordinate(model_config["NMS_WINDOW_SIZE"]) * micron_scale threshold = model_config["NMS_THRESHOLD"] # New array Key maxima = gp.ArrayKey("MAXIMA") pipeline = (pipeline + nl.gunpowder.nodes.helpers.UnSqueeze(foreground) + nl.gunpowder.nodes.NonMaxSuppression(foreground, maxima, window_size, threshold) + nl.gunpowder.nodes.helpers.Squeeze(foreground) + nl.gunpowder.nodes.helpers.Squeeze(maxima)) return pipeline, maxima
def add_emb_preds(config, pipelines, raw): model_config = copy.deepcopy(DEFAULT_CONFIG) model_config.update(json.load(open(config["EMB_MODEL_CONFIG"]))) device = config.get("DEVICE", "cuda") checkpoint_file = config["EMB_MODEL_CHECKPOINT"] emb_pred = gp.ArrayKey("EMB_PRED") emb_pipelines = [] for pipeline in pipelines: emb_pipelines.append( pipeline + nl.gunpowder.nodes.helpers.UnSqueeze(raw) + gp.torch.Predict( nl.networks.pytorch.EmbeddingUnet(model_config), inputs={"raw": raw}, outputs={0: emb_pred}, checkpoint=checkpoint_file, device=device, ) + nl.gunpowder.nodes.helpers.Squeeze(raw) + nl.gunpowder.nodes.helpers.Squeeze(emb_pred)) return emb_pipelines, (emb_pred, )
def predict_2d(raw_data, gt_data, predictor): raw_channels = max(1, raw_data.num_channels) input_shape = predictor.input_shape output_shape = predictor.output_shape dataset_shape = raw_data.shape dataset_roi = raw_data.roi voxel_size = raw_data.voxel_size # switch to world units input_size = voxel_size * input_shape output_size = voxel_size * output_shape raw = gp.ArrayKey('RAW') gt = gp.ArrayKey('GT') target = gp.ArrayKey('TARGET') prediction = gp.ArrayKey('PREDICTION') channel_dims = 0 if raw_channels == 1 else 1 data_dims = len(dataset_shape) - channel_dims if data_dims == 3: num_samples = dataset_shape[0] sample_shape = dataset_shape[channel_dims + 1:] else: raise RuntimeError( "For 2D validation, please provide a 3D array where the first " "dimension indexes the samples.") num_samples = raw_data.num_samples sample_shape = gp.Coordinate(sample_shape) sample_size = sample_shape * voxel_size scan_request = gp.BatchRequest() scan_request.add(raw, input_size) scan_request.add(prediction, output_size) if gt_data: scan_request.add(gt, output_size) scan_request.add(target, output_size) # overwrite source ROI to treat samples as z dimension spec = gp.ArraySpec(roi=gp.Roi((0, ) + dataset_roi.get_begin(), (num_samples, ) + sample_size), voxel_size=(1, ) + voxel_size) if gt_data: sources = (raw_data.get_source(raw, overwrite_spec=spec), gt_data.get_source(gt, overwrite_spec=spec)) pipeline = sources + gp.MergeProvider() else: pipeline = raw_data.get_source(raw, overwrite_spec=spec) pipeline += gp.Pad(raw, None) if gt_data: pipeline += gp.Pad(gt, None) # raw: ([c,] s, h, w) # gt: ([c,] s, h, w) pipeline += gp.Normalize(raw) # raw: ([c,] s, h, w) # gt: ([c,] s, h, w) if gt_data: pipeline += predictor.add_target(gt, target) # raw: ([c,] s, h, w) # gt: ([c,] s, h, w) # target: ([c,] s, h, w) if channel_dims == 0: pipeline += AddChannelDim(raw) if gt_data and predictor.target_channels == 0: pipeline += AddChannelDim(target) # raw: (c, s, h, w) # gt: ([c,] s, h, w) # target: (c, s, h, w) pipeline += TransposeDims(raw, (1, 0, 2, 3)) if gt_data: pipeline += TransposeDims(target, (1, 0, 2, 3)) # raw: (s, c, h, w) # gt: ([c,] s, h, w) # target: (s, c, h, w) pipeline += gp_torch.Predict(model=predictor, inputs={'x': raw}, outputs={0: prediction}) # raw: (s, c, h, w) # gt: ([c,] s, h, w) # target: (s, c, h, w) # prediction: (s, c, h, w) pipeline += gp.Scan(scan_request) total_request = gp.BatchRequest() total_request.add(raw, sample_size) total_request.add(prediction, sample_size) if gt_data: total_request.add(gt, sample_size) total_request.add(target, sample_size) with gp.build(pipeline): batch = pipeline.request_batch(total_request) ret = {'raw': batch[raw], 'prediction': batch[prediction]} if gt_data: ret.update({'gt': batch[gt], 'target': batch[target]}) return ret
fmap_inc_factor=4, # this needs to be increased later (3) downsample_factors=[ [1, 2, 2], [1, 2, 2], [1, 2, 2], ], kernel_size_down=[[[3, 3, 3], [3, 3, 3]]] * 4, kernel_size_up=[[[3, 3, 3], [3, 3, 3]]] * 3, padding='valid') model = torch.nn.Sequential(unet, ConvPass(24, 1, [(1, 1, 1)], activation='Sigmoid')) loss = torch.nn.BCELoss() optimizer = torch.optim.Adam(model.parameters()) # declare gunpowder arrays raw = gp.ArrayKey('RAW') seg = gp.ArrayKey('SEGMENTATION') out_cage_map = gp.ArrayKey('OUT_CAGE_MAP') out_density_map = gp.ArrayKey('OUT_DENSITY_MAP') prediction = gp.ArrayKey('PREDICTION') class PrepareTrainingData(gp.BatchFilter): def process(self, batch, request): batch[out_cage_map].data = batch[out_cage_map].data.astype(np.float32) batch[out_cage_map].spec.dtype = np.float32 # assemble pipeline sourceA = gp.ZarrSource('../data/cropped_sample_A.zarr', {
def predict(iteration): ################## # DECLARE ARRAYS # ################## # raw intensities raw = gp.ArrayKey('RAW') # the predicted affinities pred_affs = gp.ArrayKey('PRED_AFFS') #################### # DECLARE REQUESTS # #################### with open('test_net_config.json', 'r') as f: net_config = json.load(f) # get the input and output size in world units (nm, in this case) voxel_size = gp.Coordinate((40, 4, 4)) input_size = gp.Coordinate(net_config['input_shape']) * voxel_size output_size = gp.Coordinate(net_config['output_shape']) * voxel_size context = input_size - output_size # formulate the request for what a batch should contain request = gp.BatchRequest() request.add(raw, input_size) request.add(pred_affs, output_size) ############################# # ASSEMBLE TESTING PIPELINE # ############################# source = gp.Hdf5Source('sample_A_padded_20160501.hdf', datasets={raw: 'volumes/raw'}) # get the ROI provided for raw (we need it later to calculate the ROI in # which we can make predictions) with gp.build(source): raw_roi = source.spec[raw].roi pipeline = ( # read from HDF5 file source + # convert raw to float in [0, 1] gp.Normalize(raw) + # perform one training iteration for each passing batch (here we use # the tensor names earlier stored in train_net.config) gp.tensorflow.Predict( graph='test_net.meta', checkpoint='train_net_checkpoint_%d' % iteration, inputs={net_config['raw']: raw}, outputs={net_config['pred_affs']: pred_affs}, array_specs={ pred_affs: gp.ArraySpec(roi=raw_roi.grow(-context, -context)) }) + # store all passing batches in the same HDF5 file gp.Hdf5Write({ raw: '/volumes/raw', pred_affs: '/volumes/pred_affs', }, output_filename='predictions_sample_A.hdf', compression_type='gzip') + # show a summary of time spend in each node every 10 iterations gp.PrintProfilingStats(every=10) + # iterate over the whole dataset in a scanning fashion, emitting # requests that match the size of the network gp.Scan(reference=request)) with gp.build(pipeline): # request an empty batch from Scan to trigger scanning of the dataset # without keeping the complete dataset in memory pipeline.request_batch(gp.BatchRequest())
def predict(iteration, raw_file, raw_dataset, out_file, db_host, db_name, worker_config, network_config, out_properties={}, **kwargs): setup_dir = os.path.dirname(os.path.realpath(__file__)) with open( os.path.join(setup_dir, '{}_net_config.json'.format(network_config)), 'r') as f: net_config = json.load(f) # voxels input_shape = gp.Coordinate(net_config['input_shape']) output_shape = gp.Coordinate(net_config['output_shape']) # nm voxel_size = gp.Coordinate((40, 4, 4)) input_size = input_shape * voxel_size output_size = output_shape * voxel_size parameterfile = os.path.join(setup_dir, 'parameter.json') if os.path.exists(parameterfile): with open(parameterfile, 'r') as f: parameters = json.load(f) else: parameters = {} raw = gp.ArrayKey('RAW') pred_postpre_vectors = gp.ArrayKey('PRED_POSTPRE_VECTORS') pred_post_indicator = gp.ArrayKey('PRED_POST_INDICATOR') chunk_request = gp.BatchRequest() chunk_request.add(raw, input_size) chunk_request.add(pred_postpre_vectors, output_size) chunk_request.add(pred_post_indicator, output_size) d_property = out_properties[ 'pred_partner_vectors'] if 'pred_partner_vectors' in out_properties else None m_property = out_properties[ 'pred_syn_indicator_out'] if 'pred_syn_indicator_out' in out_properties else None # Hdf5Source if raw_file.endswith('.hdf'): pipeline = gp.Hdf5Source(raw_file, datasets={raw: raw_dataset}, array_specs={ raw: gp.ArraySpec(interpolatable=True), }) elif raw_file.endswith('.zarr') or raw_file.endswith('.n5'): pipeline = gp.ZarrSource(raw_file, datasets={raw: raw_dataset}, array_specs={ raw: gp.ArraySpec(interpolatable=True), }) else: raise RuntimeError('unknwon input data format {}'.format(raw_file)) pipeline += gp.Pad(raw, size=None) pipeline += gp.Normalize(raw) pipeline += gp.IntensityScaleShift(raw, 2, -1) pipeline += gp.tensorflow.Predict( os.path.join(setup_dir, 'train_net_checkpoint_%d' % iteration), inputs={net_config['raw']: raw}, outputs={ net_config['pred_syn_indicator_out']: pred_post_indicator, net_config['pred_partner_vectors']: pred_postpre_vectors }, graph=os.path.join(setup_dir, '{}_net.meta'.format(network_config))) d_scale = parameters['d_scale'] if 'd_scale' in parameters else None if d_scale != 1 and d_scale is not None: pipeline += gp.IntensityScaleShift(pred_postpre_vectors, 1. / d_scale, 0) # Map back to nm world. if m_property is not None and 'scale' in m_property: if m_property['scale'] != 1: pipeline += gp.IntensityScaleShift(pred_post_indicator, m_property['scale'], 0) if d_property is not None and 'scale' in d_property: pipeline += gp.IntensityScaleShift(pred_postpre_vectors, d_property['scale'], 0) if d_property is not None and 'dtype' in d_property: assert d_property['dtype'] == 'int8' or d_property[ 'dtype'] == 'float32', 'predict not adapted to dtype {}'.format( d_property['dtype']) if d_property['dtype'] == 'int8': pipeline += IntensityScaleShiftClip(pred_postpre_vectors, 1, 0, clip=(-128, 127)) pipeline += gp.ZarrWrite(dataset_names={ pred_post_indicator: 'volumes/pred_syn_indicator', pred_postpre_vectors: 'volumes/pred_partner_vectors', }, output_filename=out_file) pipeline += gp.PrintProfilingStats(every=10) pipeline += gp.DaisyRequestBlocks( chunk_request, roi_map={ raw: 'read_roi', pred_postpre_vectors: 'write_roi', pred_post_indicator: 'write_roi' }, num_workers=worker_config['num_cache_workers'], block_done_callback=lambda b, s, d: block_done_callback( db_host, db_name, worker_config, b, s, d)) print("Starting prediction...") with gp.build(pipeline): pipeline.request_batch(gp.BatchRequest()) print("Prediction finished")
def build_pipeline(parameter, augment=True): voxel_size = gp.Coordinate(parameter['voxel_size']) # Array Specifications. raw = gp.ArrayKey('RAW') gt_neurons = gp.ArrayKey('GT_NEURONS') gt_postpre_vectors = gp.ArrayKey('GT_POSTPRE_VECTORS') gt_post_indicator = gp.ArrayKey('GT_POST_INDICATOR') post_loss_weight = gp.ArrayKey('POST_LOSS_WEIGHT') vectors_mask = gp.ArrayKey('VECTORS_MASK') pred_postpre_vectors = gp.ArrayKey('PRED_POSTPRE_VECTORS') pred_post_indicator = gp.ArrayKey('PRED_POST_INDICATOR') grad_syn_indicator = gp.ArrayKey('GRAD_SYN_INDICATOR') grad_partner_vectors = gp.ArrayKey('GRAD_PARTNER_VECTORS') # Points specifications dummypostsyn = gp.PointsKey('DUMMYPOSTSYN') postsyn = gp.PointsKey('POSTSYN') presyn = gp.PointsKey('PRESYN') trg_context = 140 # AddPartnerVectorMap context in nm - pre-post distance with open('train_net_config.json', 'r') as f: net_config = json.load(f) input_size = gp.Coordinate(net_config['input_shape']) * voxel_size output_size = gp.Coordinate(net_config['output_shape']) * voxel_size request = gp.BatchRequest() request.add(raw, input_size) request.add(gt_neurons, output_size) request.add(gt_postpre_vectors, output_size) request.add(gt_post_indicator, output_size) request.add(post_loss_weight, output_size) request.add(vectors_mask, output_size) request.add(dummypostsyn, output_size) for (key, request_spec) in request.items(): print(key) print(request_spec.roi) request_spec.roi.contains(request_spec.roi) # slkfdms snapshot_request = gp.BatchRequest({ pred_post_indicator: request[gt_postpre_vectors], pred_postpre_vectors: request[gt_postpre_vectors], grad_syn_indicator: request[gt_postpre_vectors], grad_partner_vectors: request[gt_postpre_vectors], vectors_mask: request[gt_postpre_vectors] }) postsyn_rastersetting = gp.RasterizationSettings( radius=parameter['blob_radius'], mask=gt_neurons, mode=parameter['blob_mode']) pipeline = tuple([ create_source(sample, raw, presyn, postsyn, dummypostsyn, parameter, gt_neurons) for sample in samples ]) pipeline += gp.RandomProvider() if augment: pipeline += gp.ElasticAugment([4, 40, 40], [0, 2, 2], [0, math.pi / 2.0], prob_slip=0.05, prob_shift=0.05, max_misalign=10, subsample=8) pipeline += gp.SimpleAugment(transpose_only=[1, 2], mirror_only=[1, 2]) pipeline += gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) pipeline += gp.IntensityScaleShift(raw, 2, -1) pipeline += gp.RasterizePoints( postsyn, gt_post_indicator, gp.ArraySpec(voxel_size=voxel_size, dtype=np.int32), postsyn_rastersetting) spec = gp.ArraySpec(voxel_size=voxel_size) pipeline += AddPartnerVectorMap( src_points=postsyn, trg_points=presyn, array=gt_postpre_vectors, radius=parameter['d_blob_radius'], trg_context=trg_context, # enlarge array_spec=spec, mask=gt_neurons, pointmask=vectors_mask) pipeline += gp.BalanceLabels(labels=gt_post_indicator, scales=post_loss_weight, slab=(-1, -1, -1), clipmin=parameter['cliprange'][0], clipmax=parameter['cliprange'][1]) if parameter['d_scale'] != 1: pipeline += gp.IntensityScaleShift(gt_postpre_vectors, scale=parameter['d_scale'], shift=0) pipeline += gp.PreCache(cache_size=40, num_workers=10) pipeline += gp.tensorflow.Train( './train_net', optimizer=net_config['optimizer'], loss=net_config['loss'], summary=net_config['summary'], log_dir='./tensorboard/', save_every=30000, # 10000 log_every=100, inputs={ net_config['raw']: raw, net_config['gt_partner_vectors']: gt_postpre_vectors, net_config['gt_syn_indicator']: gt_post_indicator, net_config['vectors_mask']: vectors_mask, # Loss weights --> mask net_config['indicator_weight']: post_loss_weight, # Loss weights }, outputs={ net_config['pred_partner_vectors']: pred_postpre_vectors, net_config['pred_syn_indicator']: pred_post_indicator, }, gradients={ net_config['pred_partner_vectors']: grad_partner_vectors, net_config['pred_syn_indicator']: grad_syn_indicator, }, ) # Visualize. pipeline += gp.IntensityScaleShift(raw, 0.5, 0.5) pipeline += gp.Snapshot( { raw: 'volumes/raw', gt_neurons: 'volumes/labels/neuron_ids', gt_post_indicator: 'volumes/gt_post_indicator', gt_postpre_vectors: 'volumes/gt_postpre_vectors', pred_postpre_vectors: 'volumes/pred_postpre_vectors', pred_post_indicator: 'volumes/pred_post_indicator', post_loss_weight: 'volumes/post_loss_weight', grad_syn_indicator: 'volumes/post_indicator_gradients', grad_partner_vectors: 'volumes/partner_vectors_gradients', vectors_mask: 'volumes/vectors_mask' }, every=1000, output_filename='batch_{iteration}.hdf', compression_type='gzip', additional_request=snapshot_request) pipeline += gp.PrintProfilingStats(every=100) print("Starting training...") max_iteration = parameter['max_iteration'] with gp.build(pipeline) as b: for i in range(max_iteration): b.request_batch(request)
def predict(**kwargs): name = kwargs['name'] raw = gp.ArrayKey('RAW') raw_cropped = gp.ArrayKey('RAW_CROPPED') pred_affs = gp.ArrayKey('PRED_AFFS') pred_fgbg = gp.ArrayKey('PRED_FGBG') with open(os.path.join(kwargs['input_folder'], name + '_config.json'), 'r') as f: net_config = json.load(f) with open(os.path.join(kwargs['input_folder'], name + '_names.json'), 'r') as f: net_names = json.load(f) voxel_size = gp.Coordinate(kwargs['voxel_size']) input_shape_world = gp.Coordinate(net_config['input_shape']) * voxel_size output_shape_world = gp.Coordinate(net_config['output_shape']) * voxel_size context = (input_shape_world - output_shape_world) // 2 # formulate the request for what a batch should contain request = gp.BatchRequest() request.add(raw, input_shape_world) request.add(raw_cropped, output_shape_world) request.add(pred_affs, output_shape_world) request.add(pred_fgbg, output_shape_world) if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr": raise NotImplementedError("predict node for %s not implemented yet", kwargs['input_format']) if kwargs['input_format'] == "hdf": sourceNode = gp.Hdf5Source with h5py.File( os.path.join(kwargs['data_folder'], kwargs['sample'] + ".hdf"), 'r') as f: shape = f['volumes/raw'].shape elif kwargs['input_format'] == "zarr": sourceNode = gp.ZarrSource f = zarr.open( os.path.join(kwargs['data_folder'], kwargs['sample'] + ".zarr"), 'r') shape = f['volumes/raw'].shape source = sourceNode(os.path.join( kwargs['data_folder'], kwargs['sample'] + "." + kwargs['input_format']), datasets={raw: 'volumes/raw'}) if kwargs['output_format'] != "zarr": raise NotImplementedError("Please use zarr as prediction output") # pre-create zarr file zf = zarr.open(os.path.join(kwargs['output_folder'], kwargs['sample'] + '.zarr'), mode='w') zf.create('volumes/pred_affs', shape=[3] + list(shape), chunks=[3] + list(shape), dtype=np.float32) zf['volumes/pred_affs'].attrs['offset'] = [0, 0, 0] zf['volumes/pred_affs'].attrs['resolution'] = kwargs['voxel_size'] zf.create('volumes/pred_fgbg', shape=[1] + list(shape), chunks=[1] + list(shape), dtype=np.float32) zf['volumes/pred_fgbg'].attrs['offset'] = [0, 0, 0] zf['volumes/pred_fgbg'].attrs['resolution'] = kwargs['voxel_size'] zf.create('volumes/raw_cropped', shape=[1] + list(shape), chunks=[1] + list(shape), dtype=np.float32) zf['volumes/raw_cropped'].attrs['offset'] = [0, 0, 0] zf['volumes/raw_cropped'].attrs['resolution'] = kwargs['voxel_size'] pipeline = ( # read from HDF5 file source + gp.Pad(raw, context) + # perform one training iteration for each passing batch (here we use # the tensor names earlier stored in train_net.config) gp.tensorflow.Predict(graph=os.path.join(kwargs['input_folder'], name + '.meta'), checkpoint=kwargs['checkpoint'], inputs={net_names['raw']: raw}, outputs={ net_names['pred_affs']: pred_affs, net_names['pred_fgbg']: pred_fgbg, net_names['raw_cropped']: raw_cropped }) + # store all passing batches in the same HDF5 file gp.ZarrWrite( { raw_cropped: '/volumes/raw_cropped', pred_affs: '/volumes/pred_affs', pred_fgbg: '/volumes/pred_fgbg', }, output_dir=kwargs['output_folder'], output_filename=kwargs['sample'] + ".zarr", compression_type='gzip') + # show a summary of time spend in each node every 10 iterations gp.PrintProfilingStats(every=10) + # iterate over the whole dataset in a scanning fashion, emitting # requests that match the size of the network gp.Scan(reference=request)) with gp.build(pipeline): # request an empty batch from Scan to trigger scanning of the dataset # without keeping the complete dataset in memory pipeline.request_batch(gp.BatchRequest())
def create_train_pipeline(self, model): print( f"Creating training pipeline with batch size {self.params['batch_size']}" ) filename = self.params['data_file'] raw_dataset = self.params['dataset']['train']['raw'] gt_dataset = self.params['dataset']['train']['gt'] optimizer = self.params['optimizer'](model.parameters(), **self.params['optimizer_kwargs']) raw = gp.ArrayKey('RAW') gt_labels = gp.ArrayKey('LABELS') gt_aff = gp.ArrayKey('AFFINITIES') predictions = gp.ArrayKey('PREDICTIONS') emb = gp.ArrayKey('EMBEDDING') raw_data = daisy.open_ds(filename, raw_dataset) source_roi = gp.Roi(raw_data.roi.get_offset(), raw_data.roi.get_shape()) source_voxel_size = gp.Coordinate(raw_data.voxel_size) out_voxel_size = gp.Coordinate(raw_data.voxel_size) # Get in and out shape in_shape = gp.Coordinate(model.in_shape) out_shape = gp.Coordinate(model.out_shape[2:]) is_2d = in_shape.dims() == 2 in_shape = in_shape * out_voxel_size out_shape = out_shape * out_voxel_size context = (in_shape - out_shape) / 2 gt_labels_out_shape = out_shape # Add fake 3rd dim if is_2d: source_voxel_size = gp.Coordinate((1, *source_voxel_size)) source_roi = gp.Roi((0, *source_roi.get_offset()), (raw_data.shape[0], *source_roi.get_shape())) context = gp.Coordinate((0, *context)) aff_neighborhood = [[0, -1, 0], [0, 0, -1]] gt_labels_out_shape = (1, *gt_labels_out_shape) else: aff_neighborhood = [[-1, 0, 0], [0, -1, 0], [0, 0, -1]] logger.info(f"source roi: {source_roi}") logger.info(f"in_shape: {in_shape}") logger.info(f"out_shape: {out_shape}") logger.info(f"voxel_size: {out_voxel_size}") logger.info(f"context: {context}") request = gp.BatchRequest() request.add(raw, in_shape) request.add(gt_aff, out_shape) request.add(predictions, out_shape) snapshot_request = gp.BatchRequest() snapshot_request[emb] = gp.ArraySpec( roi=gp.Roi((0, ) * in_shape.dims(), gp.Coordinate((*model.base_encoder.out_shape[2:], )) * out_voxel_size)) snapshot_request[gt_labels] = gp.ArraySpec( roi=gp.Roi(context, gt_labels_out_shape)) source = ( gp.ZarrSource(filename, { raw: raw_dataset, gt_labels: gt_dataset }, array_specs={ raw: gp.ArraySpec(roi=source_roi, voxel_size=source_voxel_size, interpolatable=True), gt_labels: gp.ArraySpec(roi=source_roi, voxel_size=source_voxel_size) }) + gp.Normalize(raw, self.params['norm_factor']) + gp.Pad(raw, context) + gp.Pad(gt_labels, context) + gp.RandomLocation() # raw : (l=1, h, w) # gt_labels: (l=1, h, w) ) source = self._augmentation_pipeline(raw, source) pipeline = ( source + # raw : (l=1, h, w) # gt_labels: (l=1, h, w) gp.AddAffinities(aff_neighborhood, gt_labels, gt_aff) + SetDtype(gt_aff, np.float32) + # raw : (l=1, h, w) # gt_aff : (c=2, l=1, h, w) AddChannelDim(raw) # raw : (c=1, l=1, h, w) # gt_aff : (c=2, l=1, h, w) ) if is_2d: pipeline = ( pipeline + RemoveSpatialDim(raw) + RemoveSpatialDim(gt_aff) # raw : (c=1, h, w) # gt_aff : (c=2, h, w) ) pipeline = ( pipeline + gp.Stack(self.params['batch_size']) + gp.PreCache() + # raw : (b, c=1, h, w) # gt_aff : (b, c=2, h, w) # (which is what train requires) gp.torch.Train( model, self.loss, optimizer, inputs={'raw': raw}, loss_inputs={ 0: predictions, 1: gt_aff }, outputs={ 0: predictions, 1: emb }, array_specs={ predictions: gp.ArraySpec(voxel_size=out_voxel_size), }, checkpoint_basename=self.logdir + '/checkpoints/model', save_every=self.params['save_every'], log_dir=self.logdir, log_every=self.log_every) + # everything is 2D at this point, plus extra dimensions for # channels and batch # raw : (b, c=1, h, w) # gt_aff : (b, c=2, h, w) # predictions: (b, c=2, h, w) # Crop GT to look at labels gp.Crop(gt_labels, gp.Roi(context, gt_labels_out_shape)) + gp.Snapshot(output_dir=self.logdir + '/snapshots', output_filename='it{iteration}.hdf', dataset_names={ raw: 'raw', gt_labels: 'gt_labels', predictions: 'predictions', gt_aff: 'gt_aff', emb: 'emb' }, additional_request=snapshot_request, every=self.params['save_every']) + gp.PrintProfilingStats(every=500)) return pipeline, request
output_roi.get_end() / input_spec.voxel_size, ))] output_spec = copy.deepcopy(input_spec) output_spec.roi = output_roi output_array = gp.Array(output_data, output_spec) batch[self.output_array] = output_array input_size = Coordinate([74, 260, 260]) output_size = Coordinate([42, 168, 168]) path_to_data = Path("/nrs/funke/mouselight-v2") # array keys for data sources raw = gp.ArrayKey("RAW") swcs = gp.PointsKey("SWCS") labels = gp.ArrayKey("LABELS") # array keys for base volume raw_base = gp.ArrayKey("RAW_BASE") labels_base = gp.ArrayKey("LABELS_BASE") swc_base = gp.PointsKey("SWC_BASE") # array keys for add volume raw_add = gp.ArrayKey("RAW_ADD") labels_add = gp.ArrayKey("LABELS_ADD") swc_add = gp.PointsKey("SWC_ADD") # array keys for fused volume raw_fused = gp.ArrayKey("RAW_FUSED")
def predict( model: Model, raw_array: Array, prediction_array_identifier: LocalArrayIdentifier, num_cpu_workers: int = 4, compute_context: ComputeContext = LocalTorch(), output_roi: Optional[Roi] = None, ): # get the model's input and output size input_voxel_size = Coordinate(raw_array.voxel_size) output_voxel_size = model.scale(input_voxel_size) input_shape = Coordinate(model.eval_input_shape) input_size = input_voxel_size * input_shape output_size = output_voxel_size * model.compute_output_shape(input_shape)[1] logger.info( "Predicting with input size %s, output size %s", input_size, output_size ) # calculate input and output rois context = (input_size - output_size) / 2 if output_roi is None: input_roi = raw_array.roi output_roi = input_roi.grow(-context, -context) else: input_roi = output_roi.grow(context, context) logger.info("Total input ROI: %s, output ROI: %s", input_roi, output_roi) # prepare prediction dataset axes = ["c"] + [axis for axis in raw_array.axes if axis != "c"] ZarrArray.create_from_array_identifier( prediction_array_identifier, axes, output_roi, model.num_out_channels, output_voxel_size, np.float32, ) # create gunpowder keys raw = gp.ArrayKey("RAW") prediction = gp.ArrayKey("PREDICTION") # assemble prediction pipeline # prepare data source pipeline = DaCapoArraySource(raw_array, raw) # raw: (c, d, h, w) pipeline += gp.Pad(raw, Coordinate((None,) * input_voxel_size.dims)) # raw: (c, d, h, w) pipeline += gp.Unsqueeze([raw]) # raw: (1, c, d, h, w) gt_padding = (output_size - output_roi.shape) % output_size prediction_roi = output_roi.grow(gt_padding) # predict pipeline += gp_torch.Predict( model=model, inputs={"x": raw}, outputs={0: prediction}, array_specs={ prediction: gp.ArraySpec( roi=prediction_roi, voxel_size=output_voxel_size, dtype=np.float32 ) }, spawn_subprocess=False, device=str(compute_context.device), ) # raw: (1, c, d, h, w) # prediction: (1, [c,] d, h, w) # prepare writing pipeline += gp.Squeeze([raw, prediction]) # raw: (c, d, h, w) # prediction: (c, d, h, w) # raw: (c, d, h, w) # prediction: (c, d, h, w) # write to zarr pipeline += gp.ZarrWrite( {prediction: prediction_array_identifier.dataset}, prediction_array_identifier.container.parent, prediction_array_identifier.container.name, ) # create reference batch request ref_request = gp.BatchRequest() ref_request.add(raw, input_size) ref_request.add(prediction, output_size) pipeline += gp.Scan(ref_request) # build pipeline and predict in complete output ROI with gp.build(pipeline): pipeline.request_batch(gp.BatchRequest()) container = zarr.open(prediction_array_identifier.container) dataset = container[prediction_array_identifier.dataset] dataset.attrs["axes"] = ( raw_array.axes if "c" in raw_array.axes else ["c"] + raw_array.axes )
def train(iterations): ################## # DECLARE ARRAYS # ################## # raw intensities raw = gp.ArrayKey('RAW') # objects labelled with unique IDs gt_labels = gp.ArrayKey('LABELS') # array of per-voxel affinities to direct neighbors gt_affs = gp.ArrayKey('AFFINITIES') # weights to use to balance the loss loss_weights = gp.ArrayKey('LOSS_WEIGHTS') # the predicted affinities pred_affs = gp.ArrayKey('PRED_AFFS') # the gredient of the loss wrt to the predicted affinities pred_affs_gradients = gp.ArrayKey('PRED_AFFS_GRADIENTS') #################### # DECLARE REQUESTS # #################### with open('train_net_config.json', 'r') as f: net_config = json.load(f) # get the input and output size in world units (nm, in this case) voxel_size = gp.Coordinate((8, 8, 8)) input_size = gp.Coordinate(net_config['input_shape']) * voxel_size output_size = gp.Coordinate(net_config['output_shape']) * voxel_size # formulate the request for what a batch should (at least) contain request = gp.BatchRequest() request.add(raw, input_size) request.add(gt_affs, output_size) request.add(loss_weights, output_size) # when we make a snapshot for inspection (see below), we also want to # request the predicted affinities and gradients of the loss wrt the # affinities snapshot_request = gp.BatchRequest() snapshot_request[pred_affs] = request[gt_affs] snapshot_request[pred_affs_gradients] = request[gt_affs] ############################## # ASSEMBLE TRAINING PIPELINE # ############################## pipeline = ( # a tuple of sources, one for each sample (A, B, and C) provided by the # CREMI challenge tuple( # read batches from the HDF5 file gp.Hdf5Source(os.path.join(data_dir, 'fib.hdf'), datasets={ raw: 'volumes/raw', gt_labels: 'volumes/labels/neuron_ids' }) + # convert raw to float in [0, 1] gp.Normalize(raw) + # chose a random location for each requested batch gp.RandomLocation()) + # chose a random source (i.e., sample) from the above gp.RandomProvider() + # elastically deform the batch gp.ElasticAugment([8, 8, 8], [0, 2, 2], [0, math.pi / 2.0], prob_slip=0.05, prob_shift=0.05, max_misalign=25) + # apply transpose and mirror augmentations gp.SimpleAugment(transpose_only=[1, 2]) + # scale and shift the intensity of the raw array gp.IntensityAugment(raw, scale_min=0.9, scale_max=1.1, shift_min=-0.1, shift_max=0.1, z_section_wise=True) + # grow a boundary between labels gp.GrowBoundary(gt_labels, steps=3, only_xy=True) + # convert labels into affinities between voxels gp.AddAffinities([[-1, 0, 0], [0, -1, 0], [0, 0, -1]], gt_labels, gt_affs) + # create a weight array that balances positive and negative samples in # the affinity array gp.BalanceLabels(gt_affs, loss_weights) + # pre-cache batches from the point upstream gp.PreCache(cache_size=10, num_workers=5) + # perform one training iteration for each passing batch (here we use # the tensor names earlier stored in train_net.config) gp.tensorflow.Train( 'train_net', net_config['optimizer'], net_config['loss'], inputs={ net_config['raw']: raw, net_config['gt_affs']: gt_affs, net_config['loss_weights']: loss_weights }, outputs={net_config['pred_affs']: pred_affs}, gradients={net_config['pred_affs']: pred_affs_gradients}, save_every=10000) + # save the passing batch as an HDF5 file for inspection gp.Snapshot( { raw: '/volumes/raw', gt_labels: '/volumes/labels/neuron_ids', gt_affs: '/volumes/labels/affs', pred_affs: '/volumes/pred_affs', pred_affs_gradients: '/volumes/pred_affs_gradients' }, output_dir='snapshots', output_filename='batch_{iteration}.hdf', every=1000, additional_request=snapshot_request, compression_type='gzip') + # show a summary of time spend in each node every 10 iterations gp.PrintProfilingStats(every=1000)) ######### # TRAIN # ######### print("Training for", iterations, "iterations") with gp.build(pipeline): for i in range(iterations): pipeline.request_batch(request) print("Finished")
"cc_threshold": 0.50, "loc_type": "edt", "score_thr": 0, "score_type": "sum", "nms_radius": None } parameters = synful.detection.SynapseExtractionParameters( extract_type=parameter_dic['extract_type'], cc_threshold=parameter_dic['cc_threshold'], loc_type=parameter_dic['loc_type'], score_thr=parameter_dic['score_thr'], score_type=parameter_dic['score_type'], nms_radius=parameter_dic['nms_radius']) gp.ArrayKey('M_PRED') gp.ArrayKey('D_PRED') gp.PointsKey('PRESYN') gp.PointsKey('POSTSYN') class TestSource(gp.BatchProvider): def __init__(self, m_pred, d_pred, voxel_size): self.voxel_size = voxel_size self.m_pred = m_pred self.d_pred = d_pred def setup(self): self.provides( gp.ArrayKeys.M_PRED, gp.ArraySpec(roi=gp.Roi((0, 0, 0), (200, 200, 200)),
def create_train_pipeline(self, model): print(f"Creating training pipeline with batch size \ {self.params['batch_size']}") filename = self.params['data_file'] raw_dataset = self.params['dataset']['train']['raw'] gt_dataset = self.params['dataset']['train']['gt'] optimizer = self.params['optimizer'](model.parameters(), **self.params['optimizer_kwargs']) raw = gp.ArrayKey('RAW') gt_labels = gp.ArrayKey('LABELS') points = gp.GraphKey("POINTS") locations = gp.ArrayKey("LOCATIONS") predictions = gp.ArrayKey('PREDICTIONS') emb = gp.ArrayKey('EMBEDDING') raw_data = daisy.open_ds(filename, raw_dataset) source_roi = gp.Roi(raw_data.roi.get_offset(), raw_data.roi.get_shape()) source_voxel_size = gp.Coordinate(raw_data.voxel_size) out_voxel_size = gp.Coordinate(raw_data.voxel_size) # Get in and out shape in_shape = gp.Coordinate(model.in_shape) out_roi = gp.Coordinate(model.base_encoder.out_shape[2:]) is_2d = in_shape.dims() == 2 in_shape = in_shape * out_voxel_size out_roi = out_roi * out_voxel_size out_shape = gp.Coordinate( (self.params["num_points"], *model.out_shape[2:])) context = (in_shape - out_roi) / 2 gt_labels_out_shape = out_roi # Add fake 3rd dim if is_2d: source_voxel_size = gp.Coordinate((1, *source_voxel_size)) source_roi = gp.Roi((0, *source_roi.get_offset()), (raw_data.shape[0], *source_roi.get_shape())) context = gp.Coordinate((0, *context)) gt_labels_out_shape = (1, *gt_labels_out_shape) points_roi = out_voxel_size * tuple((*self.params["point_roi"], )) points_pad = (0, *points_roi) context = gp.Coordinate((0, None, None)) else: points_roi = source_voxel_size * tuple(self.params["point_roi"]) points_pad = points_roi context = gp.Coordinate((None, None, None)) logger.info(f"source roi: {source_roi}") logger.info(f"in_shape: {in_shape}") logger.info(f"out_shape: {out_shape}") logger.info(f"voxel_size: {out_voxel_size}") logger.info(f"context: {context}") logger.info(f"out_voxel_size: {out_voxel_size}") request = gp.BatchRequest() request.add(raw, in_shape) request.add(points, points_roi) request.add(gt_labels, out_roi) request[locations] = gp.ArraySpec(nonspatial=True) request[predictions] = gp.ArraySpec(nonspatial=True) snapshot_request = gp.BatchRequest() snapshot_request[emb] = gp.ArraySpec( roi=gp.Roi((0, ) * in_shape.dims(), gp.Coordinate((*model.base_encoder.out_shape[2:], )) * out_voxel_size)) source = ( (gp.ZarrSource(filename, { raw: raw_dataset, gt_labels: gt_dataset }, array_specs={ raw: gp.ArraySpec(roi=source_roi, voxel_size=source_voxel_size, interpolatable=True), gt_labels: gp.ArraySpec(roi=source_roi, voxel_size=source_voxel_size) }), PointsLabelsSource(points, self.data, scale=source_voxel_size)) + gp.MergeProvider() + gp.Pad(raw, context) + gp.Pad(gt_labels, context) + gp.Pad(points, points_pad) + gp.RandomLocation(ensure_nonempty=points) + gp.Normalize(raw, self.params['norm_factor']) # raw : (source_roi) # gt_labels: (source_roi) # points : (c=1, source_locations_shape) # If 2d then source_roi = (1, input_shape) in order to select a RL ) source = self._augmentation_pipeline(raw, source) pipeline = ( source + # Batches seem to be rejected because points are chosen near the # edge of the points ROI and the augmentations remove them. # TODO: Figure out if this is an actual issue, and if anything can # be done. gp.Reject(ensure_nonempty=points) + SetDtype(gt_labels, np.int64) + # raw : (source_roi) # gt_labels: (source_roi) # points : (c=1, source_locations_shape) AddChannelDim(raw) + AddChannelDim(gt_labels) # raw : (c=1, source_roi) # gt_labels: (c=2, source_roi) # points : (c=1, source_locations_shape) ) if is_2d: pipeline = ( # Remove extra dim the 2d roi had pipeline + RemoveSpatialDim(raw) + RemoveSpatialDim(gt_labels) + RemoveSpatialDim(points) # raw : (c=1, roi) # gt_labels: (c=1, roi) # points : (c=1, locations_shape) ) pipeline = ( pipeline + FillLocations(raw, points, locations, is_2d=False, max_points=1) + gp.Stack(self.params['batch_size']) + gp.PreCache() + # raw : (b, c=1, roi) # gt_labels: (b, c=1, roi) # locations: (b, c=1, locations_shape) # (which is what train requires) gp.torch.Train( model, self.loss, optimizer, inputs={ 'raw': raw, 'points': locations }, loss_inputs={ 0: predictions, 1: gt_labels, 2: locations }, outputs={ 0: predictions, 1: emb }, array_specs={ predictions: gp.ArraySpec(nonspatial=True), emb: gp.ArraySpec(voxel_size=out_voxel_size) }, checkpoint_basename=self.logdir + '/checkpoints/model', save_every=self.params['save_every'], log_dir=self.logdir, log_every=self.log_every) + # everything is 2D at this point, plus extra dimensions for # channels and batch # raw : (b, c=1, roi) # gt_labels : (b, c=1, roi) # predictions: (b, num_points) gp.Snapshot(output_dir=self.logdir + '/snapshots', output_filename='it{iteration}.hdf', dataset_names={ raw: 'raw', gt_labels: 'gt_labels', predictions: 'predictions', emb: 'emb' }, additional_request=snapshot_request, every=self.params['save_every']) + InspectBatch('END') + gp.PrintProfilingStats(every=500)) return pipeline, request
def train(until): model = SpineUNet() loss = torch.nn.BCELoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.001) input_size = (8, 96, 96) raw = gp.ArrayKey('RAW') labels = gp.ArrayKey('LABELS') affs = gp.ArrayKey('AFFS') affs_predicted = gp.ArrayKey('AFFS_PREDICTED') pipeline = ( ( gp.ZarrSource( 'data/20200201.zarr', { raw: 'train/sample1/raw', labels: 'train/sample1/labels' }), gp.ZarrSource( 'data/20200201.zarr', { raw: 'train/sample2/raw', labels: 'train/sample2/labels' }), gp.ZarrSource( 'data/20200201.zarr', { raw: 'train/sample3/raw', labels: 'train/sample3/labels' }) ) + gp.RandomProvider() + gp.Normalize(raw) + gp.RandomLocation() + gp.SimpleAugment(transpose_only=(1, 2)) + gp.ElasticAugment((2, 10, 10), (0.0, 0.5, 0.5), [0, math.pi]) + gp.AddAffinities( [(1, 0, 0), (0, 1, 0), (0, 0, 1)], labels, affs) + gp.Normalize(affs, factor=1.0) + #gp.PreCache(num_workers=1) + # raw: (d, h, w) # affs: (3, d, h, w) gp.Stack(1) + # raw: (1, d, h, w) # affs: (1, 3, d, h, w) AddChannelDim(raw) + # raw: (1, 1, d, h, w) # affs: (1, 3, d, h, w) gp_torch.Train( model, loss, optimizer, inputs={'x': raw}, outputs={0: affs_predicted}, loss_inputs={0: affs_predicted, 1: affs}, save_every=10000) + RemoveChannelDim(raw) + RemoveChannelDim(raw) + RemoveChannelDim(affs) + RemoveChannelDim(affs_predicted) + # raw: (d, h, w) # affs: (3, d, h, w) # affs_predicted: (3, d, h, w) gp.Snapshot( { raw: 'raw', labels: 'labels', affs: 'affs', affs_predicted: 'affs_predicted' }, every=500, output_filename='iteration_{iteration}.hdf') ) request = gp.BatchRequest() request.add(raw, input_size) request.add(labels, input_size) request.add(affs, input_size) request.add(affs_predicted, input_size) with gp.build(pipeline): for i in range(until): pipeline.request_batch(request)
def make_pipeline(self): raw = gp.ArrayKey('RAW') pred_affs = gp.ArrayKey('PREDICTIONS') source_shape = zarr.open(self.data_file)[self.dataset].shape raw_roi = gp.Roi(np.zeros(len(source_shape[1:])), source_shape[1:]) data = daisy.open_ds(self.data_file, self.dataset) source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape()) voxel_size = gp.Coordinate(data.voxel_size) # Get in and out shape in_shape = gp.Coordinate(self.model.in_shape) out_shape = gp.Coordinate(self.model.out_shape[2:]) is_2d = in_shape.dims() == 2 in_shape = in_shape * voxel_size out_shape = out_shape * voxel_size logger.info(f"source roi: {source_roi}") logger.info(f"in_shape: {in_shape}") logger.info(f"out_shape: {out_shape}") logger.info(f"voxel_size: {voxel_size}") request = gp.BatchRequest() request.add(raw, in_shape) request.add(pred_affs, out_shape) context = (in_shape - out_shape) / 2 source = (gp.ZarrSource(self.data_file, { raw: self.dataset, }, array_specs={ raw: gp.ArraySpec(roi=source_roi, interpolatable=False) })) in_dims = len(self.model.in_shape) if is_2d: # 2D: [samples, y, x] or [samples, channels, y, x] needs_channel_fix = (len(data.shape) - in_dims == 1) if needs_channel_fix: source = (source + AddChannelDim(raw, axis=1)) # raw [samples, channels, y, x] else: # 3D: [z, y, x] or [channel, z, y, x] or [sample, channel, z, y, x] needs_channel_fix = (len(data.shape) - in_dims == 0) needs_batch_fix = (len(data.shape) - in_dims <= 1) if needs_channel_fix: source = (source + AddChannelDim(raw, axis=0)) # Batch fix if needs_batch_fix: source = (source + AddChannelDim(raw)) # raw: [sample, channels, z, y, x] with gp.build(source): raw_roi = source.spec[raw].roi logger.info(f"raw_roi: {raw_roi}") pipeline = (source + gp.Normalize(raw, factor=self.params['norm_factor']) + gp.Pad(raw, context) + gp.PreCache() + gp.torch.Predict( self.model, inputs={'raw': raw}, outputs={0: pred_affs}, array_specs={pred_affs: gp.ArraySpec(roi=raw_roi)})) pipeline = (pipeline + gp.ZarrWrite({ pred_affs: 'predictions', }, output_dir=self.curr_log_dir, output_filename='predictions.zarr', compression_type='gzip') + gp.Scan(request)) return pipeline, request, pred_affs
args = parser.parse_args() directory = args.directory max_dist = args.max_dist # download some test data url = "https://cremi.org/static/data/sample_A_20160501.hdf" urllib.request.urlretrieve(url, os.path.join(directory, 'sample_A.hdf')) # configure where to store results result_file = "sample_A.n5" ds_name = "neuron_ids_stardists_downsampled" if max_dist is not None: ds_name += "_max{0:}".format(max_dist) # declare arrays to use raw = gp.ArrayKey("RAW") labels = gp.ArrayKey("LABELS") stardists = gp.ArrayKey("STARDIST") # prepare requests for scanning (i.e. chunks) and overall scan_request = gp.BatchRequest() scan_request[stardists] = gp.Roi( gp.Coordinate((0, 0, 0)), gp.Coordinate((40, 100, 100)) * gp.Coordinate((40, 8, 8))) voxel_size = gp.Coordinate((40, 4, 4)) request = gp.BatchRequest( ) # empty request will loop over whole area with scanning request[stardists] = gp.Roi( gp.Coordinate((40, 200, 200)) * gp.Coordinate((40, 8, 8)), gp.Coordinate((40, 100, 100)) * gp.Coordinate((40, 8, 8)) * gp.Coordinate( (2, 2, 2)))
def train_simple_pipeline(n_iterations, setup_config, mknet_tensor_names, loss_tensor_names): input_shape = gp.Coordinate(setup_config["INPUT_SHAPE"]) output_shape = gp.Coordinate(setup_config["OUTPUT_SHAPE"]) voxel_size = gp.Coordinate(setup_config["VOXEL_SIZE"]) num_iterations = setup_config["NUM_ITERATIONS"] cache_size = setup_config["CACHE_SIZE"] num_workers = setup_config["NUM_WORKERS"] snapshot_every = setup_config["SNAPSHOT_EVERY"] checkpoint_every = setup_config["CHECKPOINT_EVERY"] profile_every = setup_config["PROFILE_EVERY"] seperate_by = setup_config["SEPERATE_BY"] gap_crossing_dist = setup_config["GAP_CROSSING_DIST"] match_distance_threshold = setup_config["MATCH_DISTANCE_THRESHOLD"] point_balance_radius = setup_config["POINT_BALANCE_RADIUS"] neuron_radius = setup_config["NEURON_RADIUS"] samples_path = Path(setup_config["SAMPLES_PATH"]) mongo_url = setup_config["MONGO_URL"] input_size = input_shape * voxel_size output_size = output_shape * voxel_size # voxels have size ~= 1 micron on z axis # use this value to scale anything that depends on world unit distance micron_scale = voxel_size[0] seperate_distance = (np.array(seperate_by)).tolist() # array keys for data sources raw = gp.ArrayKey("RAW") consensus = gp.PointsKey("CONSENSUS") skeletonization = gp.PointsKey("SKELETONIZATION") matched = gp.PointsKey("MATCHED") labels = gp.ArrayKey("LABELS") labels_fg = gp.ArrayKey("LABELS_FG") labels_fg_bin = gp.ArrayKey("LABELS_FG_BIN") loss_weights = gp.ArrayKey("LOSS_WEIGHTS") # tensorflow tensors gt_fg = gp.ArrayKey("GT_FG") fg_pred = gp.ArrayKey("FG_PRED") embedding = gp.ArrayKey("EMBEDDING") fg = gp.ArrayKey("FG") maxima = gp.ArrayKey("MAXIMA") gradient_embedding = gp.ArrayKey("GRADIENT_EMBEDDING") gradient_fg = gp.ArrayKey("GRADIENT_FG") emst = gp.ArrayKey("EMST") edges_u = gp.ArrayKey("EDGES_U") edges_v = gp.ArrayKey("EDGES_V") ratio_pos = gp.ArrayKey("RATIO_POS") ratio_neg = gp.ArrayKey("RATIO_NEG") dist = gp.ArrayKey("DIST") num_pos_pairs = gp.ArrayKey("NUM_POS") num_neg_pairs = gp.ArrayKey("NUM_NEG") # add request request = gp.BatchRequest() request.add(labels_fg, output_size) request.add(labels_fg_bin, output_size) request.add(loss_weights, output_size) request.add(raw, input_size) request.add(labels, input_size) request.add(matched, input_size) request.add(skeletonization, input_size) request.add(consensus, input_size) # add snapshot request snapshot_request = gp.BatchRequest() request.add(labels_fg, output_size) # tensorflow requests # snapshot_request.add(raw, input_size) # input_size request for positioning # snapshot_request.add(embedding, output_size, voxel_size=voxel_size) # snapshot_request.add(fg, output_size, voxel_size=voxel_size) # snapshot_request.add(gt_fg, output_size, voxel_size=voxel_size) # snapshot_request.add(fg_pred, output_size, voxel_size=voxel_size) # snapshot_request.add(maxima, output_size, voxel_size=voxel_size) # snapshot_request.add(gradient_embedding, output_size, voxel_size=voxel_size) # snapshot_request.add(gradient_fg, output_size, voxel_size=voxel_size) # snapshot_request[emst] = gp.ArraySpec() # snapshot_request[edges_u] = gp.ArraySpec() # snapshot_request[edges_v] = gp.ArraySpec() # snapshot_request[ratio_pos] = gp.ArraySpec() # snapshot_request[ratio_neg] = gp.ArraySpec() # snapshot_request[dist] = gp.ArraySpec() # snapshot_request[num_pos_pairs] = gp.ArraySpec() # snapshot_request[num_neg_pairs] = gp.ArraySpec() data_sources = tuple( ( gp.N5Source( filename=str((sample / "fluorescence-near-consensus.n5").absolute()), datasets={raw: "volume"}, array_specs={ raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size, dtype=np.uint16) }, ), gp.DaisyGraphProvider( f"mouselight-{sample.name}-consensus", mongo_url, points=[consensus], directed=True, node_attrs=[], edge_attrs=[], ), gp.DaisyGraphProvider( f"mouselight-{sample.name}-skeletonization", mongo_url, points=[skeletonization], directed=False, node_attrs=[], edge_attrs=[], ), ) + gp.MergeProvider() + gp.RandomLocation( ensure_nonempty=consensus, ensure_centered=True, point_balance_radius=point_balance_radius * micron_scale, ) + TopologicalMatcher( skeletonization, consensus, matched, failures=Path("matching_failures_slow"), match_distance_threshold=match_distance_threshold * micron_scale, max_gap_crossing=gap_crossing_dist * micron_scale, try_complete=False, use_gurobi=True, ) + RejectIfEmpty(matched) + RasterizeSkeleton( points=matched, array=labels, array_spec=gp.ArraySpec( interpolatable=False, voxel_size=voxel_size, dtype=np.uint32), ) + GrowLabels(labels, radii=[neuron_radius * micron_scale]) # TODO: Do these need to be scaled by world units? + gp.ElasticAugment( [40, 10, 10], [0.25, 1, 1], [0, math.pi / 2.0], subsample=4, use_fast_points_transform=True, recompute_missing_points=False, ) # + gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2]) + gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001) for sample in samples_path.iterdir() if sample.name in ("2018-07-02", "2018-08-01")) pipeline = ( data_sources + gp.RandomProvider() + Crop(labels, labels_fg) + BinarizeGt(labels_fg, labels_fg_bin) + gp.BalanceLabels(labels_fg_bin, loss_weights) + gp.PreCache(cache_size=cache_size, num_workers=num_workers) + gp.tensorflow.Train( "train_net", optimizer=create_custom_loss(mknet_tensor_names, setup_config), loss=None, inputs={ mknet_tensor_names["loss_weights"]: loss_weights, mknet_tensor_names["raw"]: raw, mknet_tensor_names["gt_labels"]: labels_fg, }, outputs={ mknet_tensor_names["embedding"]: embedding, mknet_tensor_names["fg"]: fg, loss_tensor_names["fg_pred"]: fg_pred, loss_tensor_names["maxima"]: maxima, loss_tensor_names["gt_fg"]: gt_fg, loss_tensor_names["emst"]: emst, loss_tensor_names["edges_u"]: edges_u, loss_tensor_names["edges_v"]: edges_v, loss_tensor_names["ratio_pos"]: ratio_pos, loss_tensor_names["ratio_neg"]: ratio_neg, loss_tensor_names["dist"]: dist, loss_tensor_names["num_pos_pairs"]: num_pos_pairs, loss_tensor_names["num_neg_pairs"]: num_neg_pairs, }, gradients={ mknet_tensor_names["embedding"]: gradient_embedding, mknet_tensor_names["fg"]: gradient_fg, }, save_every=checkpoint_every, summary="Merge/MergeSummary:0", log_dir="tensorflow_logs", ) + gp.PrintProfilingStats(every=profile_every) + gp.Snapshot( additional_request=snapshot_request, output_filename="snapshot_{}_{}.hdf".format( int(np.min(seperate_distance)), "{id}"), dataset_names={ # raw data raw: "volumes/raw", # labeled data labels: "volumes/labels", # trees skeletonization: "points/skeletonization", consensus: "points/consensus", matched: "points/matched", # output volumes embedding: "volumes/embedding", fg: "volumes/fg", maxima: "volumes/maxima", gt_fg: "volumes/gt_fg", fg_pred: "volumes/fg_pred", gradient_embedding: "volumes/gradient_embedding", gradient_fg: "volumes/gradient_fg", # output trees emst: "emst", edges_u: "edges_u", edges_v: "edges_v", # output debug data ratio_pos: "ratio_pos", ratio_neg: "ratio_neg", dist: "dist", num_pos_pairs: "num_pos_pairs", num_neg_pairs: "num_neg_pairs", loss_weights: "volumes/loss_weights", }, every=snapshot_every, )) with gp.build(pipeline): for _ in range(num_iterations): pipeline.request_batch(request)
def train_until(**kwargs): if tf.train.latest_checkpoint(kwargs['output_folder']): trained_until = int( tf.train.latest_checkpoint(kwargs['output_folder']).split('_')[-1]) else: trained_until = 0 if trained_until >= kwargs['max_iteration']: return anchor = gp.ArrayKey('ANCHOR') raw = gp.ArrayKey('RAW') raw_cropped = gp.ArrayKey('RAW_CROPPED') gt_threeclass = gp.ArrayKey('GT_THREECLASS') loss_weights_threeclass = gp.ArrayKey('LOSS_WEIGHTS_THREECLASS') pred_threeclass = gp.ArrayKey('PRED_THREECLASS') pred_threeclass_gradients = gp.ArrayKey('PRED_THREECLASS_GRADIENTS') with open( os.path.join(kwargs['output_folder'], kwargs['name'] + '_config.json'), 'r') as f: net_config = json.load(f) with open( os.path.join(kwargs['output_folder'], kwargs['name'] + '_names.json'), 'r') as f: net_names = json.load(f) voxel_size = gp.Coordinate(kwargs['voxel_size']) input_shape_world = gp.Coordinate(net_config['input_shape']) * voxel_size output_shape_world = gp.Coordinate(net_config['output_shape']) * voxel_size # formulate the request for what a batch should (at least) contain request = gp.BatchRequest() request.add(raw, input_shape_world) request.add(raw_cropped, output_shape_world) request.add(gt_threeclass, output_shape_world) request.add(anchor, output_shape_world) request.add(loss_weights_threeclass, output_shape_world) # when we make a snapshot for inspection (see below), we also want to # request the predicted affinities and gradients of the loss wrt the # affinities snapshot_request = gp.BatchRequest() snapshot_request.add(raw_cropped, output_shape_world) snapshot_request.add(gt_threeclass, output_shape_world) snapshot_request.add(pred_threeclass, output_shape_world) # snapshot_request.add(pred_threeclass_gradients, output_shape_world) if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr": raise NotImplementedError("train node for {} not implemented".format( kwargs['input_format'])) fls = [] shapes = [] for f in kwargs['data_files']: fls.append(os.path.splitext(f)[0]) if kwargs['input_format'] == "hdf": vol = h5py.File(f, 'r')['volumes/raw'] elif kwargs['input_format'] == "zarr": vol = zarr.open(f, 'r')['volumes/raw'] print(f, vol.shape, vol.dtype) shapes.append(vol.shape) if vol.dtype != np.float32: print("please convert to float32") ln = len(fls) print("first 5 files: ", fls[0:4]) # padR = 46 # padGT = 32 if kwargs['input_format'] == "hdf": sourceNode = gp.Hdf5Source elif kwargs['input_format'] == "zarr": sourceNode = gp.ZarrSource augmentation = kwargs['augmentation'] pipeline = ( tuple( # read batches from the HDF5 file sourceNode( fls[t] + "." + kwargs['input_format'], datasets={ raw: 'volumes/raw', gt_threeclass: 'volumes/gt_threeclass', anchor: 'volumes/gt_threeclass', }, array_specs={ raw: gp.ArraySpec(interpolatable=True), gt_threeclass: gp.ArraySpec(interpolatable=False), anchor: gp.ArraySpec(interpolatable=False) } ) + gp.MergeProvider() + gp.Pad(raw, None) + gp.Pad(gt_threeclass, None) + gp.Pad(anchor, gp.Coordinate((2,2,2))) # chose a random location for each requested batch + gp.RandomLocation() for t in range(ln) ) + # chose a random source (i.e., sample) from the above gp.RandomProvider() + # elastically deform the batch (gp.ElasticAugment( augmentation['elastic']['control_point_spacing'], augmentation['elastic']['jitter_sigma'], [augmentation['elastic']['rotation_min']*np.pi/180.0, augmentation['elastic']['rotation_max']*np.pi/180.0], subsample=augmentation['elastic'].get('subsample', 1)) \ if augmentation.get('elastic') is not None else NoOp()) + # apply transpose and mirror augmentations gp.SimpleAugment(mirror_only=augmentation['simple'].get("mirror"), transpose_only=augmentation['simple'].get("transpose")) + # # scale and shift the intensity of the raw array gp.IntensityAugment( raw, scale_min=augmentation['intensity']['scale'][0], scale_max=augmentation['intensity']['scale'][1], shift_min=augmentation['intensity']['shift'][0], shift_max=augmentation['intensity']['shift'][1], z_section_wise=False) + # grow a boundary between labels # TODO: check # gp.GrowBoundary( # gt_threeclass, # steps=1, # only_xy=False) + gp.BalanceLabels( gt_threeclass, loss_weights_threeclass, num_classes=3) + # pre-cache batches from the point upstream gp.PreCache( cache_size=kwargs['cache_size'], num_workers=kwargs['num_workers']) + # perform one training iteration for each passing batch (here we use # the tensor names earlier stored in train_net.config) gp.tensorflow.Train( os.path.join(kwargs['output_folder'], kwargs['name']), optimizer=net_names['optimizer'], summary=net_names['summaries'], log_dir=kwargs['output_folder'], loss=net_names['loss'], inputs={ net_names['raw']: raw, net_names['anchor']: anchor, net_names['gt_threeclass']: gt_threeclass, net_names['loss_weights_threeclass']: loss_weights_threeclass }, outputs={ net_names['pred_threeclass']: pred_threeclass, net_names['raw_cropped']: raw_cropped, }, gradients={ net_names['pred_threeclass']: pred_threeclass_gradients, }, save_every=kwargs['checkpoints']) + # save the passing batch as an HDF5 file for inspection gp.Snapshot( { raw: '/volumes/raw', raw_cropped: 'volumes/raw_cropped', gt_threeclass: '/volumes/gt_threeclass', pred_threeclass: '/volumes/pred_threeclass', }, output_dir=os.path.join(kwargs['output_folder'], 'snapshots'), output_filename='batch_{iteration}.hdf', every=kwargs['snapshots'], additional_request=snapshot_request, compression_type='gzip') + # show a summary of time spend in each node every 10 iterations gp.PrintProfilingStats(every=kwargs['profiling']) ) ######### # TRAIN # ######### print("Starting training...") with gp.build(pipeline): print(pipeline) for i in range(trained_until, kwargs['max_iteration']): # print("request", request) start = time.time() pipeline.request_batch(request) time_of_iteration = time.time() - start logger.info("Batch: iteration=%d, time=%f", i, time_of_iteration) # exit() print("Training finished")
def predict_3d(raw_data, gt_data, predictor): raw_channels = max(1, raw_data.num_channels) input_shape = predictor.input_shape output_shape = predictor.output_shape voxel_size = raw_data.voxel_size # switch to world units input_size = voxel_size * input_shape output_size = voxel_size * output_shape raw = gp.ArrayKey('RAW') gt = gp.ArrayKey('GT') target = gp.ArrayKey('TARGET') prediction = gp.ArrayKey('PREDICTION') channel_dims = 0 if raw_channels == 1 else 1 num_samples = raw_data.num_samples assert num_samples == 0, ( "Multiple samples for 3D validation not yet implemented") scan_request = gp.BatchRequest() scan_request.add(raw, input_size) scan_request.add(prediction, output_size) if gt_data: scan_request.add(gt, output_size) scan_request.add(target, output_size) if gt_data: sources = (raw_data.get_source(raw), gt_data.get_source(gt)) pipeline = sources + gp.MergeProvider() else: pipeline = raw_data.get_source(raw) pipeline += gp.Pad(raw, None) if gt_data: pipeline += gp.Pad(gt, None) # raw: ([c,] d, h, w) # gt: ([c,] d, h, w) pipeline += gp.Normalize(raw) # raw: ([c,] d, h, w) # gt: ([c,] d, h, w) if gt_data: pipeline += predictor.add_target(gt, target) # raw: ([c,] d, h, w) # gt: ([c,] d, h, w) # target: ([c,] d, h, w) if channel_dims == 0: pipeline += AddChannelDim(raw) # raw: (c, d, h, w) # gt: ([c,] d, h, w) # target: ([c,] d, h, w) # add a "batch" dimension pipeline += AddChannelDim(raw) # raw: (1, c, d, h, w) # gt: ([c,] d, h, w) # target: ([c,] d, h, w) pipeline += gp_torch.Predict(model=predictor, inputs={'x': raw}, outputs={0: prediction}) # remove "batch" dimension pipeline += RemoveChannelDim(raw) pipeline += RemoveChannelDim(prediction) # raw: (c, d, h, w) # gt: ([c,] d, h, w) # target: ([c,] d, h, w) # prediction: ([c,] d, h, w) if channel_dims == 0: pipeline += RemoveChannelDim(raw) # raw: ([c,] d, h, w) # gt: ([c,] d, h, w) # target: ([c,] d, h, w) # prediction: ([c,] d, h, w) pipeline += gp.Scan(scan_request) # ensure validation ROI is at least the size of the network input roi = raw_data.roi.grow(input_size / 2, input_size / 2) total_request = gp.BatchRequest() total_request[raw] = gp.ArraySpec(roi=roi) total_request[prediction] = gp.ArraySpec(roi=roi) if gt_data: total_request[gt] = gp.ArraySpec(roi=roi) total_request[target] = gp.ArraySpec(roi=roi) with gp.build(pipeline): batch = pipeline.request_batch(total_request) ret = {'raw': batch[raw], 'prediction': batch[prediction]} if gt_data: ret.update({'gt': batch[gt], 'target': batch[target]}) return ret
from __future__ import print_function import gunpowder as gp from gunpowder import * from gunpowder.ext import malis import tensorflow as tf import mala ############ # Training # ############ # Declare Arrays raw = gp.ArrayKey('RAW') gt = gp.ArrayKey('GT') mask = gp.ArrayKey('mask') prediction = gp.ArrayKey('prediction') grad = gp.ArrayKey('gradient') # define training values input_shape = (40, 300, 300) voxel_size = (40, 4, 4) #define network parameters these will be used to define the feed dict for the network raw_tf = tf.placeholder(tf.float32, shape=input_shape) raw_batched = tf.reshape(raw_tf, (1, 1) + input_shape) unet = mala.networks.unet(raw_batched, 3, 3, [[1, 1, 1], [1, 1, 1], [3, 3, 3]]) #since we want binary predictions we will be using 1 feature map per output = mala.networks.conv_pass(unet, kernel_size=1,
def train_until(max_iteration, name='train_net', output_folder='.', clip_max=2000): # get the latest checkpoint if tf.train.latest_checkpoint(output_folder): trained_until = int(tf.train.latest_checkpoint(output_folder).split('_')[-1]) else: trained_until = 0 if trained_until >= max_iteration: return with open(os.path.join(output_folder, name + '_config.json'), 'r') as f: net_config = json.load(f) with open(os.path.join(output_folder, name + '_names.json'), 'r') as f: net_names = json.load(f) # array keys raw = gp.ArrayKey('RAW') gt_mask = gp.ArrayKey('GT_MASK') gt_dt = gp.ArrayKey('GT_DT') pred_dt = gp.ArrayKey('PRED_DT') loss_gradient = gp.ArrayKey('LOSS_GRADIENT') voxel_size = gp.Coordinate((1, 1, 1)) input_shape = gp.Coordinate(net_config['input_shape']) output_shape = gp.Coordinate(net_config['output_shape']) context = gp.Coordinate(input_shape - output_shape) / 2 request = gp.BatchRequest() request.add(raw, input_shape) request.add(gt_mask, output_shape) request.add(gt_dt, output_shape) snapshot_request = gp.BatchRequest() snapshot_request.add(raw, input_shape) snapshot_request.add(gt_mask, output_shape) snapshot_request.add(gt_dt, output_shape) snapshot_request.add(pred_dt, output_shape) snapshot_request.add(loss_gradient, output_shape) # specify data source data_sources = tuple() for data_file in data_files: current_path = os.path.join(data_dir, data_file) with h5py.File(current_path, 'r') as f: data_sources += tuple( gp.Hdf5Source( current_path, datasets={ raw: sample + '/raw', gt_mask: sample + '/fg' }, array_specs={ raw: gp.ArraySpec(interpolatable=True, dtype=np.uint16, voxel_size=voxel_size), gt_mask: gp.ArraySpec(interpolatable=False, dtype=np.bool, voxel_size=voxel_size), } ) + Convert(gt_mask, np.uint8) + gp.Pad(raw, context) + gp.Pad(gt_mask, context) + gp.RandomLocation() for sample in f) pipeline = ( data_sources + gp.RandomProvider() + gp.Reject(gt_mask, min_masked=0.005, reject_probability=1.) + DistanceTransform(gt_mask, gt_dt, 3) + nl.Clip(raw, 0, clip_max) + gp.Normalize(raw, factor=1.0/clip_max) + gp.ElasticAugment( control_point_spacing=[20, 20, 20], jitter_sigma=[1, 1, 1], rotation_interval=[0, math.pi/2.0], subsample=4) + gp.SimpleAugment(mirror_only=[1,2], transpose_only=[1,2]) + gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1) + gp.IntensityScaleShift(raw, 2,-1) + # train gp.PreCache( cache_size=40, num_workers=5) + gp.tensorflow.Train( os.path.join(output_folder, name), optimizer=net_names['optimizer'], loss=net_names['loss'], inputs={ net_names['raw']: raw, net_names['gt_dt']: gt_dt, }, outputs={ net_names['pred_dt']: pred_dt, }, gradients={ net_names['pred_dt']: loss_gradient, }, save_every=5000) + # visualize gp.Snapshot({ raw: 'volumes/raw', gt_mask: 'volumes/gt_mask', gt_dt: 'volumes/gt_dt', pred_dt: 'volumes/pred_dt', loss_gradient: 'volumes/gradient', }, output_filename=os.path.join(output_folder, 'snapshots', 'batch_{iteration}.hdf'), additional_request=snapshot_request, every=2000) + gp.PrintProfilingStats(every=500) ) with gp.build(pipeline): print("Starting training...") for i in range(max_iteration - trained_until): pipeline.request_batch(request)
def batch__aug_data_generator(input_path, batch_size=12, voxel_shape=[1, 1, 1], input_shape=[240, 240, 4], output_shape=[240, 240, 4], without_background=False, mix_output=False, validate=False, seq=None): raw = gp.ArrayKey('raw') gt = gp.ArrayKey('ground_truth') files = os.listdir(input_path) files = [os.path.join(input_path, f) for f in files] pipeline = ( tuple( gp.ZarrSource( files[t], # the zarr container { raw: 'raw', gt: 'ground_truth' }, # which dataset to associate to the array key { raw: gp.ArraySpec(interpolatable=True, dtype=np.dtype('float32'), voxel_size=voxel_shape), gt: gp.ArraySpec(interpolatable=True, dtype=np.dtype('float32'), voxel_size=voxel_shape) } # meta-information ) + gp.RandomLocation() for t in range(len(files))) + gp.RandomProvider() # +gp.Stack(batch_size) ) input_size = gp.Coordinate(input_shape) output_size = gp.Coordinate(output_shape) request = gp.BatchRequest() request.add(raw, input_size) request.add(gt, input_size) diff = input_shape[1] - output_shape[1] diff = int(diff / 2) max_p = input_shape[1] - diff different_shape = diff > 0 if different_shape: print('Difference padding: {}'.format(diff)) with gp.build(pipeline): while 1: b = 0 imgs = [] masks = [] while b < batch_size: valid = False batch = pipeline.request_batch(request) if validate: valid = validate_mask(batch[gt].data) else: valid = True while (valid == False): batch = pipeline.request_batch(request) valid = validate_mask(batch[gt].data) im = batch[raw].data out = batch[gt].data # if different_shape: # out = out[diff:max_p,diff:max_p,:] if without_background: out = out[:, :, 1:4] if mix_output: out = out.argmax(axis=3).astype(float) imgs.append(im) masks.append(out) b = b + 1 imgs = np.asarray(imgs) masks = np.asarray(masks) if seq is not None: imgs, masks = augmentation(imgs, masks, seq) if different_shape: out = [] for m in masks: out.append(m[diff:max_p, diff:max_p, :]) masks = np.asarray(out) yield imgs, masks