def predict(iteration,path_to_dataGP): input_size = (8, 96, 96) output_size = (4, 64, 64) amount_size = gp.Coordinate((2, 16, 16)) model = SpineUNet(crop_output='output_size') raw = gp.ArrayKey('RAW') affs_predicted = gp.ArrayKey('AFFS_PREDICTED') reference_request = gp.BatchRequest() reference_request.add(raw, input_size) reference_request.add(affs_predicted, output_size) source = gp.ZarrSource( path_to_dataGP, { raw: 'validate/sample1/raw' } ) with gp.build(source): source_roi = source.spec[raw].roi request = gp.BatchRequest() request[raw] = gp.ArraySpec(roi=source_roi) request[affs_predicted] = gp.ArraySpec(roi=source_roi) pipeline = ( source + gp.Pad(raw,amount_size) + gp.Normalize(raw) + # raw: (d, h, w) gp.Stack(1) + # raw: (1, d, h, w) AddChannelDim(raw) + # raw: (1, 1, d, h, w) gp_torch.Predict( model, inputs={'x': raw}, outputs={0: affs_predicted}, checkpoint=f'C:/Users/filip/spine_yodl/model_checkpoint_{iteration}') + RemoveChannelDim(raw) + RemoveChannelDim(raw) + RemoveChannelDim(affs_predicted) + # raw: (d, h, w) # affs_predicted: (3, d, h, w) gp.Scan(reference_request) ) with gp.build(pipeline): prediction = pipeline.request_batch(request) return prediction[raw].data, prediction[affs_predicted].data
def test_delete_points_in_context(self): points = gp.PointsKey("POINTS") pv_array = gp.ArrayKey("PARENT_VECTORS") mask = gp.ArrayKey("MASK") radius = [0.1, 0.1, 0.1, 0.1] ts = TracksSource(TEST_FILE, points) apv = AddParentVectors(points, pv_array, mask, radius) request = gp.BatchRequest() request.add(points, gp.Coordinate((1, 4, 4, 4))) request.add(pv_array, gp.Coordinate((1, 4, 4, 4))) request.add(mask, gp.Coordinate((1, 4, 4, 4))) pipeline = (ts + gp.Pad(points, None) + apv) with gp.build(pipeline): pipeline.request_batch(request)
def test_add_parent_vectors(self): points = gp.PointsKey("POINTS") pv_array = gp.ArrayKey("PARENT_VECTORS") mask = gp.ArrayKey("MASK") radius = [0.1, 0.1, 0.1, 0.1] ts = TracksSource(TEST_FILE, points) apv = AddParentVectors(points, pv_array, mask, radius) request = gp.BatchRequest() request.add(points, gp.Coordinate((3, 4, 4, 4))) request.add(pv_array, gp.Coordinate((1, 4, 4, 4))) request.add(mask, gp.Coordinate((1, 4, 4, 4))) pipeline = (ts + gp.Pad(points, None) + apv) with gp.build(pipeline): batch = pipeline.request_batch(request) points = batch[points].data expected_mask = np.zeros(shape=(1, 4, 4, 4)) expected_mask[0, 0, 0, 0] = 1 expected_mask[0, 1, 2, 3] = 1 expected_parent_vectors_z = np.zeros(shape=(1, 4, 4, 4)) expected_parent_vectors_z[0, 1, 2, 3] = -1.0 expected_parent_vectors_y = np.zeros(shape=(1, 4, 4, 4)) expected_parent_vectors_y[0, 1, 2, 3] = -2.0 expected_parent_vectors_x = np.zeros(shape=(1, 4, 4, 4)) expected_parent_vectors_x[0, 1, 2, 3] = -3.0 # print("MASK") # print(batch[mask].data) self.assertListEqual(expected_mask.tolist(), batch[mask].data.tolist()) parent_vectors = batch[pv_array].data self.assertListEqual(expected_parent_vectors_z.tolist(), parent_vectors[0].tolist()) self.assertListEqual(expected_parent_vectors_y.tolist(), parent_vectors[1].tolist()) self.assertListEqual(expected_parent_vectors_x.tolist(), parent_vectors[2].tolist())
def build_pipeline(data_dir, model, checkpoint_file, input_size, output_size, raw, labels, affs_predicted, dataset_shape, num_samples, sample_size): checkpoint = torch.load(checkpoint_file) model.load_state_dict(checkpoint['model_state_dict']) scan_request = gp.BatchRequest() scan_request.add(raw, input_size) scan_request.add(affs_predicted, output_size) scan_request.add(labels, output_size) pipeline = ( gp.ZarrSource(str(data_dir), { raw: 'validate/raw', labels: 'validate/gt' }) + gp.Pad(raw, size=None) + gp.Normalize(raw) + # raw: (s, h, w) # labels: (s, h, w) train.AddChannelDim(raw) + # raw: (c=1, s, h, w) # labels: (s, h, w) train.TransposeDims(raw, (1, 0, 2, 3)) + # raw: (s, c=1, h, w) # labels: (s, h, w) Predict(model=model, inputs={'x': raw}, outputs={0: affs_predicted}) + # raw: (s, c=1, h, w) # affs_predicted: (s, c=2, h, w) # labels: (s, h, w) train.TransposeDims(raw, (1, 0, 2, 3)) + train.RemoveChannelDim(raw) + # raw: (s, h, w) # affs_predicted: (s, c=2, h, w) # labels: (s, h, w) gp.PrintProfilingStats(every=100) + gp.Scan(scan_request)) return pipeline
def predict( model: Model, raw_array: Array, prediction_array_identifier: LocalArrayIdentifier, num_cpu_workers: int = 4, compute_context: ComputeContext = LocalTorch(), output_roi: Optional[Roi] = None, ): # get the model's input and output size input_voxel_size = Coordinate(raw_array.voxel_size) output_voxel_size = model.scale(input_voxel_size) input_shape = Coordinate(model.eval_input_shape) input_size = input_voxel_size * input_shape output_size = output_voxel_size * model.compute_output_shape(input_shape)[1] logger.info( "Predicting with input size %s, output size %s", input_size, output_size ) # calculate input and output rois context = (input_size - output_size) / 2 if output_roi is None: input_roi = raw_array.roi output_roi = input_roi.grow(-context, -context) else: input_roi = output_roi.grow(context, context) logger.info("Total input ROI: %s, output ROI: %s", input_roi, output_roi) # prepare prediction dataset axes = ["c"] + [axis for axis in raw_array.axes if axis != "c"] ZarrArray.create_from_array_identifier( prediction_array_identifier, axes, output_roi, model.num_out_channels, output_voxel_size, np.float32, ) # create gunpowder keys raw = gp.ArrayKey("RAW") prediction = gp.ArrayKey("PREDICTION") # assemble prediction pipeline # prepare data source pipeline = DaCapoArraySource(raw_array, raw) # raw: (c, d, h, w) pipeline += gp.Pad(raw, Coordinate((None,) * input_voxel_size.dims)) # raw: (c, d, h, w) pipeline += gp.Unsqueeze([raw]) # raw: (1, c, d, h, w) gt_padding = (output_size - output_roi.shape) % output_size prediction_roi = output_roi.grow(gt_padding) # predict pipeline += gp_torch.Predict( model=model, inputs={"x": raw}, outputs={0: prediction}, array_specs={ prediction: gp.ArraySpec( roi=prediction_roi, voxel_size=output_voxel_size, dtype=np.float32 ) }, spawn_subprocess=False, device=str(compute_context.device), ) # raw: (1, c, d, h, w) # prediction: (1, [c,] d, h, w) # prepare writing pipeline += gp.Squeeze([raw, prediction]) # raw: (c, d, h, w) # prediction: (c, d, h, w) # raw: (c, d, h, w) # prediction: (c, d, h, w) # write to zarr pipeline += gp.ZarrWrite( {prediction: prediction_array_identifier.dataset}, prediction_array_identifier.container.parent, prediction_array_identifier.container.name, ) # create reference batch request ref_request = gp.BatchRequest() ref_request.add(raw, input_size) ref_request.add(prediction, output_size) pipeline += gp.Scan(ref_request) # build pipeline and predict in complete output ROI with gp.build(pipeline): pipeline.request_batch(gp.BatchRequest()) container = zarr.open(prediction_array_identifier.container) dataset = container[prediction_array_identifier.dataset] dataset.attrs["axes"] = ( raw_array.axes if "c" in raw_array.axes else ["c"] + raw_array.axes )
def create_pipeline_3d(task, predictor, optimizer, batch_size, outdir, snapshot_every): raw_channels = max(1, task.data.raw.num_channels) input_shape = predictor.input_shape output_shape = predictor.output_shape voxel_size = task.data.raw.train.voxel_size # switch to world units input_size = voxel_size * input_shape output_size = voxel_size * output_shape raw = gp.ArrayKey('RAW') gt = gp.ArrayKey('GT') target = gp.ArrayKey('TARGET') weights = gp.ArrayKey('WEIGHTS') prediction = gp.ArrayKey('PREDICTION') channel_dims = 0 if raw_channels == 1 else 1 num_samples = task.data.raw.train.num_samples assert num_samples == 0, ( "Multiple samples for 3D training not yet implemented") sources = (task.data.raw.train.get_source(raw), task.data.gt.train.get_source(gt)) pipeline = sources + gp.MergeProvider() pipeline += gp.Pad(raw, None) # raw: ([c,] d, h, w) # gt: ([c,] d, h, w) pipeline += gp.Normalize(raw) # raw: ([c,] d, h, w) # gt: ([c,] d, h, w) pipeline += gp.RandomLocation() # raw: ([c,] d, h, w) # gt: ([c,] d, h, w) for augmentation in eval(task.augmentations): pipeline += augmentation pipeline += predictor.add_target(gt, target) # (don't care about gt anymore) # raw: ([c,] d, h, w) # target: ([c,] d, h, w) weights_node = task.loss.add_weights(target, weights) if weights_node: pipeline += weights_node loss_inputs = {0: prediction, 1: target, 2: weights} else: loss_inputs = {0: prediction, 1: target} # raw: ([c,] d, h, w) # target: ([c,] d, h, w) # [weights: ([c,] d, h, w)] if channel_dims == 0: pipeline += AddChannelDim(raw) # raw: (c, d, h, w) # target: ([c,] d, h, w) # [weights: ([c,] d, h, w)] pipeline += gp.PreCache() pipeline += gp.Stack(batch_size) # raw: (b, c, d, h, w) # target: (b, [c,] d, h, w) # [weights: (b, [c,] d, h, w)] pipeline += gp_torch.Train(model=predictor, loss=task.loss, optimizer=optimizer, inputs={'x': raw}, loss_inputs=loss_inputs, outputs={0: prediction}, save_every=1e6) # raw: (b, c, d, h, w) # target: (b, [c,] d, h, w) # [weights: (b, [c,] d, h, w)] # prediction: (b, [c,] d, h, w) if snapshot_every > 0: # get channels first pipeline += TransposeDims(raw, (1, 0, 2, 3, 4)) if predictor.target_channels > 0: pipeline += TransposeDims(target, (1, 0, 2, 3, 4)) if weights_node: pipeline += TransposeDims(weights, (1, 0, 2, 3, 4)) if predictor.prediction_channels > 0: pipeline += TransposeDims(prediction, (1, 0, 2, 3, 4)) # raw: (c, b, d, h, w) # target: ([c,] b, d, h, w) # [weights: ([c,] b, d, h, w)] # prediction: ([c,] b, d, h, w) if channel_dims == 0: pipeline += RemoveChannelDim(raw) # raw: ([c,] b, d, h, w) # target: (c, b, d, h, w) # [weights: ([c,] b, d, h, w)] # prediction: (c, b, d, h, w) pipeline += gp.Snapshot(dataset_names={ raw: 'raw', target: 'target', prediction: 'prediction', weights: 'weights' }, every=snapshot_every, output_dir=os.path.join(outdir, 'snapshots'), output_filename="{iteration}.hdf") pipeline += gp.PrintProfilingStats(every=100) request = gp.BatchRequest() request.add(raw, input_size) request.add(gt, output_size) request.add(target, output_size) if weights_node: request.add(weights, output_size) request.add(prediction, output_size) return pipeline, request
def create_pipeline_2d(task, predictor, optimizer, batch_size, outdir, snapshot_every): raw_channels = task.data.raw.num_channels filename = task.data.raw.train.filename input_shape = predictor.input_shape output_shape = predictor.output_shape dataset_shape = task.data.raw.train.shape dataset_roi = task.data.raw.train.roi voxel_size = task.data.raw.train.voxel_size # switch to world units input_size = voxel_size * input_shape output_size = voxel_size * output_shape raw = gp.ArrayKey('RAW') gt = gp.ArrayKey('GT') target = gp.ArrayKey('TARGET') weights = gp.ArrayKey('WEIGHTS') prediction = gp.ArrayKey('PREDICTION') channel_dims = 0 if raw_channels == 1 else 1 data_dims = len(dataset_shape) - channel_dims if data_dims == 3: num_samples = dataset_shape[0] sample_shape = dataset_shape[channel_dims + 1:] else: raise RuntimeError("For 2D training, please provide a 3D array where " "the first dimension indexes the samples.") sample_shape = gp.Coordinate(sample_shape) sample_size = sample_shape * voxel_size # overwrite source ROI to treat samples as z dimension spec = gp.ArraySpec(roi=gp.Roi((0, ) + dataset_roi.get_begin(), (num_samples, ) + sample_size), voxel_size=(1, ) + voxel_size) sources = (task.data.raw.train.get_source(raw, overwrite_spec=spec), task.data.gt.train.get_source(gt, overwrite_spec=spec)) pipeline = sources + gp.MergeProvider() pipeline += gp.Pad(raw, None) pipeline += gp.Normalize(raw) # raw: ([c,] d=1, h, w) # gt: ([c,] d=1, h, w) pipeline += gp.RandomLocation() # raw: ([c,] d=1, h, w) # gt: ([c,] d=1, h, w) for augmentation in eval(task.augmentations): pipeline += augmentation pipeline += predictor.add_target(gt, target) # (don't care about gt anymore) # raw: ([c,] d=1, h, w) # target: ([c,] d=1, h, w) weights_node = task.loss.add_weights(target, weights) if weights_node: pipeline += weights_node loss_inputs = {0: prediction, 1: target, 2: weights} else: loss_inputs = {0: prediction, 1: target} # raw: ([c,] d=1, h, w) # target: ([c,] d=1, h, w) # [weights: ([c,] d=1, h, w)] # get rid of z dim: pipeline += Squash(dim=-3) # raw: ([c,] h, w) # target: ([c,] h, w) # [weights: ([c,] h, w)] if channel_dims == 0: pipeline += AddChannelDim(raw) # raw: (c, h, w) # target: ([c,] h, w) # [weights: ([c,] h, w)] pipeline += gp.PreCache() pipeline += gp.Stack(batch_size) # raw: (b, c, h, w) # target: (b, [c,] h, w) # [weights: (b, [c,] h, w)] pipeline += gp_torch.Train(model=predictor, loss=task.loss, optimizer=optimizer, inputs={'x': raw}, loss_inputs=loss_inputs, outputs={0: prediction}, save_every=1e6) # raw: (b, c, h, w) # target: (b, [c,] h, w) # [weights: (b, [c,] h, w)] # prediction: (b, [c,] h, w) if snapshot_every > 0: # get channels first pipeline += TransposeDims(raw, (1, 0, 2, 3)) if predictor.target_channels > 0: pipeline += TransposeDims(target, (1, 0, 2, 3)) if weights_node: pipeline += TransposeDims(weights, (1, 0, 2, 3)) if predictor.prediction_channels > 0: pipeline += TransposeDims(prediction, (1, 0, 2, 3)) # raw: (c, b, h, w) # target: ([c,] b, h, w) # [weights: ([c,] b, h, w)] # prediction: ([c,] b, h, w) if channel_dims == 0: pipeline += RemoveChannelDim(raw) # raw: ([c,] b, h, w) # target: ([c,] b, h, w) # [weights: ([c,] b, h, w)] # prediction: ([c,] b, h, w) pipeline += gp.Snapshot(dataset_names={ raw: 'raw', target: 'target', prediction: 'prediction', weights: 'weights' }, every=snapshot_every, output_dir=os.path.join(outdir, 'snapshots'), output_filename="{iteration}.hdf") pipeline += gp.PrintProfilingStats(every=100) request = gp.BatchRequest() request.add(raw, input_size) request.add(gt, output_size) request.add(target, output_size) if weights_node: request.add(weights, output_size) request.add(prediction, output_size) return pipeline, request
def train_until(max_iteration, name='train_net', output_folder='.', clip_max=2000): # get the latest checkpoint if tf.train.latest_checkpoint(output_folder): trained_until = int(tf.train.latest_checkpoint(output_folder).split('_')[-1]) else: trained_until = 0 if trained_until >= max_iteration: return with open(os.path.join(output_folder, name + '_config.json'), 'r') as f: net_config = json.load(f) with open(os.path.join(output_folder, name + '_names.json'), 'r') as f: net_names = json.load(f) # array keys raw = gp.ArrayKey('RAW') gt_instances = gp.ArrayKey('GT_INSTANCES') gt_mask = gp.ArrayKey('GT_MASK') pred_mask = gp.ArrayKey('PRED_MASK') #loss_weights = gp.ArrayKey('LOSS_WEIGHTS') loss_gradients = gp.ArrayKey('LOSS_GRADIENTS') # array keys for base and add volume raw_base = gp.ArrayKey('RAW_BASE') gt_instances_base = gp.ArrayKey('GT_INSTANCES_BASE') gt_mask_base = gp.ArrayKey('GT_MASK_BASE') raw_add = gp.ArrayKey('RAW_ADD') gt_instances_add = gp.ArrayKey('GT_INSTANCES_ADD') gt_mask_add = gp.ArrayKey('GT_MASK_ADD') voxel_size = gp.Coordinate((1, 1, 1)) input_shape = gp.Coordinate(net_config['input_shape']) output_shape = gp.Coordinate(net_config['output_shape']) context = gp.Coordinate(input_shape - output_shape) / 2 request = gp.BatchRequest() request.add(raw, input_shape) request.add(gt_instances, output_shape) request.add(gt_mask, output_shape) #request.add(loss_weights, output_shape) request.add(raw_base, input_shape) request.add(raw_add, input_shape) request.add(gt_mask_base, output_shape) request.add(gt_mask_add, output_shape) snapshot_request = gp.BatchRequest() snapshot_request.add(raw, input_shape) #snapshot_request.add(raw_base, input_shape) #snapshot_request.add(raw_add, input_shape) snapshot_request.add(gt_mask, output_shape) #snapshot_request.add(gt_mask_base, output_shape) #snapshot_request.add(gt_mask_add, output_shape) snapshot_request.add(pred_mask, output_shape) snapshot_request.add(loss_gradients, output_shape) # specify data source # data source for base volume data_sources_base = tuple() for data_file in data_files: current_path = os.path.join(data_dir, data_file) with h5py.File(current_path, 'r') as f: data_sources_base += tuple( gp.Hdf5Source( current_path, datasets={ raw_base: sample + '/raw', gt_instances_base: sample + '/gt', gt_mask_base: sample + '/fg', }, array_specs={ raw_base: gp.ArraySpec(interpolatable=True, dtype=np.uint16, voxel_size=voxel_size), gt_instances_base: gp.ArraySpec(interpolatable=False, dtype=np.uint16, voxel_size=voxel_size), gt_mask_base: gp.ArraySpec(interpolatable=False, dtype=np.bool, voxel_size=voxel_size), } ) + Convert(gt_mask_base, np.uint8) + gp.Pad(raw_base, context) + gp.Pad(gt_instances_base, context) + gp.Pad(gt_mask_base, context) + gp.RandomLocation(min_masked=0.005, mask=gt_mask_base) #gp.Reject(gt_mask_base, min_masked=0.005, reject_probability=1.) for sample in f) data_sources_base += gp.RandomProvider() # data source for add volume data_sources_add = tuple() for data_file in data_files: current_path = os.path.join(data_dir, data_file) with h5py.File(current_path, 'r') as f: data_sources_add += tuple( gp.Hdf5Source( current_path, datasets={ raw_add: sample + '/raw', gt_instances_add: sample + '/gt', gt_mask_add: sample + '/fg', }, array_specs={ raw_add: gp.ArraySpec(interpolatable=True, dtype=np.uint16, voxel_size=voxel_size), gt_instances_add: gp.ArraySpec(interpolatable=False, dtype=np.uint16, voxel_size=voxel_size), gt_mask_add: gp.ArraySpec(interpolatable=False, dtype=np.bool, voxel_size=voxel_size), } ) + Convert(gt_mask_add, np.uint8) + gp.Pad(raw_add, context) + gp.Pad(gt_instances_add, context) + gp.Pad(gt_mask_add, context) + gp.RandomLocation() + gp.Reject(gt_mask_add, min_masked=0.005, reject_probability=0.95) for sample in f) data_sources_add += gp.RandomProvider() data_sources = tuple([data_sources_base, data_sources_add]) + gp.MergeProvider() pipeline = ( data_sources + nl.FusionAugment( raw_base, raw_add, gt_instances_base, gt_instances_add, raw, gt_instances, blend_mode='labels_mask', blend_smoothness=5, num_blended_objects=0 ) + BinarizeLabels(gt_instances, gt_mask) + nl.Clip(raw, 0, clip_max) + gp.Normalize(raw, factor=1.0/clip_max) + gp.ElasticAugment( control_point_spacing=[20, 20, 20], jitter_sigma=[1, 1, 1], rotation_interval=[0, math.pi/2.0], subsample=4) + gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2]) + gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1) + gp.IntensityScaleShift(raw, 2, -1) + #gp.BalanceLabels(gt_mask, loss_weights) + # train gp.PreCache( cache_size=40, num_workers=10) + gp.tensorflow.Train( os.path.join(output_folder, name), optimizer=net_names['optimizer'], loss=net_names['loss'], inputs={ net_names['raw']: raw, net_names['gt']: gt_mask, #net_names['loss_weights']: loss_weights, }, outputs={ net_names['pred']: pred_mask, }, gradients={ net_names['output']: loss_gradients, }, save_every=5000) + # visualize gp.Snapshot({ raw: 'volumes/raw', pred_mask: 'volumes/pred_mask', gt_mask: 'volumes/gt_mask', #loss_weights: 'volumes/loss_weights', loss_gradients: 'volumes/loss_gradients', }, output_filename=os.path.join(output_folder, 'snapshots', 'batch_{iteration}.hdf'), additional_request=snapshot_request, every=2500) + gp.PrintProfilingStats(every=1000) ) with gp.build(pipeline): print("Starting training...") for i in range(max_iteration - trained_until): pipeline.request_batch(request)
def predict(**kwargs): name = kwargs['name'] raw = gp.ArrayKey('RAW') pred_affs = gp.ArrayKey('PRED_AFFS') with open(os.path.join(kwargs['input_folder'], name + '_config.json'), 'r') as f: net_config = json.load(f) with open(os.path.join(kwargs['input_folder'], name + '_names.json'), 'r') as f: net_names = json.load(f) voxel_size = gp.Coordinate(kwargs['voxel_size']) input_shape_world = gp.Coordinate(net_config['input_shape'])*voxel_size output_shape_world = gp.Coordinate(net_config['output_shape'])*voxel_size context = (input_shape_world - output_shape_world)//2 # add ArrayKeys to batch request request = gp.BatchRequest() request.add(raw, input_shape_world, voxel_size=voxel_size) request.add(pred_affs, output_shape_world, voxel_size=voxel_size) if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr": raise NotImplementedError("predict node for %s not implemented yet", kwargs['input_format']) if kwargs['input_format'] == "hdf": sourceNode = gp.Hdf5Source with h5py.File(os.path.join(kwargs['data_folder'], kwargs['sample'] + ".hdf"), 'r') as f: shape = f['volumes/raw'].shape elif kwargs['input_format'] == "zarr": sourceNode = gp.ZarrSource f = zarr.open(os.path.join(kwargs['data_folder'], kwargs['sample'] + ".zarr"), 'r') shape = f['volumes/raw'].shape # shape = source = sourceNode( os.path.join(kwargs['data_folder'], kwargs['sample'] + "." + kwargs['input_format']), datasets = { raw: 'volumes/raw' }, # array_specs = { # raw: gp.ArraySpec(roi=gp.Roi(gp.Coordinate((0, 0, 400)), # gp.Coordinate(input_shape_world))) # } ) crop = [] for d in range(-3, 0): if shape[d] < net_config['output_shape'][d]: crop.append((net_config['output_shape'][d]-shape[d])//2) else: crop.append(0) print("cropping", crop) context += gp.Coordinate(crop) if kwargs['output_format'] != "zarr": raise NotImplementedError("Please use zarr as prediction output") # open zarr file zf = zarr.open(os.path.join(kwargs['output_folder'], kwargs['sample'] + '.zarr'), mode='w') zf.create('volumes/pred_affs', shape=[int(np.prod(kwargs['patchshape']))] + list(shape), chunks=[int(np.prod(kwargs['patchshape']))] + list(shape)[:-1] + [20], dtype=np.float32) zf['volumes/pred_affs'].attrs['offset'] = [0, 0, 0] zf['volumes/pred_affs'].attrs['resolution'] = kwargs['voxel_size'] zf.create('volumes/raw', shape=list(shape), chunks=list(shape)[:-1] + [20], dtype=np.float32) zf['volumes/raw'].attrs['offset'] = [0, 0, 0] zf['volumes/raw'].attrs['resolution'] = kwargs['voxel_size'] outputs = { net_names['pred_affs']: pred_affs, } outVolumes = { # raw: '/volumes/raw', pred_affs: '/volumes/pred_affs', } pipeline = ( source + gp.Pad(raw, context) + # perform one training iteration for each passing batch (here we use # the tensor names earlier stored in train_net.config) gp.tensorflow.Predict( graph=os.path.join(kwargs['input_folder'], name + '.meta'), checkpoint=kwargs['checkpoint'], inputs={ net_names['raw']: raw }, # array_specs={ # pred_affs: gp.ArraySpec(roi=gp.Roi(gp.Coordinate((46, 46, 46)), # output_shape_world), # voxel_size=voxel_size) # }, outputs=outputs) + # if max(crop) > 0: # print("cropping", crop) # pipeline += gp.Crop(pred_affs, absolute_negative=crop, absolute_positive=crop) # pipeline += ( # store all passing batches in the same HDF5 file gp.ZarrWrite( outVolumes, output_dir=kwargs['output_folder'], output_filename=kwargs['sample'] + ".zarr", compression_type='gzip' ) + # show a summary of time spend in each node every 10 iterations gp.PrintProfilingStats(every=10) + # iterate over the whole dataset in a scanning fashion, emitting # requests that match the size of the network gp.Scan(reference=request) ) with gp.build(pipeline): # request an empty batch from Scan to trigger scanning of the dataset # without keeping the complete dataset in memory pipeline.request_batch(gp.BatchRequest())
def predict_volume(model, dataset, out_dir, out_filename, out_ds_names, input_key='0/raw', normalize_factor=None, model_output=0, in_shape=None, out_shape=None, spawn_subprocess=True, num_workers=0): raw = gp.ArrayKey('RAW') prediction = gp.ArrayKey('PREDICTION') data = daisy.open_ds(dataset.filename, dataset.ds_names[0]) source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape()) voxel_size = gp.Coordinate(data.voxel_size) data_dims = len(data.shape) # Get in and out shape if in_shape is None: in_shape = model.in_shape if out_shape is None: out_shape = model.out_shape in_shape = gp.Coordinate(in_shape) out_shape = gp.Coordinate(out_shape) spatial_dims = in_shape.dims() if apply_voxel_size: in_shape = in_shape * voxel_size out_shape = out_shape * voxel_size logger.info(f"source roi: {source_roi}") logger.info(f"in_shape: {in_shape}") logger.info(f"out_shape: {out_shape}") logger.info(f"voxel_size: {voxel_size}") request = gp.BatchRequest() request.add(raw, in_shape) request.add(prediction, out_shape) context = (in_shape - out_shape) / 2 print("context", context, in_shape, out_shape) source = (gp.ZarrSource( dataset.filename, { raw: dataset.ds_names[0], }, array_specs={raw: gp.ArraySpec(roi=source_roi, interpolatable=True)})) num_additional_channels = (2 + spatial_dims) - data_dims for _ in range(num_additional_channels): source += AddChannelDim(raw) # prediction requires samples first, channels second source += TransposeDims(raw, (1, 0) + tuple(range(2, 2 + spatial_dims))) with gp.build(source): raw_roi = source.spec[raw].roi logger.info(f"raw_roi: {raw_roi}") pipeline = source if normalize_factor != "skip": pipeline = pipeline + gp.Normalize(raw, factor=normalize_factor) pipeline = pipeline + (gp.Pad(raw, context) + gp.torch.Predict( model, inputs={input_name: raw}, outputs={model_output: prediction}, array_specs={prediction: gp.ArraySpec(roi=raw_roi)}, checkpoint=checkpoint, spawn_subprocess=spawn_subprocess)) # # remove sample dimension for 3D data # pipeline += RemoveChannelDim(raw) # pipeline += RemoveChannelDim(prediction) pipeline += (gp.ZarrWrite({ prediction: out_ds_names[0], }, output_dir=out_dir, output_filename=out_filename, compression_type='gzip') + gp.Scan(request, num_workers=num_workers)) with gp.build(pipeline): pipeline.request_batch(gp.BatchRequest())
def predict_2d(raw_data, gt_data, predictor): raw_channels = max(1, raw_data.num_channels) input_shape = predictor.input_shape output_shape = predictor.output_shape dataset_shape = raw_data.shape dataset_roi = raw_data.roi voxel_size = raw_data.voxel_size # switch to world units input_size = voxel_size * input_shape output_size = voxel_size * output_shape raw = gp.ArrayKey('RAW') gt = gp.ArrayKey('GT') target = gp.ArrayKey('TARGET') prediction = gp.ArrayKey('PREDICTION') channel_dims = 0 if raw_channels == 1 else 1 data_dims = len(dataset_shape) - channel_dims if data_dims == 3: num_samples = dataset_shape[0] sample_shape = dataset_shape[channel_dims + 1:] else: raise RuntimeError( "For 2D validation, please provide a 3D array where the first " "dimension indexes the samples.") num_samples = raw_data.num_samples sample_shape = gp.Coordinate(sample_shape) sample_size = sample_shape * voxel_size scan_request = gp.BatchRequest() scan_request.add(raw, input_size) scan_request.add(prediction, output_size) if gt_data: scan_request.add(gt, output_size) scan_request.add(target, output_size) # overwrite source ROI to treat samples as z dimension spec = gp.ArraySpec(roi=gp.Roi((0, ) + dataset_roi.get_begin(), (num_samples, ) + sample_size), voxel_size=(1, ) + voxel_size) if gt_data: sources = (raw_data.get_source(raw, overwrite_spec=spec), gt_data.get_source(gt, overwrite_spec=spec)) pipeline = sources + gp.MergeProvider() else: pipeline = raw_data.get_source(raw, overwrite_spec=spec) pipeline += gp.Pad(raw, None) if gt_data: pipeline += gp.Pad(gt, None) # raw: ([c,] s, h, w) # gt: ([c,] s, h, w) pipeline += gp.Normalize(raw) # raw: ([c,] s, h, w) # gt: ([c,] s, h, w) if gt_data: pipeline += predictor.add_target(gt, target) # raw: ([c,] s, h, w) # gt: ([c,] s, h, w) # target: ([c,] s, h, w) if channel_dims == 0: pipeline += AddChannelDim(raw) if gt_data and predictor.target_channels == 0: pipeline += AddChannelDim(target) # raw: (c, s, h, w) # gt: ([c,] s, h, w) # target: (c, s, h, w) pipeline += TransposeDims(raw, (1, 0, 2, 3)) if gt_data: pipeline += TransposeDims(target, (1, 0, 2, 3)) # raw: (s, c, h, w) # gt: ([c,] s, h, w) # target: (s, c, h, w) pipeline += gp_torch.Predict(model=predictor, inputs={'x': raw}, outputs={0: prediction}) # raw: (s, c, h, w) # gt: ([c,] s, h, w) # target: (s, c, h, w) # prediction: (s, c, h, w) pipeline += gp.Scan(scan_request) total_request = gp.BatchRequest() total_request.add(raw, sample_size) total_request.add(prediction, sample_size) if gt_data: total_request.add(gt, sample_size) total_request.add(target, sample_size) with gp.build(pipeline): batch = pipeline.request_batch(total_request) ret = {'raw': batch[raw], 'prediction': batch[prediction]} if gt_data: ret.update({'gt': batch[gt], 'target': batch[target]}) return ret
def predict(**kwargs): name = kwargs['name'] raw = gp.ArrayKey('RAW') raw_cropped = gp.ArrayKey('RAW_CROPPED') pred_affs = gp.ArrayKey('PRED_AFFS') pred_fgbg = gp.ArrayKey('PRED_FGBG') with open(os.path.join(kwargs['input_folder'], name + '_config.json'), 'r') as f: net_config = json.load(f) with open(os.path.join(kwargs['input_folder'], name + '_names.json'), 'r') as f: net_names = json.load(f) voxel_size = gp.Coordinate(kwargs['voxel_size']) input_shape_world = gp.Coordinate(net_config['input_shape']) * voxel_size output_shape_world = gp.Coordinate(net_config['output_shape']) * voxel_size context = (input_shape_world - output_shape_world) // 2 # formulate the request for what a batch should contain request = gp.BatchRequest() request.add(raw, input_shape_world) request.add(raw_cropped, output_shape_world) request.add(pred_affs, output_shape_world) request.add(pred_fgbg, output_shape_world) if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr": raise NotImplementedError("predict node for %s not implemented yet", kwargs['input_format']) if kwargs['input_format'] == "hdf": sourceNode = gp.Hdf5Source with h5py.File( os.path.join(kwargs['data_folder'], kwargs['sample'] + ".hdf"), 'r') as f: shape = f['volumes/raw'].shape elif kwargs['input_format'] == "zarr": sourceNode = gp.ZarrSource f = zarr.open( os.path.join(kwargs['data_folder'], kwargs['sample'] + ".zarr"), 'r') shape = f['volumes/raw'].shape source = sourceNode(os.path.join( kwargs['data_folder'], kwargs['sample'] + "." + kwargs['input_format']), datasets={raw: 'volumes/raw'}) if kwargs['output_format'] != "zarr": raise NotImplementedError("Please use zarr as prediction output") # pre-create zarr file zf = zarr.open(os.path.join(kwargs['output_folder'], kwargs['sample'] + '.zarr'), mode='w') zf.create('volumes/pred_affs', shape=[3] + list(shape), chunks=[3] + list(shape), dtype=np.float32) zf['volumes/pred_affs'].attrs['offset'] = [0, 0, 0] zf['volumes/pred_affs'].attrs['resolution'] = kwargs['voxel_size'] zf.create('volumes/pred_fgbg', shape=[1] + list(shape), chunks=[1] + list(shape), dtype=np.float32) zf['volumes/pred_fgbg'].attrs['offset'] = [0, 0, 0] zf['volumes/pred_fgbg'].attrs['resolution'] = kwargs['voxel_size'] zf.create('volumes/raw_cropped', shape=[1] + list(shape), chunks=[1] + list(shape), dtype=np.float32) zf['volumes/raw_cropped'].attrs['offset'] = [0, 0, 0] zf['volumes/raw_cropped'].attrs['resolution'] = kwargs['voxel_size'] pipeline = ( # read from HDF5 file source + gp.Pad(raw, context) + # perform one training iteration for each passing batch (here we use # the tensor names earlier stored in train_net.config) gp.tensorflow.Predict(graph=os.path.join(kwargs['input_folder'], name + '.meta'), checkpoint=kwargs['checkpoint'], inputs={net_names['raw']: raw}, outputs={ net_names['pred_affs']: pred_affs, net_names['pred_fgbg']: pred_fgbg, net_names['raw_cropped']: raw_cropped }) + # store all passing batches in the same HDF5 file gp.ZarrWrite( { raw_cropped: '/volumes/raw_cropped', pred_affs: '/volumes/pred_affs', pred_fgbg: '/volumes/pred_fgbg', }, output_dir=kwargs['output_folder'], output_filename=kwargs['sample'] + ".zarr", compression_type='gzip') + # show a summary of time spend in each node every 10 iterations gp.PrintProfilingStats(every=10) + # iterate over the whole dataset in a scanning fashion, emitting # requests that match the size of the network gp.Scan(reference=request)) with gp.build(pipeline): # request an empty batch from Scan to trigger scanning of the dataset # without keeping the complete dataset in memory pipeline.request_batch(gp.BatchRequest())
def predict(iteration, raw_file, raw_dataset, out_file, db_host, db_name, worker_config, network_config, out_properties={}, **kwargs): setup_dir = os.path.dirname(os.path.realpath(__file__)) with open( os.path.join(setup_dir, '{}_net_config.json'.format(network_config)), 'r') as f: net_config = json.load(f) # voxels input_shape = gp.Coordinate(net_config['input_shape']) output_shape = gp.Coordinate(net_config['output_shape']) # nm voxel_size = gp.Coordinate((40, 4, 4)) input_size = input_shape * voxel_size output_size = output_shape * voxel_size parameterfile = os.path.join(setup_dir, 'parameter.json') if os.path.exists(parameterfile): with open(parameterfile, 'r') as f: parameters = json.load(f) else: parameters = {} raw = gp.ArrayKey('RAW') pred_postpre_vectors = gp.ArrayKey('PRED_POSTPRE_VECTORS') pred_post_indicator = gp.ArrayKey('PRED_POST_INDICATOR') chunk_request = gp.BatchRequest() chunk_request.add(raw, input_size) chunk_request.add(pred_postpre_vectors, output_size) chunk_request.add(pred_post_indicator, output_size) d_property = out_properties[ 'pred_partner_vectors'] if 'pred_partner_vectors' in out_properties else None m_property = out_properties[ 'pred_syn_indicator_out'] if 'pred_syn_indicator_out' in out_properties else None # Hdf5Source if raw_file.endswith('.hdf'): pipeline = gp.Hdf5Source(raw_file, datasets={raw: raw_dataset}, array_specs={ raw: gp.ArraySpec(interpolatable=True), }) elif raw_file.endswith('.zarr') or raw_file.endswith('.n5'): pipeline = gp.ZarrSource(raw_file, datasets={raw: raw_dataset}, array_specs={ raw: gp.ArraySpec(interpolatable=True), }) else: raise RuntimeError('unknwon input data format {}'.format(raw_file)) pipeline += gp.Pad(raw, size=None) pipeline += gp.Normalize(raw) pipeline += gp.IntensityScaleShift(raw, 2, -1) pipeline += gp.tensorflow.Predict( os.path.join(setup_dir, 'train_net_checkpoint_%d' % iteration), inputs={net_config['raw']: raw}, outputs={ net_config['pred_syn_indicator_out']: pred_post_indicator, net_config['pred_partner_vectors']: pred_postpre_vectors }, graph=os.path.join(setup_dir, '{}_net.meta'.format(network_config))) d_scale = parameters['d_scale'] if 'd_scale' in parameters else None if d_scale != 1 and d_scale is not None: pipeline += gp.IntensityScaleShift(pred_postpre_vectors, 1. / d_scale, 0) # Map back to nm world. if m_property is not None and 'scale' in m_property: if m_property['scale'] != 1: pipeline += gp.IntensityScaleShift(pred_post_indicator, m_property['scale'], 0) if d_property is not None and 'scale' in d_property: pipeline += gp.IntensityScaleShift(pred_postpre_vectors, d_property['scale'], 0) if d_property is not None and 'dtype' in d_property: assert d_property['dtype'] == 'int8' or d_property[ 'dtype'] == 'float32', 'predict not adapted to dtype {}'.format( d_property['dtype']) if d_property['dtype'] == 'int8': pipeline += IntensityScaleShiftClip(pred_postpre_vectors, 1, 0, clip=(-128, 127)) pipeline += gp.ZarrWrite(dataset_names={ pred_post_indicator: 'volumes/pred_syn_indicator', pred_postpre_vectors: 'volumes/pred_partner_vectors', }, output_filename=out_file) pipeline += gp.PrintProfilingStats(every=10) pipeline += gp.DaisyRequestBlocks( chunk_request, roi_map={ raw: 'read_roi', pred_postpre_vectors: 'write_roi', pred_post_indicator: 'write_roi' }, num_workers=worker_config['num_cache_workers'], block_done_callback=lambda b, s, d: block_done_callback( db_host, db_name, worker_config, b, s, d)) print("Starting prediction...") with gp.build(pipeline): pipeline.request_batch(gp.BatchRequest()) print("Prediction finished")
def predict_frame(in_shape, out_shape, model_output, model_configfile, model_checkpoint, input_dataset_file, inference_frame, out_dir, out_filename, out_key_or_index=1, intermediate_layer=None, dataset_raw_key="train/raw", dataset_prediction_key="train/prediction", dataset_intermediate_key="train/prediction_interm", model_input_tensor_name="patches", model_architecture="PatchedResnet", num_workers=5): # initialize model if model_architecture == "PatchedResnet": model = PatchedResnet(1, 2, resnet_size=18) elif model_architecture == "unet": model = lisl.models.create(model_configfile) else: raise NotImplementedError(f"{model_architecture} not implemented") model.add_spatial_dim = True model.eval() # gp variables in_shape = gp.Coordinate(in_shape) out_shape = gp.Coordinate(out_shape) raw = gp.ArrayKey(f'RAW_{inference_frame}') prediction = gp.ArrayKey(f'PREDICTION_{inference_frame}') intermediate_prediction = gp.ArrayKey(f'ITERM_{inference_frame}') ds_key = f'{dataset_raw_key}/{inference_frame}' out_key = f'{dataset_prediction_key}/{inference_frame}' interm_key = f'{dataset_intermediate_key}/{inference_frame}' # build pipeline zsource = gp.ZarrSource( input_dataset_file, {raw: ds_key}, {raw: gp.ArraySpec(interpolatable=True, voxel_size=(1, 1))}) pipeline = zsource with gp.build(zsource): raw_roi = zsource.spec[raw].roi logger.info(f"raw_roi: {raw_roi}") pipeline += AddChannelDim(raw) pipeline += AddChannelDim(raw) pipeline += gp.Pad(raw, None) # setup prediction node pred_dict = {out_key_or_index: prediction} pred_spec = {prediction: gp.ArraySpec(roi=raw_roi)} if intermediate_layer is not None: pred_dict[intermediate_layer] = intermediate_prediction pred_spec[intermediate_prediction] = gp.ArraySpec(roi=raw_roi) pipeline += gp.torch.Predict(model, inputs={model_input_tensor_name: raw}, outputs=pred_dict, array_specs=pred_spec, checkpoint=model_checkpoint, spawn_subprocess=True) request = gp.BatchRequest() request.add(raw, in_shape) request.add(prediction, out_shape) zarr_dict = {prediction: out_key} if intermediate_layer is not None: zarr_dict[intermediate_prediction] = interm_key request.add(intermediate_prediction, out_shape) pipeline += gp.Scan(request, num_workers=num_workers) pipeline += gp.ZarrWrite(zarr_dict, output_dir=out_dir, output_filename=out_filename, compression_type='gzip') total_request = gp.BatchRequest() total_request[prediction] = gp.ArraySpec(roi=raw_roi) if intermediate_layer is not None: total_request[intermediate_prediction] = gp.ArraySpec(roi=raw_roi) with gp.build(pipeline): pipeline.request_batch(total_request)
def train_until(**kwargs): if tf.train.latest_checkpoint(kwargs['output_folder']): trained_until = int( tf.train.latest_checkpoint(kwargs['output_folder']).split('_')[-1]) else: trained_until = 0 if trained_until >= kwargs['max_iteration']: return anchor = gp.ArrayKey('ANCHOR') raw = gp.ArrayKey('RAW') raw_cropped = gp.ArrayKey('RAW_CROPPED') gt_threeclass = gp.ArrayKey('GT_THREECLASS') loss_weights_threeclass = gp.ArrayKey('LOSS_WEIGHTS_THREECLASS') pred_threeclass = gp.ArrayKey('PRED_THREECLASS') pred_threeclass_gradients = gp.ArrayKey('PRED_THREECLASS_GRADIENTS') with open( os.path.join(kwargs['output_folder'], kwargs['name'] + '_config.json'), 'r') as f: net_config = json.load(f) with open( os.path.join(kwargs['output_folder'], kwargs['name'] + '_names.json'), 'r') as f: net_names = json.load(f) voxel_size = gp.Coordinate(kwargs['voxel_size']) input_shape_world = gp.Coordinate(net_config['input_shape']) * voxel_size output_shape_world = gp.Coordinate(net_config['output_shape']) * voxel_size # formulate the request for what a batch should (at least) contain request = gp.BatchRequest() request.add(raw, input_shape_world) request.add(raw_cropped, output_shape_world) request.add(gt_threeclass, output_shape_world) request.add(anchor, output_shape_world) request.add(loss_weights_threeclass, output_shape_world) # when we make a snapshot for inspection (see below), we also want to # request the predicted affinities and gradients of the loss wrt the # affinities snapshot_request = gp.BatchRequest() snapshot_request.add(raw_cropped, output_shape_world) snapshot_request.add(gt_threeclass, output_shape_world) snapshot_request.add(pred_threeclass, output_shape_world) # snapshot_request.add(pred_threeclass_gradients, output_shape_world) if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr": raise NotImplementedError("train node for {} not implemented".format( kwargs['input_format'])) fls = [] shapes = [] for f in kwargs['data_files']: fls.append(os.path.splitext(f)[0]) if kwargs['input_format'] == "hdf": vol = h5py.File(f, 'r')['volumes/raw'] elif kwargs['input_format'] == "zarr": vol = zarr.open(f, 'r')['volumes/raw'] print(f, vol.shape, vol.dtype) shapes.append(vol.shape) if vol.dtype != np.float32: print("please convert to float32") ln = len(fls) print("first 5 files: ", fls[0:4]) # padR = 46 # padGT = 32 if kwargs['input_format'] == "hdf": sourceNode = gp.Hdf5Source elif kwargs['input_format'] == "zarr": sourceNode = gp.ZarrSource augmentation = kwargs['augmentation'] pipeline = ( tuple( # read batches from the HDF5 file sourceNode( fls[t] + "." + kwargs['input_format'], datasets={ raw: 'volumes/raw', gt_threeclass: 'volumes/gt_threeclass', anchor: 'volumes/gt_threeclass', }, array_specs={ raw: gp.ArraySpec(interpolatable=True), gt_threeclass: gp.ArraySpec(interpolatable=False), anchor: gp.ArraySpec(interpolatable=False) } ) + gp.MergeProvider() + gp.Pad(raw, None) + gp.Pad(gt_threeclass, None) + gp.Pad(anchor, gp.Coordinate((2,2,2))) # chose a random location for each requested batch + gp.RandomLocation() for t in range(ln) ) + # chose a random source (i.e., sample) from the above gp.RandomProvider() + # elastically deform the batch (gp.ElasticAugment( augmentation['elastic']['control_point_spacing'], augmentation['elastic']['jitter_sigma'], [augmentation['elastic']['rotation_min']*np.pi/180.0, augmentation['elastic']['rotation_max']*np.pi/180.0], subsample=augmentation['elastic'].get('subsample', 1)) \ if augmentation.get('elastic') is not None else NoOp()) + # apply transpose and mirror augmentations gp.SimpleAugment(mirror_only=augmentation['simple'].get("mirror"), transpose_only=augmentation['simple'].get("transpose")) + # # scale and shift the intensity of the raw array gp.IntensityAugment( raw, scale_min=augmentation['intensity']['scale'][0], scale_max=augmentation['intensity']['scale'][1], shift_min=augmentation['intensity']['shift'][0], shift_max=augmentation['intensity']['shift'][1], z_section_wise=False) + # grow a boundary between labels # TODO: check # gp.GrowBoundary( # gt_threeclass, # steps=1, # only_xy=False) + gp.BalanceLabels( gt_threeclass, loss_weights_threeclass, num_classes=3) + # pre-cache batches from the point upstream gp.PreCache( cache_size=kwargs['cache_size'], num_workers=kwargs['num_workers']) + # perform one training iteration for each passing batch (here we use # the tensor names earlier stored in train_net.config) gp.tensorflow.Train( os.path.join(kwargs['output_folder'], kwargs['name']), optimizer=net_names['optimizer'], summary=net_names['summaries'], log_dir=kwargs['output_folder'], loss=net_names['loss'], inputs={ net_names['raw']: raw, net_names['anchor']: anchor, net_names['gt_threeclass']: gt_threeclass, net_names['loss_weights_threeclass']: loss_weights_threeclass }, outputs={ net_names['pred_threeclass']: pred_threeclass, net_names['raw_cropped']: raw_cropped, }, gradients={ net_names['pred_threeclass']: pred_threeclass_gradients, }, save_every=kwargs['checkpoints']) + # save the passing batch as an HDF5 file for inspection gp.Snapshot( { raw: '/volumes/raw', raw_cropped: 'volumes/raw_cropped', gt_threeclass: '/volumes/gt_threeclass', pred_threeclass: '/volumes/pred_threeclass', }, output_dir=os.path.join(kwargs['output_folder'], 'snapshots'), output_filename='batch_{iteration}.hdf', every=kwargs['snapshots'], additional_request=snapshot_request, compression_type='gzip') + # show a summary of time spend in each node every 10 iterations gp.PrintProfilingStats(every=kwargs['profiling']) ) ######### # TRAIN # ######### print("Starting training...") with gp.build(pipeline): print(pipeline) for i in range(trained_until, kwargs['max_iteration']): # print("request", request) start = time.time() pipeline.request_batch(request) time_of_iteration = time.time() - start logger.info("Batch: iteration=%d, time=%f", i, time_of_iteration) # exit() print("Training finished")
def build_batch_provider(self, datasets, model, task, snapshot_container=None): input_shape = Coordinate(model.input_shape) output_shape = Coordinate(model.output_shape) # get voxel sizes raw_voxel_size = datasets[0].raw.voxel_size prediction_voxel_size = model.scale(raw_voxel_size) # define input and output size: # switch to world units input_size = raw_voxel_size * input_shape output_size = prediction_voxel_size * output_shape # padding of groundtruth/mask gt_mask_padding = output_size + task.predictor.padding(prediction_voxel_size) # define keys: raw_key = gp.ArrayKey("RAW") gt_key = gp.ArrayKey("GT") mask_key = gp.ArrayKey("MASK") target_key = gp.ArrayKey("TARGET") weight_key = gp.ArrayKey("WEIGHT") # Get source nodes dataset_sources = [] for dataset in datasets: raw_source = DaCapoArraySource(dataset.raw, raw_key) raw_source += gp.Pad(raw_key, None, 0) gt_source = DaCapoArraySource(dataset.gt, gt_key) gt_source += gp.Pad(gt_key, gt_mask_padding, 0) if dataset.mask is not None: mask_source = DaCapoArraySource(dataset.mask, mask_key) else: # Always provide a mask. By default it is simply an array # of ones with the same shape/roi as gt. Avoids making us # specially handle no mask case and allows padding of the # ground truth without worrying about training on incorrect # data. mask_source = DaCapoArraySource(OnesArray.like(dataset.gt), mask_key) mask_source += gp.Pad(mask_key, gt_mask_padding, 0) array_sources = [raw_source, gt_source, mask_source] dataset_source = ( tuple(array_sources) + gp.MergeProvider() + gp.RandomLocation() ) dataset_sources.append(dataset_source) pipeline = tuple(dataset_sources) + gp.RandomProvider() for augment in self.augments: pipeline += augment.node(raw_key, gt_key, mask_key) pipeline += gp.Reject(mask_key, min_masked=self.min_masked) # Add predictor nodes to pipeline pipeline += DaCapoTargetFilter( task.predictor, gt_key=gt_key, target_key=target_key, weights_key=weight_key, mask_key=mask_key, ) # Trainer attributes: if self.num_data_fetchers > 1: pipeline += gp.PreCache(num_workers=self.num_data_fetchers) # stack to create a batch dimension pipeline += gp.Stack(self.batch_size) # print profiling stats pipeline += gp.PrintProfilingStats(every=self.print_profiling) # generate request for all necessary inputs to training request = gp.BatchRequest() request.add(raw_key, input_size) request.add(target_key, output_size) request.add(weight_key, output_size) # request additional keys for snapshots request.add(gt_key, output_size) request.add(mask_key, output_size) self._request = request self._pipeline = pipeline self._raw_key = raw_key self._gt_key = gt_key self._mask_key = mask_key self._weight_key = weight_key self._target_key = target_key self._loss = task.loss self.snapshot_container = snapshot_container
def predict_volume(model, dataset, out_dir, out_filename, out_ds_names, checkpoint, input_name='raw_0', normalize_factor=None, model_output=0, in_shape=None, out_shape=None, spawn_subprocess=True, num_workers=0, apply_voxel_size=True): raw = gp.ArrayKey('RAW') prediction = gp.ArrayKey('PREDICTION') data = daisy.open_ds(dataset.filename, dataset.ds_names[0]) source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape()) voxel_size = gp.Coordinate(data.voxel_size) data_dims = len(data.shape) # Get in and out shape if in_shape is None: in_shape = model.in_shape if out_shape is None: out_shape = model.out_shape in_shape = gp.Coordinate(in_shape) out_shape = gp.Coordinate(out_shape) spatial_dims = in_shape.dims() is_2d = spatial_dims == 2 in_shape = in_shape * voxel_size out_shape = out_shape * voxel_size logger.info(f"source roi: {source_roi}") logger.info(f"in_shape: {in_shape}") logger.info(f"out_shape: {out_shape}") logger.info(f"voxel_size: {voxel_size}") request = gp.BatchRequest() request.add(raw, in_shape) request.add(prediction, out_shape) context = (in_shape - out_shape) / 2 source = (gp.ZarrSource( dataset.filename, { raw: dataset.ds_names[0], }, array_specs={raw: gp.ArraySpec(roi=source_roi, interpolatable=True)})) # ensure raw has sample and channel dims # # n = number of samples # c = number of channels # 2D raw is either (n, y, x) or (c, n, y, x) # 3D raw is either (z, y, x) or (c, z, y, x) for _ in range((2 + spatial_dims) - data_dims): source += AddChannelDim(raw) # 2D raw: (c, n, y, x) # 3D raw: (c, n=1, z, y, x) # prediction requires samples first, channels second source += TransposeDims(raw, (1, 0) + tuple(range(2, 2 + spatial_dims))) # 2D raw: (n, c, y, x) # 3D raw: (n=1, c, z, y, x) with gp.build(source): raw_roi = source.spec[raw].roi logger.info(f"raw_roi: {raw_roi}") pipeline = source if normalize_factor != "skip": pipeline = pipeline + gp.Normalize(raw, factor=normalize_factor) pipeline = pipeline + (gp.Pad(raw, context) + gp.torch.Predict( model, inputs={input_name: raw}, outputs={model_output: prediction}, array_specs={prediction: gp.ArraySpec(roi=raw_roi)}, checkpoint=checkpoint, spawn_subprocess=spawn_subprocess)) # 2D raw : (n, c, y, x) # 2D prediction: (n, c, y, x) # 3D raw : (n=1, c, z, y, x) # 3D prediction: (n=1, c, z, y, x) if is_2d: # restore channels first for 2D data pipeline += TransposeDims(raw, (1, 0) + tuple(range(2, 2 + spatial_dims))) pipeline += TransposeDims(prediction, (1, 0) + tuple(range(2, 2 + spatial_dims))) else: # remove sample dimension for 3D data pipeline += RemoveChannelDim(raw) pipeline += RemoveChannelDim(prediction) # 2D raw : (c, n, y, x) # 2D prediction: (c, n, y, x) # 3D raw : (c, z, y, x) # 3D prediction: (c, z, y, x) pipeline += (gp.ZarrWrite({ prediction: out_ds_names[0], }, output_dir=out_dir, output_filename=out_filename, compression_type='gzip') + gp.Scan(request, num_workers=num_workers)) logger.info("Writing prediction to %s/%s[%s]", out_dir, out_filename, out_ds_names[0]) with gp.build(pipeline): pipeline.request_batch(gp.BatchRequest())
def train_until(max_iteration): in_channels = 1 num_fmaps = 12 fmap_inc_factors = 6 downsample_factors = [(1, 3, 3), (1, 3, 3), (3, 3, 3)] unet = UNet(in_channels, num_fmaps, fmap_inc_factors, downsample_factors, constant_upsample=True) model = Convolve(unet, 12, 1) loss = torch.nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=1e-6) # start of gunpowder part: raw = gp.ArrayKey('RAW') points = gp.GraphKey('POINTS') groundtruth = gp.ArrayKey('RASTER') prediction = gp.ArrayKey('PRED_POINT') grad = gp.ArrayKey('GRADIENT') voxel_size = gp.Coordinate((40, 4, 4)) input_shape = (96, 430, 430) output_shape = (60, 162, 162) input_size = gp.Coordinate(input_shape) * voxel_size output_size = gp.Coordinate(output_shape) * voxel_size request = gp.BatchRequest() request.add(raw, input_size) request.add(points, output_size) request.add(groundtruth, output_size) request.add(prediction, output_size) request.add(grad, output_size) pos_sources = tuple( gp.ZarrSource(filename, {raw: 'volumes/raw'}, {raw: gp.ArraySpec(interpolatable=True)}) + AddCenterPoint(points, raw) + gp.Pad(raw, None) + gp.RandomLocation(ensure_nonempty=points) for filename in pos_samples) + gp.RandomProvider() neg_sources = tuple( gp.ZarrSource(filename, {raw: 'volumes/raw'}, {raw: gp.ArraySpec(interpolatable=True)}) + AddNoPoint(points, raw) + gp.RandomLocation() for filename in neg_samples) + gp.RandomProvider() data_sources = (pos_sources, neg_sources) data_sources += gp.RandomProvider(probabilities=[0.9, 0.1]) data_sources += gp.Normalize(raw) train_pipeline = data_sources train_pipeline += gp.ElasticAugment(control_point_spacing=[4, 40, 40], jitter_sigma=[0, 2, 2], rotation_interval=[0, math.pi / 2.0], prob_slip=0.05, prob_shift=0.05, max_misalign=10, subsample=8) train_pipeline += gp.SimpleAugment(transpose_only=[1, 2]) train_pipeline += gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1, \ z_section_wise=True) train_pipeline += gp.RasterizePoints( points, groundtruth, array_spec=gp.ArraySpec(voxel_size=voxel_size), settings=gp.RasterizationSettings(radius=(100, 100, 100), mode='peak')) train_pipeline += gp.PreCache(cache_size=40, num_workers=10) train_pipeline += Reshape(raw, (1, 1) + input_shape) train_pipeline += Reshape(groundtruth, (1, 1) + output_shape) train_pipeline += gp_torch.Train(model=model, loss=loss, optimizer=optimizer, inputs={'x': raw}, outputs={0: prediction}, loss_inputs={ 0: prediction, 1: groundtruth }, gradients={0: grad}, save_every=1000, log_dir='log') train_pipeline += Reshape(raw, input_shape) train_pipeline += Reshape(groundtruth, output_shape) train_pipeline += Reshape(prediction, output_shape) train_pipeline += Reshape(grad, output_shape) train_pipeline += gp.Snapshot( { raw: 'volumes/raw', groundtruth: 'volumes/groundtruth', prediction: 'volumes/prediction', grad: 'volumes/gradient' }, every=500, output_filename='test_{iteration}.hdf') train_pipeline += gp.PrintProfilingStats(every=10) with gp.build(train_pipeline): for i in range(max_iteration): train_pipeline.request_batch(request)
def make_pipeline(self): raw = gp.ArrayKey('RAW') embs = gp.ArrayKey('EMBS') source_shape = zarr.open(self.data_file)[self.dataset].shape raw_roi = gp.Roi(np.zeros(len(source_shape[1:])), source_shape[1:]) data = daisy.open_ds(self.data_file, self.dataset) source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape()) voxel_size = gp.Coordinate(data.voxel_size) # Get in and out shape in_shape = gp.Coordinate(self.model.in_shape) out_shape = gp.Coordinate(self.model.out_shape[2:]) is_2d = in_shape.dims() == 2 logger.info(f"source roi: {source_roi}") logger.info(f"in_shape: {in_shape}") logger.info(f"out_shape: {out_shape}") logger.info(f"voxel_size: {voxel_size}") in_shape = in_shape * voxel_size out_shape = out_shape * voxel_size logger.info(f"source roi: {source_roi}") logger.info(f"in_shape: {in_shape}") logger.info(f"out_shape: {out_shape}") logger.info(f"voxel_size: {voxel_size}") request = gp.BatchRequest() request.add(raw, in_shape) request.add(embs, out_shape) context = (in_shape - out_shape) / 2 source = (gp.ZarrSource(self.data_file, { raw: self.dataset, }, array_specs={ raw: gp.ArraySpec(roi=source_roi, interpolatable=False) })) if is_2d: source = (source + AddChannelDim(raw, axis=1)) else: source = (source + AddChannelDim(raw, axis=0) + AddChannelDim(raw)) source = ( source # raw : (c=1, roi) ) with gp.build(source): raw_roi = source.spec[raw].roi logger.info(f"raw_roi: {raw_roi}") pipeline = ( source + gp.Normalize(raw, factor=self.params['norm_factor']) + gp.Pad(raw, context) + gp.PreCache() + gp.torch.Predict(self.model, inputs={'raw': raw}, outputs={0: embs}, array_specs={embs: gp.ArraySpec(roi=raw_roi)})) pipeline = (pipeline + gp.ZarrWrite({ embs: 'embs', }, output_dir=self.curr_log_dir, output_filename=self.dataset + '_embs.zarr', compression_type='gzip') + gp.Scan(request)) return pipeline, request, embs
def create_pipeline_3d( task, data, predictor, optimizer, batch_size, outdir, snapshot_every ): raw_channels = max(1, data.raw.num_channels) input_shape = gp.Coordinate(task.model.input_shape) output_shape = gp.Coordinate(task.model.output_shape) voxel_size = data.raw.train.voxel_size task.predictor = task.predictor.to("cuda") # switch to world units input_size = voxel_size * input_shape output_size = voxel_size * output_shape raw = gp.ArrayKey("RAW") gt = gp.ArrayKey("GT") mask = gp.ArrayKey("MASK") target = gp.ArrayKey("TARGET") weights = gp.ArrayKey("WEIGHTS") model_outputs = gp.ArrayKey("MODEL_OUTPUTS") model_output_grads = gp.ArrayKey("MODEL_OUT_GRAD") prediction = gp.ArrayKey("PREDICTION") pred_gradients = gp.ArrayKey("PRED_GRADIENTS") snapshot_dataset_names = { raw: "raw", model_outputs: "model_outputs", model_output_grads: "model_out_grad", target: "target", prediction: "prediction", pred_gradients: "pred_gradients", weights: "weights", } aux_keys = {} aux_grad_keys = {} for name, _, _ in task.aux_tasks: aux_keys[name] = ( gp.ArrayKey(f"{name.upper()}_PREDICTION"), gp.ArrayKey(f"{name.upper()}_TARGET"), None, ) aux_grad_keys[name] = gp.ArrayKey(f"{name.upper()}_PRED_GRAD") aux_pred, aux_target, _ = aux_keys[name] snapshot_dataset_names[aux_pred] = f"{name}_pred" snapshot_dataset_names[aux_target] = f"{name}_target" aux_grad = aux_grad_keys[name] snapshot_dataset_names[aux_grad] = f"{name}_aux_grad" channel_dims = 0 if raw_channels == 1 else 1 num_samples = data.raw.train.num_samples assert num_samples == 0, "Multiple samples for 3D training not yet implemented" sources = (data.raw.train.get_source(raw), data.gt.train.get_source(gt)) pipeline = sources + gp.MergeProvider() pipeline += gp.Pad(raw, input_shape / 2 * voxel_size) # pipeline += gp.Pad(gt, input_shape / 2 * voxel_size) # raw: ([c,] d, h, w) # gt: ([c,] d, h, w) pipeline += gp.Normalize(raw) mask_node = task.loss.add_mask(gt, mask) if mask_node is not None: pipeline += mask_node pipeline += gp.RandomLocation(min_masked=1e-6, mask=mask) else: # raw: ([c,] d, h, w) # gt: ([c,] d, h, w) pipeline += gp.RandomLocation() # raw: ([c,] d, h, w) # gt: ([c,] d, h, w) for augmentation in eval(task.augmentations): pipeline += augmentation pipeline += predictor.add_target(gt, target) # (don't care about gt anymore) # raw: ([c,] d, h, w) # target: ([c,] d, h, w) weights_node = task.loss.add_weights(target, weights) loss_inputs = [] if weights_node: pipeline += weights_node loss_inputs.append({0: prediction, 1: target, 2: weights}) else: loss_inputs.append({0: prediction, 1: target}) head_outputs = [] head_gradients = [] for name, aux_predictor, aux_loss in task.aux_tasks: aux_prediction, aux_target, aux_weights = aux_keys[name] pipeline += aux_predictor.add_target(gt, aux_target) aux_weights_node = aux_loss.add_weights(aux_target, aux_weights) if aux_weights_node: aux_weights = gp.ArrayKey(f"{name.upper()}_WEIGHTS") aux_keys[name] = ( aux_prediction, aux_target, aux_weights, ) pipeline += aux_weights_node loss_inputs.append({0: aux_prediction, 1: aux_target, 2: aux_weights}) snapshot_dataset_names[aux_weights] = f"{name}_weights" else: loss_inputs.append({0: aux_prediction, 1: aux_target}) head_outputs.append({0: aux_prediction}) aux_pred_gradient = aux_grad_keys[name] head_gradients.append({0: aux_pred_gradient}) # raw: ([c,] d, h, w) # target: ([c,] d, h, w) # [weights: ([c,] d, h, w)] if channel_dims == 0: pipeline += AddChannelDim(raw) # raw: (c, d, h, w) # target: ([c,] d, h, w) # [weights: ([c,] d, h, w)] pipeline += gp.PreCache() pipeline += gp.Stack(batch_size) # raw: (b, c, d, h, w) # target: (b, [c,] d, h, w) # [weights: (b, [c,] d, h, w)] pipeline += Train( model=task.model, heads=[("opt", predictor)] + [(name, aux_pred) for name, aux_pred, _ in task.aux_tasks], losses=[task.loss] + [loss for _, _, loss in task.aux_tasks], optimizer=optimizer, inputs={"x": raw}, outputs={0: model_outputs}, head_outputs=[{0: prediction}] + head_outputs, loss_inputs=loss_inputs, gradients=[{0: model_output_grads}, {0: pred_gradients}] + head_gradients, save_every=1e6, ) # raw: (b, c, d, h, w) # target: (b, [c,] d, h, w) # [weights: (b, [c,] d, h, w)] # prediction: (b, [c,] d, h, w) if snapshot_every > 0: # get channels first pipeline += TransposeDims(raw, (1, 0, 2, 3, 4)) if predictor.target_channels > 0: pipeline += TransposeDims(target, (1, 0, 2, 3, 4)) if weights_node: pipeline += TransposeDims(weights, (1, 0, 2, 3, 4)) if predictor.prediction_channels > 0: pipeline += TransposeDims(prediction, (1, 0, 2, 3, 4)) # raw: (c, b, d, h, w) # target: ([c,] b, d, h, w) # [weights: ([c,] b, d, h, w)] # prediction: ([c,] b, d, h, w) if channel_dims == 0: pipeline += RemoveChannelDim(raw) # raw: ([c,] b, d, h, w) # target: (c, b, d, h, w) # [weights: ([c,] b, d, h, w)] # prediction: (c, b, d, h, w) pipeline += gp.Snapshot( dataset_names=snapshot_dataset_names, every=snapshot_every, output_dir=os.path.join(outdir, "snapshots"), output_filename="{iteration}.hdf", ) pipeline += gp.PrintProfilingStats(every=10) request = gp.BatchRequest() request.add(raw, input_size) request.add(gt, output_size) if mask_node is not None: request.add(mask, output_size) request.add(target, output_size) for name, _, _ in task.aux_tasks: aux_pred, aux_target, aux_weight = aux_keys[name] request.add(aux_pred, output_size) request.add(aux_target, output_size) if aux_weight is not None: request.add(aux_weight, output_size) aux_pred_grad = aux_grad_keys[name] request.add(aux_pred_grad, output_size) if weights_node: request.add(weights, output_size) request.add(prediction, output_size) request.add(pred_gradients, output_size) return pipeline, request
def create_train_pipeline(self, model): print( f"Creating training pipeline with batch size {self.params['batch_size']}" ) filename = self.params['data_file'] raw_dataset = self.params['dataset']['train']['raw'] gt_dataset = self.params['dataset']['train']['gt'] optimizer = self.params['optimizer'](model.parameters(), **self.params['optimizer_kwargs']) raw = gp.ArrayKey('RAW') gt_labels = gp.ArrayKey('LABELS') gt_aff = gp.ArrayKey('AFFINITIES') predictions = gp.ArrayKey('PREDICTIONS') emb = gp.ArrayKey('EMBEDDING') raw_data = daisy.open_ds(filename, raw_dataset) source_roi = gp.Roi(raw_data.roi.get_offset(), raw_data.roi.get_shape()) source_voxel_size = gp.Coordinate(raw_data.voxel_size) out_voxel_size = gp.Coordinate(raw_data.voxel_size) # Get in and out shape in_shape = gp.Coordinate(model.in_shape) out_shape = gp.Coordinate(model.out_shape[2:]) is_2d = in_shape.dims() == 2 in_shape = in_shape * out_voxel_size out_shape = out_shape * out_voxel_size context = (in_shape - out_shape) / 2 gt_labels_out_shape = out_shape # Add fake 3rd dim if is_2d: source_voxel_size = gp.Coordinate((1, *source_voxel_size)) source_roi = gp.Roi((0, *source_roi.get_offset()), (raw_data.shape[0], *source_roi.get_shape())) context = gp.Coordinate((0, *context)) aff_neighborhood = [[0, -1, 0], [0, 0, -1]] gt_labels_out_shape = (1, *gt_labels_out_shape) else: aff_neighborhood = [[-1, 0, 0], [0, -1, 0], [0, 0, -1]] logger.info(f"source roi: {source_roi}") logger.info(f"in_shape: {in_shape}") logger.info(f"out_shape: {out_shape}") logger.info(f"voxel_size: {out_voxel_size}") logger.info(f"context: {context}") request = gp.BatchRequest() request.add(raw, in_shape) request.add(gt_aff, out_shape) request.add(predictions, out_shape) snapshot_request = gp.BatchRequest() snapshot_request[emb] = gp.ArraySpec( roi=gp.Roi((0, ) * in_shape.dims(), gp.Coordinate((*model.base_encoder.out_shape[2:], )) * out_voxel_size)) snapshot_request[gt_labels] = gp.ArraySpec( roi=gp.Roi(context, gt_labels_out_shape)) source = ( gp.ZarrSource(filename, { raw: raw_dataset, gt_labels: gt_dataset }, array_specs={ raw: gp.ArraySpec(roi=source_roi, voxel_size=source_voxel_size, interpolatable=True), gt_labels: gp.ArraySpec(roi=source_roi, voxel_size=source_voxel_size) }) + gp.Normalize(raw, self.params['norm_factor']) + gp.Pad(raw, context) + gp.Pad(gt_labels, context) + gp.RandomLocation() # raw : (l=1, h, w) # gt_labels: (l=1, h, w) ) source = self._augmentation_pipeline(raw, source) pipeline = ( source + # raw : (l=1, h, w) # gt_labels: (l=1, h, w) gp.AddAffinities(aff_neighborhood, gt_labels, gt_aff) + SetDtype(gt_aff, np.float32) + # raw : (l=1, h, w) # gt_aff : (c=2, l=1, h, w) AddChannelDim(raw) # raw : (c=1, l=1, h, w) # gt_aff : (c=2, l=1, h, w) ) if is_2d: pipeline = ( pipeline + RemoveSpatialDim(raw) + RemoveSpatialDim(gt_aff) # raw : (c=1, h, w) # gt_aff : (c=2, h, w) ) pipeline = ( pipeline + gp.Stack(self.params['batch_size']) + gp.PreCache() + # raw : (b, c=1, h, w) # gt_aff : (b, c=2, h, w) # (which is what train requires) gp.torch.Train( model, self.loss, optimizer, inputs={'raw': raw}, loss_inputs={ 0: predictions, 1: gt_aff }, outputs={ 0: predictions, 1: emb }, array_specs={ predictions: gp.ArraySpec(voxel_size=out_voxel_size), }, checkpoint_basename=self.logdir + '/checkpoints/model', save_every=self.params['save_every'], log_dir=self.logdir, log_every=self.log_every) + # everything is 2D at this point, plus extra dimensions for # channels and batch # raw : (b, c=1, h, w) # gt_aff : (b, c=2, h, w) # predictions: (b, c=2, h, w) # Crop GT to look at labels gp.Crop(gt_labels, gp.Roi(context, gt_labels_out_shape)) + gp.Snapshot(output_dir=self.logdir + '/snapshots', output_filename='it{iteration}.hdf', dataset_names={ raw: 'raw', gt_labels: 'gt_labels', predictions: 'predictions', gt_aff: 'gt_aff', emb: 'emb' }, additional_request=snapshot_request, every=self.params['save_every']) + gp.PrintProfilingStats(every=500)) return pipeline, request
def random_point_pairs_pipeline(model, loss, optimizer, dataset, augmentation_parameters, point_density, out_dir, normalize_factor=None, checkpoint_interval=5000, snapshot_interval=5000): raw_0 = gp.ArrayKey('RAW_0') points_0 = gp.GraphKey('POINTS_0') locations_0 = gp.ArrayKey('LOCATIONS_0') emb_0 = gp.ArrayKey('EMBEDDING_0') raw_1 = gp.ArrayKey('RAW_1') points_1 = gp.GraphKey('POINTS_1') locations_1 = gp.ArrayKey('LOCATIONS_1') emb_1 = gp.ArrayKey('EMBEDDING_1') # TODO parse this key from somewhere key = 'train/raw/0' data = daisy.open_ds(dataset.filename, key) source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape()) voxel_size = gp.Coordinate(data.voxel_size) emb_voxel_size = voxel_size # Get in and out shape in_shape = gp.Coordinate(model.in_shape) out_shape = gp.Coordinate(model.out_shape) logger.info(f"source roi: {source_roi}") logger.info(f"in_shape: {in_shape}") logger.info(f"out_shape: {out_shape}") logger.info(f"voxel_size: {voxel_size}") request = gp.BatchRequest() request.add(raw_0, in_shape) request.add(raw_1, in_shape) request.add(points_0, out_shape) request.add(points_1, out_shape) request[locations_0] = gp.ArraySpec(nonspatial=True) request[locations_1] = gp.ArraySpec(nonspatial=True) snapshot_request = gp.BatchRequest() snapshot_request[emb_0] = gp.ArraySpec(roi=request[points_0].roi) snapshot_request[emb_1] = gp.ArraySpec(roi=request[points_1].roi) # Let's hardcode this for now # TODO read actual number from zarr file keys n_samples = 447 batch_size = 1 dim = 2 padding = (100, 100) sources = [] for i in range(n_samples): ds_key = f'train/raw/{i}' image_sources = tuple( gp.ZarrSource( dataset.filename, {raw: ds_key}, {raw: gp.ArraySpec(interpolatable=True, voxel_size=(1, 1))}) + gp.Pad(raw, None) for raw in [raw_0, raw_1]) random_point_generator = RandomPointGenerator(density=point_density, repetitions=2) point_sources = tuple( (RandomPointSource(points_0, dim, random_point_generator=random_point_generator), RandomPointSource(points_1, dim, random_point_generator=random_point_generator))) # TODO: get augmentation parameters from some config file! points_and_image_sources = tuple( (img_source, point_source) + gp.MergeProvider() + \ gp.SimpleAugment() + \ gp.ElasticAugment( spatial_dims=2, control_point_spacing=(10, 10), jitter_sigma=(0.0, 0.0), rotation_interval=(0, math.pi/2)) + \ gp.IntensityAugment(r, scale_min=0.8, scale_max=1.2, shift_min=-0.2, shift_max=0.2, clip=False) + \ gp.NoiseAugment(r, var=0.01, clip=False) for r, img_source, point_source in zip([raw_0, raw_1], image_sources, point_sources)) sample_source = points_and_image_sources + gp.MergeProvider() data = daisy.open_ds(dataset.filename, ds_key) source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape()) sample_source += gp.Crop(raw_0, source_roi) sample_source += gp.Crop(raw_1, source_roi) sample_source += gp.Pad(raw_0, padding) sample_source += gp.Pad(raw_1, padding) sample_source += gp.RandomLocation() sources.append(sample_source) sources = tuple(sources) pipeline = sources + gp.RandomProvider() pipeline += gp.Unsqueeze([raw_0, raw_1]) pipeline += PrepareBatch(raw_0, raw_1, points_0, points_1, locations_0, locations_1) # How does prepare batch relate to Stack????? pipeline += RejectArray(ensure_nonempty=locations_1) pipeline += RejectArray(ensure_nonempty=locations_0) # batch content # raw_0: (1, h, w) # raw_1: (1, h, w) # locations_0: (n, 2) # locations_1: (n, 2) pipeline += gp.Stack(batch_size) # batch content # raw_0: (b, 1, h, w) # raw_1: (b, 1, h, w) # locations_0: (b, n, 2) # locations_1: (b, n, 2) pipeline += gp.PreCache(num_workers=10) pipeline += gp.torch.Train( model, loss, optimizer, inputs={ 'raw_0': raw_0, 'raw_1': raw_1 }, loss_inputs={ 'emb_0': emb_0, 'emb_1': emb_1, 'locations_0': locations_0, 'locations_1': locations_1 }, outputs={ 2: emb_0, 3: emb_1 }, array_specs={ emb_0: gp.ArraySpec(voxel_size=emb_voxel_size), emb_1: gp.ArraySpec(voxel_size=emb_voxel_size) }, checkpoint_basename=os.path.join(out_dir, 'model'), save_every=checkpoint_interval) pipeline += gp.Snapshot( { raw_0: 'raw_0', raw_1: 'raw_1', emb_0: 'emb_0', emb_1: 'emb_1', # locations_0 : 'locations_0', # locations_1 : 'locations_1', }, every=snapshot_interval, additional_request=snapshot_request) return pipeline, request
def create_train_pipeline(self, model): print(f"Creating training pipeline with batch size \ {self.params['batch_size']}") filename = self.params['data_file'] raw_dataset = self.params['dataset']['train']['raw'] gt_dataset = self.params['dataset']['train']['gt'] optimizer = self.params['optimizer'](model.parameters(), **self.params['optimizer_kwargs']) raw = gp.ArrayKey('RAW') gt_labels = gp.ArrayKey('LABELS') points = gp.GraphKey("POINTS") locations = gp.ArrayKey("LOCATIONS") predictions = gp.ArrayKey('PREDICTIONS') emb = gp.ArrayKey('EMBEDDING') raw_data = daisy.open_ds(filename, raw_dataset) source_roi = gp.Roi(raw_data.roi.get_offset(), raw_data.roi.get_shape()) source_voxel_size = gp.Coordinate(raw_data.voxel_size) out_voxel_size = gp.Coordinate(raw_data.voxel_size) # Get in and out shape in_shape = gp.Coordinate(model.in_shape) out_roi = gp.Coordinate(model.base_encoder.out_shape[2:]) is_2d = in_shape.dims() == 2 in_shape = in_shape * out_voxel_size out_roi = out_roi * out_voxel_size out_shape = gp.Coordinate( (self.params["num_points"], *model.out_shape[2:])) context = (in_shape - out_roi) / 2 gt_labels_out_shape = out_roi # Add fake 3rd dim if is_2d: source_voxel_size = gp.Coordinate((1, *source_voxel_size)) source_roi = gp.Roi((0, *source_roi.get_offset()), (raw_data.shape[0], *source_roi.get_shape())) context = gp.Coordinate((0, *context)) gt_labels_out_shape = (1, *gt_labels_out_shape) points_roi = out_voxel_size * tuple((*self.params["point_roi"], )) points_pad = (0, *points_roi) context = gp.Coordinate((0, None, None)) else: points_roi = source_voxel_size * tuple(self.params["point_roi"]) points_pad = points_roi context = gp.Coordinate((None, None, None)) logger.info(f"source roi: {source_roi}") logger.info(f"in_shape: {in_shape}") logger.info(f"out_shape: {out_shape}") logger.info(f"voxel_size: {out_voxel_size}") logger.info(f"context: {context}") logger.info(f"out_voxel_size: {out_voxel_size}") request = gp.BatchRequest() request.add(raw, in_shape) request.add(points, points_roi) request.add(gt_labels, out_roi) request[locations] = gp.ArraySpec(nonspatial=True) request[predictions] = gp.ArraySpec(nonspatial=True) snapshot_request = gp.BatchRequest() snapshot_request[emb] = gp.ArraySpec( roi=gp.Roi((0, ) * in_shape.dims(), gp.Coordinate((*model.base_encoder.out_shape[2:], )) * out_voxel_size)) source = ( (gp.ZarrSource(filename, { raw: raw_dataset, gt_labels: gt_dataset }, array_specs={ raw: gp.ArraySpec(roi=source_roi, voxel_size=source_voxel_size, interpolatable=True), gt_labels: gp.ArraySpec(roi=source_roi, voxel_size=source_voxel_size) }), PointsLabelsSource(points, self.data, scale=source_voxel_size)) + gp.MergeProvider() + gp.Pad(raw, context) + gp.Pad(gt_labels, context) + gp.Pad(points, points_pad) + gp.RandomLocation(ensure_nonempty=points) + gp.Normalize(raw, self.params['norm_factor']) # raw : (source_roi) # gt_labels: (source_roi) # points : (c=1, source_locations_shape) # If 2d then source_roi = (1, input_shape) in order to select a RL ) source = self._augmentation_pipeline(raw, source) pipeline = ( source + # Batches seem to be rejected because points are chosen near the # edge of the points ROI and the augmentations remove them. # TODO: Figure out if this is an actual issue, and if anything can # be done. gp.Reject(ensure_nonempty=points) + SetDtype(gt_labels, np.int64) + # raw : (source_roi) # gt_labels: (source_roi) # points : (c=1, source_locations_shape) AddChannelDim(raw) + AddChannelDim(gt_labels) # raw : (c=1, source_roi) # gt_labels: (c=2, source_roi) # points : (c=1, source_locations_shape) ) if is_2d: pipeline = ( # Remove extra dim the 2d roi had pipeline + RemoveSpatialDim(raw) + RemoveSpatialDim(gt_labels) + RemoveSpatialDim(points) # raw : (c=1, roi) # gt_labels: (c=1, roi) # points : (c=1, locations_shape) ) pipeline = ( pipeline + FillLocations(raw, points, locations, is_2d=False, max_points=1) + gp.Stack(self.params['batch_size']) + gp.PreCache() + # raw : (b, c=1, roi) # gt_labels: (b, c=1, roi) # locations: (b, c=1, locations_shape) # (which is what train requires) gp.torch.Train( model, self.loss, optimizer, inputs={ 'raw': raw, 'points': locations }, loss_inputs={ 0: predictions, 1: gt_labels, 2: locations }, outputs={ 0: predictions, 1: emb }, array_specs={ predictions: gp.ArraySpec(nonspatial=True), emb: gp.ArraySpec(voxel_size=out_voxel_size) }, checkpoint_basename=self.logdir + '/checkpoints/model', save_every=self.params['save_every'], log_dir=self.logdir, log_every=self.log_every) + # everything is 2D at this point, plus extra dimensions for # channels and batch # raw : (b, c=1, roi) # gt_labels : (b, c=1, roi) # predictions: (b, num_points) gp.Snapshot(output_dir=self.logdir + '/snapshots', output_filename='it{iteration}.hdf', dataset_names={ raw: 'raw', gt_labels: 'gt_labels', predictions: 'predictions', emb: 'emb' }, additional_request=snapshot_request, every=self.params['save_every']) + InspectBatch('END') + gp.PrintProfilingStats(every=500)) return pipeline, request
def train_until(**kwargs): print("cuda visibile devices", os.environ["CUDA_VISIBLE_DEVICES"]) if tf.train.latest_checkpoint(kwargs['output_folder']): trained_until = int( tf.train.latest_checkpoint(kwargs['output_folder']).split('_')[-1]) else: trained_until = 0 if trained_until >= kwargs['max_iteration']: return raw = gp.ArrayKey('RAW') raw_cropped = gp.ArrayKey('RAW_CROPPED') gt_labels = gp.ArrayKey('GT_LABELS') gt_instances = gp.ArrayKey('GT_INSTANCES') gt_affs = gp.ArrayKey('GT_AFFS') gt_numinst = gp.ArrayKey('GT_NUMINST') gt_sample_mask = gp.ArrayKey('GT_SAMPLE_MASK') pred_affs = gp.ArrayKey('PRED_AFFS') pred_affs_gradients = gp.ArrayKey('PRED_AFFS_GRADIENTS') pred_numinst = gp.ArrayKey('PRED_NUMINST') with open(os.path.join(kwargs['output_folder'], kwargs['name'] + '_config.json'), 'r') as f: net_config = json.load(f) with open(os.path.join(kwargs['output_folder'], kwargs['name'] + '_names.json'), 'r') as f: net_names = json.load(f) voxel_size = gp.Coordinate(kwargs['voxel_size']) input_shape_world = gp.Coordinate(net_config['input_shape'])*voxel_size output_shape_world = gp.Coordinate(net_config['output_shape'])*voxel_size context = gp.Coordinate(input_shape_world - output_shape_world) / 2 # formulate the request for what a batch should (at least) contain request = gp.BatchRequest() request.add(raw, input_shape_world) request.add(raw_cropped, output_shape_world) request.add(gt_labels, output_shape_world) request.add(gt_instances, output_shape_world) request.add(gt_sample_mask, output_shape_world) request.add(gt_affs, output_shape_world) if kwargs['overlapping_inst']: request.add(gt_numinst, output_shape_world) # request.add(loss_weights_affs, output_shape_world) # when we make a snapshot for inspection (see below), we also want to # request the predicted affinities and gradients of the loss wrt the # affinities snapshot_request = gp.BatchRequest() snapshot_request.add(raw_cropped, output_shape_world) snapshot_request.add(pred_affs, output_shape_world) if kwargs['overlapping_inst']: snapshot_request.add(pred_numinst, output_shape_world) # snapshot_request.add(pred_affs_gradients, output_shape_world) if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr": raise NotImplementedError("train node for %s not implemented yet", kwargs['input_format']) raw_key = kwargs.get('raw_key', 'volumes/raw') print('raw key: ', raw_key) fls = [] shapes = [] for f in kwargs['data_files']: fls.append(os.path.splitext(f)[0]) if kwargs['input_format'] == "hdf": vol = h5py.File(f, 'r')[raw_key] elif kwargs['input_format'] == "zarr": vol = zarr.open(f, 'r')[raw_key] # print(f, vol.shape, vol.dtype) shapes.append(vol.shape) if vol.dtype != np.float32: print("please convert to float32") ln = len(fls) print("first 5 files: ", fls[0:4]) if kwargs['input_format'] == "hdf": sourceNode = gp.Hdf5Source elif kwargs['input_format'] == "zarr": sourceNode = gp.ZarrSource neighborhood = [] psH = np.array(kwargs['patchshape'])//2 for i in range(-psH[1], psH[1]+1, kwargs['patchstride'][1]): for j in range(-psH[2], psH[2]+1, kwargs['patchstride'][2]): neighborhood.append([i,j]) datasets = { raw: raw_key, gt_labels: 'volumes/gt_labels', gt_instances: 'volumes/gt_instances' } array_specs = { raw: gp.ArraySpec(interpolatable=True), gt_labels: gp.ArraySpec(interpolatable=False), gt_instances: gp.ArraySpec(interpolatable=False) } inputs = { net_names['raw']: raw, net_names['gt_affs']: gt_affs, # net_names['loss_weights_affs']: loss_weights_affs, } outputs = { net_names['pred_affs']: pred_affs, net_names['raw_cropped']: raw_cropped, } snapshot = { raw: '/volumes/raw', raw_cropped: 'volumes/raw_cropped', gt_affs: '/volumes/gt_affs', pred_affs: '/volumes/pred_affs', pred_affs_gradients: '/volumes/pred_affs_gradients', } if kwargs['overlapping_inst']: datasets[gt_numinst] = 'volumes/gt_numinst' array_specs[gt_numinst] = gp.ArraySpec(interpolatable=False) inputs[net_names['gt_numinst']] = gt_numinst outputs[net_names['pred_numinst']] = pred_numinst snapshot[gt_numinst] = '/volumes/gt_numinst' snapshot[pred_numinst] = '/volumes/pred_numinst' augmentation = kwargs['augmentation'] sampling = kwargs['sampling'] source_fg = tuple( sourceNode( fls[t] + "." + kwargs['input_format'], datasets=datasets, array_specs=array_specs ) + gp.Pad(raw, context) + # chose a random location for each requested batch nl.CountOverlap(gt_labels, gt_sample_mask, maxnuminst=1) + gp.RandomLocation( min_masked=sampling['min_masked'], mask=gt_sample_mask ) for t in range(ln) ) source_fg += gp.RandomProvider() source_overlap = tuple( sourceNode( fls[t] + "." + kwargs['input_format'], datasets=datasets, array_specs=array_specs ) + gp.Pad(raw, context) + # chose a random location for each requested batch nl.MaskCloseDistanceToOverlap( gt_labels, gt_sample_mask, sampling['overlap_min_dist'], sampling['overlap_max_dist'] ) + gp.RandomLocation( min_masked=sampling['min_masked_overlap'], mask=gt_sample_mask ) for t in range(ln) ) source_overlap += gp.RandomProvider() pipeline = ( (source_fg, source_overlap) + # chose a random source (i.e., sample) from the above gp.RandomProvider(probabilities=[sampling['probability_fg'], sampling['probability_overlap']]) + # elastically deform the batch gp.ElasticAugment( augmentation['elastic']['control_point_spacing'], augmentation['elastic']['jitter_sigma'], [augmentation['elastic']['rotation_min']*np.pi/180.0, augmentation['elastic']['rotation_max']*np.pi/180.0]) + # apply transpose and mirror augmentations gp.SimpleAugment( mirror_only=augmentation['simple'].get("mirror"), transpose_only=augmentation['simple'].get("transpose")) + # # scale and shift the intensity of the raw array gp.IntensityAugment( raw, scale_min=augmentation['intensity']['scale'][0], scale_max=augmentation['intensity']['scale'][1], shift_min=augmentation['intensity']['shift'][0], shift_max=augmentation['intensity']['shift'][1], z_section_wise=False) + gp.IntensityScaleShift(raw, 2, -1) + # convert labels into affinities between voxels nl.AddAffinities( neighborhood, gt_labels, gt_affs, multiple_labels=kwargs['overlapping_inst']) + # pre-cache batches from the point upstream gp.PreCache( cache_size=kwargs['cache_size'], num_workers=kwargs['num_workers']) + # perform one training iteration for each passing batch (here we use # the tensor names earlier stored in train_net.config) gp.tensorflow.Train( os.path.join(kwargs['output_folder'], kwargs['name']), optimizer=net_names['optimizer'], summary=net_names['summaries'], log_dir=kwargs['output_folder'], loss=net_names['loss'], inputs=inputs, outputs=outputs, gradients={ net_names['pred_affs']: pred_affs_gradients, }, save_every=kwargs['checkpoints']) + # save the passing batch as an HDF5 file for inspection gp.Snapshot( snapshot, output_dir=os.path.join(kwargs['output_folder'], 'snapshots'), output_filename='batch_{iteration}.hdf', every=kwargs['snapshots'], additional_request=snapshot_request, compression_type='gzip') + # show a summary of time spend in each node every 10 iterations gp.PrintProfilingStats(every=kwargs['profiling']) ) ######### # TRAIN # ######### print("Starting training...") with gp.build(pipeline): print(pipeline) for i in range(trained_until, kwargs['max_iteration']): # print("request", request) start = time.time() pipeline.request_batch(request) time_of_iteration = time.time() - start logger.info( "Batch: iteration=%d, time=%f", i, time_of_iteration) # exit() print("Training finished")
def make_pipeline(self): raw = gp.ArrayKey('RAW') pred_affs = gp.ArrayKey('PREDICTIONS') source_shape = zarr.open(self.data_file)[self.dataset].shape raw_roi = gp.Roi(np.zeros(len(source_shape[1:])), source_shape[1:]) data = daisy.open_ds(self.data_file, self.dataset) source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape()) voxel_size = gp.Coordinate(data.voxel_size) # Get in and out shape in_shape = gp.Coordinate(self.model.in_shape) out_shape = gp.Coordinate(self.model.out_shape[2:]) is_2d = in_shape.dims() == 2 in_shape = in_shape * voxel_size out_shape = out_shape * voxel_size logger.info(f"source roi: {source_roi}") logger.info(f"in_shape: {in_shape}") logger.info(f"out_shape: {out_shape}") logger.info(f"voxel_size: {voxel_size}") request = gp.BatchRequest() request.add(raw, in_shape) request.add(pred_affs, out_shape) context = (in_shape - out_shape) / 2 source = (gp.ZarrSource(self.data_file, { raw: self.dataset, }, array_specs={ raw: gp.ArraySpec(roi=source_roi, interpolatable=False) })) in_dims = len(self.model.in_shape) if is_2d: # 2D: [samples, y, x] or [samples, channels, y, x] needs_channel_fix = (len(data.shape) - in_dims == 1) if needs_channel_fix: source = (source + AddChannelDim(raw, axis=1)) # raw [samples, channels, y, x] else: # 3D: [z, y, x] or [channel, z, y, x] or [sample, channel, z, y, x] needs_channel_fix = (len(data.shape) - in_dims == 0) needs_batch_fix = (len(data.shape) - in_dims <= 1) if needs_channel_fix: source = (source + AddChannelDim(raw, axis=0)) # Batch fix if needs_batch_fix: source = (source + AddChannelDim(raw)) # raw: [sample, channels, z, y, x] with gp.build(source): raw_roi = source.spec[raw].roi logger.info(f"raw_roi: {raw_roi}") pipeline = (source + gp.Normalize(raw, factor=self.params['norm_factor']) + gp.Pad(raw, context) + gp.PreCache() + gp.torch.Predict( self.model, inputs={'raw': raw}, outputs={0: pred_affs}, array_specs={pred_affs: gp.ArraySpec(roi=raw_roi)})) pipeline = (pipeline + gp.ZarrWrite({ pred_affs: 'predictions', }, output_dir=self.curr_log_dir, output_filename='predictions.zarr', compression_type='gzip') + gp.Scan(request)) return pipeline, request, pred_affs
def predict(data_dir, train_dir, iteration, sample, test_net_name='train_net', train_net_name='train_net', output_dir='.', clip_max=1000): if "hdf" not in data_dir: return print("Predicting ", sample) print( 'checkpoint: ', os.path.join(train_dir, train_net_name + '_checkpoint_%d' % iteration)) checkpoint = os.path.join(train_dir, train_net_name + '_checkpoint_%d' % iteration) with open(os.path.join(train_dir, test_net_name + '_config.json'), 'r') as f: net_config = json.load(f) with open(os.path.join(train_dir, test_net_name + '_names.json'), 'r') as f: net_names = json.load(f) # ArrayKeys raw = gp.ArrayKey('RAW') pred_mask = gp.ArrayKey('PRED_MASK') input_shape = gp.Coordinate(net_config['input_shape']) output_shape = gp.Coordinate(net_config['output_shape']) voxel_size = gp.Coordinate((1, 1, 1)) context = gp.Coordinate(input_shape - output_shape) / 2 # add ArrayKeys to batch request request = gp.BatchRequest() request.add(raw, input_shape, voxel_size=voxel_size) request.add(pred_mask, output_shape, voxel_size=voxel_size) print("chunk request %s" % request) source = (gp.Hdf5Source( data_dir, datasets={ raw: sample + '/raw', }, array_specs={ raw: gp.ArraySpec( interpolatable=True, dtype=np.uint16, voxel_size=voxel_size), }, ) + gp.Pad(raw, context) + nl.Clip(raw, 0, clip_max) + gp.Normalize(raw, factor=1.0 / clip_max) + gp.IntensityScaleShift(raw, 2, -1)) with gp.build(source): raw_roi = source.spec[raw].roi print("raw_roi: %s" % raw_roi) sample_shape = raw_roi.grow(-context, -context).get_shape() print(sample_shape) # create zarr file with corresponding chunk size zf = zarr.open(os.path.join(output_dir, sample + '.zarr'), mode='w') zf.create('volumes/pred_mask', shape=sample_shape, chunks=output_shape, dtype=np.float16) zf['volumes/pred_mask'].attrs['offset'] = [0, 0, 0] zf['volumes/pred_mask'].attrs['resolution'] = [1, 1, 1] pipeline = ( source + gp.tensorflow.Predict( graph=os.path.join(train_dir, test_net_name + '.meta'), checkpoint=checkpoint, inputs={ net_names['raw']: raw, }, outputs={ net_names['pred']: pred_mask, }, array_specs={ pred_mask: gp.ArraySpec(roi=raw_roi.grow(-context, -context), voxel_size=voxel_size), }, max_shared_memory=1024 * 1024 * 1024) + Convert(pred_mask, np.float16) + gp.ZarrWrite( dataset_names={ pred_mask: 'volumes/pred_mask', }, output_dir=output_dir, output_filename=sample + '.zarr', compression_type='gzip', dataset_dtypes={pred_mask: np.float16}) + # show a summary of time spend in each node every x iterations gp.PrintProfilingStats(every=100) + gp.Scan(reference=request, num_workers=5, cache_size=50)) with gp.build(pipeline): pipeline.request_batch(gp.BatchRequest())
def predict_3d(raw_data, gt_data, predictor): raw_channels = max(1, raw_data.num_channels) input_shape = predictor.input_shape output_shape = predictor.output_shape voxel_size = raw_data.voxel_size # switch to world units input_size = voxel_size * input_shape output_size = voxel_size * output_shape raw = gp.ArrayKey('RAW') gt = gp.ArrayKey('GT') target = gp.ArrayKey('TARGET') prediction = gp.ArrayKey('PREDICTION') channel_dims = 0 if raw_channels == 1 else 1 num_samples = raw_data.num_samples assert num_samples == 0, ( "Multiple samples for 3D validation not yet implemented") scan_request = gp.BatchRequest() scan_request.add(raw, input_size) scan_request.add(prediction, output_size) if gt_data: scan_request.add(gt, output_size) scan_request.add(target, output_size) if gt_data: sources = (raw_data.get_source(raw), gt_data.get_source(gt)) pipeline = sources + gp.MergeProvider() else: pipeline = raw_data.get_source(raw) pipeline += gp.Pad(raw, None) if gt_data: pipeline += gp.Pad(gt, None) # raw: ([c,] d, h, w) # gt: ([c,] d, h, w) pipeline += gp.Normalize(raw) # raw: ([c,] d, h, w) # gt: ([c,] d, h, w) if gt_data: pipeline += predictor.add_target(gt, target) # raw: ([c,] d, h, w) # gt: ([c,] d, h, w) # target: ([c,] d, h, w) if channel_dims == 0: pipeline += AddChannelDim(raw) # raw: (c, d, h, w) # gt: ([c,] d, h, w) # target: ([c,] d, h, w) # add a "batch" dimension pipeline += AddChannelDim(raw) # raw: (1, c, d, h, w) # gt: ([c,] d, h, w) # target: ([c,] d, h, w) pipeline += gp_torch.Predict(model=predictor, inputs={'x': raw}, outputs={0: prediction}) # remove "batch" dimension pipeline += RemoveChannelDim(raw) pipeline += RemoveChannelDim(prediction) # raw: (c, d, h, w) # gt: ([c,] d, h, w) # target: ([c,] d, h, w) # prediction: ([c,] d, h, w) if channel_dims == 0: pipeline += RemoveChannelDim(raw) # raw: ([c,] d, h, w) # gt: ([c,] d, h, w) # target: ([c,] d, h, w) # prediction: ([c,] d, h, w) pipeline += gp.Scan(scan_request) # ensure validation ROI is at least the size of the network input roi = raw_data.roi.grow(input_size / 2, input_size / 2) total_request = gp.BatchRequest() total_request[raw] = gp.ArraySpec(roi=roi) total_request[prediction] = gp.ArraySpec(roi=roi) if gt_data: total_request[gt] = gp.ArraySpec(roi=roi) total_request[target] = gp.ArraySpec(roi=roi) with gp.build(pipeline): batch = pipeline.request_batch(total_request) ret = {'raw': batch[raw], 'prediction': batch[prediction]} if gt_data: ret.update({'gt': batch[gt], 'target': batch[target]}) return ret
def predict(**kwargs): name = kwargs['name'] raw = gp.ArrayKey('RAW') pred_affs = gp.ArrayKey('PRED_AFFS') pred_numinst = gp.ArrayKey('PRED_NUMINST') with open(os.path.join(kwargs['input_folder'], name + '_config.json'), 'r') as f: net_config = json.load(f) with open(os.path.join(kwargs['input_folder'], name + '_names.json'), 'r') as f: net_names = json.load(f) voxel_size = gp.Coordinate(kwargs['voxel_size']) input_shape_world = gp.Coordinate(net_config['input_shape']) * voxel_size output_shape_world = gp.Coordinate(net_config['output_shape']) * voxel_size context = (input_shape_world - output_shape_world) // 2 chunksize = list(np.asarray(output_shape_world) // 2) raw_key = kwargs.get('raw_key', 'volumes/raw') # add ArrayKeys to batch request request = gp.BatchRequest() request.add(raw, input_shape_world, voxel_size=voxel_size) request.add(pred_affs, output_shape_world, voxel_size=voxel_size) if kwargs['overlapping_inst']: request.add(pred_numinst, output_shape_world, voxel_size=voxel_size) if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr": raise NotImplementedError("predict node for %s not implemented yet", kwargs['input_format']) if kwargs['input_format'] == "hdf": sourceNode = gp.Hdf5Source with h5py.File( os.path.join(kwargs['data_folder'], kwargs['sample'] + ".hdf"), 'r') as f: shape = f[raw_key].shape[1:] elif kwargs['input_format'] == "zarr": sourceNode = gp.ZarrSource f = zarr.open( os.path.join(kwargs['data_folder'], kwargs['sample'] + ".zarr"), 'r') shape = f[raw_key].shape[1:] source = sourceNode(os.path.join( kwargs['data_folder'], kwargs['sample'] + "." + kwargs['input_format']), datasets={raw: raw_key}) if kwargs['output_format'] != "zarr": raise NotImplementedError("Please use zarr as prediction output") # open zarr file zf = zarr.open(os.path.join(kwargs['output_folder'], kwargs['sample'] + '.zarr'), mode='w') zf.create('volumes/pred_affs', shape=[int(np.prod(kwargs['patchshape']))] + list(shape), chunks=[int(np.prod(kwargs['patchshape']))] + list(chunksize), dtype=np.float16) zf['volumes/pred_affs'].attrs['offset'] = [0, 0] zf['volumes/pred_affs'].attrs['resolution'] = kwargs['voxel_size'] if kwargs['overlapping_inst']: zf.create('volumes/pred_numinst', shape=[int(kwargs['max_num_inst']) + 1] + list(shape), chunks=[int(kwargs['max_num_inst']) + 1] + list(chunksize), dtype=np.float16) zf['volumes/pred_numinst'].attrs['offset'] = [0, 0] zf['volumes/pred_numinst'].attrs['resolution'] = kwargs['voxel_size'] outputs = { net_names['pred_affs']: pred_affs, } outVolumes = { pred_affs: '/volumes/pred_affs', } if kwargs['overlapping_inst']: outputs[net_names['pred_numinst']] = pred_numinst outVolumes[pred_numinst] = '/volumes/pred_numinst' pipeline = ( source + gp.Pad(raw, context) + gp.IntensityScaleShift(raw, 2, -1) + # perform one training iteration for each passing batch (here we use # the tensor names earlier stored in train_net.config) gp.tensorflow.Predict(graph=os.path.join(kwargs['input_folder'], name + '.meta'), checkpoint=kwargs['checkpoint'], inputs={net_names['raw']: raw}, outputs=outputs) + # store all passing batches in the same HDF5 file gp.ZarrWrite(outVolumes, output_dir=kwargs['output_folder'], output_filename=kwargs['sample'] + ".zarr", compression_type='gzip') + # show a summary of time spend in each node every 10 iterations gp.PrintProfilingStats(every=100) + # iterate over the whole dataset in a scanning fashion, emitting # requests that match the size of the network gp.Scan(reference=request)) with gp.build(pipeline): # request an empty batch from Scan to trigger scanning of the dataset # without keeping the complete dataset in memory pipeline.request_batch(gp.BatchRequest())
def train_until(max_iteration, name='train_net', output_folder='.', clip_max=2000): # get the latest checkpoint if tf.train.latest_checkpoint(output_folder): trained_until = int(tf.train.latest_checkpoint(output_folder).split('_')[-1]) else: trained_until = 0 if trained_until >= max_iteration: return with open(os.path.join(output_folder, name + '_config.json'), 'r') as f: net_config = json.load(f) with open(os.path.join(output_folder, name + '_names.json'), 'r') as f: net_names = json.load(f) # array keys raw = gp.ArrayKey('RAW') gt_mask = gp.ArrayKey('GT_MASK') gt_dt = gp.ArrayKey('GT_DT') pred_dt = gp.ArrayKey('PRED_DT') loss_gradient = gp.ArrayKey('LOSS_GRADIENT') voxel_size = gp.Coordinate((1, 1, 1)) input_shape = gp.Coordinate(net_config['input_shape']) output_shape = gp.Coordinate(net_config['output_shape']) context = gp.Coordinate(input_shape - output_shape) / 2 request = gp.BatchRequest() request.add(raw, input_shape) request.add(gt_mask, output_shape) request.add(gt_dt, output_shape) snapshot_request = gp.BatchRequest() snapshot_request.add(raw, input_shape) snapshot_request.add(gt_mask, output_shape) snapshot_request.add(gt_dt, output_shape) snapshot_request.add(pred_dt, output_shape) snapshot_request.add(loss_gradient, output_shape) # specify data source data_sources = tuple() for data_file in data_files: current_path = os.path.join(data_dir, data_file) with h5py.File(current_path, 'r') as f: data_sources += tuple( gp.Hdf5Source( current_path, datasets={ raw: sample + '/raw', gt_mask: sample + '/fg' }, array_specs={ raw: gp.ArraySpec(interpolatable=True, dtype=np.uint16, voxel_size=voxel_size), gt_mask: gp.ArraySpec(interpolatable=False, dtype=np.bool, voxel_size=voxel_size), } ) + Convert(gt_mask, np.uint8) + gp.Pad(raw, context) + gp.Pad(gt_mask, context) + gp.RandomLocation() for sample in f) pipeline = ( data_sources + gp.RandomProvider() + gp.Reject(gt_mask, min_masked=0.005, reject_probability=1.) + DistanceTransform(gt_mask, gt_dt, 3) + nl.Clip(raw, 0, clip_max) + gp.Normalize(raw, factor=1.0/clip_max) + gp.ElasticAugment( control_point_spacing=[20, 20, 20], jitter_sigma=[1, 1, 1], rotation_interval=[0, math.pi/2.0], subsample=4) + gp.SimpleAugment(mirror_only=[1,2], transpose_only=[1,2]) + gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1) + gp.IntensityScaleShift(raw, 2,-1) + # train gp.PreCache( cache_size=40, num_workers=5) + gp.tensorflow.Train( os.path.join(output_folder, name), optimizer=net_names['optimizer'], loss=net_names['loss'], inputs={ net_names['raw']: raw, net_names['gt_dt']: gt_dt, }, outputs={ net_names['pred_dt']: pred_dt, }, gradients={ net_names['pred_dt']: loss_gradient, }, save_every=5000) + # visualize gp.Snapshot({ raw: 'volumes/raw', gt_mask: 'volumes/gt_mask', gt_dt: 'volumes/gt_dt', pred_dt: 'volumes/pred_dt', loss_gradient: 'volumes/gradient', }, output_filename=os.path.join(output_folder, 'snapshots', 'batch_{iteration}.hdf'), additional_request=snapshot_request, every=2000) + gp.PrintProfilingStats(every=500) ) with gp.build(pipeline): print("Starting training...") for i in range(max_iteration - trained_until): pipeline.request_batch(request)
def train_until(**kwargs): if tf.train.latest_checkpoint(kwargs['output_folder']): trained_until = int( tf.train.latest_checkpoint(kwargs['output_folder']).split('_')[-1]) else: trained_until = 0 if trained_until >= kwargs['max_iteration']: return anchor = gp.ArrayKey('ANCHOR') raw = gp.ArrayKey('RAW') raw_cropped = gp.ArrayKey('RAW_CROPPED') points = gp.PointsKey('POINTS') gt_cp = gp.ArrayKey('GT_CP') pred_cp = gp.ArrayKey('PRED_CP') pred_cp_gradients = gp.ArrayKey('PRED_CP_GRADIENTS') with open( os.path.join(kwargs['output_folder'], kwargs['name'] + '_config.json'), 'r') as f: net_config = json.load(f) with open( os.path.join(kwargs['output_folder'], kwargs['name'] + '_names.json'), 'r') as f: net_names = json.load(f) voxel_size = gp.Coordinate(kwargs['voxel_size']) input_shape_world = gp.Coordinate(net_config['input_shape']) * voxel_size output_shape_world = gp.Coordinate(net_config['output_shape']) * voxel_size # formulate the request for what a batch should (at least) contain request = gp.BatchRequest() request.add(raw, input_shape_world) request.add(raw_cropped, output_shape_world) request.add(gt_cp, output_shape_world) request.add(anchor, output_shape_world) # when we make a snapshot for inspection (see below), we also want to # request the predicted affinities and gradients of the loss wrt the # affinities snapshot_request = gp.BatchRequest() snapshot_request.add(raw_cropped, output_shape_world) snapshot_request.add(gt_cp, output_shape_world) snapshot_request.add(pred_cp, output_shape_world) # snapshot_request.add(pred_cp_gradients, output_shape_world) if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr": raise NotImplementedError("train node for %s not implemented yet", kwargs['input_format']) fls = [] shapes = [] mn = [] mx = [] for f in kwargs['data_files']: fls.append(os.path.splitext(f)[0]) if kwargs['input_format'] == "hdf": vol = h5py.File(f, 'r')['volumes/raw'] elif kwargs['input_format'] == "zarr": vol = zarr.open(f, 'r')['volumes/raw'] print(f, vol.shape, vol.dtype) shapes.append(vol.shape) mn.append(np.min(vol)) mx.append(np.max(vol)) if vol.dtype != np.float32: print("please convert to float32") ln = len(fls) print("first 5 files: ", fls[0:4]) if kwargs['input_format'] == "hdf": sourceNode = gp.Hdf5Source elif kwargs['input_format'] == "zarr": sourceNode = gp.ZarrSource augmentation = kwargs['augmentation'] sources = tuple( (sourceNode(fls[t] + "." + kwargs['input_format'], datasets={ raw: 'volumes/raw', anchor: 'volumes/gt_fgbg', }, array_specs={ raw: gp.ArraySpec(interpolatable=True), anchor: gp.ArraySpec(interpolatable=False) }), gp.CsvIDPointsSource(fls[t] + ".csv", points, points_spec=gp.PointsSpec( roi=gp.Roi(gp.Coordinate(( 0, 0, 0)), gp.Coordinate(shapes[t]))))) + gp.MergeProvider() # + Clip(raw, mn=mn[t], mx=mx[t]) # + NormalizeMinMax(raw, mn=mn[t], mx=mx[t]) + gp.Pad(raw, None) + gp.Pad(points, None) # chose a random location for each requested batch + gp.RandomLocation() for t in range(ln)) pipeline = ( sources + # chose a random source (i.e., sample) from the above gp.RandomProvider() + # elastically deform the batch (gp.ElasticAugment( augmentation['elastic']['control_point_spacing'], augmentation['elastic']['jitter_sigma'], [augmentation['elastic']['rotation_min']*np.pi/180.0, augmentation['elastic']['rotation_max']*np.pi/180.0], subsample=augmentation['elastic'].get('subsample', 1)) \ if augmentation.get('elastic') is not None else NoOp()) + # apply transpose and mirror augmentations gp.SimpleAugment(mirror_only=augmentation['simple'].get("mirror"), transpose_only=augmentation['simple'].get("transpose")) + # (gp.SimpleAugment( # mirror_only=augmentation['simple'].get("mirror"), # transpose_only=augmentation['simple'].get("transpose")) \ # if augmentation.get('simple') is not None and \ # augmentation.get('simple') != {} else NoOp()) + # # scale and shift the intensity of the raw array (gp.IntensityAugment( raw, scale_min=augmentation['intensity']['scale'][0], scale_max=augmentation['intensity']['scale'][1], shift_min=augmentation['intensity']['shift'][0], shift_max=augmentation['intensity']['shift'][1], z_section_wise=False) \ if augmentation.get('intensity') is not None and \ augmentation.get('intensity') != {} else NoOp()) + gp.RasterizePoints( points, gt_cp, array_spec=gp.ArraySpec(voxel_size=voxel_size), settings=gp.RasterizationSettings( radius=(2, 2, 2), mode='peak')) + # pre-cache batches from the point upstream gp.PreCache( cache_size=kwargs['cache_size'], num_workers=kwargs['num_workers']) + # perform one training iteration for each passing batch (here we use # the tensor names earlier stored in train_net.config) gp.tensorflow.Train( os.path.join(kwargs['output_folder'], kwargs['name']), optimizer=net_names['optimizer'], summary=net_names['summaries'], log_dir=kwargs['output_folder'], loss=net_names['loss'], inputs={ net_names['raw']: raw, net_names['gt_cp']: gt_cp, net_names['anchor']: anchor, }, outputs={ net_names['pred_cp']: pred_cp, net_names['raw_cropped']: raw_cropped, }, gradients={ # net_names['pred_cp']: pred_cp_gradients, }, save_every=kwargs['checkpoints']) + # save the passing batch as an HDF5 file for inspection gp.Snapshot( { raw: '/volumes/raw', raw_cropped: 'volumes/raw_cropped', gt_cp: '/volumes/gt_cp', pred_cp: '/volumes/pred_cp', # pred_cp_gradients: '/volumes/pred_cp_gradients', }, output_dir=os.path.join(kwargs['output_folder'], 'snapshots'), output_filename='batch_{iteration}.hdf', every=kwargs['snapshots'], additional_request=snapshot_request, compression_type='gzip') + # show a summary of time spend in each node every 10 iterations gp.PrintProfilingStats(every=kwargs['profiling']) ) ######### # TRAIN # ######### print("Starting training...") with gp.build(pipeline): print(pipeline) for i in range(trained_until, kwargs['max_iteration']): # print("request", request) start = time.time() pipeline.request_batch(request) time_of_iteration = time.time() - start logger.info("Batch: iteration=%d, time=%f", i, time_of_iteration) # exit() print("Training finished")