Exemplo n.º 1
0
def test_check_required_params():
    required = {'foo', 'bar'}
    params = {'foo': 0, 'baz': 1}

    with pytest.raises(ValueError):
        check_required_params(required_keys=required, params=params)

    params['bar'] = 2
    check_required_params(required_keys=required, params=params)
Exemplo n.º 2
0
def model_fn(features,
             labels,
             mode,
             params,
             config=None):
    """MeshNet model function.

    Args:
        features: 5D float `Tensor`, input tensor. This is the first item
            returned from the `input_fn` passed to `train`, `evaluate`, and
            `predict`. Use `NDHWC` format.
        labels: 4D float `Tensor`, labels tensor. This is the second item
            returned from the `input_fn` passed to `train`, `evaluate`, and
            `predict`. Labels should not be one-hot encoded.
        mode: Optional. Specifies if this training, evaluation or prediction.
        params: `dict` of parameters.
            - n_classes: (required) number of classes to classify.
            - optimizer: instance of TensorFlow optimizer. Required if
                training.
            - n_filters: number of filters to use in each convolution. The
                original implementation used 21 filters to classify brainmask
                and 71 filters for the multi-class problem.
            - dropout_rate: rate of dropout. For example, 0.1 would drop 10% of
                input units.
        config: configuration object.

    Returns:
        `tf.estimator.EstimatorSpec`

    Raises:
        `ValueError` if required parameters are not in `params`.
    """
    volume = features
    if isinstance(volume, dict):
        volume = features['volume']
    
    required_keys = {'n_classes'}
    default_params = {'optimizer': None, 'n_filters': 96, 'keep_prob': 0.5}
    check_required_params(params=params, required_keys=required_keys)
    set_default_params(params=params, defaults=default_params)
    check_optimizer_for_training(optimizer=params['optimizer'], mode=mode)

    tf.logging.debug("Parameters for model:")
    tf.logging.debug(params)

    # Dilation rate by layer.
    dilation_rates = (
        (1, 1, 1),
        (1, 1, 1),
        (1, 1, 1),
        (2, 2, 2),
        (4, 4, 4),
        (8, 8, 8),
        (1, 1, 1))

    is_mc_v = tf.constant(False,dtype=tf.bool)
    is_mc_b = tf.constant(True,dtype=tf.bool)
    
    outputs = volume
    
    for ii, dilation_rate in enumerate(dilation_rates):
        outputs = _layer(
            outputs, mode=mode, layer_num=ii + 1, filters=params['n_filters'],
            kernel_size=3, dilation_rate=dilation_rate,keep_prob=params['keep_prob'], is_mc_v=is_mc_v, is_mc_b=is_mc_b)

    with tf.variable_scope('logits'):
        logits = vwn_conv.conv3d(
            inputs=outputs, filters=params['n_classes'], kernel_size=(1, 1, 1),
            padding='SAME', activation=None, is_mc=is_mc_v)
        
    predicted_classes = tf.argmax(logits, axis=-1)

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'class_ids': predicted_classes,
            'probabilities': tf.nn.softmax(logits),
            'logits': logits,
        }
        export_outputs = {
            'outputs': tf.estimator.export.PredictOutput(predictions)}
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            export_outputs=export_outputs)

    if params['prior_path'] != None:
        prior_np = np.load(params['prior_path'])
    
    with tf.variable_scope("prior"):
        i=-1
        for v in tf.get_collection('ms'):
            i += 1
            if params['prior_path'] == None:
                tf.add_to_collection('ms_prior',tf.Variable(tf.constant(0, dtype = v.dtype, shape = v.shape),trainable = False))
            else:
                tf.add_to_collection('ms_prior',tf.Variable(tf.convert_to_tensor(prior_np[0][i], dtype = tf.float32),trainable = False))
        
        ms = tf.get_collection('ms')
        ms_prior = tf.get_collection('ms_prior')

        i=-1
        for v in tf.get_collection('ms'):
            i += 1
            if params['prior_path'] == None:
                tf.add_to_collection('sigmas_prior',tf.Variable(tf.constant(1, dtype = v.dtype, shape = v.shape),trainable = False))
            else:
                tf.add_to_collection('sigmas_prior',tf.Variable(tf.convert_to_tensor(prior_np[1][i], dtype = tf.float32),trainable = False))

        sigmas = tf.get_collection('sigmas')
        sigmas_prior = tf.get_collection('sigmas_prior')
     
    nll_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits))
    tf.summary.scalar('nll_loss', nll_loss)
    
    n_examples = tf.constant(params['n_examples'],dtype=ms[0].dtype)
    tf.summary.scalar('n_examples', n_examples)
    
    l2_loss = tf.add_n([tf.reduce_sum((tf.square(ms[i] - ms_prior[i])) / ((tf.square(sigmas_prior[i]) + 1e-8) * 2.0)) for i in range(len(ms))], name = 'l2_loss') / n_examples
    tf.summary.scalar('l2_loss', l2_loss)
    
    loss = nll_loss + l2_loss

    assert mode == tf.estimator.ModeKeys.TRAIN

    global_step = tf.train.get_global_step()
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = params['optimizer'].minimize(loss, global_step=global_step)

    return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
Exemplo n.º 3
0
def model_fn(features,
             labels,
             mode,
             params,
             config=None):
    """HighRes3DNet model function.

    Args:
        features: 5D float `Tensor`, input tensor. This is the first item
            returned from the `input_fn` passed to `train`, `evaluate`, and
            `predict`. Use `NDHWC` format.
        labels: 4D float `Tensor`, labels tensor. This is the second item
            returned from the `input_fn` passed to `train`, `evaluate`, and
            `predict`. Labels should not be one-hot encoded.
        mode: Optional. Specifies if this training, evaluation or prediction.
        params: `dict` of parameters. All parameters below are required.
            - n_classes: (required) number of classes to classify.
            - optimizer: instance of TensorFlow optimizer. Required if
                training.
        config: configuration object.

    Returns:
        `tf.estimator.EstimatorSpec`

    Raises:
        `ValueError` if required parameters are not in `params`.
    """
    required_keys = {'n_classes'}
    default_params = {'optimizer': None}
    check_required_params(params=params, required_keys=required_keys)
    set_default_params(params=params, defaults=default_params)
    check_optimizer_for_training(optimizer=params['optimizer'], mode=mode)

    tf.logging.debug("Parameters for model:")
    tf.logging.debug(params)

    training = mode == tf.estimator.ModeKeys.TRAIN

    with tf.variable_scope('conv_0'):
        conv = tf.layers.conv3d(
            features, filters=16, kernel_size=3, padding='SAME')
    with tf.variable_scope('batchnorm_0'):
        conv = tf.layers.batch_normalization(
            conv, training=training, fused=FUSED_BATCH_NORM)
    with tf.variable_scope('relu_0'):
        outputs = tf.nn.relu(conv)

    for ii in range(3):
        offset = 1
        layer_num = ii + offset
        outputs = _resblock(
            outputs, mode=mode, layer_num=layer_num, filters=16, kernel_size=3,
            dilation_rate=1)

    for ii in range(3):
        offset = 4
        layer_num = ii + offset
        outputs = _resblock(
            outputs, mode=mode, layer_num=layer_num, filters=16, kernel_size=3,
            dilation_rate=2)

    for ii in range(3):
        offset = 7
        layer_num = ii + offset
        outputs = _resblock(
            outputs, mode=mode, layer_num=layer_num, filters=16, kernel_size=3,
            dilation_rate=4)

    with tf.variable_scope('logits'):
        logits = tf.layers.conv3d(
            outputs, filters=params['n_classes'], kernel_size=1,
            padding='SAME')

    predicted_classes = tf.argmax(logits, axis=-1)

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'class_ids': predicted_classes,
            'probabilities': tf.nn.softmax(logits),
            'logits': logits,
        }
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=labels, logits=logits)
    loss = tf.reduce_mean(cross_entropy)

    # Add evaluation metrics for class 1.
    labels = tf.cast(labels, predicted_classes.dtype)
    labels_onehot = tf.one_hot(labels, params['n_classes'])
    predictions_onehot = tf.one_hot(predicted_classes, params['n_classes'])
    eval_metric_ops = {
        'accuracy': tf.metrics.accuracy(labels, predicted_classes),
        'dice': streaming_dice(
            labels_onehot[..., 1], predictions_onehot[..., 1]),
        'hamming': streaming_hamming(
            labels_onehot[..., 1], predictions_onehot[..., 1]),
    }

    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(
            mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)

    assert mode == tf.estimator.ModeKeys.TRAIN

    global_step = tf.train.get_global_step()
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = params['optimizer'].minimize(loss, global_step=global_step)

    return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
Exemplo n.º 4
0
def model_fn(features, labels, mode, params, config=None):
    """MeshNet model function.

    Args:
        features: 4D float `Tensor`, input tensor. This is the first item
            returned from the `input_fn` passed to `train`, `evaluate`, and
            `predict`. Use `NHWC` format.
        labels: 3D float `Tensor`, labels tensor. This is the second item
            returned from the `input_fn` passed to `train`, `evaluate`, and
            `predict`. Labels should not be one-hot encoded.
        mode: Optional. Specifies if this training, evaluation or prediction.
        params: `dict` of parameters.
            - n_classes: (required) number of classes to classify.
            - optimizer: instance of TensorFlow optimizer. Required if
                training.
        config: configuration object.

    Returns:
        `tf.estimator.EstimatorSpec`

    Raises:
        `ValueError` if required parameters are not in `params`.
    """
    # TODO (kaczmarj): remove this error once implementation is fixed.
    raise NotImplementedError("This QuickNAT implementation is not complete.")
    required_keys = {'n_classes', 'optimizer'}
    check_required_params(params=params, required_keys=required_keys)

    training = mode == tf.estimator.ModeKeys.TRAIN

    # ENCODING
    with tf.variable_scope('dense_block_1'):
        dense1 = _dense_block(features, block_num=1, mode=mode)
    with tf.variable_scope('maxpool_1'):
        pool1, poolargmax1 = tf.nn.max_pool_with_argmax(
            dense1,
            ksize=MAX_POOL_KSIZE,
            strides=MAX_POOL_STRIDES,
            padding='SAME',
        )

    with tf.variable_scope('dense_block_2'):
        dense2 = _dense_block(pool1, block_num=2, mode=mode)
    with tf.variable_scope('maxpool_2'):
        pool2, poolargmax2 = tf.nn.max_pool_with_argmax(
            dense2,
            ksize=MAX_POOL_KSIZE,
            strides=MAX_POOL_STRIDES,
            padding='SAME',
        )

    with tf.variable_scope('dense_block_3'):
        dense3 = _dense_block(pool2, block_num=3, mode=mode)
    with tf.variable_scope('maxpool_3'):
        pool3, poolargmax3 = tf.nn.max_pool_with_argmax(
            dense3,
            ksize=MAX_POOL_KSIZE,
            strides=MAX_POOL_STRIDES,
            padding='SAME',
        )

    with tf.variable_scope('dense_block_4'):
        dense4 = _dense_block(pool3, block_num=4, mode=mode)
    with tf.variable_scope('maxpool_4'):
        pool4, poolargmax4 = tf.nn.max_pool_with_argmax(
            dense4,
            ksize=MAX_POOL_KSIZE,
            strides=MAX_POOL_STRIDES,
            padding='SAME',
        )

    # BOTTLENECK
    with tf.variable_scope('bottleneck'):
        conv_bottleneck = tf.layers.conv2d(pool4,
                                           64,
                                           kernel_size=(5, 5),
                                           padding='SAME')
        bn_bottleneck = tf.layers.batch_normalization(conv_bottleneck,
                                                      training=training)

    # DECODING
    with tf.variable_scope('unpool_1'):
        unpool1 = unpool_2d(bn_bottleneck,
                            ind=poolargmax4,
                            stride=MAX_POOL_STRIDES)

    concat1 = tf.concat([dense4, unpool1], axis=-1)

    with tf.variable_scope('dense_block_5'):
        dense5 = _dense_block(concat1, block_num=5, mode=mode)

    with tf.variable_scope('unpool_2'):
        unpool2 = unpool_2d(dense5, ind=poolargmax3, stride=MAX_POOL_STRIDES)

    concat2 = tf.concat([dense3, unpool2], axis=-1)

    with tf.variable_scope('dense_block_6'):
        dense6 = _dense_block(concat2, block_num=6, mode=mode)

    with tf.variable_scope('unpool_3'):
        unpool3 = unpool_2d(dense6, ind=poolargmax2, stride=MAX_POOL_STRIDES)

    concat3 = tf.concat([dense2, unpool3], axis=-1)

    with tf.variable_scope('dense_block_7'):
        dense7 = _dense_block(concat3, block_num=7, mode=mode)

    with tf.variable_scope('unpool_4'):
        unpool4 = unpool_2d(dense7, ind=poolargmax1, stride=MAX_POOL_STRIDES)

    concat4 = tf.concat([dense1, unpool4], axis=-1)

    with tf.variable_scope('dense_block_8'):
        dense8 = _dense_block(concat4, block_num=8, mode=mode)

    with tf.variable_scope('logits'):
        logits = tf.layers.conv2d(dense8,
                                  filters=params['n_classes'],
                                  kernel_size=(1, 1))

    predicted_classes = tf.argmax(logits, axis=-1)

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'class_ids': predicted_classes,
            'probabilities': tf.nn.softmax(logits),
            'logits': logits,
        }
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    # QUESTION (kaczmarj): is this the same as
    # `tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(...))`
    loss = tf.losses.sparse_softmax_cross_entropy(
        labels=labels,
        logits=logits,
        reduction=tf.losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
    )

    # Compute metrics here...
    # Use `tf.summary.scalar` to add summaries to tensorboard.

    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(
            mode=mode,
            loss=loss,
            eval_metric_ops=None,
        )

    assert mode == tf.estimator.ModeKeys.TRAIN

    train_op = params['optimizer'].minimize(
        loss, global_step=tf.train.get_global_step())
    return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
Exemplo n.º 5
0
def model_fn(features,
             labels,
             mode,
             params,
             config=None):
    """MeshNet model function.

    Args:
        features: 5D float `Tensor`, input tensor. This is the first item
            returned from the `input_fn` passed to `train`, `evaluate`, and
            `predict`. Use `NDHWC` format.
        labels: 4D float `Tensor`, labels tensor. This is the second item
            returned from the `input_fn` passed to `train`, `evaluate`, and
            `predict`. Labels should not be one-hot encoded.
        mode: Optional. Specifies if this training, evaluation or prediction.
        params: `dict` of parameters.
            - n_classes: (required) number of classes to classify.
            - optimizer: instance of TensorFlow optimizer. Required if
                training.
            - n_filters: number of filters to use in each convolution. The
                original implementation used 21 filters to classify brainmask
                and 71 filters for the multi-class problem.
            - dropout_rate: rate of dropout. For example, 0.1 would drop 10% of
                input units.
        config: configuration object.

    Returns:
        `tf.estimator.EstimatorSpec`

    Raises:
        `ValueError` if required parameters are not in `params`.
    """
    required_keys = {'n_classes'}
    default_params = {'optimizer': None, 'n_filters': 21, 'dropout_rate': 0.25}
    check_required_params(params=params, required_keys=required_keys)
    set_default_params(params=params, defaults=default_params)
    check_optimizer_for_training(optimizer=params['optimizer'], mode=mode)

    tf.logging.debug("Parameters for model:")
    tf.logging.debug(params)

    # Dilation rate by layer.
    dilation_rates = (
        (1, 1, 1),
        (1, 1, 1),
        (1, 1, 1),
        (2, 2, 2),
        (4, 4, 4),
        (8, 8, 8),
        (1, 1, 1),
    )

    outputs = features

    for ii, dilation_rate in enumerate(dilation_rates):
        outputs = _layer(
            outputs, mode=mode, layer_num=ii + 1, filters=params['n_filters'],
            kernel_size=3, dilation_rate=dilation_rate,
            dropout_rate=params['dropout_rate'],
        )

    with tf.variable_scope('logits'):
        logits = tf.layers.conv3d(
            inputs=outputs, filters=params['n_classes'], kernel_size=(1, 1, 1),
            padding='SAME', activation=None,
        )
    predicted_classes = tf.argmax(logits, axis=-1)

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'class_ids': predicted_classes,
            'probabilities': tf.nn.softmax(logits),
            'logits': logits,
        }
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=labels, logits=logits)
    loss = tf.reduce_mean(cross_entropy)

    # Add evaluation metrics for class 1.
    labels = tf.cast(labels, predicted_classes.dtype)
    labels_onehot = tf.one_hot(labels, params['n_classes'])
    predictions_onehot = tf.one_hot(predicted_classes, params['n_classes'])
    eval_metric_ops = {
        'accuracy': tf.metrics.accuracy(labels, predicted_classes),
        'dice': streaming_dice(
            labels_onehot[..., 1], predictions_onehot[..., 1]),
        'hamming': streaming_hamming(
            labels_onehot[..., 1], predictions_onehot[..., 1]),
    }

    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(
            mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)

    assert mode == tf.estimator.ModeKeys.TRAIN

    global_step = tf.train.get_global_step()
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = params['optimizer'].minimize(loss, global_step=global_step)

    return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
Exemplo n.º 6
0
def model_fn(features, labels, mode, params, config=None):
    """HighRes3DNet model function.

    Args:
        features: 5D float `Tensor`, input tensor. This is the first item
            returned from the `input_fn` passed to `train`, `evaluate`, and
            `predict`. Use `NDHWC` format.
        labels: 4D float `Tensor`, labels tensor. This is the second item
            returned from the `input_fn` passed to `train`, `evaluate`, and
            `predict`. Labels should not be one-hot encoded.
        mode: Optional. Specifies if this training, evaluation or prediction.
        params: `dict` of parameters. All parameters below are required.
            - n_classes: (required) number of classes to classify.
            - optimizer: instance of TensorFlow optimizer. Required if
                training.
            - one_batchnorm_per_resblock: (default false) if true, only apply
                first batch normalization layer in each residually connected
                block. Empirically, only using first batch normalization layer
                allowed the model to model to be trained on 128**3 float32
                inputs.
            - dropout_rate: (default 0), value between 0 and 1, dropout rate
                to be applied immediately before last convolution layer. If 0
                or false, dropout is not applied.
        config: configuration object.

    Returns:
        `tf.estimator.EstimatorSpec`

    Raises:
        `ValueError` if required parameters are not in `params`.
    """
    volume = features
    if isinstance(volume, dict):
        volume = features['volume']

    required_keys = {'n_classes'}
    default_params = {
        'optimizer': None,
        'one_batchnorm_per_resblock': False,
        'dropout_rate': 0,
    }
    check_required_params(params=params, required_keys=required_keys)
    set_default_params(params=params, defaults=default_params)
    check_optimizer_for_training(optimizer=params['optimizer'], mode=mode)

    tf.logging.debug("Parameters for model:")
    tf.logging.debug(params)

    training = mode == tf.estimator.ModeKeys.TRAIN

    with tf.variable_scope('conv_0'):
        x = tf.layers.conv3d(volume, filters=16, kernel_size=3, padding='SAME')
    with tf.variable_scope('batchnorm_0'):
        x = tf.layers.batch_normalization(x,
                                          training=training,
                                          fused=FUSED_BATCH_NORM)
    with tf.variable_scope('relu_0'):
        x = tf.nn.relu(x)

    layer_num = 0
    one_batchnorm = params['one_batchnorm_per_resblock']

    # 16-filter residually connected blocks.
    for ii in range(3):
        layer_num += 1
        x = _resblock(x,
                      mode=mode,
                      layer_num=layer_num,
                      filters=16,
                      kernel_size=3,
                      dilation_rate=1,
                      one_batchnorm=one_batchnorm)

    # 32-filter residually connected blocks. Pad inputs immediately before
    # first elementwise sum to match shape of last dimension.
    layer_num += 1
    paddings = [[0, 0], [0, 0], [0, 0], [0, 0], [8, 8]]
    x = _resblock(x,
                  mode=mode,
                  layer_num=layer_num,
                  filters=32,
                  kernel_size=3,
                  dilation_rate=2,
                  paddings=paddings,
                  one_batchnorm=one_batchnorm)
    for ii in range(2):
        layer_num += 1
        x = _resblock(x,
                      mode=mode,
                      layer_num=layer_num,
                      filters=32,
                      kernel_size=3,
                      dilation_rate=2,
                      one_batchnorm=one_batchnorm)

    # 64-filter residually connected blocks. Pad inputs immediately before
    # first elementwise sum to match shape of last dimension.
    layer_num += 1
    paddings = [[0, 0], [0, 0], [0, 0], [0, 0], [16, 16]]
    x = _resblock(x,
                  mode=mode,
                  layer_num=layer_num,
                  filters=64,
                  kernel_size=3,
                  dilation_rate=4,
                  paddings=paddings,
                  one_batchnorm=one_batchnorm)
    for ii in range(2):
        layer_num += 1
        x = _resblock(x,
                      mode=mode,
                      layer_num=layer_num,
                      filters=64,
                      kernel_size=3,
                      dilation_rate=4,
                      one_batchnorm=one_batchnorm)

    with tf.variable_scope('conv_1'):
        x = tf.layers.conv3d(x, filters=80, kernel_size=1, padding='SAME')

    if params['dropout_rate']:
        x = tf.layers.dropout(x,
                              rate=params['dropout_rate'],
                              training=training)

    with tf.variable_scope('logits'):
        logits = tf.layers.conv3d(x,
                                  filters=params['n_classes'],
                                  kernel_size=1,
                                  padding='SAME')

    predictions = tf.nn.softmax(logits=logits)
    predicted_classes = tf.argmax(logits, axis=-1)

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'class_ids': predicted_classes,
            'probabilities': predictions,
            'logits': logits
        }
        # Outputs for SavedModel.
        export_outputs = {
            'outputs': tf.estimator.export.PredictOutput(predictions)
        }
        return tf.estimator.EstimatorSpec(mode=mode,
                                          predictions=predictions,
                                          export_outputs=export_outputs)

    onehot_labels = tf.one_hot(labels, params['n_classes'])
    # loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=onehot_labels, logits=logits)
    # loss = tf.losses.softmax_cross_entropy(onehot_labels=onehot_labels, logits=logits)

    # loss = losses.dice(labels=labels, predictions=predictions[..., 1], axis=(1, 2, 3))
    loss = losses.tversky(labels=onehot_labels,
                          predictions=predictions,
                          axis=(1, 2, 3))
    # loss = losses.generalized_dice(labels=onehot_labels, predictions=predictions, axis=(1, 2, 3))

    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(mode=mode,
                                          loss=loss,
                                          eval_metric_ops={
                                              'dice':
                                              metrics.streaming_dice(
                                                  labels,
                                                  predicted_classes,
                                                  axis=(1, 2, 3)),
                                          })

    assert mode == tf.estimator.ModeKeys.TRAIN, "unknown mode key {}".format(
        "mode")

    global_step = tf.train.get_global_step()
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = params['optimizer'].minimize(loss, global_step=global_step)

    # Get Dice score of each class.
    dice_coefficients = tf.reduce_mean(metrics.dice(onehot_labels,
                                                    tf.one_hot(
                                                        tf.argmax(predictions,
                                                                  axis=-1),
                                                        params['n_classes']),
                                                    axis=(1, 2, 3)),
                                       axis=0)

    logging_hook = tf.train.LoggingTensorHook(
        {
            "loss": loss,
            "dice": dice_coefficients
        }, every_n_iter=100)

    return tf.estimator.EstimatorSpec(mode=mode,
                                      loss=loss,
                                      train_op=train_op,
                                      training_hooks=[logging_hook])
Exemplo n.º 7
0
def model_fn(features,
             labels,
             mode,
             params,
             config=None):
    """HighRes3DNet model function.

    Args:
        features: 5D float `Tensor`, input tensor. This is the first item
            returned from the `input_fn` passed to `train`, `evaluate`, and
            `predict`. Use `NDHWC` format.
        labels: 4D float `Tensor`, labels tensor. This is the second item
            returned from the `input_fn` passed to `train`, `evaluate`, and
            `predict`. Labels should not be one-hot encoded.
        mode: Optional. Specifies if this training, evaluation or prediction.
        params: `dict` of parameters. All parameters below are required.
            - n_classes: (required) number of classes to classify.
            - optimizer: instance of TensorFlow optimizer. Required if
                training.
            - one_batchnorm_per_resblock: (default false) if true, only apply
                first batch normalization layer in each residually connected
                block. Empirically, only using first batch normalization layer
                allowed the model to model to be trained on 128**3 float32
                inputs.
            - dropout_rate: (default 0), value between 0 and 1, dropout rate
                to be applied immediately before last convolution layer. If 0
                or false, dropout is not applied.
        config: configuration object.

    Returns:
        `tf.estimator.EstimatorSpec`

    Raises:
        `ValueError` if required parameters are not in `params`.
    """
    volume = features
    if isinstance(volume, dict):
        volume = features['volume']

    required_keys = {'n_classes'}
    default_params = {
        'optimizer': None,
        'one_batchnorm_per_resblock': False,
        'dropout_rate': 0,
    }
    check_required_params(params=params, required_keys=required_keys)
    set_default_params(params=params, defaults=default_params)
    check_optimizer_for_training(optimizer=params['optimizer'], mode=mode)

    tf.logging.debug("Parameters for model:")
    tf.logging.debug(params)

    training = mode == tf.estimator.ModeKeys.TRAIN

    with tf.variable_scope('conv_0'):
        conv = tf.layers.conv3d(
            volume, filters=16, kernel_size=3, padding='SAME')
    with tf.variable_scope('batchnorm_0'):
        conv = tf.layers.batch_normalization(
            conv, training=training, fused=FUSED_BATCH_NORM)
    with tf.variable_scope('relu_0'):
        outputs = tf.nn.relu(conv)

    layer_num = 0
    one_batchnorm = params['one_batchnorm_per_resblock']

    for ii in range(3):
        layer_num += 1
        outputs = _resblock(
            outputs, mode=mode, layer_num=layer_num, filters=16, kernel_size=3,
            dilation_rate=1, one_batchnorm=one_batchnorm)

    for ii in range(3):
        layer_num += 1
        outputs = _resblock(
            outputs, mode=mode, layer_num=layer_num, filters=16, kernel_size=3,
            dilation_rate=2, one_batchnorm=one_batchnorm)

    for ii in range(3):
        layer_num += 1
        outputs = _resblock(
            outputs, mode=mode, layer_num=layer_num, filters=16, kernel_size=3,
            dilation_rate=4, one_batchnorm=one_batchnorm)

    if params['dropout_rate']:
        outputs = tf.layers.dropout(
            outputs, rate=params['dropout_rate'], training=training)

    with tf.variable_scope('logits'):
        logits = tf.layers.conv3d(
            outputs, filters=params['n_classes'], kernel_size=1,
            padding='SAME')

    predicted_classes = tf.argmax(logits, axis=-1)

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'class_ids': predicted_classes,
            'probabilities': tf.nn.softmax(logits),
            'logits': logits}
        # Outputs for SavedModel.
        export_outputs = {
            'outputs': tf.estimator.export.PredictOutput(predictions)}
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            export_outputs=export_outputs)

    loss = tf.losses.sparse_softmax_cross_entropy(
        labels=labels,
        logits=logits)

    # Add evaluation metrics for class 1.
    labels = tf.cast(labels, predicted_classes.dtype)
    labels_onehot = tf.one_hot(labels, params['n_classes'])
    predictions_onehot = tf.one_hot(predicted_classes, params['n_classes'])
    eval_metric_ops = {
        'accuracy': tf.metrics.accuracy(labels, predicted_classes),
        'dice': streaming_dice(
            labels_onehot[..., 1], predictions_onehot[..., 1]),
        'hamming': streaming_hamming(
            labels_onehot[..., 1], predictions_onehot[..., 1]),
    }

    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(
            mode=mode,
            loss=loss,
            eval_metric_ops=eval_metric_ops)

    assert mode == tf.estimator.ModeKeys.TRAIN

    global_step = tf.train.get_global_step()
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = params['optimizer'].minimize(loss, global_step=global_step)

    return tf.estimator.EstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=train_op)
Exemplo n.º 8
0
def model_fn(features, labels, mode, params, config=None):
    """3D U-Net model function.

    Args:

    Returns:

    Raises:
    """
    volume = features
    if isinstance(volume, dict):
        volume = features['volume']

    required_keys = {'n_classes'}
    default_params = {
        'optimizer': None,
        'batchnorm': True,
    }

    check_required_params(params=params, required_keys=required_keys)
    set_default_params(params=params, defaults=default_params)
    check_optimizer_for_training(optimizer=params['optimizer'], mode=mode)

    bn = params['batchnorm']

    # start encoding
    shortcut_1 = _conv_block(volume,
                             filters1=32,
                             filters2=64,
                             mode=mode,
                             layer_num=0,
                             batchnorm=bn)

    with tf.variable_scope('maxpool_1'):
        x = tf.layers.max_pooling3d(inputs=shortcut_1,
                                    pool_size=(2, 2, 2),
                                    strides=(2, 2, 2),
                                    padding='same')

    shortcut_2 = _conv_block(x,
                             filters1=64,
                             filters2=128,
                             mode=mode,
                             layer_num=1,
                             batchnorm=bn)

    with tf.variable_scope('maxpool_2'):
        x = tf.layers.max_pooling3d(inputs=shortcut_2,
                                    pool_size=(2, 2, 2),
                                    strides=(2, 2, 2),
                                    padding='same')

    shortcut_3 = _conv_block(x,
                             filters1=128,
                             filters2=256,
                             mode=mode,
                             layer_num=2,
                             batchnorm=bn)

    with tf.variable_scope('maxpool_3'):
        x = tf.layers.max_pooling3d(inputs=shortcut_3,
                                    pool_size=(2, 2, 2),
                                    strides=(2, 2, 2),
                                    padding='same')

    x = _conv_block(x,
                    filters1=256,
                    filters2=512,
                    mode=mode,
                    layer_num=3,
                    batchnorm=bn)

    # start decoding
    with tf.variable_scope("upconv_0"):
        x = tf.layers.conv3d_transpose(inputs=x,
                                       filters=512,
                                       kernel_size=(2, 2, 2),
                                       strides=(2, 2, 2),
                                       kernel_regularizer=_regularizer)

    with tf.variable_scope('concat_1'):
        x = tf.concat((shortcut_3, x), axis=-1)

    x = _conv_block(x,
                    filters1=256,
                    filters2=256,
                    mode=mode,
                    layer_num=4,
                    batchnorm=bn)

    with tf.variable_scope("upconv_1"):
        x = tf.layers.conv3d_transpose(inputs=x,
                                       filters=256,
                                       kernel_size=(2, 2, 2),
                                       strides=(2, 2, 2),
                                       kernel_regularizer=_regularizer)

    with tf.variable_scope('concat_2'):
        x = tf.concat((shortcut_2, x), axis=-1)

    x = _conv_block(x,
                    filters1=128,
                    filters2=128,
                    mode=mode,
                    layer_num=5,
                    batchnorm=bn)

    with tf.variable_scope("upconv_2"):
        x = tf.layers.conv3d_transpose(inputs=x,
                                       filters=128,
                                       kernel_size=(2, 2, 2),
                                       strides=(2, 2, 2),
                                       kernel_regularizer=_regularizer)

    with tf.variable_scope('concat_3'):
        x = tf.concat((shortcut_1, x), axis=-1)

    x = _conv_block(x,
                    filters1=64,
                    filters2=64,
                    mode=mode,
                    layer_num=6,
                    batchnorm=bn)

    with tf.variable_scope('logits'):
        logits = tf.layers.conv3d(inputs=x,
                                  filters=params['n_classes'],
                                  kernel_size=(1, 1, 1),
                                  padding='same',
                                  activation=None,
                                  kernel_regularizer=_regularizer)
    # end decoding

    with tf.variable_scope('predictions'):
        predictions = tf.nn.softmax(logits=logits)

    class_ids = tf.argmax(logits, axis=-1)

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'class_ids': class_ids,
            'probabilities': predictions,
            'logits': logits
        }
        # Outputs for SavedModel.
        export_outputs = {
            'outputs': tf.estimator.export.PredictOutput(predictions)
        }
        return tf.estimator.EstimatorSpec(mode=mode,
                                          predictions=predictions,
                                          export_outputs=export_outputs)

    onehot_labels = tf.one_hot(labels, params['n_classes'])

    loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=onehot_labels,
                                           logits=logits)
    # loss = tf.losses.softmax_cross_entropy(onehot_labels=onehot_labels, logits=logits)

    l2_loss = tf.losses.get_regularization_loss()
    loss += l2_loss

    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(mode=mode,
                                          loss=loss,
                                          eval_metric_ops={
                                              'dice':
                                              metrics.streaming_dice(labels,
                                                                     class_ids,
                                                                     axis=(1,
                                                                           2,
                                                                           3)),
                                          })

    assert mode == tf.estimator.ModeKeys.TRAIN

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = params['optimizer'].minimize(
            loss, global_step=tf.train.get_global_step())

    dice_coefficients = tf.reduce_mean(
        metrics.dice(onehot_labels,
                     tf.one_hot(class_ids, axis=(1, 2, 3)),
                     axis=0))

    logging_hook = tf.train.LoggingTensorHook(
        {
            "loss": loss,
            "dice": dice_coefficients
        }, every_n_iter=100)

    return tf.estimator.EstimatorSpec(mode=mode,
                                      loss=loss,
                                      train_op=train_op,
                                      training_hooks=[logging_hook])