Exemplo n.º 1
0
def define_data_input(model, queue_batch=None):
    """Adds TF ops to load input data."""

    label_volume_map = {}
    for vol in FLAGS.label_volumes.split(','):
        volname, path, dataset = vol.split(':')
        label_volume_map[volname] = h5py.File(path, 'r')[dataset]

    image_volume_map = {}
    for vol in FLAGS.data_volumes.split(','):
        volname, path, dataset = vol.split(':')
        image_volume_map[volname] = h5py.File(path, 'r')[dataset]

    if queue_batch is None:
        queue_batch = FLAGS.batch_size

    # Fetch sizes of images and labels
    label_size = train_labels_size(model)
    image_size = train_image_size(model)

    label_radii = (label_size // 2).tolist()
    label_size = label_size.tolist()
    image_radii = (image_size // 2).tolist()
    image_size = image_size.tolist()

    # Fetch a single coordinate and volume name from a queue reading the
    # coordinate files or from saved hard/important examples
    coord, volname = inputs.load_patch_coordinates(FLAGS.train_coords)

    # Load object labels (segmentation).
    labels = inputs.load_from_numpylike(coord, volname, label_size,
                                        label_volume_map)

    label_shape = [1] + label_size[::-1] + [1]
    labels = tf.reshape(labels, label_shape)

    loss_weights = tf.constant(np.ones(label_shape, dtype=np.float32))

    # Load image data.
    patch = inputs.load_from_numpylike(coord, volname, image_size,
                                       image_volume_map)
    if FLAGS.with_membrane:
        data_shape = [1] + image_size[::-1] + [2]
    else:
        data_shape = [1] + image_size[::-1] + [1]
    patch = tf.reshape(patch, shape=data_shape)

    if ((FLAGS.image_stddev is None or FLAGS.image_mean is None)
            and not FLAGS.image_offset_scale_map):
        raise ValueError(
            '--image_mean, --image_stddev or --image_offset_scale_map '
            'need to be defined')

    # Convert segmentation into a soft object mask.
    lom = tf.logical_and(
        labels > 0,
        tf.equal(labels, labels[0, label_radii[2], label_radii[1],
                                label_radii[0], 0]))
    labels = inputs.soften_labels(lom)

    # Apply basic augmentations.
    transform_axes = augmentation.PermuteAndReflect(
        rank=5,
        permutable_axes=_get_permutable_axes(),
        reflectable_axes=_get_reflectable_axes())
    labels = transform_axes(labels)
    patch = transform_axes(patch)
    loss_weights = transform_axes(loss_weights)

    # Normalize image data.
    patch = inputs.offset_and_scale_patches(
        patch,
        volname[0],
        offset_scale_map=_get_offset_and_scale_map(),
        default_offset=FLAGS.image_mean,
        default_scale=FLAGS.image_stddev)

    # Create a batch of examples. Note that any TF operation before this line
    # will be hidden behind a queue, so expensive/slow ops can take advantage
    # of multithreading.
    patches, labels, loss_weights = tf.train.shuffle_batch(
        [patch, labels, loss_weights],
        queue_batch,
        num_threads=max(1, FLAGS.batch_size // 2),
        capacity=32 * FLAGS.batch_size,
        min_after_dequeue=4 * FLAGS.batch_size,
        enqueue_many=True)

    return patches, labels, loss_weights, coord, volname
Exemplo n.º 2
0
def define_data_input(model, queue_batch=None):
  """Adds TF ops to load input data."""

  label_volume_map = {}
  for vol in FLAGS.label_volumes.split(','):
    volname, path, dataset = vol.split(':')
    label_volume_map[volname] = h5py.File(path)[dataset]

  image_volume_map = {}
  for vol in FLAGS.data_volumes.split(','):
    volname, path, dataset = vol.split(':')
    image_volume_map[volname] = h5py.File(path)[dataset]

  if queue_batch is None:
    queue_batch = FLAGS.batch_size

  # Fetch sizes of images and labels
  label_size = train_labels_size(model)
  image_size = train_image_size(model)

  label_radii = (label_size // 2).tolist()
  label_size = label_size.tolist()
  image_radii = (image_size // 2).tolist()
  image_size = image_size.tolist()

  # Fetch a single coordinate and volume name from a queue reading the
  # coordinate files or from saved hard/important examples
  import os.path
  if os.path.isfile(FLAGS.train_coords):
    logging.info('{} exists.'.format(FLAGS.train_coords))
  else:
    logging.error('{} does not exist.'.format(FLAGS.train_coords))
  if FLAGS.sharding_rule == 0:
    coord, volname = inputs.load_patch_coordinates(FLAGS.train_coords)
  elif FLAGS.sharding_rule == 1 and 'horovod' in sys.modules:
    d = tf.data.TFRecordDataset(FLAGS.train_coords, compression_type='GZIP')
    d = d.shard(hvd.size(), hvd.rank())
    d = d.map(parser_fn)
    iterator = d.make_one_shot_iterator()
    coord, volname = iterator.get_next()
  else:
    logging.warning("You need to install Horovod to use sharding. Turning sharding off..")
    FLAGS.sharding_rule = 0
    coord, volname = inputs.load_patch_coordinates(FLAGS.train_coords)

 # Load object labels (segmentation).
  labels = inputs.load_from_numpylike(
      coord, volname, label_size, label_volume_map)

  label_shape = [1] + label_size[::-1] + [1]
  #label_shape = [1] + [1] + label_size[::-1] # NCDHW
  labels = tf.reshape(labels, label_shape)

  loss_weights = tf.constant(np.ones(label_shape, dtype=np.float32))

  # Load image data.
  patch = inputs.load_from_numpylike(
      coord, volname, image_size, image_volume_map)
  data_shape = [1] + image_size[::-1] + [1]
  patch = tf.reshape(patch, shape=data_shape)

  if ((FLAGS.image_stddev is None or FLAGS.image_mean is None) and
      not FLAGS.image_offset_scale_map):
    raise ValueError('--image_mean, --image_stddev or --image_offset_scale_map '
                     'need to be defined')

  # Convert segmentation into a soft object mask.
  lom = tf.logical_and(
      labels > 0,
      tf.equal(labels, labels[0,
                              label_radii[2],
                              label_radii[1],
                              label_radii[0],
                              0]))
  labels = inputs.soften_labels(lom)

  # Apply basic augmentations.
  transform_axes = augmentation.PermuteAndReflect(
      rank=5, permutable_axes=_get_permutable_axes(),
      reflectable_axes=_get_reflectable_axes())
  labels = transform_axes(labels)
  patch = transform_axes(patch)
  loss_weights = transform_axes(loss_weights)

  # Normalize image data.
  patch = inputs.offset_and_scale_patches(
      patch, volname[0],
      offset_scale_map=_get_offset_and_scale_map(),
      default_offset=FLAGS.image_mean,
      default_scale=FLAGS.image_stddev)

  # Create a batch of examples. Note that any TF operation before this line
  # will be hidden behind a queue, so expensive/slow ops can take advantage
  # of multithreading.
  #MK TODO: check num_threads usage here
  patches, labels, loss_weights = tf.train.shuffle_batch(
      [patch, labels, loss_weights], queue_batch,
      num_threads=max(1, FLAGS.batch_size // 2),
      capacity=32 * FLAGS.batch_size,
      min_after_dequeue=4 * FLAGS.batch_size,
      enqueue_many=True)

  return patches, labels, loss_weights, coord, volname