def predict(iteration, raw_file, raw_dataset, out_file, db_host, db_name, worker_config, network_config, out_properties={}, **kwargs): setup_dir = os.path.dirname(os.path.realpath(__file__)) with open( os.path.join(setup_dir, '{}_net_config.json'.format(network_config)), 'r') as f: net_config = json.load(f) # voxels input_shape = gp.Coordinate(net_config['input_shape']) output_shape = gp.Coordinate(net_config['output_shape']) # nm voxel_size = gp.Coordinate((40, 4, 4)) input_size = input_shape * voxel_size output_size = output_shape * voxel_size parameterfile = os.path.join(setup_dir, 'parameter.json') if os.path.exists(parameterfile): with open(parameterfile, 'r') as f: parameters = json.load(f) else: parameters = {} raw = gp.ArrayKey('RAW') pred_postpre_vectors = gp.ArrayKey('PRED_POSTPRE_VECTORS') pred_post_indicator = gp.ArrayKey('PRED_POST_INDICATOR') chunk_request = gp.BatchRequest() chunk_request.add(raw, input_size) chunk_request.add(pred_postpre_vectors, output_size) chunk_request.add(pred_post_indicator, output_size) d_property = out_properties[ 'pred_partner_vectors'] if 'pred_partner_vectors' in out_properties else None m_property = out_properties[ 'pred_syn_indicator_out'] if 'pred_syn_indicator_out' in out_properties else None # Hdf5Source if raw_file.endswith('.hdf'): pipeline = gp.Hdf5Source(raw_file, datasets={raw: raw_dataset}, array_specs={ raw: gp.ArraySpec(interpolatable=True), }) elif raw_file.endswith('.zarr') or raw_file.endswith('.n5'): pipeline = gp.ZarrSource(raw_file, datasets={raw: raw_dataset}, array_specs={ raw: gp.ArraySpec(interpolatable=True), }) else: raise RuntimeError('unknwon input data format {}'.format(raw_file)) pipeline += gp.Pad(raw, size=None) pipeline += gp.Normalize(raw) pipeline += gp.IntensityScaleShift(raw, 2, -1) pipeline += gp.tensorflow.Predict( os.path.join(setup_dir, 'train_net_checkpoint_%d' % iteration), inputs={net_config['raw']: raw}, outputs={ net_config['pred_syn_indicator_out']: pred_post_indicator, net_config['pred_partner_vectors']: pred_postpre_vectors }, graph=os.path.join(setup_dir, '{}_net.meta'.format(network_config))) d_scale = parameters['d_scale'] if 'd_scale' in parameters else None if d_scale != 1 and d_scale is not None: pipeline += gp.IntensityScaleShift(pred_postpre_vectors, 1. / d_scale, 0) # Map back to nm world. if m_property is not None and 'scale' in m_property: if m_property['scale'] != 1: pipeline += gp.IntensityScaleShift(pred_post_indicator, m_property['scale'], 0) if d_property is not None and 'scale' in d_property: pipeline += gp.IntensityScaleShift(pred_postpre_vectors, d_property['scale'], 0) if d_property is not None and 'dtype' in d_property: assert d_property['dtype'] == 'int8' or d_property[ 'dtype'] == 'float32', 'predict not adapted to dtype {}'.format( d_property['dtype']) if d_property['dtype'] == 'int8': pipeline += IntensityScaleShiftClip(pred_postpre_vectors, 1, 0, clip=(-128, 127)) pipeline += gp.ZarrWrite(dataset_names={ pred_post_indicator: 'volumes/pred_syn_indicator', pred_postpre_vectors: 'volumes/pred_partner_vectors', }, output_filename=out_file) pipeline += gp.PrintProfilingStats(every=10) pipeline += gp.DaisyRequestBlocks( chunk_request, roi_map={ raw: 'read_roi', pred_postpre_vectors: 'write_roi', pred_post_indicator: 'write_roi' }, num_workers=worker_config['num_cache_workers'], block_done_callback=lambda b, s, d: block_done_callback( db_host, db_name, worker_config, b, s, d)) print("Starting prediction...") with gp.build(pipeline): pipeline.request_batch(gp.BatchRequest()) print("Prediction finished")
def train_until(max_iteration, name='train_net', output_folder='.', clip_max=2000): # get the latest checkpoint if tf.train.latest_checkpoint(output_folder): trained_until = int(tf.train.latest_checkpoint(output_folder).split('_')[-1]) else: trained_until = 0 if trained_until >= max_iteration: return with open(os.path.join(output_folder, name + '_config.json'), 'r') as f: net_config = json.load(f) with open(os.path.join(output_folder, name + '_names.json'), 'r') as f: net_names = json.load(f) # array keys raw = gp.ArrayKey('RAW') gt_mask = gp.ArrayKey('GT_MASK') gt_dt = gp.ArrayKey('GT_DT') pred_dt = gp.ArrayKey('PRED_DT') loss_gradient = gp.ArrayKey('LOSS_GRADIENT') voxel_size = gp.Coordinate((1, 1, 1)) input_shape = gp.Coordinate(net_config['input_shape']) output_shape = gp.Coordinate(net_config['output_shape']) context = gp.Coordinate(input_shape - output_shape) / 2 request = gp.BatchRequest() request.add(raw, input_shape) request.add(gt_mask, output_shape) request.add(gt_dt, output_shape) snapshot_request = gp.BatchRequest() snapshot_request.add(raw, input_shape) snapshot_request.add(gt_mask, output_shape) snapshot_request.add(gt_dt, output_shape) snapshot_request.add(pred_dt, output_shape) snapshot_request.add(loss_gradient, output_shape) # specify data source data_sources = tuple() for data_file in data_files: current_path = os.path.join(data_dir, data_file) with h5py.File(current_path, 'r') as f: data_sources += tuple( gp.Hdf5Source( current_path, datasets={ raw: sample + '/raw', gt_mask: sample + '/fg' }, array_specs={ raw: gp.ArraySpec(interpolatable=True, dtype=np.uint16, voxel_size=voxel_size), gt_mask: gp.ArraySpec(interpolatable=False, dtype=np.bool, voxel_size=voxel_size), } ) + Convert(gt_mask, np.uint8) + gp.Pad(raw, context) + gp.Pad(gt_mask, context) + gp.RandomLocation() for sample in f) pipeline = ( data_sources + gp.RandomProvider() + gp.Reject(gt_mask, min_masked=0.005, reject_probability=1.) + DistanceTransform(gt_mask, gt_dt, 3) + nl.Clip(raw, 0, clip_max) + gp.Normalize(raw, factor=1.0/clip_max) + gp.ElasticAugment( control_point_spacing=[20, 20, 20], jitter_sigma=[1, 1, 1], rotation_interval=[0, math.pi/2.0], subsample=4) + gp.SimpleAugment(mirror_only=[1,2], transpose_only=[1,2]) + gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1) + gp.IntensityScaleShift(raw, 2,-1) + # train gp.PreCache( cache_size=40, num_workers=5) + gp.tensorflow.Train( os.path.join(output_folder, name), optimizer=net_names['optimizer'], loss=net_names['loss'], inputs={ net_names['raw']: raw, net_names['gt_dt']: gt_dt, }, outputs={ net_names['pred_dt']: pred_dt, }, gradients={ net_names['pred_dt']: loss_gradient, }, save_every=5000) + # visualize gp.Snapshot({ raw: 'volumes/raw', gt_mask: 'volumes/gt_mask', gt_dt: 'volumes/gt_dt', pred_dt: 'volumes/pred_dt', loss_gradient: 'volumes/gradient', }, output_filename=os.path.join(output_folder, 'snapshots', 'batch_{iteration}.hdf'), additional_request=snapshot_request, every=2000) + gp.PrintProfilingStats(every=500) ) with gp.build(pipeline): print("Starting training...") for i in range(max_iteration - trained_until): pipeline.request_batch(request)
def node(self, raw_key: gp.ArrayKey, _gt_key=None, _mask_key=None): return gp.IntensityScaleShift(raw_key, scale=self.scale, shift=self.shift)
def build_pipeline(parameter, augment=True): voxel_size = gp.Coordinate(parameter['voxel_size']) # Array Specifications. raw = gp.ArrayKey('RAW') gt_neurons = gp.ArrayKey('GT_NEURONS') gt_postpre_vectors = gp.ArrayKey('GT_POSTPRE_VECTORS') gt_post_indicator = gp.ArrayKey('GT_POST_INDICATOR') post_loss_weight = gp.ArrayKey('POST_LOSS_WEIGHT') vectors_mask = gp.ArrayKey('VECTORS_MASK') pred_postpre_vectors = gp.ArrayKey('PRED_POSTPRE_VECTORS') pred_post_indicator = gp.ArrayKey('PRED_POST_INDICATOR') grad_syn_indicator = gp.ArrayKey('GRAD_SYN_INDICATOR') grad_partner_vectors = gp.ArrayKey('GRAD_PARTNER_VECTORS') # Points specifications dummypostsyn = gp.PointsKey('DUMMYPOSTSYN') postsyn = gp.PointsKey('POSTSYN') presyn = gp.PointsKey('PRESYN') trg_context = 140 # AddPartnerVectorMap context in nm - pre-post distance with open('train_net_config.json', 'r') as f: net_config = json.load(f) input_size = gp.Coordinate(net_config['input_shape']) * voxel_size output_size = gp.Coordinate(net_config['output_shape']) * voxel_size request = gp.BatchRequest() request.add(raw, input_size) request.add(gt_neurons, output_size) request.add(gt_postpre_vectors, output_size) request.add(gt_post_indicator, output_size) request.add(post_loss_weight, output_size) request.add(vectors_mask, output_size) request.add(dummypostsyn, output_size) for (key, request_spec) in request.items(): print(key) print(request_spec.roi) request_spec.roi.contains(request_spec.roi) # slkfdms snapshot_request = gp.BatchRequest({ pred_post_indicator: request[gt_postpre_vectors], pred_postpre_vectors: request[gt_postpre_vectors], grad_syn_indicator: request[gt_postpre_vectors], grad_partner_vectors: request[gt_postpre_vectors], vectors_mask: request[gt_postpre_vectors] }) postsyn_rastersetting = gp.RasterizationSettings( radius=parameter['blob_radius'], mask=gt_neurons, mode=parameter['blob_mode']) pipeline = tuple([ create_source(sample, raw, presyn, postsyn, dummypostsyn, parameter, gt_neurons) for sample in samples ]) pipeline += gp.RandomProvider() if augment: pipeline += gp.ElasticAugment([4, 40, 40], [0, 2, 2], [0, math.pi / 2.0], prob_slip=0.05, prob_shift=0.05, max_misalign=10, subsample=8) pipeline += gp.SimpleAugment(transpose_only=[1, 2], mirror_only=[1, 2]) pipeline += gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) pipeline += gp.IntensityScaleShift(raw, 2, -1) pipeline += gp.RasterizePoints( postsyn, gt_post_indicator, gp.ArraySpec(voxel_size=voxel_size, dtype=np.int32), postsyn_rastersetting) spec = gp.ArraySpec(voxel_size=voxel_size) pipeline += AddPartnerVectorMap( src_points=postsyn, trg_points=presyn, array=gt_postpre_vectors, radius=parameter['d_blob_radius'], trg_context=trg_context, # enlarge array_spec=spec, mask=gt_neurons, pointmask=vectors_mask) pipeline += gp.BalanceLabels(labels=gt_post_indicator, scales=post_loss_weight, slab=(-1, -1, -1), clipmin=parameter['cliprange'][0], clipmax=parameter['cliprange'][1]) if parameter['d_scale'] != 1: pipeline += gp.IntensityScaleShift(gt_postpre_vectors, scale=parameter['d_scale'], shift=0) pipeline += gp.PreCache(cache_size=40, num_workers=10) pipeline += gp.tensorflow.Train( './train_net', optimizer=net_config['optimizer'], loss=net_config['loss'], summary=net_config['summary'], log_dir='./tensorboard/', save_every=30000, # 10000 log_every=100, inputs={ net_config['raw']: raw, net_config['gt_partner_vectors']: gt_postpre_vectors, net_config['gt_syn_indicator']: gt_post_indicator, net_config['vectors_mask']: vectors_mask, # Loss weights --> mask net_config['indicator_weight']: post_loss_weight, # Loss weights }, outputs={ net_config['pred_partner_vectors']: pred_postpre_vectors, net_config['pred_syn_indicator']: pred_post_indicator, }, gradients={ net_config['pred_partner_vectors']: grad_partner_vectors, net_config['pred_syn_indicator']: grad_syn_indicator, }, ) # Visualize. pipeline += gp.IntensityScaleShift(raw, 0.5, 0.5) pipeline += gp.Snapshot( { raw: 'volumes/raw', gt_neurons: 'volumes/labels/neuron_ids', gt_post_indicator: 'volumes/gt_post_indicator', gt_postpre_vectors: 'volumes/gt_postpre_vectors', pred_postpre_vectors: 'volumes/pred_postpre_vectors', pred_post_indicator: 'volumes/pred_post_indicator', post_loss_weight: 'volumes/post_loss_weight', grad_syn_indicator: 'volumes/post_indicator_gradients', grad_partner_vectors: 'volumes/partner_vectors_gradients', vectors_mask: 'volumes/vectors_mask' }, every=1000, output_filename='batch_{iteration}.hdf', compression_type='gzip', additional_request=snapshot_request) pipeline += gp.PrintProfilingStats(every=100) print("Starting training...") max_iteration = parameter['max_iteration'] with gp.build(pipeline) as b: for i in range(max_iteration): b.request_batch(request)
def predict(**kwargs): name = kwargs['name'] raw = gp.ArrayKey('RAW') pred_affs = gp.ArrayKey('PRED_AFFS') pred_numinst = gp.ArrayKey('PRED_NUMINST') with open(os.path.join(kwargs['input_folder'], name + '_config.json'), 'r') as f: net_config = json.load(f) with open(os.path.join(kwargs['input_folder'], name + '_names.json'), 'r') as f: net_names = json.load(f) voxel_size = gp.Coordinate(kwargs['voxel_size']) input_shape_world = gp.Coordinate(net_config['input_shape']) * voxel_size output_shape_world = gp.Coordinate(net_config['output_shape']) * voxel_size context = (input_shape_world - output_shape_world) // 2 chunksize = list(np.asarray(output_shape_world) // 2) raw_key = kwargs.get('raw_key', 'volumes/raw') # add ArrayKeys to batch request request = gp.BatchRequest() request.add(raw, input_shape_world, voxel_size=voxel_size) request.add(pred_affs, output_shape_world, voxel_size=voxel_size) if kwargs['overlapping_inst']: request.add(pred_numinst, output_shape_world, voxel_size=voxel_size) if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr": raise NotImplementedError("predict node for %s not implemented yet", kwargs['input_format']) if kwargs['input_format'] == "hdf": sourceNode = gp.Hdf5Source with h5py.File( os.path.join(kwargs['data_folder'], kwargs['sample'] + ".hdf"), 'r') as f: shape = f[raw_key].shape[1:] elif kwargs['input_format'] == "zarr": sourceNode = gp.ZarrSource f = zarr.open( os.path.join(kwargs['data_folder'], kwargs['sample'] + ".zarr"), 'r') shape = f[raw_key].shape[1:] source = sourceNode(os.path.join( kwargs['data_folder'], kwargs['sample'] + "." + kwargs['input_format']), datasets={raw: raw_key}) if kwargs['output_format'] != "zarr": raise NotImplementedError("Please use zarr as prediction output") # open zarr file zf = zarr.open(os.path.join(kwargs['output_folder'], kwargs['sample'] + '.zarr'), mode='w') zf.create('volumes/pred_affs', shape=[int(np.prod(kwargs['patchshape']))] + list(shape), chunks=[int(np.prod(kwargs['patchshape']))] + list(chunksize), dtype=np.float16) zf['volumes/pred_affs'].attrs['offset'] = [0, 0] zf['volumes/pred_affs'].attrs['resolution'] = kwargs['voxel_size'] if kwargs['overlapping_inst']: zf.create('volumes/pred_numinst', shape=[int(kwargs['max_num_inst']) + 1] + list(shape), chunks=[int(kwargs['max_num_inst']) + 1] + list(chunksize), dtype=np.float16) zf['volumes/pred_numinst'].attrs['offset'] = [0, 0] zf['volumes/pred_numinst'].attrs['resolution'] = kwargs['voxel_size'] outputs = { net_names['pred_affs']: pred_affs, } outVolumes = { pred_affs: '/volumes/pred_affs', } if kwargs['overlapping_inst']: outputs[net_names['pred_numinst']] = pred_numinst outVolumes[pred_numinst] = '/volumes/pred_numinst' pipeline = ( source + gp.Pad(raw, context) + gp.IntensityScaleShift(raw, 2, -1) + # perform one training iteration for each passing batch (here we use # the tensor names earlier stored in train_net.config) gp.tensorflow.Predict(graph=os.path.join(kwargs['input_folder'], name + '.meta'), checkpoint=kwargs['checkpoint'], inputs={net_names['raw']: raw}, outputs=outputs) + # store all passing batches in the same HDF5 file gp.ZarrWrite(outVolumes, output_dir=kwargs['output_folder'], output_filename=kwargs['sample'] + ".zarr", compression_type='gzip') + # show a summary of time spend in each node every 10 iterations gp.PrintProfilingStats(every=100) + # iterate over the whole dataset in a scanning fashion, emitting # requests that match the size of the network gp.Scan(reference=request)) with gp.build(pipeline): # request an empty batch from Scan to trigger scanning of the dataset # without keeping the complete dataset in memory pipeline.request_batch(gp.BatchRequest())
def predict(data_dir, train_dir, iteration, sample, test_net_name='train_net', train_net_name='train_net', output_dir='.', clip_max=1000): if "hdf" not in data_dir: return print("Predicting ", sample) print( 'checkpoint: ', os.path.join(train_dir, train_net_name + '_checkpoint_%d' % iteration)) checkpoint = os.path.join(train_dir, train_net_name + '_checkpoint_%d' % iteration) with open(os.path.join(train_dir, test_net_name + '_config.json'), 'r') as f: net_config = json.load(f) with open(os.path.join(train_dir, test_net_name + '_names.json'), 'r') as f: net_names = json.load(f) # ArrayKeys raw = gp.ArrayKey('RAW') pred_mask = gp.ArrayKey('PRED_MASK') input_shape = gp.Coordinate(net_config['input_shape']) output_shape = gp.Coordinate(net_config['output_shape']) voxel_size = gp.Coordinate((1, 1, 1)) context = gp.Coordinate(input_shape - output_shape) / 2 # add ArrayKeys to batch request request = gp.BatchRequest() request.add(raw, input_shape, voxel_size=voxel_size) request.add(pred_mask, output_shape, voxel_size=voxel_size) print("chunk request %s" % request) source = (gp.Hdf5Source( data_dir, datasets={ raw: sample + '/raw', }, array_specs={ raw: gp.ArraySpec( interpolatable=True, dtype=np.uint16, voxel_size=voxel_size), }, ) + gp.Pad(raw, context) + nl.Clip(raw, 0, clip_max) + gp.Normalize(raw, factor=1.0 / clip_max) + gp.IntensityScaleShift(raw, 2, -1)) with gp.build(source): raw_roi = source.spec[raw].roi print("raw_roi: %s" % raw_roi) sample_shape = raw_roi.grow(-context, -context).get_shape() print(sample_shape) # create zarr file with corresponding chunk size zf = zarr.open(os.path.join(output_dir, sample + '.zarr'), mode='w') zf.create('volumes/pred_mask', shape=sample_shape, chunks=output_shape, dtype=np.float16) zf['volumes/pred_mask'].attrs['offset'] = [0, 0, 0] zf['volumes/pred_mask'].attrs['resolution'] = [1, 1, 1] pipeline = ( source + gp.tensorflow.Predict( graph=os.path.join(train_dir, test_net_name + '.meta'), checkpoint=checkpoint, inputs={ net_names['raw']: raw, }, outputs={ net_names['pred']: pred_mask, }, array_specs={ pred_mask: gp.ArraySpec(roi=raw_roi.grow(-context, -context), voxel_size=voxel_size), }, max_shared_memory=1024 * 1024 * 1024) + Convert(pred_mask, np.float16) + gp.ZarrWrite( dataset_names={ pred_mask: 'volumes/pred_mask', }, output_dir=output_dir, output_filename=sample + '.zarr', compression_type='gzip', dataset_dtypes={pred_mask: np.float16}) + # show a summary of time spend in each node every x iterations gp.PrintProfilingStats(every=100) + gp.Scan(reference=request, num_workers=5, cache_size=50)) with gp.build(pipeline): pipeline.request_batch(gp.BatchRequest())
def train_until(**kwargs): print("cuda visibile devices", os.environ["CUDA_VISIBLE_DEVICES"]) if tf.train.latest_checkpoint(kwargs['output_folder']): trained_until = int( tf.train.latest_checkpoint(kwargs['output_folder']).split('_')[-1]) else: trained_until = 0 if trained_until >= kwargs['max_iteration']: return raw = gp.ArrayKey('RAW') raw_cropped = gp.ArrayKey('RAW_CROPPED') gt_labels = gp.ArrayKey('GT_LABELS') gt_instances = gp.ArrayKey('GT_INSTANCES') gt_affs = gp.ArrayKey('GT_AFFS') gt_numinst = gp.ArrayKey('GT_NUMINST') gt_sample_mask = gp.ArrayKey('GT_SAMPLE_MASK') pred_affs = gp.ArrayKey('PRED_AFFS') pred_affs_gradients = gp.ArrayKey('PRED_AFFS_GRADIENTS') pred_numinst = gp.ArrayKey('PRED_NUMINST') with open(os.path.join(kwargs['output_folder'], kwargs['name'] + '_config.json'), 'r') as f: net_config = json.load(f) with open(os.path.join(kwargs['output_folder'], kwargs['name'] + '_names.json'), 'r') as f: net_names = json.load(f) voxel_size = gp.Coordinate(kwargs['voxel_size']) input_shape_world = gp.Coordinate(net_config['input_shape'])*voxel_size output_shape_world = gp.Coordinate(net_config['output_shape'])*voxel_size context = gp.Coordinate(input_shape_world - output_shape_world) / 2 # formulate the request for what a batch should (at least) contain request = gp.BatchRequest() request.add(raw, input_shape_world) request.add(raw_cropped, output_shape_world) request.add(gt_labels, output_shape_world) request.add(gt_instances, output_shape_world) request.add(gt_sample_mask, output_shape_world) request.add(gt_affs, output_shape_world) if kwargs['overlapping_inst']: request.add(gt_numinst, output_shape_world) # request.add(loss_weights_affs, output_shape_world) # when we make a snapshot for inspection (see below), we also want to # request the predicted affinities and gradients of the loss wrt the # affinities snapshot_request = gp.BatchRequest() snapshot_request.add(raw_cropped, output_shape_world) snapshot_request.add(pred_affs, output_shape_world) if kwargs['overlapping_inst']: snapshot_request.add(pred_numinst, output_shape_world) # snapshot_request.add(pred_affs_gradients, output_shape_world) if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr": raise NotImplementedError("train node for %s not implemented yet", kwargs['input_format']) raw_key = kwargs.get('raw_key', 'volumes/raw') print('raw key: ', raw_key) fls = [] shapes = [] for f in kwargs['data_files']: fls.append(os.path.splitext(f)[0]) if kwargs['input_format'] == "hdf": vol = h5py.File(f, 'r')[raw_key] elif kwargs['input_format'] == "zarr": vol = zarr.open(f, 'r')[raw_key] # print(f, vol.shape, vol.dtype) shapes.append(vol.shape) if vol.dtype != np.float32: print("please convert to float32") ln = len(fls) print("first 5 files: ", fls[0:4]) if kwargs['input_format'] == "hdf": sourceNode = gp.Hdf5Source elif kwargs['input_format'] == "zarr": sourceNode = gp.ZarrSource neighborhood = [] psH = np.array(kwargs['patchshape'])//2 for i in range(-psH[1], psH[1]+1, kwargs['patchstride'][1]): for j in range(-psH[2], psH[2]+1, kwargs['patchstride'][2]): neighborhood.append([i,j]) datasets = { raw: raw_key, gt_labels: 'volumes/gt_labels', gt_instances: 'volumes/gt_instances' } array_specs = { raw: gp.ArraySpec(interpolatable=True), gt_labels: gp.ArraySpec(interpolatable=False), gt_instances: gp.ArraySpec(interpolatable=False) } inputs = { net_names['raw']: raw, net_names['gt_affs']: gt_affs, # net_names['loss_weights_affs']: loss_weights_affs, } outputs = { net_names['pred_affs']: pred_affs, net_names['raw_cropped']: raw_cropped, } snapshot = { raw: '/volumes/raw', raw_cropped: 'volumes/raw_cropped', gt_affs: '/volumes/gt_affs', pred_affs: '/volumes/pred_affs', pred_affs_gradients: '/volumes/pred_affs_gradients', } if kwargs['overlapping_inst']: datasets[gt_numinst] = 'volumes/gt_numinst' array_specs[gt_numinst] = gp.ArraySpec(interpolatable=False) inputs[net_names['gt_numinst']] = gt_numinst outputs[net_names['pred_numinst']] = pred_numinst snapshot[gt_numinst] = '/volumes/gt_numinst' snapshot[pred_numinst] = '/volumes/pred_numinst' augmentation = kwargs['augmentation'] sampling = kwargs['sampling'] source_fg = tuple( sourceNode( fls[t] + "." + kwargs['input_format'], datasets=datasets, array_specs=array_specs ) + gp.Pad(raw, context) + # chose a random location for each requested batch nl.CountOverlap(gt_labels, gt_sample_mask, maxnuminst=1) + gp.RandomLocation( min_masked=sampling['min_masked'], mask=gt_sample_mask ) for t in range(ln) ) source_fg += gp.RandomProvider() source_overlap = tuple( sourceNode( fls[t] + "." + kwargs['input_format'], datasets=datasets, array_specs=array_specs ) + gp.Pad(raw, context) + # chose a random location for each requested batch nl.MaskCloseDistanceToOverlap( gt_labels, gt_sample_mask, sampling['overlap_min_dist'], sampling['overlap_max_dist'] ) + gp.RandomLocation( min_masked=sampling['min_masked_overlap'], mask=gt_sample_mask ) for t in range(ln) ) source_overlap += gp.RandomProvider() pipeline = ( (source_fg, source_overlap) + # chose a random source (i.e., sample) from the above gp.RandomProvider(probabilities=[sampling['probability_fg'], sampling['probability_overlap']]) + # elastically deform the batch gp.ElasticAugment( augmentation['elastic']['control_point_spacing'], augmentation['elastic']['jitter_sigma'], [augmentation['elastic']['rotation_min']*np.pi/180.0, augmentation['elastic']['rotation_max']*np.pi/180.0]) + # apply transpose and mirror augmentations gp.SimpleAugment( mirror_only=augmentation['simple'].get("mirror"), transpose_only=augmentation['simple'].get("transpose")) + # # scale and shift the intensity of the raw array gp.IntensityAugment( raw, scale_min=augmentation['intensity']['scale'][0], scale_max=augmentation['intensity']['scale'][1], shift_min=augmentation['intensity']['shift'][0], shift_max=augmentation['intensity']['shift'][1], z_section_wise=False) + gp.IntensityScaleShift(raw, 2, -1) + # convert labels into affinities between voxels nl.AddAffinities( neighborhood, gt_labels, gt_affs, multiple_labels=kwargs['overlapping_inst']) + # pre-cache batches from the point upstream gp.PreCache( cache_size=kwargs['cache_size'], num_workers=kwargs['num_workers']) + # perform one training iteration for each passing batch (here we use # the tensor names earlier stored in train_net.config) gp.tensorflow.Train( os.path.join(kwargs['output_folder'], kwargs['name']), optimizer=net_names['optimizer'], summary=net_names['summaries'], log_dir=kwargs['output_folder'], loss=net_names['loss'], inputs=inputs, outputs=outputs, gradients={ net_names['pred_affs']: pred_affs_gradients, }, save_every=kwargs['checkpoints']) + # save the passing batch as an HDF5 file for inspection gp.Snapshot( snapshot, output_dir=os.path.join(kwargs['output_folder'], 'snapshots'), output_filename='batch_{iteration}.hdf', every=kwargs['snapshots'], additional_request=snapshot_request, compression_type='gzip') + # show a summary of time spend in each node every 10 iterations gp.PrintProfilingStats(every=kwargs['profiling']) ) ######### # TRAIN # ######### print("Starting training...") with gp.build(pipeline): print(pipeline) for i in range(trained_until, kwargs['max_iteration']): # print("request", request) start = time.time() pipeline.request_batch(request) time_of_iteration = time.time() - start logger.info( "Batch: iteration=%d, time=%f", i, time_of_iteration) # exit() print("Training finished")
def train_until(max_iteration, name='train_net', output_folder='.', clip_max=2000): # get the latest checkpoint if tf.train.latest_checkpoint(output_folder): trained_until = int(tf.train.latest_checkpoint(output_folder).split('_')[-1]) else: trained_until = 0 if trained_until >= max_iteration: return with open(os.path.join(output_folder, name + '_config.json'), 'r') as f: net_config = json.load(f) with open(os.path.join(output_folder, name + '_names.json'), 'r') as f: net_names = json.load(f) # array keys raw = gp.ArrayKey('RAW') gt_instances = gp.ArrayKey('GT_INSTANCES') gt_mask = gp.ArrayKey('GT_MASK') pred_mask = gp.ArrayKey('PRED_MASK') #loss_weights = gp.ArrayKey('LOSS_WEIGHTS') loss_gradients = gp.ArrayKey('LOSS_GRADIENTS') # array keys for base and add volume raw_base = gp.ArrayKey('RAW_BASE') gt_instances_base = gp.ArrayKey('GT_INSTANCES_BASE') gt_mask_base = gp.ArrayKey('GT_MASK_BASE') raw_add = gp.ArrayKey('RAW_ADD') gt_instances_add = gp.ArrayKey('GT_INSTANCES_ADD') gt_mask_add = gp.ArrayKey('GT_MASK_ADD') voxel_size = gp.Coordinate((1, 1, 1)) input_shape = gp.Coordinate(net_config['input_shape']) output_shape = gp.Coordinate(net_config['output_shape']) context = gp.Coordinate(input_shape - output_shape) / 2 request = gp.BatchRequest() request.add(raw, input_shape) request.add(gt_instances, output_shape) request.add(gt_mask, output_shape) #request.add(loss_weights, output_shape) request.add(raw_base, input_shape) request.add(raw_add, input_shape) request.add(gt_mask_base, output_shape) request.add(gt_mask_add, output_shape) snapshot_request = gp.BatchRequest() snapshot_request.add(raw, input_shape) #snapshot_request.add(raw_base, input_shape) #snapshot_request.add(raw_add, input_shape) snapshot_request.add(gt_mask, output_shape) #snapshot_request.add(gt_mask_base, output_shape) #snapshot_request.add(gt_mask_add, output_shape) snapshot_request.add(pred_mask, output_shape) snapshot_request.add(loss_gradients, output_shape) # specify data source # data source for base volume data_sources_base = tuple() for data_file in data_files: current_path = os.path.join(data_dir, data_file) with h5py.File(current_path, 'r') as f: data_sources_base += tuple( gp.Hdf5Source( current_path, datasets={ raw_base: sample + '/raw', gt_instances_base: sample + '/gt', gt_mask_base: sample + '/fg', }, array_specs={ raw_base: gp.ArraySpec(interpolatable=True, dtype=np.uint16, voxel_size=voxel_size), gt_instances_base: gp.ArraySpec(interpolatable=False, dtype=np.uint16, voxel_size=voxel_size), gt_mask_base: gp.ArraySpec(interpolatable=False, dtype=np.bool, voxel_size=voxel_size), } ) + Convert(gt_mask_base, np.uint8) + gp.Pad(raw_base, context) + gp.Pad(gt_instances_base, context) + gp.Pad(gt_mask_base, context) + gp.RandomLocation(min_masked=0.005, mask=gt_mask_base) #gp.Reject(gt_mask_base, min_masked=0.005, reject_probability=1.) for sample in f) data_sources_base += gp.RandomProvider() # data source for add volume data_sources_add = tuple() for data_file in data_files: current_path = os.path.join(data_dir, data_file) with h5py.File(current_path, 'r') as f: data_sources_add += tuple( gp.Hdf5Source( current_path, datasets={ raw_add: sample + '/raw', gt_instances_add: sample + '/gt', gt_mask_add: sample + '/fg', }, array_specs={ raw_add: gp.ArraySpec(interpolatable=True, dtype=np.uint16, voxel_size=voxel_size), gt_instances_add: gp.ArraySpec(interpolatable=False, dtype=np.uint16, voxel_size=voxel_size), gt_mask_add: gp.ArraySpec(interpolatable=False, dtype=np.bool, voxel_size=voxel_size), } ) + Convert(gt_mask_add, np.uint8) + gp.Pad(raw_add, context) + gp.Pad(gt_instances_add, context) + gp.Pad(gt_mask_add, context) + gp.RandomLocation() + gp.Reject(gt_mask_add, min_masked=0.005, reject_probability=0.95) for sample in f) data_sources_add += gp.RandomProvider() data_sources = tuple([data_sources_base, data_sources_add]) + gp.MergeProvider() pipeline = ( data_sources + nl.FusionAugment( raw_base, raw_add, gt_instances_base, gt_instances_add, raw, gt_instances, blend_mode='labels_mask', blend_smoothness=5, num_blended_objects=0 ) + BinarizeLabels(gt_instances, gt_mask) + nl.Clip(raw, 0, clip_max) + gp.Normalize(raw, factor=1.0/clip_max) + gp.ElasticAugment( control_point_spacing=[20, 20, 20], jitter_sigma=[1, 1, 1], rotation_interval=[0, math.pi/2.0], subsample=4) + gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2]) + gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1) + gp.IntensityScaleShift(raw, 2, -1) + #gp.BalanceLabels(gt_mask, loss_weights) + # train gp.PreCache( cache_size=40, num_workers=10) + gp.tensorflow.Train( os.path.join(output_folder, name), optimizer=net_names['optimizer'], loss=net_names['loss'], inputs={ net_names['raw']: raw, net_names['gt']: gt_mask, #net_names['loss_weights']: loss_weights, }, outputs={ net_names['pred']: pred_mask, }, gradients={ net_names['output']: loss_gradients, }, save_every=5000) + # visualize gp.Snapshot({ raw: 'volumes/raw', pred_mask: 'volumes/pred_mask', gt_mask: 'volumes/gt_mask', #loss_weights: 'volumes/loss_weights', loss_gradients: 'volumes/loss_gradients', }, output_filename=os.path.join(output_folder, 'snapshots', 'batch_{iteration}.hdf'), additional_request=snapshot_request, every=2500) + gp.PrintProfilingStats(every=1000) ) with gp.build(pipeline): print("Starting training...") for i in range(max_iteration - trained_until): pipeline.request_batch(request)
def train_until(**kwargs): if tf.train.latest_checkpoint(kwargs['output_folder']): trained_until = int( tf.train.latest_checkpoint(kwargs['output_folder']).split('_')[-1]) else: trained_until = 0 if trained_until >= kwargs['max_iteration']: return anchor = gp.ArrayKey('ANCHOR') raw = gp.ArrayKey('RAW') raw_cropped = gp.ArrayKey('RAW_CROPPED') gt_labels = gp.ArrayKey('GT_LABELS') gt_threeclass = gp.ArrayKey('GT_THREECLASS') gt_sdt = gp.ArrayKey('GT_SDT') gt_cpv = gp.ArrayKey('GT_CPV') gt_points = gp.PointsKey('GT_CPV_POINTS') pred_sdt = gp.ArrayKey('PRED_SDT') pred_cpv = gp.ArrayKey('PRED_CPV') pred_sdt_gradients = gp.ArrayKey('PRED_SDT_GRADIENTS') pred_cpv_gradients = gp.ArrayKey('PRED_CPV_GRADIENTS') with open( os.path.join(kwargs['output_folder'], kwargs['name'] + '_config.json'), 'r') as f: net_config = json.load(f) with open( os.path.join(kwargs['output_folder'], kwargs['name'] + '_names.json'), 'r') as f: net_names = json.load(f) voxel_size = gp.Coordinate(kwargs['voxel_size']) input_shape_world = gp.Coordinate(net_config['input_shape']) * voxel_size output_shape_world = gp.Coordinate(net_config['output_shape']) * voxel_size # formulate the request for what a batch should (at least) contain request = gp.BatchRequest() request.add(raw, input_shape_world) request.add(raw_cropped, output_shape_world) request.add(gt_labels, output_shape_world) request.add(gt_threeclass, output_shape_world) request.add(gt_sdt, output_shape_world) request.add(gt_cpv, output_shape_world) request.add(anchor, output_shape_world) # when we make a snapshot for inspection (see below), we also want to # request the predicted affinities and gradients of the loss wrt the # affinities snapshot_request = gp.BatchRequest() snapshot_request.add(raw_cropped, output_shape_world) snapshot_request.add(gt_sdt, output_shape_world) snapshot_request.add(gt_threeclass, output_shape_world) snapshot_request.add(gt_labels, output_shape_world) snapshot_request.add(pred_sdt, output_shape_world) snapshot_request.add(pred_cpv, output_shape_world) # snapshot_request.add(pred_sdt_gradients, output_shape_world) if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr": raise NotImplementedError("train node for %s not implemented yet", kwargs['input_format']) fls = [] shapes = [] for f in kwargs['data_files']: fls.append(os.path.splitext(f)[0]) if kwargs['input_format'] == "hdf": vol = h5py.File(f, 'r')['volumes/raw'] elif kwargs['input_format'] == "zarr": vol = zarr.open(f, 'r')['volumes/raw'] print(f, vol.shape, vol.dtype) shapes.append(vol.shape) ln = len(fls) print("first 5 files: ", fls[0:4]) # padR = 46 # padGT = 32 if kwargs['input_format'] == "hdf": sourceNode = gp.Hdf5Source elif kwargs['input_format'] == "zarr": sourceNode = gp.ZarrSource augmentation = kwargs['augmentation'] pipeline = ( tuple( ( sourceNode( fls[t] + "." + kwargs['input_format'], datasets={ raw: 'volumes/raw', gt_labels: 'volumes/gt_labels', gt_threeclass: 'volumes/gt_threeclass', # gt_sdt: 'volumes/gt_tanh', anchor: 'volumes/gt_tanh', }, array_specs={ raw: gp.ArraySpec(interpolatable=True), gt_labels: gp.ArraySpec(interpolatable=False), gt_threeclass: gp.ArraySpec(interpolatable=False), # gt_sdt: gp.ArraySpec(interpolatable=False), anchor: gp.ArraySpec(interpolatable=False) } ), gp.CsvIDPointsSource( fls[t] + ".csv", gt_points, points_spec=gp.PointsSpec(roi=gp.Roi( gp.Coordinate((0, 0, 0)), gp.Coordinate(shapes[t]))) ) ) + gp.MergeProvider() + gp.Pad(raw, None) + gp.Pad(gt_labels, None) + gp.Pad(gt_threeclass, None) # + gp.Pad(gt_sdt, None) + gp.Pad(gt_points, None) # chose a random location for each requested batch + gp.RandomLocation() for t in range(ln) ) + # chose a random source (i.e., sample) from the above gp.RandomProvider() + # elastically deform the batch (gp.ElasticAugment( augmentation['elastic']['control_point_spacing'], augmentation['elastic']['jitter_sigma'], [augmentation['elastic']['rotation_min']*np.pi/180.0, augmentation['elastic']['rotation_max']*np.pi/180.0], subsample=augmentation['elastic'].get('subsample', 1)) \ if augmentation.get('elastic') is not None else NoOp()) + gp.AddSdt( gt_labels, gt_threeclass, gt_sdt, -9) + # apply transpose and mirror augmentations gp.SimpleAugment(mirror_only=augmentation['simple'].get("mirror"), transpose_only=augmentation['simple'].get("transpose")) + # # scale and shift the intensity of the raw array # # scale and shift the intensity of the raw array (gp.IntensityAugment( raw, scale_min=augmentation['intensity']['scale'][0], scale_max=augmentation['intensity']['scale'][1], shift_min=augmentation['intensity']['shift'][0], shift_max=augmentation['intensity']['shift'][1], z_section_wise=False) \ if augmentation.get('intensity') is not None else NoOp()) + (gp.IntensityScaleShift( raw, scale=augmentation['scale_shift']['scale'], shift=augmentation['scale_shift']['shift']) \ if augmentation.get('scale_shift') is not None else NoOp()) + gp.AddCPV( gt_points, gt_labels, gt_cpv) + # pre-cache batches from the point upstream gp.PreCache( cache_size=kwargs['cache_size'], num_workers=kwargs['num_workers']) + # perform one training iteration for each passing batch (here we use # the tensor names earlier stored in train_net.config) gp.tensorflow.Train( os.path.join(kwargs['output_folder'], kwargs['name']), optimizer=net_names['optimizer'], summary=net_names['summaries'], log_dir=kwargs['output_folder'], loss=net_names['loss'], inputs={ net_names['raw']: raw, net_names['gt_labels']: gt_labels, net_names['gt_sdt']: gt_sdt, net_names['gt_cpv']: gt_cpv, net_names['anchor']: anchor, }, outputs={ net_names['pred_sdt']: pred_sdt, net_names['pred_cpv']: pred_cpv, net_names['raw_cropped']: raw_cropped, }, gradients={ net_names['pred_sdt']: pred_sdt_gradients, net_names['pred_cpv']: pred_cpv_gradients }, save_every=kwargs['checkpoints']) + # save the passing batch as an HDF5 file for inspection gp.Snapshot( { raw: '/volumes/raw', raw_cropped: 'volumes/raw_cropped', gt_sdt: '/volumes/gt_sdt', pred_sdt: '/volumes/pred_sdt', pred_cpv: '/volumes/pred_cpv', # pred_sdt_gradients: '/volumes/pred_sdt_gradients', }, output_dir=os.path.join(kwargs['output_folder'], 'snapshots'), output_filename='batch_{iteration}.hdf', every=kwargs['snapshots'], additional_request=snapshot_request, compression_type='gzip') + # show a summary of time spend in each node every 10 iterations gp.PrintProfilingStats(every=kwargs['profiling']) ) ######### # TRAIN # ######### print("Starting training...") with gp.build(pipeline): print(pipeline) for i in range(trained_until, kwargs['max_iteration']): # print("request", request) start = time.time() pipeline.request_batch(request) time_of_iteration = time.time() - start logger.info("Batch: iteration=%d, time=%f", i, time_of_iteration) # exit() print("Training finished")