def define_data_input(model, queue_batch=None): """Adds TF ops to load input data.""" label_volume_map = {} for vol in FLAGS.label_volumes.split(','): volname, path, dataset = vol.split(':') label_volume_map[volname] = h5py.File(path, 'r')[dataset] image_volume_map = {} for vol in FLAGS.data_volumes.split(','): volname, path, dataset = vol.split(':') image_volume_map[volname] = h5py.File(path, 'r')[dataset] if queue_batch is None: queue_batch = FLAGS.batch_size # Fetch sizes of images and labels label_size = train_labels_size(model) image_size = train_image_size(model) label_radii = (label_size // 2).tolist() label_size = label_size.tolist() image_radii = (image_size // 2).tolist() image_size = image_size.tolist() # Fetch a single coordinate and volume name from a queue reading the # coordinate files or from saved hard/important examples coord, volname = inputs.load_patch_coordinates(FLAGS.train_coords) # Load object labels (segmentation). labels = inputs.load_from_numpylike(coord, volname, label_size, label_volume_map) label_shape = [1] + label_size[::-1] + [1] labels = tf.reshape(labels, label_shape) loss_weights = tf.constant(np.ones(label_shape, dtype=np.float32)) # Load image data. patch = inputs.load_from_numpylike(coord, volname, image_size, image_volume_map) if FLAGS.with_membrane: data_shape = [1] + image_size[::-1] + [2] else: data_shape = [1] + image_size[::-1] + [1] patch = tf.reshape(patch, shape=data_shape) if ((FLAGS.image_stddev is None or FLAGS.image_mean is None) and not FLAGS.image_offset_scale_map): raise ValueError( '--image_mean, --image_stddev or --image_offset_scale_map ' 'need to be defined') # Convert segmentation into a soft object mask. lom = tf.logical_and( labels > 0, tf.equal(labels, labels[0, label_radii[2], label_radii[1], label_radii[0], 0])) labels = inputs.soften_labels(lom) # Apply basic augmentations. transform_axes = augmentation.PermuteAndReflect( rank=5, permutable_axes=_get_permutable_axes(), reflectable_axes=_get_reflectable_axes()) labels = transform_axes(labels) patch = transform_axes(patch) loss_weights = transform_axes(loss_weights) # Normalize image data. patch = inputs.offset_and_scale_patches( patch, volname[0], offset_scale_map=_get_offset_and_scale_map(), default_offset=FLAGS.image_mean, default_scale=FLAGS.image_stddev) # Create a batch of examples. Note that any TF operation before this line # will be hidden behind a queue, so expensive/slow ops can take advantage # of multithreading. patches, labels, loss_weights = tf.train.shuffle_batch( [patch, labels, loss_weights], queue_batch, num_threads=max(1, FLAGS.batch_size // 2), capacity=32 * FLAGS.batch_size, min_after_dequeue=4 * FLAGS.batch_size, enqueue_many=True) return patches, labels, loss_weights, coord, volname
def define_data_input(model, queue_batch=None): """Adds TF ops to load input data.""" label_volume_map = {} for vol in FLAGS.label_volumes.split(','): volname, path, dataset = vol.split(':') label_volume_map[volname] = h5py.File(path)[dataset] image_volume_map = {} for vol in FLAGS.data_volumes.split(','): volname, path, dataset = vol.split(':') image_volume_map[volname] = h5py.File(path)[dataset] if queue_batch is None: queue_batch = FLAGS.batch_size # Fetch sizes of images and labels label_size = train_labels_size(model) image_size = train_image_size(model) label_radii = (label_size // 2).tolist() label_size = label_size.tolist() image_radii = (image_size // 2).tolist() image_size = image_size.tolist() # Fetch a single coordinate and volume name from a queue reading the # coordinate files or from saved hard/important examples import os.path if os.path.isfile(FLAGS.train_coords): logging.info('{} exists.'.format(FLAGS.train_coords)) else: logging.error('{} does not exist.'.format(FLAGS.train_coords)) if FLAGS.sharding_rule == 0: coord, volname = inputs.load_patch_coordinates(FLAGS.train_coords) elif FLAGS.sharding_rule == 1 and 'horovod' in sys.modules: d = tf.data.TFRecordDataset(FLAGS.train_coords, compression_type='GZIP') d = d.shard(hvd.size(), hvd.rank()) d = d.map(parser_fn) iterator = d.make_one_shot_iterator() coord, volname = iterator.get_next() else: logging.warning("You need to install Horovod to use sharding. Turning sharding off..") FLAGS.sharding_rule = 0 coord, volname = inputs.load_patch_coordinates(FLAGS.train_coords) # Load object labels (segmentation). labels = inputs.load_from_numpylike( coord, volname, label_size, label_volume_map) label_shape = [1] + label_size[::-1] + [1] #label_shape = [1] + [1] + label_size[::-1] # NCDHW labels = tf.reshape(labels, label_shape) loss_weights = tf.constant(np.ones(label_shape, dtype=np.float32)) # Load image data. patch = inputs.load_from_numpylike( coord, volname, image_size, image_volume_map) data_shape = [1] + image_size[::-1] + [1] patch = tf.reshape(patch, shape=data_shape) if ((FLAGS.image_stddev is None or FLAGS.image_mean is None) and not FLAGS.image_offset_scale_map): raise ValueError('--image_mean, --image_stddev or --image_offset_scale_map ' 'need to be defined') # Convert segmentation into a soft object mask. lom = tf.logical_and( labels > 0, tf.equal(labels, labels[0, label_radii[2], label_radii[1], label_radii[0], 0])) labels = inputs.soften_labels(lom) # Apply basic augmentations. transform_axes = augmentation.PermuteAndReflect( rank=5, permutable_axes=_get_permutable_axes(), reflectable_axes=_get_reflectable_axes()) labels = transform_axes(labels) patch = transform_axes(patch) loss_weights = transform_axes(loss_weights) # Normalize image data. patch = inputs.offset_and_scale_patches( patch, volname[0], offset_scale_map=_get_offset_and_scale_map(), default_offset=FLAGS.image_mean, default_scale=FLAGS.image_stddev) # Create a batch of examples. Note that any TF operation before this line # will be hidden behind a queue, so expensive/slow ops can take advantage # of multithreading. #MK TODO: check num_threads usage here patches, labels, loss_weights = tf.train.shuffle_batch( [patch, labels, loss_weights], queue_batch, num_threads=max(1, FLAGS.batch_size // 2), capacity=32 * FLAGS.batch_size, min_after_dequeue=4 * FLAGS.batch_size, enqueue_many=True) return patches, labels, loss_weights, coord, volname