示例#1
0
def _layer(inputs, mode, layer_num, filters, kernel_size, dilation_rate,
           is_mc_b, is_mc_g, use_expectation):
    """Layer building block of MeshNet.

    Performs 3D convolution, activation, batch normalization, and dropout on
    `inputs` tensor.

    Args:
        inputs : float `Tensor`, input tensor.
        mode : string, a TensorFlow mode key.
        layer_num : int, value to append to each operator name. This should be
            the layer number in the network.
        filters : int, number of 3D convolution filters.
        kernel_size : int or tuple, size of 3D convolution kernel.
        dilation_rate : int or tuple, rate of dilution in 3D convolution.
        dropout_rate : float, the dropout rate between 0 and 1.

    Returns:
        `Tensor` of same type as `inputs`.
    """
    training = mode == tf.estimator.ModeKeys.TRAIN

    with tf.variable_scope('layer_{}'.format(layer_num)):
        conv = vwn_conv.conv3d(inputs,
                               filters=filters,
                               kernel_size=kernel_size,
                               padding='SAME',
                               dilation_rate=dilation_rate,
                               activation=None,
                               is_mc=is_mc_g)
        conv = concrete_dropout(conv,
                                is_mc_b,
                                filters,
                                use_expectation=use_expectation)
        return tf.nn.relu(conv)
示例#2
0
def model_fn(features,
             labels,
             mode,
             params,
             config=None):
    """MeshNet model function.

    Args:
        features: 5D float `Tensor`, input tensor. This is the first item
            returned from the `input_fn` passed to `train`, `evaluate`, and
            `predict`. Use `NDHWC` format.
        labels: 4D float `Tensor`, labels tensor. This is the second item
            returned from the `input_fn` passed to `train`, `evaluate`, and
            `predict`. Labels should not be one-hot encoded.
        mode: Optional. Specifies if this training, evaluation or prediction.
        params: `dict` of parameters.
            - n_classes: (required) number of classes to classify.
            - optimizer: instance of TensorFlow optimizer. Required if
                training.
            - n_filters: number of filters to use in each convolution. The
                original implementation used 21 filters to classify brainmask
                and 71 filters for the multi-class problem.
            - dropout_rate: rate of dropout. For example, 0.1 would drop 10% of
                input units.
        config: configuration object.

    Returns:
        `tf.estimator.EstimatorSpec`

    Raises:
        `ValueError` if required parameters are not in `params`.
    """
    volume = features
    if isinstance(volume, dict):
        volume = features['volume']
    
    required_keys = {'n_classes'}
    default_params = {'optimizer': None, 'n_filters': 96, 'keep_prob': 0.5}
    check_required_params(params=params, required_keys=required_keys)
    set_default_params(params=params, defaults=default_params)
    check_optimizer_for_training(optimizer=params['optimizer'], mode=mode)

    tf.logging.debug("Parameters for model:")
    tf.logging.debug(params)

    # Dilation rate by layer.
    dilation_rates = (
        (1, 1, 1),
        (1, 1, 1),
        (1, 1, 1),
        (2, 2, 2),
        (4, 4, 4),
        (8, 8, 8),
        (1, 1, 1))

    is_mc_v = tf.constant(False,dtype=tf.bool)
    is_mc_b = tf.constant(True,dtype=tf.bool)
    
    outputs = volume
    
    for ii, dilation_rate in enumerate(dilation_rates):
        outputs = _layer(
            outputs, mode=mode, layer_num=ii + 1, filters=params['n_filters'],
            kernel_size=3, dilation_rate=dilation_rate,keep_prob=params['keep_prob'], is_mc_v=is_mc_v, is_mc_b=is_mc_b)

    with tf.variable_scope('logits'):
        logits = vwn_conv.conv3d(
            inputs=outputs, filters=params['n_classes'], kernel_size=(1, 1, 1),
            padding='SAME', activation=None, is_mc=is_mc_v)
        
    predicted_classes = tf.argmax(logits, axis=-1)

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'class_ids': predicted_classes,
            'probabilities': tf.nn.softmax(logits),
            'logits': logits,
        }
        export_outputs = {
            'outputs': tf.estimator.export.PredictOutput(predictions)}
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            export_outputs=export_outputs)

    if params['prior_path'] != None:
        prior_np = np.load(params['prior_path'])
    
    with tf.variable_scope("prior"):
        i=-1
        for v in tf.get_collection('ms'):
            i += 1
            if params['prior_path'] == None:
                tf.add_to_collection('ms_prior',tf.Variable(tf.constant(0, dtype = v.dtype, shape = v.shape),trainable = False))
            else:
                tf.add_to_collection('ms_prior',tf.Variable(tf.convert_to_tensor(prior_np[0][i], dtype = tf.float32),trainable = False))
        
        ms = tf.get_collection('ms')
        ms_prior = tf.get_collection('ms_prior')

        i=-1
        for v in tf.get_collection('ms'):
            i += 1
            if params['prior_path'] == None:
                tf.add_to_collection('sigmas_prior',tf.Variable(tf.constant(1, dtype = v.dtype, shape = v.shape),trainable = False))
            else:
                tf.add_to_collection('sigmas_prior',tf.Variable(tf.convert_to_tensor(prior_np[1][i], dtype = tf.float32),trainable = False))

        sigmas = tf.get_collection('sigmas')
        sigmas_prior = tf.get_collection('sigmas_prior')
     
    nll_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits))
    tf.summary.scalar('nll_loss', nll_loss)
    
    n_examples = tf.constant(params['n_examples'],dtype=ms[0].dtype)
    tf.summary.scalar('n_examples', n_examples)
    
    l2_loss = tf.add_n([tf.reduce_sum((tf.square(ms[i] - ms_prior[i])) / ((tf.square(sigmas_prior[i]) + 1e-8) * 2.0)) for i in range(len(ms))], name = 'l2_loss') / n_examples
    tf.summary.scalar('l2_loss', l2_loss)
    
    loss = nll_loss + l2_loss

    assert mode == tf.estimator.ModeKeys.TRAIN

    global_step = tf.train.get_global_step()
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = params['optimizer'].minimize(loss, global_step=global_step)

    return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)