def predict(data_dir, train_dir, iteration, sample, test_net_name='train_net', train_net_name='train_net', output_dir='.', clip_max=1000): if "hdf" not in data_dir: return print("Predicting ", sample) print( 'checkpoint: ', os.path.join(train_dir, train_net_name + '_checkpoint_%d' % iteration)) checkpoint = os.path.join(train_dir, train_net_name + '_checkpoint_%d' % iteration) with open(os.path.join(train_dir, test_net_name + '_config.json'), 'r') as f: net_config = json.load(f) with open(os.path.join(train_dir, test_net_name + '_names.json'), 'r') as f: net_names = json.load(f) # ArrayKeys raw = gp.ArrayKey('RAW') pred_mask = gp.ArrayKey('PRED_MASK') input_shape = gp.Coordinate(net_config['input_shape']) output_shape = gp.Coordinate(net_config['output_shape']) voxel_size = gp.Coordinate((1, 1, 1)) context = gp.Coordinate(input_shape - output_shape) / 2 # add ArrayKeys to batch request request = gp.BatchRequest() request.add(raw, input_shape, voxel_size=voxel_size) request.add(pred_mask, output_shape, voxel_size=voxel_size) print("chunk request %s" % request) source = (gp.Hdf5Source( data_dir, datasets={ raw: sample + '/raw', }, array_specs={ raw: gp.ArraySpec( interpolatable=True, dtype=np.uint16, voxel_size=voxel_size), }, ) + gp.Pad(raw, context) + nl.Clip(raw, 0, clip_max) + gp.Normalize(raw, factor=1.0 / clip_max) + gp.IntensityScaleShift(raw, 2, -1)) with gp.build(source): raw_roi = source.spec[raw].roi print("raw_roi: %s" % raw_roi) sample_shape = raw_roi.grow(-context, -context).get_shape() print(sample_shape) # create zarr file with corresponding chunk size zf = zarr.open(os.path.join(output_dir, sample + '.zarr'), mode='w') zf.create('volumes/pred_mask', shape=sample_shape, chunks=output_shape, dtype=np.float16) zf['volumes/pred_mask'].attrs['offset'] = [0, 0, 0] zf['volumes/pred_mask'].attrs['resolution'] = [1, 1, 1] pipeline = ( source + gp.tensorflow.Predict( graph=os.path.join(train_dir, test_net_name + '.meta'), checkpoint=checkpoint, inputs={ net_names['raw']: raw, }, outputs={ net_names['pred']: pred_mask, }, array_specs={ pred_mask: gp.ArraySpec(roi=raw_roi.grow(-context, -context), voxel_size=voxel_size), }, max_shared_memory=1024 * 1024 * 1024) + Convert(pred_mask, np.float16) + gp.ZarrWrite( dataset_names={ pred_mask: 'volumes/pred_mask', }, output_dir=output_dir, output_filename=sample + '.zarr', compression_type='gzip', dataset_dtypes={pred_mask: np.float16}) + # show a summary of time spend in each node every x iterations gp.PrintProfilingStats(every=100) + gp.Scan(reference=request, num_workers=5, cache_size=50)) with gp.build(pipeline): pipeline.request_batch(gp.BatchRequest())
def train_until(max_iteration, name='train_net', output_folder='.', clip_max=2000): # get the latest checkpoint if tf.train.latest_checkpoint(output_folder): trained_until = int(tf.train.latest_checkpoint(output_folder).split('_')[-1]) else: trained_until = 0 if trained_until >= max_iteration: return with open(os.path.join(output_folder, name + '_config.json'), 'r') as f: net_config = json.load(f) with open(os.path.join(output_folder, name + '_names.json'), 'r') as f: net_names = json.load(f) # array keys raw = gp.ArrayKey('RAW') gt_mask = gp.ArrayKey('GT_MASK') gt_dt = gp.ArrayKey('GT_DT') pred_dt = gp.ArrayKey('PRED_DT') loss_gradient = gp.ArrayKey('LOSS_GRADIENT') voxel_size = gp.Coordinate((1, 1, 1)) input_shape = gp.Coordinate(net_config['input_shape']) output_shape = gp.Coordinate(net_config['output_shape']) context = gp.Coordinate(input_shape - output_shape) / 2 request = gp.BatchRequest() request.add(raw, input_shape) request.add(gt_mask, output_shape) request.add(gt_dt, output_shape) snapshot_request = gp.BatchRequest() snapshot_request.add(raw, input_shape) snapshot_request.add(gt_mask, output_shape) snapshot_request.add(gt_dt, output_shape) snapshot_request.add(pred_dt, output_shape) snapshot_request.add(loss_gradient, output_shape) # specify data source data_sources = tuple() for data_file in data_files: current_path = os.path.join(data_dir, data_file) with h5py.File(current_path, 'r') as f: data_sources += tuple( gp.Hdf5Source( current_path, datasets={ raw: sample + '/raw', gt_mask: sample + '/fg' }, array_specs={ raw: gp.ArraySpec(interpolatable=True, dtype=np.uint16, voxel_size=voxel_size), gt_mask: gp.ArraySpec(interpolatable=False, dtype=np.bool, voxel_size=voxel_size), } ) + Convert(gt_mask, np.uint8) + gp.Pad(raw, context) + gp.Pad(gt_mask, context) + gp.RandomLocation() for sample in f) pipeline = ( data_sources + gp.RandomProvider() + gp.Reject(gt_mask, min_masked=0.005, reject_probability=1.) + DistanceTransform(gt_mask, gt_dt, 3) + nl.Clip(raw, 0, clip_max) + gp.Normalize(raw, factor=1.0/clip_max) + gp.ElasticAugment( control_point_spacing=[20, 20, 20], jitter_sigma=[1, 1, 1], rotation_interval=[0, math.pi/2.0], subsample=4) + gp.SimpleAugment(mirror_only=[1,2], transpose_only=[1,2]) + gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1) + gp.IntensityScaleShift(raw, 2,-1) + # train gp.PreCache( cache_size=40, num_workers=5) + gp.tensorflow.Train( os.path.join(output_folder, name), optimizer=net_names['optimizer'], loss=net_names['loss'], inputs={ net_names['raw']: raw, net_names['gt_dt']: gt_dt, }, outputs={ net_names['pred_dt']: pred_dt, }, gradients={ net_names['pred_dt']: loss_gradient, }, save_every=5000) + # visualize gp.Snapshot({ raw: 'volumes/raw', gt_mask: 'volumes/gt_mask', gt_dt: 'volumes/gt_dt', pred_dt: 'volumes/pred_dt', loss_gradient: 'volumes/gradient', }, output_filename=os.path.join(output_folder, 'snapshots', 'batch_{iteration}.hdf'), additional_request=snapshot_request, every=2000) + gp.PrintProfilingStats(every=500) ) with gp.build(pipeline): print("Starting training...") for i in range(max_iteration - trained_until): pipeline.request_batch(request)
def train_until(max_iteration, name='train_net', output_folder='.', clip_max=2000): # get the latest checkpoint if tf.train.latest_checkpoint(output_folder): trained_until = int(tf.train.latest_checkpoint(output_folder).split('_')[-1]) else: trained_until = 0 if trained_until >= max_iteration: return with open(os.path.join(output_folder, name + '_config.json'), 'r') as f: net_config = json.load(f) with open(os.path.join(output_folder, name + '_names.json'), 'r') as f: net_names = json.load(f) # array keys raw = gp.ArrayKey('RAW') gt_instances = gp.ArrayKey('GT_INSTANCES') gt_mask = gp.ArrayKey('GT_MASK') pred_mask = gp.ArrayKey('PRED_MASK') #loss_weights = gp.ArrayKey('LOSS_WEIGHTS') loss_gradients = gp.ArrayKey('LOSS_GRADIENTS') # array keys for base and add volume raw_base = gp.ArrayKey('RAW_BASE') gt_instances_base = gp.ArrayKey('GT_INSTANCES_BASE') gt_mask_base = gp.ArrayKey('GT_MASK_BASE') raw_add = gp.ArrayKey('RAW_ADD') gt_instances_add = gp.ArrayKey('GT_INSTANCES_ADD') gt_mask_add = gp.ArrayKey('GT_MASK_ADD') voxel_size = gp.Coordinate((1, 1, 1)) input_shape = gp.Coordinate(net_config['input_shape']) output_shape = gp.Coordinate(net_config['output_shape']) context = gp.Coordinate(input_shape - output_shape) / 2 request = gp.BatchRequest() request.add(raw, input_shape) request.add(gt_instances, output_shape) request.add(gt_mask, output_shape) #request.add(loss_weights, output_shape) request.add(raw_base, input_shape) request.add(raw_add, input_shape) request.add(gt_mask_base, output_shape) request.add(gt_mask_add, output_shape) snapshot_request = gp.BatchRequest() snapshot_request.add(raw, input_shape) #snapshot_request.add(raw_base, input_shape) #snapshot_request.add(raw_add, input_shape) snapshot_request.add(gt_mask, output_shape) #snapshot_request.add(gt_mask_base, output_shape) #snapshot_request.add(gt_mask_add, output_shape) snapshot_request.add(pred_mask, output_shape) snapshot_request.add(loss_gradients, output_shape) # specify data source # data source for base volume data_sources_base = tuple() for data_file in data_files: current_path = os.path.join(data_dir, data_file) with h5py.File(current_path, 'r') as f: data_sources_base += tuple( gp.Hdf5Source( current_path, datasets={ raw_base: sample + '/raw', gt_instances_base: sample + '/gt', gt_mask_base: sample + '/fg', }, array_specs={ raw_base: gp.ArraySpec(interpolatable=True, dtype=np.uint16, voxel_size=voxel_size), gt_instances_base: gp.ArraySpec(interpolatable=False, dtype=np.uint16, voxel_size=voxel_size), gt_mask_base: gp.ArraySpec(interpolatable=False, dtype=np.bool, voxel_size=voxel_size), } ) + Convert(gt_mask_base, np.uint8) + gp.Pad(raw_base, context) + gp.Pad(gt_instances_base, context) + gp.Pad(gt_mask_base, context) + gp.RandomLocation(min_masked=0.005, mask=gt_mask_base) #gp.Reject(gt_mask_base, min_masked=0.005, reject_probability=1.) for sample in f) data_sources_base += gp.RandomProvider() # data source for add volume data_sources_add = tuple() for data_file in data_files: current_path = os.path.join(data_dir, data_file) with h5py.File(current_path, 'r') as f: data_sources_add += tuple( gp.Hdf5Source( current_path, datasets={ raw_add: sample + '/raw', gt_instances_add: sample + '/gt', gt_mask_add: sample + '/fg', }, array_specs={ raw_add: gp.ArraySpec(interpolatable=True, dtype=np.uint16, voxel_size=voxel_size), gt_instances_add: gp.ArraySpec(interpolatable=False, dtype=np.uint16, voxel_size=voxel_size), gt_mask_add: gp.ArraySpec(interpolatable=False, dtype=np.bool, voxel_size=voxel_size), } ) + Convert(gt_mask_add, np.uint8) + gp.Pad(raw_add, context) + gp.Pad(gt_instances_add, context) + gp.Pad(gt_mask_add, context) + gp.RandomLocation() + gp.Reject(gt_mask_add, min_masked=0.005, reject_probability=0.95) for sample in f) data_sources_add += gp.RandomProvider() data_sources = tuple([data_sources_base, data_sources_add]) + gp.MergeProvider() pipeline = ( data_sources + nl.FusionAugment( raw_base, raw_add, gt_instances_base, gt_instances_add, raw, gt_instances, blend_mode='labels_mask', blend_smoothness=5, num_blended_objects=0 ) + BinarizeLabels(gt_instances, gt_mask) + nl.Clip(raw, 0, clip_max) + gp.Normalize(raw, factor=1.0/clip_max) + gp.ElasticAugment( control_point_spacing=[20, 20, 20], jitter_sigma=[1, 1, 1], rotation_interval=[0, math.pi/2.0], subsample=4) + gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2]) + gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1) + gp.IntensityScaleShift(raw, 2, -1) + #gp.BalanceLabels(gt_mask, loss_weights) + # train gp.PreCache( cache_size=40, num_workers=10) + gp.tensorflow.Train( os.path.join(output_folder, name), optimizer=net_names['optimizer'], loss=net_names['loss'], inputs={ net_names['raw']: raw, net_names['gt']: gt_mask, #net_names['loss_weights']: loss_weights, }, outputs={ net_names['pred']: pred_mask, }, gradients={ net_names['output']: loss_gradients, }, save_every=5000) + # visualize gp.Snapshot({ raw: 'volumes/raw', pred_mask: 'volumes/pred_mask', gt_mask: 'volumes/gt_mask', #loss_weights: 'volumes/loss_weights', loss_gradients: 'volumes/loss_gradients', }, output_filename=os.path.join(output_folder, 'snapshots', 'batch_{iteration}.hdf'), additional_request=snapshot_request, every=2500) + gp.PrintProfilingStats(every=1000) ) with gp.build(pipeline): print("Starting training...") for i in range(max_iteration - trained_until): pipeline.request_batch(request)