Beispiel #1
0
  def metric_fn(labels_r2, logits_r2):
    """Compute evaluation metrics."""
    if labels_r2.dtype == tf.bfloat16:
      labels_r2 = tf.cast(labels_r2, tf.float32)
    if logits_r2.dtype == tf.bfloat16:
      logits_r2 = tf.cast(logits_r2, tf.float32)

    labels = tf.reshape(labels_r2, original_shape)
    logits = tf.reshape(logits_r2, original_shape)

    predictions = tf.nn.softmax(logits)
    categorical_crossentropy = tf.keras.losses.categorical_crossentropy(
        labels, predictions, from_logits=False)
    adaptive_dice32_val = metrics.adaptive_dice32(labels, predictions)
    return {
        'accuracy':
            tf.metrics.accuracy(
                labels=tf.argmax(labels, -1),
                predictions=tf.argmax(predictions, -1)),
        'adaptice_dice32':
            tf.metrics.mean(adaptive_dice32_val, name='adaptive_dice32'),
        'categorical_crossentropy':
            tf.metrics.mean(
                categorical_crossentropy, name='categorical_crossentropy'),
    }
Beispiel #2
0
def _unet_model_fn(image, labels, mode, params):
    """Builds the UNet model graph, train op and eval metrics.

  Args:
    image: input image Tensor. Shape [x, y, z, num_channels].
    labels: input label Tensor. Shape [x, y, z, num_classes].
    mode: TRAIN, EVAL or PREDICT.
    params: model parameters dictionary.

  Returns:
    EstimatorSpec or TPUEstimatorSpec.
  """
    with tf.variable_scope('base', reuse=tf.AUTO_REUSE):
        if params['use_bfloat16']:
            with tf.contrib.tpu.bfloat16_scope():
                logits = unet3d_base(
                    image,
                    pool_size=(2, 2, 2),
                    n_labels=params['num_classes'],
                    deconvolution=params['deconvolution'],
                    depth=params['depth'],
                    n_base_filters=params['num_base_filters'],
                    batch_normalization=params['use_batch_norm'],
                    data_format=params['data_format'])
        else:
            with tf.variable_scope(''):
                logits = unet3d_base(
                    image,
                    pool_size=(2, 2, 2),
                    n_labels=params['num_classes'],
                    deconvolution=params['deconvolution'],
                    depth=params['depth'],
                    n_base_filters=params['num_base_filters'],
                    batch_normalization=params['use_batch_norm'],
                    data_format=params['data_format'])

    loss = None
    if mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL:
        with tf.variable_scope('loss', reuse=tf.AUTO_REUSE):
            if params['loss'] == 'adaptive_dice32':
                predictions = tf.nn.softmax(logits)
                assert (
                    labels.get_shape().as_list() ==
                    predictions.get_shape().as_list()
                ), 'predictions shape {} is not equal to label shape {}'.format(
                    predictions.get_shape().as_list(),
                    labels.get_shape().as_list())
                loss = metrics.adaptive_dice32(labels, predictions)
            else:
                if mode == tf.estimator.ModeKeys.TRAIN and params[
                        'use_index_label_in_train']:
                    assert (
                        len(labels.get_shape().as_list()) + 1 == len(
                            logits.get_shape().as_list())
                    ), 'logits shape {} is not equal to label shape {} plus one'.format(
                        logits.get_shape().as_list(),
                        labels.get_shape().as_list())
                    labels_idx = tf.cast(labels, dtype=tf.int32)
                else:
                    assert (
                        labels.get_shape().as_list() ==
                        logits.get_shape().as_list()
                    ), 'logits shape {} is not equal to label shape {}'.format(
                        logits.get_shape().as_list(),
                        labels.get_shape().as_list())
                    # Convert the one-hot encoding to label index.
                    channel_dim = -1
                    labels_idx = tf.argmax(labels,
                                           axis=channel_dim,
                                           output_type=tf.int32)
                logits = tf.cast(logits, dtype=tf.float32)
                loss = tf.losses.sparse_softmax_cross_entropy(
                    labels=labels_idx, logits=logits)

    if mode == tf.estimator.ModeKeys.TRAIN:
        learning_rate = tf.compat.v1.train.exponential_decay(
            float(params['init_learning_rate']),
            tf.compat.v1.train.get_or_create_global_step(),
            decay_steps=params['lr_decay_steps'],
            decay_rate=params['lr_decay_rate'])

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        optimizer = create_optimizer(learning_rate, params)
        if params['use_tpu']:
            optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)

        minimize_op = optimizer.minimize(loss, tf.train.get_global_step())
        with tf.control_dependencies(update_ops):
            train_op = minimize_op

            def host_call_fn(gs, lr):
                """Training host call. Creates scalar summaries for training metrics.

        Args:
          gs: `Tensor with shape `[batch]` for the global_step
          lr: `Tensor` with shape `[batch]` for the learning_rate.

        Returns:
          List of summary ops to run on the CPU host.
        """
                gs = gs[0]
                with summary.create_file_writer(
                        params['model_dir']).as_default():
                    with summary.always_record_summaries():
                        summary.scalar('learning_rate', lr[0], step=gs)
                        return summary.all_summary_ops()

            # To log the loss, current learning rate, and epoch for Tensorboard, the
            # summary op needs to be run on the host CPU via host_call. host_call
            # expects [batch_size, ...] Tensors, thus reshape to introduce a batch
            # dimension. These Tensors are implicitly concatenated to
            # [params['batch_size']].
            gs_t = tf.reshape(tf.train.get_global_step(), [1])
            lr_t = tf.reshape(learning_rate, [1])

            host_call = (host_call_fn, [gs_t, lr_t])

        if params['use_tpu']:
            return tf.contrib.tpu.TPUEstimatorSpec(mode=mode,
                                                   loss=loss,
                                                   train_op=train_op,
                                                   host_call=host_call)
        # Note: hook cannot accesss tensors defined in model_fn in TPUEstimator.
        logging_hook = tf.train.LoggingTensorHook({'loss': loss},
                                                  every_n_iter=10)
        return tf.estimator.EstimatorSpec(mode=mode,
                                          loss=loss,
                                          training_hooks=[logging_hook],
                                          train_op=train_op)

    if mode == tf.estimator.ModeKeys.EVAL:
        # Reshape labels/logits to R2 tensor to avoid TPU padding issue.
        # TPU tends to pad the last dimension to 128x,
        # and the second to last dimension to 8x.
        labels_r2 = tf.reshape(labels, [params['eval_batch_size'], -1])
        logits_r2 = tf.reshape(logits, [params['eval_batch_size'], -1])
        original_shape = [params['eval_batch_size']
                          ] + (params['input_image_size'] + [-1])
        if params['use_tpu']:
            return tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=loss,
                eval_metrics=(get_metric_fn(original_shape),
                              [labels_r2, logits_r2]))
        return tf.estimator.EstimatorSpec(
            mode=mode,
            loss=loss,
            eval_metrics=(get_metric_fn(original_shape),
                          [labels_r2, logits_r2]))

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'classes': tf.identity(tf.math.argmax(logits, axis=-1), 'Classes'),
            'scores': tf.identity(tf.nn.softmax(logits, axis=-1), 'Scores'),
        }
        if params['use_tpu']:
            return tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                predictions=predictions,
                export_outputs={
                    'classify': tf.estimator.export.PredictOutput(predictions)
                })
        return tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.PREDICT,
            predictions=predictions,
            export_outputs={
                'classify': tf.estimator.export.PredictOutput(predictions)
            })