def _layer(inputs, mode, layer_num, filters, kernel_size, dilation_rate, is_mc_b, is_mc_g, use_expectation): """Layer building block of MeshNet. Performs 3D convolution, activation, batch normalization, and dropout on `inputs` tensor. Args: inputs : float `Tensor`, input tensor. mode : string, a TensorFlow mode key. layer_num : int, value to append to each operator name. This should be the layer number in the network. filters : int, number of 3D convolution filters. kernel_size : int or tuple, size of 3D convolution kernel. dilation_rate : int or tuple, rate of dilution in 3D convolution. dropout_rate : float, the dropout rate between 0 and 1. Returns: `Tensor` of same type as `inputs`. """ training = mode == tf.estimator.ModeKeys.TRAIN with tf.variable_scope('layer_{}'.format(layer_num)): conv = vwn_conv.conv3d(inputs, filters=filters, kernel_size=kernel_size, padding='SAME', dilation_rate=dilation_rate, activation=None, is_mc=is_mc_g) conv = concrete_dropout(conv, is_mc_b, filters, use_expectation=use_expectation) return tf.nn.relu(conv)
def model_fn(features, labels, mode, params, config=None): """MeshNet model function. Args: features: 5D float `Tensor`, input tensor. This is the first item returned from the `input_fn` passed to `train`, `evaluate`, and `predict`. Use `NDHWC` format. labels: 4D float `Tensor`, labels tensor. This is the second item returned from the `input_fn` passed to `train`, `evaluate`, and `predict`. Labels should not be one-hot encoded. mode: Optional. Specifies if this training, evaluation or prediction. params: `dict` of parameters. - n_classes: (required) number of classes to classify. - optimizer: instance of TensorFlow optimizer. Required if training. - n_filters: number of filters to use in each convolution. The original implementation used 21 filters to classify brainmask and 71 filters for the multi-class problem. - dropout_rate: rate of dropout. For example, 0.1 would drop 10% of input units. config: configuration object. Returns: `tf.estimator.EstimatorSpec` Raises: `ValueError` if required parameters are not in `params`. """ volume = features if isinstance(volume, dict): volume = features['volume'] required_keys = {'n_classes'} default_params = {'optimizer': None, 'n_filters': 96, 'keep_prob': 0.5} check_required_params(params=params, required_keys=required_keys) set_default_params(params=params, defaults=default_params) check_optimizer_for_training(optimizer=params['optimizer'], mode=mode) tf.logging.debug("Parameters for model:") tf.logging.debug(params) # Dilation rate by layer. dilation_rates = ( (1, 1, 1), (1, 1, 1), (1, 1, 1), (2, 2, 2), (4, 4, 4), (8, 8, 8), (1, 1, 1)) is_mc_v = tf.constant(False,dtype=tf.bool) is_mc_b = tf.constant(True,dtype=tf.bool) outputs = volume for ii, dilation_rate in enumerate(dilation_rates): outputs = _layer( outputs, mode=mode, layer_num=ii + 1, filters=params['n_filters'], kernel_size=3, dilation_rate=dilation_rate,keep_prob=params['keep_prob'], is_mc_v=is_mc_v, is_mc_b=is_mc_b) with tf.variable_scope('logits'): logits = vwn_conv.conv3d( inputs=outputs, filters=params['n_classes'], kernel_size=(1, 1, 1), padding='SAME', activation=None, is_mc=is_mc_v) predicted_classes = tf.argmax(logits, axis=-1) if mode == tf.estimator.ModeKeys.PREDICT: predictions = { 'class_ids': predicted_classes, 'probabilities': tf.nn.softmax(logits), 'logits': logits, } export_outputs = { 'outputs': tf.estimator.export.PredictOutput(predictions)} return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, export_outputs=export_outputs) if params['prior_path'] != None: prior_np = np.load(params['prior_path']) with tf.variable_scope("prior"): i=-1 for v in tf.get_collection('ms'): i += 1 if params['prior_path'] == None: tf.add_to_collection('ms_prior',tf.Variable(tf.constant(0, dtype = v.dtype, shape = v.shape),trainable = False)) else: tf.add_to_collection('ms_prior',tf.Variable(tf.convert_to_tensor(prior_np[0][i], dtype = tf.float32),trainable = False)) ms = tf.get_collection('ms') ms_prior = tf.get_collection('ms_prior') i=-1 for v in tf.get_collection('ms'): i += 1 if params['prior_path'] == None: tf.add_to_collection('sigmas_prior',tf.Variable(tf.constant(1, dtype = v.dtype, shape = v.shape),trainable = False)) else: tf.add_to_collection('sigmas_prior',tf.Variable(tf.convert_to_tensor(prior_np[1][i], dtype = tf.float32),trainable = False)) sigmas = tf.get_collection('sigmas') sigmas_prior = tf.get_collection('sigmas_prior') nll_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)) tf.summary.scalar('nll_loss', nll_loss) n_examples = tf.constant(params['n_examples'],dtype=ms[0].dtype) tf.summary.scalar('n_examples', n_examples) l2_loss = tf.add_n([tf.reduce_sum((tf.square(ms[i] - ms_prior[i])) / ((tf.square(sigmas_prior[i]) + 1e-8) * 2.0)) for i in range(len(ms))], name = 'l2_loss') / n_examples tf.summary.scalar('l2_loss', l2_loss) loss = nll_loss + l2_loss assert mode == tf.estimator.ModeKeys.TRAIN global_step = tf.train.get_global_step() update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = params['optimizer'].minimize(loss, global_step=global_step) return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)