Beispiel #1
0
def model_fn(features,
             labels,
             mode,
             params,
             config=None):
    """MeshNet model function.

    Args:
        features: 5D float `Tensor`, input tensor. This is the first item
            returned from the `input_fn` passed to `train`, `evaluate`, and
            `predict`. Use `NDHWC` format.
        labels: 4D float `Tensor`, labels tensor. This is the second item
            returned from the `input_fn` passed to `train`, `evaluate`, and
            `predict`. Labels should not be one-hot encoded.
        mode: Optional. Specifies if this training, evaluation or prediction.
        params: `dict` of parameters.
            - n_classes: (required) number of classes to classify.
            - optimizer: instance of TensorFlow optimizer. Required if
                training.
            - n_filters: number of filters to use in each convolution. The
                original implementation used 21 filters to classify brainmask
                and 71 filters for the multi-class problem.
            - dropout_rate: rate of dropout. For example, 0.1 would drop 10% of
                input units.
        config: configuration object.

    Returns:
        `tf.estimator.EstimatorSpec`

    Raises:
        `ValueError` if required parameters are not in `params`.
    """
    required_keys = {'n_classes'}
    default_params = {'optimizer': None, 'n_filters': 21, 'dropout_rate': 0.25}
    check_required_params(params=params, required_keys=required_keys)
    set_default_params(params=params, defaults=default_params)
    check_optimizer_for_training(optimizer=params['optimizer'], mode=mode)

    tf.logging.debug("Parameters for model:")
    tf.logging.debug(params)

    # Dilation rate by layer.
    dilation_rates = (
        (1, 1, 1),
        (1, 1, 1),
        (1, 1, 1),
        (2, 2, 2),
        (4, 4, 4),
        (8, 8, 8),
        (1, 1, 1),
    )

    outputs = features

    for ii, dilation_rate in enumerate(dilation_rates):
        outputs = _layer(
            outputs, mode=mode, layer_num=ii + 1, filters=params['n_filters'],
            kernel_size=3, dilation_rate=dilation_rate,
            dropout_rate=params['dropout_rate'],
        )

    with tf.variable_scope('logits'):
        logits = tf.layers.conv3d(
            inputs=outputs, filters=params['n_classes'], kernel_size=(1, 1, 1),
            padding='SAME', activation=None,
        )
    predicted_classes = tf.argmax(logits, axis=-1)

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'class_ids': predicted_classes,
            'probabilities': tf.nn.softmax(logits),
            'logits': logits,
        }
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=labels, logits=logits)
    loss = tf.reduce_mean(cross_entropy)

    # Add evaluation metrics for class 1.
    labels = tf.cast(labels, predicted_classes.dtype)
    labels_onehot = tf.one_hot(labels, params['n_classes'])
    predictions_onehot = tf.one_hot(predicted_classes, params['n_classes'])
    eval_metric_ops = {
        'accuracy': tf.metrics.accuracy(labels, predicted_classes),
        'dice': streaming_dice(
            labels_onehot[..., 1], predictions_onehot[..., 1]),
        'hamming': streaming_hamming(
            labels_onehot[..., 1], predictions_onehot[..., 1]),
    }

    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(
            mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)

    assert mode == tf.estimator.ModeKeys.TRAIN

    global_step = tf.train.get_global_step()
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = params['optimizer'].minimize(loss, global_step=global_step)

    return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
Beispiel #2
0
def model_fn(features,
             labels,
             mode,
             params,
             config=None):
    """HighRes3DNet model function.

    Args:
        features: 5D float `Tensor`, input tensor. This is the first item
            returned from the `input_fn` passed to `train`, `evaluate`, and
            `predict`. Use `NDHWC` format.
        labels: 4D float `Tensor`, labels tensor. This is the second item
            returned from the `input_fn` passed to `train`, `evaluate`, and
            `predict`. Labels should not be one-hot encoded.
        mode: Optional. Specifies if this training, evaluation or prediction.
        params: `dict` of parameters. All parameters below are required.
            - n_classes: (required) number of classes to classify.
            - optimizer: instance of TensorFlow optimizer. Required if
                training.
        config: configuration object.

    Returns:
        `tf.estimator.EstimatorSpec`

    Raises:
        `ValueError` if required parameters are not in `params`.
    """
    required_keys = {'n_classes'}
    default_params = {'optimizer': None}
    check_required_params(params=params, required_keys=required_keys)
    set_default_params(params=params, defaults=default_params)
    check_optimizer_for_training(optimizer=params['optimizer'], mode=mode)

    tf.logging.debug("Parameters for model:")
    tf.logging.debug(params)

    training = mode == tf.estimator.ModeKeys.TRAIN

    with tf.variable_scope('conv_0'):
        conv = tf.layers.conv3d(
            features, filters=16, kernel_size=3, padding='SAME')
    with tf.variable_scope('batchnorm_0'):
        conv = tf.layers.batch_normalization(
            conv, training=training, fused=FUSED_BATCH_NORM)
    with tf.variable_scope('relu_0'):
        outputs = tf.nn.relu(conv)

    for ii in range(3):
        offset = 1
        layer_num = ii + offset
        outputs = _resblock(
            outputs, mode=mode, layer_num=layer_num, filters=16, kernel_size=3,
            dilation_rate=1)

    for ii in range(3):
        offset = 4
        layer_num = ii + offset
        outputs = _resblock(
            outputs, mode=mode, layer_num=layer_num, filters=16, kernel_size=3,
            dilation_rate=2)

    for ii in range(3):
        offset = 7
        layer_num = ii + offset
        outputs = _resblock(
            outputs, mode=mode, layer_num=layer_num, filters=16, kernel_size=3,
            dilation_rate=4)

    with tf.variable_scope('logits'):
        logits = tf.layers.conv3d(
            outputs, filters=params['n_classes'], kernel_size=1,
            padding='SAME')

    predicted_classes = tf.argmax(logits, axis=-1)

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'class_ids': predicted_classes,
            'probabilities': tf.nn.softmax(logits),
            'logits': logits,
        }
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=labels, logits=logits)
    loss = tf.reduce_mean(cross_entropy)

    # Add evaluation metrics for class 1.
    labels = tf.cast(labels, predicted_classes.dtype)
    labels_onehot = tf.one_hot(labels, params['n_classes'])
    predictions_onehot = tf.one_hot(predicted_classes, params['n_classes'])
    eval_metric_ops = {
        'accuracy': tf.metrics.accuracy(labels, predicted_classes),
        'dice': streaming_dice(
            labels_onehot[..., 1], predictions_onehot[..., 1]),
        'hamming': streaming_hamming(
            labels_onehot[..., 1], predictions_onehot[..., 1]),
    }

    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(
            mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)

    assert mode == tf.estimator.ModeKeys.TRAIN

    global_step = tf.train.get_global_step()
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = params['optimizer'].minimize(loss, global_step=global_step)

    return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
Beispiel #3
0
def model_fn(features,
             labels,
             mode,
             params,
             config=None):
    """HighRes3DNet model function.

    Args:
        features: 5D float `Tensor`, input tensor. This is the first item
            returned from the `input_fn` passed to `train`, `evaluate`, and
            `predict`. Use `NDHWC` format.
        labels: 4D float `Tensor`, labels tensor. This is the second item
            returned from the `input_fn` passed to `train`, `evaluate`, and
            `predict`. Labels should not be one-hot encoded.
        mode: Optional. Specifies if this training, evaluation or prediction.
        params: `dict` of parameters. All parameters below are required.
            - n_classes: (required) number of classes to classify.
            - optimizer: instance of TensorFlow optimizer. Required if
                training.
            - one_batchnorm_per_resblock: (default false) if true, only apply
                first batch normalization layer in each residually connected
                block. Empirically, only using first batch normalization layer
                allowed the model to model to be trained on 128**3 float32
                inputs.
            - dropout_rate: (default 0), value between 0 and 1, dropout rate
                to be applied immediately before last convolution layer. If 0
                or false, dropout is not applied.
        config: configuration object.

    Returns:
        `tf.estimator.EstimatorSpec`

    Raises:
        `ValueError` if required parameters are not in `params`.
    """
    volume = features
    if isinstance(volume, dict):
        volume = features['volume']

    required_keys = {'n_classes'}
    default_params = {
        'optimizer': None,
        'one_batchnorm_per_resblock': False,
        'dropout_rate': 0,
    }
    check_required_params(params=params, required_keys=required_keys)
    set_default_params(params=params, defaults=default_params)
    check_optimizer_for_training(optimizer=params['optimizer'], mode=mode)

    tf.logging.debug("Parameters for model:")
    tf.logging.debug(params)

    training = mode == tf.estimator.ModeKeys.TRAIN

    with tf.variable_scope('conv_0'):
        conv = tf.layers.conv3d(
            volume, filters=16, kernel_size=3, padding='SAME')
    with tf.variable_scope('batchnorm_0'):
        conv = tf.layers.batch_normalization(
            conv, training=training, fused=FUSED_BATCH_NORM)
    with tf.variable_scope('relu_0'):
        outputs = tf.nn.relu(conv)

    layer_num = 0
    one_batchnorm = params['one_batchnorm_per_resblock']

    for ii in range(3):
        layer_num += 1
        outputs = _resblock(
            outputs, mode=mode, layer_num=layer_num, filters=16, kernel_size=3,
            dilation_rate=1, one_batchnorm=one_batchnorm)

    for ii in range(3):
        layer_num += 1
        outputs = _resblock(
            outputs, mode=mode, layer_num=layer_num, filters=16, kernel_size=3,
            dilation_rate=2, one_batchnorm=one_batchnorm)

    for ii in range(3):
        layer_num += 1
        outputs = _resblock(
            outputs, mode=mode, layer_num=layer_num, filters=16, kernel_size=3,
            dilation_rate=4, one_batchnorm=one_batchnorm)

    if params['dropout_rate']:
        outputs = tf.layers.dropout(
            outputs, rate=params['dropout_rate'], training=training)

    with tf.variable_scope('logits'):
        logits = tf.layers.conv3d(
            outputs, filters=params['n_classes'], kernel_size=1,
            padding='SAME')

    predicted_classes = tf.argmax(logits, axis=-1)

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'class_ids': predicted_classes,
            'probabilities': tf.nn.softmax(logits),
            'logits': logits}
        # Outputs for SavedModel.
        export_outputs = {
            'outputs': tf.estimator.export.PredictOutput(predictions)}
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            export_outputs=export_outputs)

    loss = tf.losses.sparse_softmax_cross_entropy(
        labels=labels,
        logits=logits)

    # Add evaluation metrics for class 1.
    labels = tf.cast(labels, predicted_classes.dtype)
    labels_onehot = tf.one_hot(labels, params['n_classes'])
    predictions_onehot = tf.one_hot(predicted_classes, params['n_classes'])
    eval_metric_ops = {
        'accuracy': tf.metrics.accuracy(labels, predicted_classes),
        'dice': streaming_dice(
            labels_onehot[..., 1], predictions_onehot[..., 1]),
        'hamming': streaming_hamming(
            labels_onehot[..., 1], predictions_onehot[..., 1]),
    }

    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(
            mode=mode,
            loss=loss,
            eval_metric_ops=eval_metric_ops)

    assert mode == tf.estimator.ModeKeys.TRAIN

    global_step = tf.train.get_global_step()
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = params['optimizer'].minimize(loss, global_step=global_step)

    return tf.estimator.EstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=train_op)