def prepare(self, request): deps = gp.BatchRequest() deps[self.array] = request[self.array].copy() return deps
def prepare(self, request): deps = gp.BatchRequest() deps[self.gt] = request[self.gt_binary].copy() return deps
def prepare(self, request): deps = gp.BatchRequest() deps[self.array] = copy.deepcopy(request[self.mask]) return deps
def prepare(self, request): deps = gp.BatchRequest() deps[self.points] = request[self.points].copy() return deps
def predict_volume(model, dataset, out_dir, out_filename, out_ds_names, input_key='0/raw', normalize_factor=None, model_output=0, in_shape=None, out_shape=None, spawn_subprocess=True, num_workers=0): raw = gp.ArrayKey('RAW') prediction = gp.ArrayKey('PREDICTION') data = daisy.open_ds(dataset.filename, dataset.ds_names[0]) source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape()) voxel_size = gp.Coordinate(data.voxel_size) data_dims = len(data.shape) # Get in and out shape if in_shape is None: in_shape = model.in_shape if out_shape is None: out_shape = model.out_shape in_shape = gp.Coordinate(in_shape) out_shape = gp.Coordinate(out_shape) spatial_dims = in_shape.dims() if apply_voxel_size: in_shape = in_shape * voxel_size out_shape = out_shape * voxel_size logger.info(f"source roi: {source_roi}") logger.info(f"in_shape: {in_shape}") logger.info(f"out_shape: {out_shape}") logger.info(f"voxel_size: {voxel_size}") request = gp.BatchRequest() request.add(raw, in_shape) request.add(prediction, out_shape) context = (in_shape - out_shape) / 2 print("context", context, in_shape, out_shape) source = (gp.ZarrSource( dataset.filename, { raw: dataset.ds_names[0], }, array_specs={raw: gp.ArraySpec(roi=source_roi, interpolatable=True)})) num_additional_channels = (2 + spatial_dims) - data_dims for _ in range(num_additional_channels): source += AddChannelDim(raw) # prediction requires samples first, channels second source += TransposeDims(raw, (1, 0) + tuple(range(2, 2 + spatial_dims))) with gp.build(source): raw_roi = source.spec[raw].roi logger.info(f"raw_roi: {raw_roi}") pipeline = source if normalize_factor != "skip": pipeline = pipeline + gp.Normalize(raw, factor=normalize_factor) pipeline = pipeline + (gp.Pad(raw, context) + gp.torch.Predict( model, inputs={input_name: raw}, outputs={model_output: prediction}, array_specs={prediction: gp.ArraySpec(roi=raw_roi)}, checkpoint=checkpoint, spawn_subprocess=spawn_subprocess)) # # remove sample dimension for 3D data # pipeline += RemoveChannelDim(raw) # pipeline += RemoveChannelDim(prediction) pipeline += (gp.ZarrWrite({ prediction: out_ds_names[0], }, output_dir=out_dir, output_filename=out_filename, compression_type='gzip') + gp.Scan(request, num_workers=num_workers)) with gp.build(pipeline): pipeline.request_batch(gp.BatchRequest())
def train_until(**kwargs): if tf.train.latest_checkpoint(kwargs['output_folder']): trained_until = int( tf.train.latest_checkpoint(kwargs['output_folder']).split('_')[-1]) else: trained_until = 0 if trained_until >= kwargs['max_iteration']: return anchor = gp.ArrayKey('ANCHOR') raw = gp.ArrayKey('RAW') raw_cropped = gp.ArrayKey('RAW_CROPPED') gt_labels = gp.ArrayKey('GT_LABELS') gt_affs = gp.ArrayKey('GT_AFFS') gt_fgbg = gp.ArrayKey('GT_FGBG') # loss_weights_affs = gp.ArrayKey('LOSS_WEIGHTS_AFFS') loss_weights_fgbg = gp.ArrayKey('LOSS_WEIGHTS_FGBG') pred_affs = gp.ArrayKey('PRED_AFFS') pred_fgbg = gp.ArrayKey('PRED_FGBG') pred_affs_gradients = gp.ArrayKey('PRED_AFFS_GRADIENTS') pred_fgbg_gradients = gp.ArrayKey('PRED_FGBG_GRADIENTS') with open( os.path.join(kwargs['output_folder'], kwargs['name'] + '_config.json'), 'r') as f: net_config = json.load(f) with open( os.path.join(kwargs['output_folder'], kwargs['name'] + '_names.json'), 'r') as f: net_names = json.load(f) voxel_size = gp.Coordinate(kwargs['voxel_size']) input_shape_world = gp.Coordinate(net_config['input_shape']) * voxel_size output_shape_world = gp.Coordinate(net_config['output_shape']) * voxel_size # formulate the request for what a batch should (at least) contain request = gp.BatchRequest() request.add(raw, input_shape_world) request.add(raw_cropped, output_shape_world) request.add(gt_labels, output_shape_world) request.add(gt_fgbg, output_shape_world) request.add(anchor, output_shape_world) request.add(gt_affs, output_shape_world) # request.add(loss_weights_affs, output_shape_world) request.add(loss_weights_fgbg, output_shape_world) # when we make a snapshot for inspection (see below), we also want to # request the predicted affinities and gradients of the loss wrt the # affinities snapshot_request = gp.BatchRequest() snapshot_request.add(raw_cropped, output_shape_world) snapshot_request.add(pred_affs, output_shape_world) # snapshot_request.add(pred_affs_gradients, output_shape_world) snapshot_request.add(gt_fgbg, output_shape_world) snapshot_request.add(pred_fgbg, output_shape_world) # snapshot_request.add(pred_fgbg_gradients, output_shape_world) if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr": raise NotImplementedError("train node for %s not implemented yet", kwargs['input_format']) fls = [] shapes = [] for f in kwargs['data_files']: fls.append(os.path.splitext(f)[0]) if kwargs['input_format'] == "hdf": vol = h5py.File(f, 'r')['volumes/raw'] elif kwargs['input_format'] == "zarr": vol = zarr.open(f, 'r')['volumes/raw'] print(f, vol.shape, vol.dtype) shapes.append(vol.shape) if vol.dtype != np.float32: print("please convert to float32") ln = len(fls) print("first 5 files: ", fls[0:4]) # padR = 46 # padGT = 32 if kwargs['input_format'] == "hdf": sourceNode = gp.Hdf5Source elif kwargs['input_format'] == "zarr": sourceNode = gp.ZarrSource augmentation = kwargs['augmentation'] pipeline = ( tuple( sourceNode( fls[t] + "." + kwargs['input_format'], datasets={ raw: 'volumes/raw', gt_labels: 'volumes/gt_labels', gt_fgbg: 'volumes/gt_fgbg', anchor: 'volumes/gt_fgbg', }, array_specs={ raw: gp.ArraySpec(interpolatable=True), gt_labels: gp.ArraySpec(interpolatable=False), gt_fgbg: gp.ArraySpec(interpolatable=False), anchor: gp.ArraySpec(interpolatable=False) } ) + gp.Pad(raw, None) + gp.Pad(gt_labels, None) + gp.Pad(gt_fgbg, None) # chose a random location for each requested batch + gp.RandomLocation() for t in range(ln) ) + # chose a random source (i.e., sample) from the above gp.RandomProvider() + # elastically deform the batch (gp.ElasticAugment( augmentation['elastic']['control_point_spacing'], augmentation['elastic']['jitter_sigma'], [augmentation['elastic']['rotation_min']*np.pi/180.0, augmentation['elastic']['rotation_max']*np.pi/180.0], subsample=augmentation['elastic'].get('subsample', 1)) \ if augmentation.get('elastic') is not None else NoOp()) + # apply transpose and mirror augmentations gp.SimpleAugment(mirror_only=augmentation['simple'].get("mirror"), transpose_only=augmentation['simple'].get("transpose")) + # # scale and shift the intensity of the raw array gp.IntensityAugment( raw, scale_min=augmentation['intensity']['scale'][0], scale_max=augmentation['intensity']['scale'][1], shift_min=augmentation['intensity']['shift'][0], shift_max=augmentation['intensity']['shift'][1], z_section_wise=False) + # grow a boundary between labels gp.GrowBoundary( gt_labels, steps=1, only_xy=False) + # convert labels into affinities between voxels gp.AddAffinities( [[-1, 0, 0], [0, -1, 0], [0, 0, -1]], gt_labels, gt_affs) + # create a weight array that balances positive and negative samples in # the affinity array # gp.BalanceLabels( # gt_affs, # loss_weights_affs) + gp.BalanceLabels( gt_fgbg, loss_weights_fgbg) + # pre-cache batches from the point upstream gp.PreCache( cache_size=kwargs['cache_size'], num_workers=kwargs['num_workers']) + # perform one training iteration for each passing batch (here we use # the tensor names earlier stored in train_net.config) gp.tensorflow.Train( os.path.join(kwargs['output_folder'], kwargs['name']), optimizer=net_names['optimizer'], summary=net_names['summaries'], log_dir=kwargs['output_folder'], loss=net_names['loss'], inputs={ net_names['raw']: raw, net_names['gt_affs']: gt_affs, net_names['gt_fgbg']: gt_fgbg, net_names['anchor']: anchor, net_names['gt_labels']: gt_labels, # net_names['loss_weights_affs']: loss_weights_affs, net_names['loss_weights_fgbg']: loss_weights_fgbg }, outputs={ net_names['pred_affs']: pred_affs, net_names['pred_fgbg']: pred_fgbg, net_names['raw_cropped']: raw_cropped, }, gradients={ net_names['pred_affs']: pred_affs_gradients, net_names['pred_fgbg']: pred_fgbg_gradients, }, malis=True, save_every=kwargs['checkpoints']) + # save the passing batch as an HDF5 file for inspection gp.Snapshot( { raw: '/volumes/raw', raw_cropped: 'volumes/raw_cropped', gt_labels: '/volumes/gt_labels', gt_affs: '/volumes/gt_affs', gt_fgbg: '/volumes/gt_fgbg', pred_affs: '/volumes/pred_affs', pred_affs_gradients: '/volumes/pred_affs_gradients', pred_fgbg: '/volumes/pred_fgbg', pred_fgbg_gradients: '/volumes/pred_fgbg_gradients', }, output_dir=os.path.join(kwargs['output_folder'], 'snapshots'), output_filename='batch_{iteration}.hdf', every=kwargs['snapshots'], additional_request=snapshot_request, compression_type='gzip') + # show a summary of time spend in each node every 10 iterations gp.PrintProfilingStats(every=kwargs['profiling']) ) ######### # TRAIN # ######### print("Starting training...") with gp.build(pipeline): print(pipeline) for i in range(trained_until, kwargs['max_iteration']): # print("request", request) start = time.time() pipeline.request_batch(request) time_of_iteration = time.time() - start logger.info("Batch: iteration=%d, time=%f", i, time_of_iteration) # exit() print("Training finished")
def train_until(**kwargs): if tf.train.latest_checkpoint(kwargs['output_folder']): trained_until = int( tf.train.latest_checkpoint(kwargs['output_folder']).split('_')[-1]) else: trained_until = 0 if trained_until >= kwargs['max_iteration']: return anchor = gp.ArrayKey('ANCHOR') raw = gp.ArrayKey('RAW') raw_cropped = gp.ArrayKey('RAW_CROPPED') points = gp.PointsKey('POINTS') gt_cp = gp.ArrayKey('GT_CP') pred_cp = gp.ArrayKey('PRED_CP') pred_cp_gradients = gp.ArrayKey('PRED_CP_GRADIENTS') with open( os.path.join(kwargs['output_folder'], kwargs['name'] + '_config.json'), 'r') as f: net_config = json.load(f) with open( os.path.join(kwargs['output_folder'], kwargs['name'] + '_names.json'), 'r') as f: net_names = json.load(f) voxel_size = gp.Coordinate(kwargs['voxel_size']) input_shape_world = gp.Coordinate(net_config['input_shape']) * voxel_size output_shape_world = gp.Coordinate(net_config['output_shape']) * voxel_size # formulate the request for what a batch should (at least) contain request = gp.BatchRequest() request.add(raw, input_shape_world) request.add(raw_cropped, output_shape_world) request.add(gt_cp, output_shape_world) request.add(anchor, output_shape_world) # when we make a snapshot for inspection (see below), we also want to # request the predicted affinities and gradients of the loss wrt the # affinities snapshot_request = gp.BatchRequest() snapshot_request.add(raw_cropped, output_shape_world) snapshot_request.add(gt_cp, output_shape_world) snapshot_request.add(pred_cp, output_shape_world) # snapshot_request.add(pred_cp_gradients, output_shape_world) if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr": raise NotImplementedError("train node for %s not implemented yet", kwargs['input_format']) fls = [] shapes = [] mn = [] mx = [] for f in kwargs['data_files']: fls.append(os.path.splitext(f)[0]) if kwargs['input_format'] == "hdf": vol = h5py.File(f, 'r')['volumes/raw'] elif kwargs['input_format'] == "zarr": vol = zarr.open(f, 'r')['volumes/raw'] print(f, vol.shape, vol.dtype) shapes.append(vol.shape) mn.append(np.min(vol)) mx.append(np.max(vol)) if vol.dtype != np.float32: print("please convert to float32") ln = len(fls) print("first 5 files: ", fls[0:4]) if kwargs['input_format'] == "hdf": sourceNode = gp.Hdf5Source elif kwargs['input_format'] == "zarr": sourceNode = gp.ZarrSource augmentation = kwargs['augmentation'] sources = tuple( (sourceNode(fls[t] + "." + kwargs['input_format'], datasets={ raw: 'volumes/raw', anchor: 'volumes/gt_fgbg', }, array_specs={ raw: gp.ArraySpec(interpolatable=True), anchor: gp.ArraySpec(interpolatable=False) }), gp.CsvIDPointsSource(fls[t] + ".csv", points, points_spec=gp.PointsSpec( roi=gp.Roi(gp.Coordinate(( 0, 0, 0)), gp.Coordinate(shapes[t]))))) + gp.MergeProvider() # + Clip(raw, mn=mn[t], mx=mx[t]) # + NormalizeMinMax(raw, mn=mn[t], mx=mx[t]) + gp.Pad(raw, None) + gp.Pad(points, None) # chose a random location for each requested batch + gp.RandomLocation() for t in range(ln)) pipeline = ( sources + # chose a random source (i.e., sample) from the above gp.RandomProvider() + # elastically deform the batch (gp.ElasticAugment( augmentation['elastic']['control_point_spacing'], augmentation['elastic']['jitter_sigma'], [augmentation['elastic']['rotation_min']*np.pi/180.0, augmentation['elastic']['rotation_max']*np.pi/180.0], subsample=augmentation['elastic'].get('subsample', 1)) \ if augmentation.get('elastic') is not None else NoOp()) + # apply transpose and mirror augmentations gp.SimpleAugment(mirror_only=augmentation['simple'].get("mirror"), transpose_only=augmentation['simple'].get("transpose")) + # (gp.SimpleAugment( # mirror_only=augmentation['simple'].get("mirror"), # transpose_only=augmentation['simple'].get("transpose")) \ # if augmentation.get('simple') is not None and \ # augmentation.get('simple') != {} else NoOp()) + # # scale and shift the intensity of the raw array (gp.IntensityAugment( raw, scale_min=augmentation['intensity']['scale'][0], scale_max=augmentation['intensity']['scale'][1], shift_min=augmentation['intensity']['shift'][0], shift_max=augmentation['intensity']['shift'][1], z_section_wise=False) \ if augmentation.get('intensity') is not None and \ augmentation.get('intensity') != {} else NoOp()) + gp.RasterizePoints( points, gt_cp, array_spec=gp.ArraySpec(voxel_size=voxel_size), settings=gp.RasterizationSettings( radius=(2, 2, 2), mode='peak')) + # pre-cache batches from the point upstream gp.PreCache( cache_size=kwargs['cache_size'], num_workers=kwargs['num_workers']) + # perform one training iteration for each passing batch (here we use # the tensor names earlier stored in train_net.config) gp.tensorflow.Train( os.path.join(kwargs['output_folder'], kwargs['name']), optimizer=net_names['optimizer'], summary=net_names['summaries'], log_dir=kwargs['output_folder'], loss=net_names['loss'], inputs={ net_names['raw']: raw, net_names['gt_cp']: gt_cp, net_names['anchor']: anchor, }, outputs={ net_names['pred_cp']: pred_cp, net_names['raw_cropped']: raw_cropped, }, gradients={ # net_names['pred_cp']: pred_cp_gradients, }, save_every=kwargs['checkpoints']) + # save the passing batch as an HDF5 file for inspection gp.Snapshot( { raw: '/volumes/raw', raw_cropped: 'volumes/raw_cropped', gt_cp: '/volumes/gt_cp', pred_cp: '/volumes/pred_cp', # pred_cp_gradients: '/volumes/pred_cp_gradients', }, output_dir=os.path.join(kwargs['output_folder'], 'snapshots'), output_filename='batch_{iteration}.hdf', every=kwargs['snapshots'], additional_request=snapshot_request, compression_type='gzip') + # show a summary of time spend in each node every 10 iterations gp.PrintProfilingStats(every=kwargs['profiling']) ) ######### # TRAIN # ######### print("Starting training...") with gp.build(pipeline): print(pipeline) for i in range(trained_until, kwargs['max_iteration']): # print("request", request) start = time.time() pipeline.request_batch(request) time_of_iteration = time.time() - start logger.info("Batch: iteration=%d, time=%f", i, time_of_iteration) # exit() print("Training finished")
def train(iterations): ################## # DECLARE ARRAYS # ################## # raw intensities raw = gp.ArrayKey('RAW') # objects labelled with unique IDs gt_labels = gp.ArrayKey('LABELS') # array of per-voxel affinities to direct neighbors gt_affs = gp.ArrayKey('AFFINITIES') # weights to use to balance the loss loss_weights = gp.ArrayKey('LOSS_WEIGHTS') # the predicted affinities pred_affs = gp.ArrayKey('PRED_AFFS') # the gredient of the loss wrt to the predicted affinities pred_affs_gradients = gp.ArrayKey('PRED_AFFS_GRADIENTS') #################### # DECLARE REQUESTS # #################### with open('train_net_config.json', 'r') as f: net_config = json.load(f) # get the input and output size in world units (nm, in this case) voxel_size = gp.Coordinate((40, 4, 4)) input_size = gp.Coordinate(net_config['input_shape']) * voxel_size output_size = gp.Coordinate(net_config['output_shape']) * voxel_size # formulate the request for what a batch should (at least) contain request = gp.BatchRequest() request.add(raw, input_size) request.add(gt_affs, output_size) request.add(loss_weights, output_size) # when we make a snapshot for inspection (see below), we also want to # request the predicted affinities and gradients of the loss wrt the # affinities snapshot_request = gp.BatchRequest() snapshot_request[pred_affs] = request[gt_affs] snapshot_request[pred_affs_gradients] = request[gt_affs] ############################## # ASSEMBLE TRAINING PIPELINE # ############################## pipeline = ( # a tuple of sources, one for each sample (A, B, and C) provided by the # CREMI challenge tuple( # read batches from the HDF5 file gp.Hdf5Source('sample_' + s + '_padded_20160501.hdf', datasets={ raw: 'volumes/raw', gt_labels: 'volumes/labels/neuron_ids' }) + # convert raw to float in [0, 1] gp.Normalize(raw) + # chose a random location for each requested batch gp.RandomLocation() for s in ['A', 'B', 'C']) + # chose a random source (i.e., sample) from the above gp.RandomProvider() + # elastically deform the batch gp.ElasticAugment([4, 40, 40], [0, 2, 2], [0, math.pi / 2.0], prob_slip=0.05, prob_shift=0.05, max_misalign=25) + # apply transpose and mirror augmentations gp.SimpleAugment(transpose_only=[1, 2]) + # scale and shift the intensity of the raw array gp.IntensityAugment(raw, scale_min=0.9, scale_max=1.1, shift_min=-0.1, shift_max=0.1, z_section_wise=True) + # grow a boundary between labels gp.GrowBoundary(gt_labels, steps=3, only_xy=True) + # convert labels into affinities between voxels gp.AddAffinities([[-1, 0, 0], [0, -1, 0], [0, 0, -1]], gt_labels, gt_affs) + # create a weight array that balances positive and negative samples in # the affinity array gp.BalanceLabels(gt_affs, loss_weights) + # pre-cache batches from the point upstream gp.PreCache(cache_size=10, num_workers=5) + # perform one training iteration for each passing batch (here we use # the tensor names earlier stored in train_net.config) gp.tensorflow.Train( 'train_net', net_config['optimizer'], net_config['loss'], inputs={ net_config['raw']: raw, net_config['gt_affs']: gt_affs, net_config['loss_weights']: loss_weights }, outputs={net_config['pred_affs']: pred_affs}, gradients={net_config['pred_affs']: pred_affs_gradients}, save_every=1) + # save the passing batch as an HDF5 file for inspection gp.Snapshot( { raw: '/volumes/raw', gt_labels: '/volumes/labels/neuron_ids', gt_affs: '/volumes/labels/affs', pred_affs: '/volumes/pred_affs', pred_affs_gradients: '/volumes/pred_affs_gradients' }, output_dir='snapshots', output_filename='batch_{iteration}.hdf', every=100, additional_request=snapshot_request, compression_type='gzip') + # show a summary of time spend in each node every 10 iterations gp.PrintProfilingStats(every=10)) ######### # TRAIN # ######### print("Training for", iterations, "iterations") with gp.build(pipeline): for i in range(iterations): pipeline.request_batch(request) print("Finished")
def create_train_pipeline(self, model): optimizer = self.params['optimizer'](model.parameters(), **self.params['optimizer_kwargs']) filename = self.params['data_file'] datasets = self.params['dataset'] raw_0 = gp.ArrayKey('RAW_0') points_0 = gp.GraphKey('POINTS_0') locations_0 = gp.ArrayKey('LOCATIONS_0') emb_0 = gp.ArrayKey('EMBEDDING_0') raw_1 = gp.ArrayKey('RAW_1') points_1 = gp.GraphKey('POINTS_1') locations_1 = gp.ArrayKey('LOCATIONS_1') emb_1 = gp.ArrayKey('EMBEDDING_1') data = daisy.open_ds(filename, datasets[0]) source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape()) voxel_size = gp.Coordinate(data.voxel_size) # Get in and out shape in_shape = gp.Coordinate(model.in_shape) out_shape = gp.Coordinate(model.out_shape[2:]) is_2d = in_shape.dims() == 2 emb_voxel_size = voxel_size cv_loss = ContrastiveVolumeLoss(self.params['temperature'], self.params['point_density'], out_shape * voxel_size) # Add fake 3rd dim if is_2d: in_shape = gp.Coordinate((1, *in_shape)) out_shape = gp.Coordinate((1, *out_shape)) voxel_size = gp.Coordinate((1, *voxel_size)) source_roi = gp.Roi((0, *source_roi.get_offset()), (data.shape[0], *source_roi.get_shape())) in_shape = in_shape * voxel_size out_shape = out_shape * voxel_size logger.info(f"source roi: {source_roi}") logger.info(f"in_shape: {in_shape}") logger.info(f"out_shape: {out_shape}") logger.info(f"voxel_size: {voxel_size}") request = gp.BatchRequest() request.add(raw_0, in_shape) request.add(raw_1, in_shape) request.add(points_0, out_shape) request.add(points_1, out_shape) request[locations_0] = gp.ArraySpec(nonspatial=True) request[locations_1] = gp.ArraySpec(nonspatial=True) snapshot_request = gp.BatchRequest() snapshot_request[emb_0] = gp.ArraySpec(roi=request[points_0].roi) snapshot_request[emb_1] = gp.ArraySpec(roi=request[points_1].roi) random_point_generator = RandomPointGenerator( density=self.params['point_density'], repetitions=2) # Use volume to calculate probabilities, RandomSourceGenerator will # normalize volumes to probablilties probabilities = np.array([ np.product(daisy.open_ds(filename, dataset).shape) for dataset in datasets ]) random_source_generator = RandomSourceGenerator( num_sources=len(datasets), probabilities=probabilities, repetitions=2) array_sources = tuple( tuple( gp.ZarrSource( filename, {raw: dataset}, # fake 3D data array_specs={ raw: gp.ArraySpec(roi=source_roi, voxel_size=voxel_size, interpolatable=True) }) for dataset in datasets) for raw in [raw_0, raw_1]) # Choose a random dataset to pull from array_sources = \ tuple(arrays + RandomMultiBranchSource(random_source_generator) + gp.Normalize(raw, self.params['norm_factor']) + gp.Pad(raw, None) for raw, arrays in zip([raw_0, raw_1], array_sources)) point_sources = tuple( (RandomPointSource(points_0, random_point_generator=random_point_generator), RandomPointSource(points_1, random_point_generator=random_point_generator))) # Merge the point and array sources together. # There is one array and point source per branch. sources = tuple((array_source, point_source) + gp.MergeProvider() for array_source, point_source in zip( array_sources, point_sources)) sources = tuple( self._make_train_augmentation_pipeline(raw, source) for raw, source in zip([raw_0, raw_1], sources)) pipeline = (sources + gp.MergeProvider() + gp.Crop(raw_0, source_roi) + gp.Crop(raw_1, source_roi) + gp.RandomLocation() + PrepareBatch(raw_0, raw_1, points_0, points_1, locations_0, locations_1, is_2d) + RejectArray(ensure_nonempty=locations_0) + RejectArray(ensure_nonempty=locations_1)) if not is_2d: pipeline = (pipeline + AddChannelDim(raw_0) + AddChannelDim(raw_1)) pipeline = (pipeline + gp.PreCache() + gp.torch.Train( model, cv_loss, optimizer, inputs={ 'raw_0': raw_0, 'raw_1': raw_1 }, loss_inputs={ 'emb_0': emb_0, 'emb_1': emb_1, 'locations_0': locations_0, 'locations_1': locations_1 }, outputs={ 2: emb_0, 3: emb_1 }, array_specs={ emb_0: gp.ArraySpec(voxel_size=emb_voxel_size), emb_1: gp.ArraySpec(voxel_size=emb_voxel_size) }, checkpoint_basename=self.logdir + '/contrastive/checkpoints/model', save_every=self.params['save_every'], log_dir=self.logdir + "/contrastive", log_every=self.log_every)) if is_2d: pipeline = ( pipeline + # everything is 3D, except emb_0 and emb_1 AddSpatialDim(emb_0) + AddSpatialDim(emb_1)) pipeline = ( pipeline + # now everything is 3D RemoveChannelDim(raw_0) + RemoveChannelDim(raw_1) + RemoveChannelDim(emb_0) + RemoveChannelDim(emb_1) + gp.Snapshot(output_dir=self.logdir + '/contrastive/snapshots', output_filename='it{iteration}.hdf', dataset_names={ raw_0: 'raw_0', raw_1: 'raw_1', locations_0: 'locations_0', locations_1: 'locations_1', emb_0: 'emb_0', emb_1: 'emb_1' }, additional_request=snapshot_request, every=self.params['save_every']) + gp.PrintProfilingStats(every=500)) return pipeline, request
def train(n_iterations): raw = gp.ArrayKey('RAW') gt = gp.ArrayKey('GT') gt_fg = gp.ArrayKey('GT_FP') embedding = gp.ArrayKey('EMBEDDING') fg = gp.ArrayKey('FG') maxima = gp.ArrayKey('MAXIMA') gradient_embedding = gp.ArrayKey('GRADIENT_EMBEDDING') gradient_fg = gp.ArrayKey('GRADIENT_FG') emst = gp.ArrayKey('EMST') edges_u = gp.ArrayKey('EDGES_U') edges_v = gp.ArrayKey('EDGES_V') request = gp.BatchRequest() request.add(raw, (200, 200)) request.add(gt, (160, 160)) snapshot_request = gp.BatchRequest() snapshot_request[embedding] = request[gt] snapshot_request[fg] = request[gt] snapshot_request[gt_fg] = request[gt] snapshot_request[maxima] = request[gt] snapshot_request[gradient_embedding] = request[gt] snapshot_request[gradient_fg] = request[gt] snapshot_request[emst] = gp.ArraySpec() snapshot_request[edges_u] = gp.ArraySpec() snapshot_request[edges_v] = gp.ArraySpec() pipeline = ( Synthetic2DSource(raw, gt) + gp.Normalize(raw) + gp.tensorflow.Train('train_net', optimizer=add_loss, loss=None, inputs={ tensor_names['raw']: raw, tensor_names['gt_labels']: gt, }, outputs={ tensor_names['embedding']: embedding, tensor_names['fg']: fg, 'maxima:0': maxima, 'gt_fg:0': gt_fg, emst_name: emst, edges_u_name: edges_u, edges_v_name: edges_v, }, gradients={ tensor_names['embedding']: gradient_embedding, tensor_names['fg']: gradient_fg, }) + gp.Snapshot(output_filename='{iteration}.hdf', dataset_names={ raw: 'volumes/raw', gt: 'volumes/gt', embedding: 'volumes/embedding', fg: 'volumes/fg', maxima: 'volumes/maxima', gt_fg: 'volumes/gt_fg', gradient_embedding: 'volumes/gradient_embedding', gradient_fg: 'volumes/gradient_fg', emst: 'emst', edges_u: 'edges_u', edges_v: 'edges_v', }, dataset_dtypes={ maxima: np.float32, gt_fg: np.float32 }, every=100, additional_request=snapshot_request)) with gp.build(pipeline): for i in range(n_iterations): pipeline.request_batch(request)
def predict_2d(raw_data, gt_data, predictor): raw_channels = max(1, raw_data.num_channels) input_shape = predictor.input_shape output_shape = predictor.output_shape dataset_shape = raw_data.shape dataset_roi = raw_data.roi voxel_size = raw_data.voxel_size # switch to world units input_size = voxel_size * input_shape output_size = voxel_size * output_shape raw = gp.ArrayKey('RAW') gt = gp.ArrayKey('GT') target = gp.ArrayKey('TARGET') prediction = gp.ArrayKey('PREDICTION') channel_dims = 0 if raw_channels == 1 else 1 data_dims = len(dataset_shape) - channel_dims if data_dims == 3: num_samples = dataset_shape[0] sample_shape = dataset_shape[channel_dims + 1:] else: raise RuntimeError( "For 2D validation, please provide a 3D array where the first " "dimension indexes the samples.") num_samples = raw_data.num_samples sample_shape = gp.Coordinate(sample_shape) sample_size = sample_shape * voxel_size scan_request = gp.BatchRequest() scan_request.add(raw, input_size) scan_request.add(prediction, output_size) if gt_data: scan_request.add(gt, output_size) scan_request.add(target, output_size) # overwrite source ROI to treat samples as z dimension spec = gp.ArraySpec(roi=gp.Roi((0, ) + dataset_roi.get_begin(), (num_samples, ) + sample_size), voxel_size=(1, ) + voxel_size) if gt_data: sources = (raw_data.get_source(raw, overwrite_spec=spec), gt_data.get_source(gt, overwrite_spec=spec)) pipeline = sources + gp.MergeProvider() else: pipeline = raw_data.get_source(raw, overwrite_spec=spec) pipeline += gp.Pad(raw, None) if gt_data: pipeline += gp.Pad(gt, None) # raw: ([c,] s, h, w) # gt: ([c,] s, h, w) pipeline += gp.Normalize(raw) # raw: ([c,] s, h, w) # gt: ([c,] s, h, w) if gt_data: pipeline += predictor.add_target(gt, target) # raw: ([c,] s, h, w) # gt: ([c,] s, h, w) # target: ([c,] s, h, w) if channel_dims == 0: pipeline += AddChannelDim(raw) if gt_data and predictor.target_channels == 0: pipeline += AddChannelDim(target) # raw: (c, s, h, w) # gt: ([c,] s, h, w) # target: (c, s, h, w) pipeline += TransposeDims(raw, (1, 0, 2, 3)) if gt_data: pipeline += TransposeDims(target, (1, 0, 2, 3)) # raw: (s, c, h, w) # gt: ([c,] s, h, w) # target: (s, c, h, w) pipeline += gp_torch.Predict(model=predictor, inputs={'x': raw}, outputs={0: prediction}) # raw: (s, c, h, w) # gt: ([c,] s, h, w) # target: (s, c, h, w) # prediction: (s, c, h, w) pipeline += gp.Scan(scan_request) total_request = gp.BatchRequest() total_request.add(raw, sample_size) total_request.add(prediction, sample_size) if gt_data: total_request.add(gt, sample_size) total_request.add(target, sample_size) with gp.build(pipeline): batch = pipeline.request_batch(total_request) ret = {'raw': batch[raw], 'prediction': batch[prediction]} if gt_data: ret.update({'gt': batch[gt], 'target': batch[target]}) return ret
def predict_3d(raw_data, gt_data, model, predictor, aux_tasks): raw_channels = max(1, raw_data.num_channels) input_shape = model.input_shape output_shape = model.output_shape voxel_size = raw_data.voxel_size # switch to world units input_size = voxel_size * input_shape output_size = voxel_size * output_shape raw = gp.ArrayKey('RAW') gt = gp.ArrayKey('GT') target = gp.ArrayKey('TARGET') model_output = gp.ArrayKey('MODEL_OUTPUT') prediction = gp.ArrayKey('PREDICTION') channel_dims = 0 if raw_channels == 1 else 1 num_samples = raw_data.num_samples assert num_samples == 0, ( "Multiple samples for 3D validation not yet implemented") if gt_data: sources = (raw_data.get_source(raw), gt_data.get_source(gt)) pipeline = sources + gp.MergeProvider() else: pipeline = raw_data.get_source(raw) pipeline += gp.Pad(raw, None) if gt_data: pipeline += gp.Pad(gt, None) # raw: ([c,] d, h, w) # gt: ([c,] d, h, w) pipeline += gp.Normalize(raw) # raw: ([c,] d, h, w) # gt: ([c,] d, h, w) if gt_data: pipeline += predictor.add_target(gt, target) # raw: ([c,] d, h, w) # gt: ([c,] d, h, w) # target: ([c,] d, h, w) if channel_dims == 0: pipeline += AddChannelDim(raw) # raw: (c, d, h, w) # gt: ([c,] d, h, w) # target: ([c,] d, h, w) # add a "batch" dimension pipeline += AddChannelDim(raw) # raw: (1, c, d, h, w) # gt: ([c,] d, h, w) # target: ([c,] d, h, w) pipeline += gp_torch.Predict(model=model, inputs={'x': raw}, outputs={0: model_output}) pipeline += gp_torch.Predict(model=predictor, inputs={'x': model_output}, outputs={0: prediction}) aux_predictions = [] for aux_name, aux_predictor, _ in aux_tasks: aux_pred_key = gp.ArrayKey(f"PRED_{aux_name.upper()}") pipeline += gp_torch.Predict(model=aux_predictor, inputs={'x': model_output}, outputs={0: aux_pred_key}) aux_predictions.append((aux_name, aux_pred_key)) # remove "batch" dimension pipeline += RemoveChannelDim(raw) pipeline += RemoveChannelDim(prediction) # raw: (c, d, h, w) # gt: ([c,] d, h, w) # target: ([c,] d, h, w) # prediction: ([c,] d, h, w) if channel_dims == 0: pipeline += RemoveChannelDim(raw) scan_request = gp.BatchRequest() scan_request.add(raw, input_size) scan_request.add(model_output, output_size) scan_request.add(prediction, output_size) for aux_name, aux_key in aux_predictions: scan_request.add(aux_key, output_size) if gt_data: scan_request.add(gt, output_size) scan_request.add(target, output_size) # raw: ([c,] d, h, w) # gt: ([c,] d, h, w) # target: ([c,] d, h, w) # prediction: ([c,] d, h, w) pipeline += gp.Scan(scan_request) # only output where the gt exists context = (input_size - output_size) / 2 output_roi = gt_data.roi.intersect(raw_data.roi.grow(-context, -context)) input_roi = output_roi.grow(context, context) assert all([a > b for a, b in zip(input_roi.get_shape(), input_size)]) assert all([a > b for a, b in zip(output_roi.get_shape(), output_size)]) total_request = gp.BatchRequest() total_request[raw] = gp.ArraySpec(roi=input_roi) total_request[model_output] = gp.ArraySpec(roi=output_roi) total_request[prediction] = gp.ArraySpec(roi=output_roi) for aux_name, aux_key in aux_predictions: total_request[aux_key] = gp.ArraySpec(roi=output_roi) if gt_data: total_request[gt] = gp.ArraySpec(roi=output_roi) total_request[target] = gp.ArraySpec(roi=output_roi) with gp.build(pipeline): batch = pipeline.request_batch(total_request) ret = { 'raw': batch[raw], 'model_out': batch[model_output], 'prediction': batch[prediction] } if gt_data: ret.update({'gt': batch[gt], 'target': batch[target]}) for aux_name, aux_key in aux_predictions: ret[aux_name] = batch[aux_key] return ret
def train_until(**kwargs): print("cuda visibile devices", os.environ["CUDA_VISIBLE_DEVICES"]) if tf.train.latest_checkpoint(kwargs['output_folder']): trained_until = int( tf.train.latest_checkpoint(kwargs['output_folder']).split('_')[-1]) else: trained_until = 0 if trained_until >= kwargs['max_iteration']: return anchor = gp.ArrayKey('ANCHOR') raw = gp.ArrayKey('RAW') raw_cropped = gp.ArrayKey('RAW_CROPPED') gt_labels = gp.ArrayKey('GT_LABELS') gt_affs = gp.ArrayKey('GT_AFFS') pred_affs = gp.ArrayKey('PRED_AFFS') pred_affs_gradients = gp.ArrayKey('PRED_AFFS_GRADIENTS') with open( os.path.join(kwargs['output_folder'], kwargs['name'] + '_config.json'), 'r') as f: net_config = json.load(f) with open( os.path.join(kwargs['output_folder'], kwargs['name'] + '_names.json'), 'r') as f: net_names = json.load(f) voxel_size = gp.Coordinate(kwargs['voxel_size']) input_shape_world = gp.Coordinate(net_config['input_shape']) * voxel_size output_shape_world = gp.Coordinate(net_config['output_shape']) * voxel_size # formulate the request for what a batch should (at least) contain request = gp.BatchRequest() # when we make a snapshot for inspection (see below), we also want to # request the predicted affinities and gradients of the loss wrt the # affinities snapshot_request = gp.BatchRequest() snapshot_request.add(raw_cropped, output_shape_world) snapshot_request.add(pred_affs, output_shape_world) snapshot_request.add(gt_affs, output_shape_world) if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr": raise NotImplementedError("train node for %s not implemented yet", kwargs['input_format']) fls = [] for f in kwargs['data_files']: fls.append(os.path.splitext(f)[0]) ln = len(fls) print("first 5 files: ", fls[0:4]) if kwargs['input_format'] == "hdf": sourceNode = gp.Hdf5Source elif kwargs['input_format'] == "zarr": sourceNode = gp.ZarrSource neighborhood = [] psH = np.array(kwargs['patchshape']) // 2 for i in range(-psH[0], psH[0] + 1, kwargs['patchstride'][0]): for j in range(-psH[1], psH[1] + 1, kwargs['patchstride'][1]): for k in range(-psH[2], psH[2] + 1, kwargs['patchstride'][2]): neighborhood.append([i, j, k]) datasets = { raw: 'volumes/raw', gt_labels: 'volumes/gt_labels', anchor: 'volumes/gt_fgbg', } input_specs = { raw: gp.ArraySpec(roi=gp.Roi((0, ) * len(input_shape_world), input_shape_world), interpolatable=True, dtype=np.float32), gt_labels: gp.ArraySpec(roi=gp.Roi((0, ) * len(output_shape_world), output_shape_world), interpolatable=False, dtype=np.uint16), anchor: gp.ArraySpec(roi=gp.Roi((0, ) * len(output_shape_world), output_shape_world), interpolatable=False, dtype=np.uint8), gt_affs: gp.ArraySpec(roi=gp.Roi((0, ) * len(output_shape_world), output_shape_world), interpolatable=False, dtype=np.uint8) } inputs = { net_names['raw']: raw, net_names['gt_affs']: gt_affs, net_names['anchor']: anchor, } outputs = { net_names['pred_affs']: pred_affs, net_names['raw_cropped']: raw_cropped, } snapshot = { raw_cropped: 'volumes/raw_cropped', gt_affs: '/volumes/gt_affs', pred_affs: '/volumes/pred_affs', } optimizer_args = None if kwargs['auto_mixed_precision']: optimizer_args = (kwargs['optimizer'], { 'args': kwargs['args'], 'kwargs': kwargs['kwargs'] }) augmentation = kwargs['augmentation'] pipeline = ( tuple( sourceNode( fls[t] + "." + kwargs['input_format'], datasets=datasets, # array_specs=array_specs ) + gp.Pad(raw, None) + gp.Pad(gt_labels, None) # chose a random location for each requested batch + gp.RandomLocation() for t in range(ln) ) + # chose a random source (i.e., sample) from the above gp.RandomProvider() + # elastically deform the batch gp.ElasticAugment( augmentation['elastic']['control_point_spacing'], augmentation['elastic']['jitter_sigma'], [augmentation['elastic']['rotation_min']*np.pi/180.0, augmentation['elastic']['rotation_max']*np.pi/180.0], subsample=4) + # apply transpose and mirror augmentations gp.SimpleAugment(mirror_only=augmentation['simple'].get("mirror"), transpose_only=augmentation['simple'].get("transpose")) + # scale and shift the intensity of the raw array gp.IntensityAugment( raw, scale_min=augmentation['intensity']['scale'][0], scale_max=augmentation['intensity']['scale'][1], shift_min=augmentation['intensity']['shift'][0], shift_max=augmentation['intensity']['shift'][1], z_section_wise=False) + # grow a boundary between labels gp.GrowBoundary( gt_labels, steps=1, only_xy=False) + # convert labels into affinities between voxels gp.AddAffinities( neighborhood, gt_labels, gt_affs) + # create a weight array that balances positive and negative samples in # the affinity array # gp.BalanceLabels( # gt_affs, # loss_weights_affs) + # pre-cache batches from the point upstream gp.PreCache( cache_size=kwargs['cache_size'], num_workers=kwargs['num_workers']) + # pre-fetch batches from the point upstream (gp.tensorflow.TFData() \ if kwargs.get('use_tf_data') else NoOp()) + # perform one training iteration for each passing batch (here we use # the tensor names earlier stored in train_net.config) gp.tensorflow.Train( os.path.join(kwargs['output_folder'], kwargs['name']), optimizer=net_names['optimizer'], summary=net_names['summaries'], log_dir=kwargs['output_folder'], loss=net_names['loss'], inputs=inputs, outputs=outputs, array_specs=input_specs, gradients={ net_names['pred_affs']: pred_affs_gradients, }, auto_mixed_precision=kwargs['auto_mixed_precision'], optimizer_args=optimizer_args, use_tf_data=kwargs['use_tf_data'], save_every=kwargs['checkpoints'], snapshot_every=kwargs['snapshots']) + # save the passing batch as an HDF5 file for inspection gp.Snapshot( snapshot, output_dir=os.path.join(kwargs['output_folder'], 'snapshots'), output_filename='batch_{iteration}.hdf', every=kwargs['snapshots'], additional_request=snapshot_request, compression_type='gzip') + # show a summary of time spend in each node every 10 iterations gp.PrintProfilingStats(every=kwargs['profiling']) ) ######### # TRAIN # ######### print("Starting training...") try: with gp.build(pipeline): print(pipeline) for i in range(trained_until, kwargs['max_iteration']): start = time.time() pipeline.request_batch(request) time_of_iteration = time.time() - start logger.info("Batch: iteration=%d, time=%f", i, time_of_iteration) # exit() except KeyboardInterrupt: sys.exit() print("Training finished")
def batch_data_aug_generator(input_path, batch_size=12, voxel_shape=[1, 1, 1], input_shape=[240, 240, 4], output_shape=[240, 240, 4], without_background=False, mix_output=False, validate=False, aug=seq): raw = gp.ArrayKey('raw') gt = gp.ArrayKey('ground_truth') files = os.listdir(input_path) files = [os.path.join(input_path, f) for f in files] pipeline = ( tuple( gp.ZarrSource( files[t], # the zarr container { raw: 'raw', gt: 'ground_truth' }, # which dataset to associate to the array key { raw: gp.ArraySpec(interpolatable=True, dtype=np.dtype('float32'), voxel_size=voxel_shape), gt: gp.ArraySpec(interpolatable=True, dtype=np.dtype('float32'), voxel_size=voxel_shape) } # meta-information ) + gp.RandomLocation() for t in range(len(files))) + gp.RandomProvider() # +gp.Stack(batch_size) ) input_size = gp.Coordinate(input_shape) output_size = gp.Coordinate(output_shape) request = gp.BatchRequest() request.add(raw, input_size) request.add(gt, input_size) diff = input_shape[1] - output_shape[1] diff = int(diff / 2) max_p = input_shape[1] - diff different_shape = diff > 0 if different_shape: print('Difference padding: {}'.format(diff)) with gp.build(pipeline): while 1: b = 0 imgs = [] masks = [] while b < batch_size: valid = False batch = pipeline.request_batch(request) if validate: valid = validate_mask(batch[gt].data) else: valid = True while (valid == False): batch = pipeline.request_batch(request) valid = validate_mask(batch[gt].data) im = batch[raw].data out = batch[gt].data # im,out = augmentation(im,out,seq) if different_shape: out = out[diff:max_p, diff:max_p, :] if without_background: out = out[:, :, 1:4] if mix_output: out = out.argmax(axis=3).astype(float) imgs.append(im) masks.append(out) b = b + 1 yield augmentation(np.asarray(imgs), np.asarray(masks), seq)
def train_until(**kwargs): print("cuda visibile devices", os.environ["CUDA_VISIBLE_DEVICES"]) if tf.train.latest_checkpoint(kwargs['output_folder']): trained_until = int( tf.train.latest_checkpoint(kwargs['output_folder']).split('_')[-1]) else: trained_until = 0 if trained_until >= kwargs['max_iteration']: return raw = gp.ArrayKey('RAW') raw_cropped = gp.ArrayKey('RAW_CROPPED') gt_labels = gp.ArrayKey('GT_LABELS') gt_instances = gp.ArrayKey('GT_INSTANCES') gt_affs = gp.ArrayKey('GT_AFFS') gt_numinst = gp.ArrayKey('GT_NUMINST') gt_fgbg = gp.ArrayKey('GT_FGBG') gt_sample_mask = gp.ArrayKey('GT_SAMPLE_MASK') pred_code = gp.ArrayKey('PRED_CODE') # pred_code_gradients = gp.ArrayKey('PRED_CODE_GRADIENTS') pred_numinst = gp.ArrayKey('PRED_NUMINST') pred_fgbg = gp.ArrayKey('PRED_FGBG') with open(os.path.join(kwargs['output_folder'], kwargs['name'] + '_config.json'), 'r') as f: net_config = json.load(f) with open(os.path.join(kwargs['output_folder'], kwargs['name'] + '_names.json'), 'r') as f: net_names = json.load(f) voxel_size = gp.Coordinate(kwargs['voxel_size']) input_shape_world = gp.Coordinate(net_config['input_shape'])*voxel_size output_shape_world = gp.Coordinate(net_config['output_shape'])*voxel_size context = gp.Coordinate(input_shape_world - output_shape_world) / 2 raw_key = kwargs.get('raw_key', 'volumes/raw') # formulate the request for what a batch should (at least) contain request = gp.BatchRequest() request.add(raw, input_shape_world) request.add(raw_cropped, output_shape_world) request.add(gt_labels, output_shape_world) request.add(gt_instances, output_shape_world) request.add(gt_sample_mask, output_shape_world) request.add(gt_affs, output_shape_world) if kwargs['overlapping_inst']: request.add(gt_numinst, output_shape_world) else: request.add(gt_fgbg, output_shape_world) # request.add(loss_weights_affs, output_shape_world) # when we make a snapshot for inspection (see below), we also want to # request the predicted affinities and gradients of the loss wrt the # affinities snapshot_request = gp.BatchRequest() snapshot_request.add(raw_cropped, output_shape_world) snapshot_request.add(pred_code, output_shape_world) # snapshot_request.add(pred_code_gradients, output_shape_world) if kwargs['overlapping_inst']: snapshot_request.add(pred_numinst, output_shape_world) else: snapshot_request.add(pred_fgbg, output_shape_world) if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr": raise NotImplementedError("train node for %s not implemented yet", kwargs['input_format']) fls = [] shapes = [] for f in kwargs['data_files']: fls.append(os.path.splitext(f)[0]) if kwargs['input_format'] == "hdf": vol = h5py.File(f, 'r')['volumes/raw_bf'] elif kwargs['input_format'] == "zarr": vol = zarr.open(f, 'r')['volumes/raw_bf'] shapes.append(vol.shape) if vol.dtype != np.float32: print("please convert to float32") ln = len(fls) print("first 5 files: ", fls[0:4]) if kwargs['input_format'] == "hdf": sourceNode = gp.Hdf5Source elif kwargs['input_format'] == "zarr": sourceNode = gp.ZarrSource neighborhood = [] psH = np.array(kwargs['patchshape'])//2 for i in range(-psH[1], psH[1]+1, kwargs['patchstride'][1]): for j in range(-psH[2], psH[2]+1, kwargs['patchstride'][2]): neighborhood.append([i,j]) datasets = { raw: 'volumes/raw_bf', gt_labels: 'volumes/gt_labels', gt_instances: 'volumes/gt_instances', } array_specs = { raw: gp.ArraySpec(interpolatable=True), gt_labels: gp.ArraySpec(interpolatable=False), gt_instances: gp.ArraySpec(interpolatable=False), } inputs = { net_names['raw']: raw, net_names['gt_affs']: gt_affs, } outputs = { net_names['pred_code']: pred_code, net_names['raw_cropped']: raw_cropped, } snapshot = { raw: '/volumes/raw', raw_cropped: 'volumes/raw_cropped', gt_affs: '/volumes/gt_affs', pred_code: '/volumes/pred_code', # pred_code_gradients: '/volumes/pred_code_gradients', } if kwargs['overlapping_inst']: datasets[gt_numinst] = '/volumes/gt_numinst' array_specs[gt_numinst] = gp.ArraySpec(interpolatable=False) inputs[net_names['gt_numinst']] = gt_numinst outputs[net_names['pred_numinst']] = pred_numinst snapshot[pred_numinst] = '/volumes/pred_numinst' else: datasets[gt_fgbg] = '/volumes/gt_fgbg' array_specs[gt_fgbg] = gp.ArraySpec(interpolatable=False) inputs[net_names['gt_fgbg']] = gt_fgbg outputs[net_names['pred_fgbg']] = pred_fgbg snapshot[pred_fgbg] = '/volumes/pred_fgbg' augmentation = kwargs['augmentation'] sampling = kwargs['sampling'] source_fg = tuple( sourceNode( fls[t] + "." + kwargs['input_format'], datasets=datasets, array_specs=array_specs ) + gp.Pad(raw, context) + # chose a random location for each requested batch nl.CountOverlap(gt_labels, gt_sample_mask, maxnuminst=1) + gp.RandomLocation( min_masked=sampling['min_masked'], mask=gt_sample_mask ) for t in range(ln) ) source_fg += gp.RandomProvider() if kwargs['overlapping_inst']: source_overlap = tuple( sourceNode( fls[t] + "." + kwargs['input_format'], datasets=datasets, array_specs=array_specs ) + gp.Pad(raw, context) + # chose a random location for each requested batch nl.MaskCloseDistanceToOverlap( gt_labels, gt_sample_mask, sampling['overlap_min_dist'], sampling['overlap_max_dist'] ) + gp.RandomLocation( min_masked=sampling['min_masked_overlap'], mask=gt_sample_mask ) for t in range(ln) ) source_overlap += gp.RandomProvider() source = ( (source_fg, source_overlap) + # chose a random source (i.e., sample) from the above gp.RandomProvider(probabilities=[sampling['probability_fg'], sampling['probability_overlap']])) else: source = source_fg pipeline = ( source + # elastically deform the batch gp.ElasticAugment( augmentation['elastic']['control_point_spacing'], augmentation['elastic']['jitter_sigma'], [augmentation['elastic']['rotation_min']*np.pi/180.0, augmentation['elastic']['rotation_max']*np.pi/180.0]) + gp.Reject(gt_sample_mask, min_masked=0.002, reject_probability=1) + # apply transpose and mirror augmentations gp.SimpleAugment( mirror_only=augmentation['simple'].get("mirror"), transpose_only=augmentation['simple'].get("transpose")) + # # scale and shift the intensity of the raw array gp.IntensityAugment( raw, scale_min=augmentation['intensity']['scale'][0], scale_max=augmentation['intensity']['scale'][1], shift_min=augmentation['intensity']['shift'][0], shift_max=augmentation['intensity']['shift'][1], z_section_wise=False) + gp.IntensityScaleShift(raw, 2, -1) + # convert labels into affinities between voxels nl.AddAffinities( neighborhood, gt_labels if kwargs['overlapping_inst'] else gt_instances, gt_affs, multiple_labels=kwargs['overlapping_inst']) + # pre-cache batches from the point upstream gp.PreCache( cache_size=kwargs['cache_size'], num_workers=kwargs['num_workers']) + # perform one training iteration for each passing batch (here we use # the tensor names earlier stored in train_net.config) gp.tensorflow.Train( os.path.join(kwargs['output_folder'], kwargs['name']), optimizer=net_names['optimizer'], summary=net_names['summaries'], log_dir=kwargs['output_folder'], loss=net_names['loss'], inputs=inputs, outputs=outputs, gradients={ # net_names['pred_code']: pred_code_gradients, }, save_every=kwargs['checkpoints']) + # save the passing batch as an HDF5 file for inspection gp.Snapshot( snapshot, output_dir=os.path.join(kwargs['output_folder'], 'snapshots'), output_filename='batch_{iteration}.hdf', every=kwargs['snapshots'], additional_request=snapshot_request, compression_type='gzip') + # show a summary of time spend in each node every 10 iterations gp.PrintProfilingStats(every=kwargs['profiling']) ) ######### # TRAIN # ######### print("Starting training...") with gp.build(pipeline): print(pipeline) for i in range(trained_until, kwargs['max_iteration']): # print("request", request) start = time.time() pipeline.request_batch(request) time_of_iteration = time.time() - start logger.info( "Batch: iteration=%d, time=%f", i, time_of_iteration) # exit() print("Training finished")
def train_until(max_iteration, name='train_net', output_folder='.', clip_max=2000): # get the latest checkpoint if tf.train.latest_checkpoint(output_folder): trained_until = int(tf.train.latest_checkpoint(output_folder).split('_')[-1]) else: trained_until = 0 if trained_until >= max_iteration: return with open(os.path.join(output_folder, name + '_config.json'), 'r') as f: net_config = json.load(f) with open(os.path.join(output_folder, name + '_names.json'), 'r') as f: net_names = json.load(f) # array keys raw = gp.ArrayKey('RAW') gt_mask = gp.ArrayKey('GT_MASK') pred_mask = gp.ArrayKey('PRED_MASK') voxel_size = gp.Coordinate((1, 1, 1)) input_shape = gp.Coordinate(net_config['input_shape']) output_shape = gp.Coordinate(net_config['output_shape']) context = gp.Coordinate(input_shape - output_shape) / 2 request = gp.BatchRequest() request.add(raw, input_shape) request.add(gt_mask, output_shape) snapshot_request = gp.BatchRequest() snapshot_request.add(raw, input_shape) snapshot_request.add(gt_mask, output_shape) snapshot_request.add(pred_mask, output_shape) # specify data source data_sources = tuple() for data_file in data_files: current_path = os.path.join(data_dir, data_file) with h5py.File(current_path, 'r') as f: data_sources += tuple( gp.Hdf5Source( current_path, datasets={ raw: sample + '/raw', gt_mask: sample + '/fg' }, array_specs={ raw: gp.ArraySpec(interpolatable=True, dtype=np.uint16, voxel_size=voxel_size), gt_mask: gp.ArraySpec(interpolatable=False, dtype=np.bool, voxel_size=voxel_size), } ) + Convert(gt_mask, np.uint8) + gp.Pad(raw, context) + gp.Pad(gt_mask, context) + gp.RandomLocation() for sample in f) pipeline = ( data_sources + gp.RandomProvider() + gp.Reject(gt_mask, min_masked=0.005, reject_probability=0.98) + nl.Clip(raw, 0, clip_max) + gp.Normalize(raw, factor=1.0/clip_max) + gp.ElasticAugment( control_point_spacing=[20, 20, 20], jitter_sigma=[1, 1, 1], rotation_interval=[0, math.pi/2.0], subsample=4) + gp.SimpleAugment(mirror_only=[1,2], transpose_only=[1,2]) + gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1) + gp.IntensityScaleShift(raw, 2,-1) + # train gp.PreCache( cache_size=40, num_workers=5) + gp.tensorflow.Train( os.path.join(output_folder, name), optimizer=net_names['optimizer'], loss=net_names['loss'], inputs={ net_names['raw']: raw, net_names['gt']: gt_mask, }, outputs={ net_names['pred']: pred_mask, }, gradients={ }, save_every=5000) + # visualize gp.Snapshot({ raw: 'volumes/raw', pred_mask: 'volumes/pred_mask', gt_mask: 'volumes/gt_mask', }, output_filename=os.path.join(output_folder, 'snapshots', 'batch_{iteration}.hdf'), additional_request=snapshot_request, every=2000) + gp.PrintProfilingStats(every=500) ) with gp.build(pipeline): print("Starting training...") for i in range(max_iteration - trained_until): pipeline.request_batch(request)
def predict(data_dir, train_dir, iteration, sample, test_net_name='train_net', train_net_name='train_net', output_dir='.', clip_max=1000): if "hdf" not in data_dir: return print("Predicting ", sample) print( 'checkpoint: ', os.path.join(train_dir, train_net_name + '_checkpoint_%d' % iteration)) checkpoint = os.path.join(train_dir, train_net_name + '_checkpoint_%d' % iteration) with open(os.path.join(train_dir, test_net_name + '_config.json'), 'r') as f: net_config = json.load(f) with open(os.path.join(train_dir, test_net_name + '_names.json'), 'r') as f: net_names = json.load(f) # ArrayKeys raw = gp.ArrayKey('RAW') pred_mask = gp.ArrayKey('PRED_MASK') input_shape = gp.Coordinate(net_config['input_shape']) output_shape = gp.Coordinate(net_config['output_shape']) voxel_size = gp.Coordinate((1, 1, 1)) context = gp.Coordinate(input_shape - output_shape) / 2 # add ArrayKeys to batch request request = gp.BatchRequest() request.add(raw, input_shape, voxel_size=voxel_size) request.add(pred_mask, output_shape, voxel_size=voxel_size) print("chunk request %s" % request) source = (gp.Hdf5Source( data_dir, datasets={ raw: sample + '/raw', }, array_specs={ raw: gp.ArraySpec( interpolatable=True, dtype=np.uint16, voxel_size=voxel_size), }, ) + gp.Pad(raw, context) + nl.Clip(raw, 0, clip_max) + gp.Normalize(raw, factor=1.0 / clip_max) + gp.IntensityScaleShift(raw, 2, -1)) with gp.build(source): raw_roi = source.spec[raw].roi print("raw_roi: %s" % raw_roi) sample_shape = raw_roi.grow(-context, -context).get_shape() print(sample_shape) # create zarr file with corresponding chunk size zf = zarr.open(os.path.join(output_dir, sample + '.zarr'), mode='w') zf.create('volumes/pred_mask', shape=sample_shape, chunks=output_shape, dtype=np.float16) zf['volumes/pred_mask'].attrs['offset'] = [0, 0, 0] zf['volumes/pred_mask'].attrs['resolution'] = [1, 1, 1] pipeline = ( source + gp.tensorflow.Predict( graph=os.path.join(train_dir, test_net_name + '.meta'), checkpoint=checkpoint, inputs={ net_names['raw']: raw, }, outputs={ net_names['pred']: pred_mask, }, array_specs={ pred_mask: gp.ArraySpec(roi=raw_roi.grow(-context, -context), voxel_size=voxel_size), }, max_shared_memory=1024 * 1024 * 1024) + Convert(pred_mask, np.float16) + gp.ZarrWrite( dataset_names={ pred_mask: 'volumes/pred_mask', }, output_dir=output_dir, output_filename=sample + '.zarr', compression_type='gzip', dataset_dtypes={pred_mask: np.float16}) + # show a summary of time spend in each node every x iterations gp.PrintProfilingStats(every=100) + gp.Scan(reference=request, num_workers=5, cache_size=50)) with gp.build(pipeline): pipeline.request_batch(gp.BatchRequest())
def predict( iteration, raw_file, raw_dataset, out_file, db_host, db_name, worker_config, network_config, out_properties={}, **kwargs): setup_dir = os.path.dirname(os.path.realpath(__file__)) with open(os.path.join(setup_dir, '{}_net_config.json'.format(network_config)), 'r') as f: net_config = json.load(f) # voxels input_shape = gp.Coordinate(net_config['input_shape']) output_shape = gp.Coordinate(net_config['output_shape']) # nm voxel_size = gp.Coordinate((40, 4, 4)) input_size = input_shape * voxel_size output_size = output_shape * voxel_size raw = gp.ArrayKey('RAW') pred_post_indicator = gp.ArrayKey('PRED_POST_INDICATOR') chunk_request = gp.BatchRequest() chunk_request.add(raw, input_size) chunk_request.add(pred_post_indicator, output_size) m_property = out_properties[ 'pred_syn_indicator_out'] if 'pred_syn_indicator_out' in out_properties else None # Select Source based on filesuffix. # Hdf5Source if raw_file.endswith('.hdf'): pipeline = gp.Hdf5Source( raw_file, datasets={ raw: raw_dataset }, array_specs={ raw: gp.ArraySpec(interpolatable=True), } ) elif raw_file.endswith('.zarr') or raw_file.endswith('.n5'): pipeline = gp.ZarrSource( raw_file, datasets={ raw: raw_dataset }, array_specs={ raw: gp.ArraySpec(interpolatable=True), } ) else: raise RuntimeError('unknwon input data format {}'.format(raw_file)) pipeline += gp.Pad(raw, size=None) pipeline += gp.Normalize(raw) pipeline += gp.IntensityScaleShift(raw, 2, -1) pipeline += gp.tensorflow.Predict( os.path.join(setup_dir, 'train_net_checkpoint_%d' % iteration), inputs={ net_config['raw']: raw }, outputs={ net_config['pred_syn_indicator_out']: pred_post_indicator, }, graph=os.path.join(setup_dir, '{}_net.meta'.format(network_config)) ) if m_property is not None and 'scale' in m_property: if m_property['scale'] != 1: pipeline += gp.IntensityScaleShift(pred_post_indicator, m_property['scale'], 0) pipeline += gp.ZarrWrite( dataset_names={ pred_post_indicator: 'volumes/pred_syn_indicator', }, output_filename=out_file ) pipeline += gp.PrintProfilingStats(every=10) pipeline += gp.DaisyRequestBlocks( chunk_request, roi_map={ raw: 'read_roi', pred_post_indicator: 'write_roi' }, num_workers=worker_config['num_cache_workers'], block_done_callback=lambda b, s, d: block_done_callback( db_host, db_name, worker_config, b, s, d)) print("Starting prediction...") with gp.build(pipeline): pipeline.request_batch(gp.BatchRequest()) print("Prediction finished")
def train_until(max_iteration, name='train_net', output_folder='.', clip_max=2000): # get the latest checkpoint if tf.train.latest_checkpoint(output_folder): trained_until = int(tf.train.latest_checkpoint(output_folder).split('_')[-1]) else: trained_until = 0 if trained_until >= max_iteration: return with open(os.path.join(output_folder, name + '_config.json'), 'r') as f: net_config = json.load(f) with open(os.path.join(output_folder, name + '_names.json'), 'r') as f: net_names = json.load(f) # array keys raw = gp.ArrayKey('RAW') gt_instances = gp.ArrayKey('GT_INSTANCES') gt_mask = gp.ArrayKey('GT_MASK') pred_mask = gp.ArrayKey('PRED_MASK') #loss_weights = gp.ArrayKey('LOSS_WEIGHTS') loss_gradients = gp.ArrayKey('LOSS_GRADIENTS') # array keys for base and add volume raw_base = gp.ArrayKey('RAW_BASE') gt_instances_base = gp.ArrayKey('GT_INSTANCES_BASE') gt_mask_base = gp.ArrayKey('GT_MASK_BASE') raw_add = gp.ArrayKey('RAW_ADD') gt_instances_add = gp.ArrayKey('GT_INSTANCES_ADD') gt_mask_add = gp.ArrayKey('GT_MASK_ADD') voxel_size = gp.Coordinate((1, 1, 1)) input_shape = gp.Coordinate(net_config['input_shape']) output_shape = gp.Coordinate(net_config['output_shape']) context = gp.Coordinate(input_shape - output_shape) / 2 request = gp.BatchRequest() request.add(raw, input_shape) request.add(gt_instances, output_shape) request.add(gt_mask, output_shape) #request.add(loss_weights, output_shape) request.add(raw_base, input_shape) request.add(raw_add, input_shape) request.add(gt_mask_base, output_shape) request.add(gt_mask_add, output_shape) snapshot_request = gp.BatchRequest() snapshot_request.add(raw, input_shape) #snapshot_request.add(raw_base, input_shape) #snapshot_request.add(raw_add, input_shape) snapshot_request.add(gt_mask, output_shape) #snapshot_request.add(gt_mask_base, output_shape) #snapshot_request.add(gt_mask_add, output_shape) snapshot_request.add(pred_mask, output_shape) snapshot_request.add(loss_gradients, output_shape) # specify data source # data source for base volume data_sources_base = tuple() for data_file in data_files: current_path = os.path.join(data_dir, data_file) with h5py.File(current_path, 'r') as f: data_sources_base += tuple( gp.Hdf5Source( current_path, datasets={ raw_base: sample + '/raw', gt_instances_base: sample + '/gt', gt_mask_base: sample + '/fg', }, array_specs={ raw_base: gp.ArraySpec(interpolatable=True, dtype=np.uint16, voxel_size=voxel_size), gt_instances_base: gp.ArraySpec(interpolatable=False, dtype=np.uint16, voxel_size=voxel_size), gt_mask_base: gp.ArraySpec(interpolatable=False, dtype=np.bool, voxel_size=voxel_size), } ) + Convert(gt_mask_base, np.uint8) + gp.Pad(raw_base, context) + gp.Pad(gt_instances_base, context) + gp.Pad(gt_mask_base, context) + gp.RandomLocation(min_masked=0.005, mask=gt_mask_base) #gp.Reject(gt_mask_base, min_masked=0.005, reject_probability=1.) for sample in f) data_sources_base += gp.RandomProvider() # data source for add volume data_sources_add = tuple() for data_file in data_files: current_path = os.path.join(data_dir, data_file) with h5py.File(current_path, 'r') as f: data_sources_add += tuple( gp.Hdf5Source( current_path, datasets={ raw_add: sample + '/raw', gt_instances_add: sample + '/gt', gt_mask_add: sample + '/fg', }, array_specs={ raw_add: gp.ArraySpec(interpolatable=True, dtype=np.uint16, voxel_size=voxel_size), gt_instances_add: gp.ArraySpec(interpolatable=False, dtype=np.uint16, voxel_size=voxel_size), gt_mask_add: gp.ArraySpec(interpolatable=False, dtype=np.bool, voxel_size=voxel_size), } ) + Convert(gt_mask_add, np.uint8) + gp.Pad(raw_add, context) + gp.Pad(gt_instances_add, context) + gp.Pad(gt_mask_add, context) + gp.RandomLocation() + gp.Reject(gt_mask_add, min_masked=0.005, reject_probability=0.95) for sample in f) data_sources_add += gp.RandomProvider() data_sources = tuple([data_sources_base, data_sources_add]) + gp.MergeProvider() pipeline = ( data_sources + nl.FusionAugment( raw_base, raw_add, gt_instances_base, gt_instances_add, raw, gt_instances, blend_mode='labels_mask', blend_smoothness=5, num_blended_objects=0 ) + BinarizeLabels(gt_instances, gt_mask) + nl.Clip(raw, 0, clip_max) + gp.Normalize(raw, factor=1.0/clip_max) + gp.ElasticAugment( control_point_spacing=[20, 20, 20], jitter_sigma=[1, 1, 1], rotation_interval=[0, math.pi/2.0], subsample=4) + gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2]) + gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1) + gp.IntensityScaleShift(raw, 2, -1) + #gp.BalanceLabels(gt_mask, loss_weights) + # train gp.PreCache( cache_size=40, num_workers=10) + gp.tensorflow.Train( os.path.join(output_folder, name), optimizer=net_names['optimizer'], loss=net_names['loss'], inputs={ net_names['raw']: raw, net_names['gt']: gt_mask, #net_names['loss_weights']: loss_weights, }, outputs={ net_names['pred']: pred_mask, }, gradients={ net_names['output']: loss_gradients, }, save_every=5000) + # visualize gp.Snapshot({ raw: 'volumes/raw', pred_mask: 'volumes/pred_mask', gt_mask: 'volumes/gt_mask', #loss_weights: 'volumes/loss_weights', loss_gradients: 'volumes/loss_gradients', }, output_filename=os.path.join(output_folder, 'snapshots', 'batch_{iteration}.hdf'), additional_request=snapshot_request, every=2500) + gp.PrintProfilingStats(every=1000) ) with gp.build(pipeline): print("Starting training...") for i in range(max_iteration - trained_until): pipeline.request_batch(request)
def train_until(max_iteration): # get the latest checkpoint if tf.train.latest_checkpoint("."): trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1]) else: trained_until = 0 if trained_until >= max_iteration: return # array keys for fused volume raw = gp.ArrayKey("RAW") labels = gp.ArrayKey("LABELS") labels_fg = gp.ArrayKey("LABELS_FG") # array keys for base volume raw_base = gp.ArrayKey("RAW_BASE") labels_base = gp.ArrayKey("LABELS_BASE") swc_base = gp.PointsKey("SWC_BASE") swc_center_base = gp.PointsKey("SWC_CENTER_BASE") # array keys for add volume raw_add = gp.ArrayKey("RAW_ADD") labels_add = gp.ArrayKey("LABELS_ADD") swc_add = gp.PointsKey("SWC_ADD") swc_center_add = gp.PointsKey("SWC_CENTER_ADD") # output data fg = gp.ArrayKey("FG") gradient_fg = gp.ArrayKey("GRADIENT_FG") loss_weights = gp.ArrayKey("LOSS_WEIGHTS") voxel_size = gp.Coordinate((4, 1, 1)) input_size = gp.Coordinate(net_config["input_shape"]) * voxel_size output_size = gp.Coordinate(net_config["output_shape"]) * voxel_size # add request request = gp.BatchRequest() request.add(raw, input_size) request.add(labels, output_size) request.add(labels_fg, output_size) request.add(loss_weights, output_size) request.add(swc_center_base, output_size) request.add(swc_center_add, output_size) # add snapshot request snapshot_request = gp.BatchRequest() snapshot_request.add(fg, output_size) snapshot_request.add(labels_fg, output_size) snapshot_request.add(gradient_fg, output_size) snapshot_request.add(raw_base, input_size) snapshot_request.add(raw_add, input_size) snapshot_request.add(labels_base, input_size) snapshot_request.add(labels_add, input_size) # data source for "base" volume data_sources_base = tuple( ( gp.Hdf5Source( filename, datasets={raw_base: "/volume"}, array_specs={ raw_base: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size, dtype=np.uint16) }, channels_first=False, ), SwcSource( filename=filename, dataset="/reconstruction", points=(swc_center_base, swc_base), scale=voxel_size, ), ) + gp.MergeProvider() + gp.RandomLocation(ensure_nonempty=swc_center_base) + RasterizeSkeleton( points=swc_base, array=labels_base, array_spec=gp.ArraySpec( interpolatable=False, voxel_size=voxel_size, dtype=np.uint32), radius=5.0, ) for filename in files) # data source for "add" volume data_sources_add = tuple( ( gp.Hdf5Source( file, datasets={raw_add: "/volume"}, array_specs={ raw_add: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size, dtype=np.uint16) }, channels_first=False, ), SwcSource( filename=file, dataset="/reconstruction", points=(swc_center_add, swc_add), scale=voxel_size, ), ) + gp.MergeProvider() + gp.RandomLocation(ensure_nonempty=swc_center_add) + RasterizeSkeleton( points=swc_add, array=labels_add, array_spec=gp.ArraySpec( interpolatable=False, voxel_size=voxel_size, dtype=np.uint32), radius=5.0, ) for file in files) data_sources = ( (data_sources_base + gp.RandomProvider()), (data_sources_add + gp.RandomProvider()), ) + gp.MergeProvider() pipeline = ( data_sources + FusionAugment( raw_base, raw_add, labels_base, labels_add, raw, labels, blend_mode="labels_mask", blend_smoothness=10, num_blended_objects=0, ) + # augment gp.ElasticAugment([40, 10, 10], [0.25, 1, 1], [0, math.pi / 2.0], subsample=4) + gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2]) + gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001) + BinarizeGt(labels, labels_fg) + gp.BalanceLabels(labels_fg, loss_weights) + # train gp.PreCache(cache_size=40, num_workers=10) + gp.tensorflow.Train( "./train_net", optimizer=net_names["optimizer"], loss=net_names["loss"], inputs={ net_names["raw"]: raw, net_names["labels_fg"]: labels_fg, net_names["loss_weights"]: loss_weights, }, outputs={net_names["fg"]: fg}, gradients={net_names["fg"]: gradient_fg}, save_every=100000, ) + # visualize gp.Snapshot( output_filename="snapshot_{iteration}.hdf", dataset_names={ raw: "volumes/raw", raw_base: "volumes/raw_base", raw_add: "volumes/raw_add", labels: "volumes/labels", labels_base: "volumes/labels_base", labels_add: "volumes/labels_add", fg: "volumes/fg", labels_fg: "volumes/labels_fg", gradient_fg: "volumes/gradient_fg", }, additional_request=snapshot_request, every=100, ) + gp.PrintProfilingStats(every=100)) with gp.build(pipeline): print("Starting training...") for i in range(max_iteration - trained_until): pipeline.request_batch(request)
def train_until(max_iteration): # get the latest checkpoint if tf.train.latest_checkpoint("."): trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1]) else: trained_until = 0 if trained_until >= max_iteration: return # array keys for data sources raw = gp.ArrayKey("RAW") swcs = gp.PointsKey("SWCS") voxel_size = gp.Coordinate((10, 3, 3)) input_size = gp.Coordinate(net_config["input_shape"]) * voxel_size * 2 # add request request = gp.BatchRequest() request.add(raw, input_size) request.add(swcs, input_size) data_sources = tuple(( gp.N5Source( filename=str(( filename / "consensus-neurons-with-machine-centerpoints-labelled-as-swcs-carved.n5" ).absolute()), datasets={raw: "volume"}, array_specs={ raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size, dtype=np.uint16) }, ), MouselightSwcFileSource( filename=str(( filename / "consensus-neurons-with-machine-centerpoints-labelled-as-swcs/G-002.swc" ).absolute()), points=(swcs, ), scale=voxel_size, transpose=(2, 1, 0), transform_file=str((filename / "transform.txt").absolute()), ), ) + gp.MergeProvider() + gp.RandomLocation(ensure_nonempty=swcs, ensure_centered=True) for filename in Path(sample_dir).iterdir() if "2018-08-01" in filename.name) pipeline = data_sources + gp.RandomProvider() with gp.build(pipeline): print("Starting training...") for i in range(max_iteration - trained_until): batch = pipeline.request_batch(request) vis_points_with_array(batch[raw].data, points_to_graph(batch[swcs].data), np.array(voxel_size))