def model_fn(features, labels, mode, params, config=None): """MeshNet model function. Args: features: 5D float `Tensor`, input tensor. This is the first item returned from the `input_fn` passed to `train`, `evaluate`, and `predict`. Use `NDHWC` format. labels: 4D float `Tensor`, labels tensor. This is the second item returned from the `input_fn` passed to `train`, `evaluate`, and `predict`. Labels should not be one-hot encoded. mode: Optional. Specifies if this training, evaluation or prediction. params: `dict` of parameters. - n_classes: (required) number of classes to classify. - optimizer: instance of TensorFlow optimizer. Required if training. - n_filters: number of filters to use in each convolution. The original implementation used 21 filters to classify brainmask and 71 filters for the multi-class problem. - dropout_rate: rate of dropout. For example, 0.1 would drop 10% of input units. config: configuration object. Returns: `tf.estimator.EstimatorSpec` Raises: `ValueError` if required parameters are not in `params`. """ volume = features if isinstance(volume, dict): volume = features['volume'] required_keys = {'n_classes'} default_params = {'optimizer': None, 'n_filters': 96, 'keep_prob': 0.5} check_required_params(params=params, required_keys=required_keys) set_default_params(params=params, defaults=default_params) check_optimizer_for_training(optimizer=params['optimizer'], mode=mode) tf.logging.debug("Parameters for model:") tf.logging.debug(params) # Dilation rate by layer. dilation_rates = ( (1, 1, 1), (1, 1, 1), (1, 1, 1), (2, 2, 2), (4, 4, 4), (8, 8, 8), (1, 1, 1)) is_mc_v = tf.constant(False,dtype=tf.bool) is_mc_b = tf.constant(True,dtype=tf.bool) outputs = volume for ii, dilation_rate in enumerate(dilation_rates): outputs = _layer( outputs, mode=mode, layer_num=ii + 1, filters=params['n_filters'], kernel_size=3, dilation_rate=dilation_rate,keep_prob=params['keep_prob'], is_mc_v=is_mc_v, is_mc_b=is_mc_b) with tf.variable_scope('logits'): logits = vwn_conv.conv3d( inputs=outputs, filters=params['n_classes'], kernel_size=(1, 1, 1), padding='SAME', activation=None, is_mc=is_mc_v) predicted_classes = tf.argmax(logits, axis=-1) if mode == tf.estimator.ModeKeys.PREDICT: predictions = { 'class_ids': predicted_classes, 'probabilities': tf.nn.softmax(logits), 'logits': logits, } export_outputs = { 'outputs': tf.estimator.export.PredictOutput(predictions)} return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, export_outputs=export_outputs) if params['prior_path'] != None: prior_np = np.load(params['prior_path']) with tf.variable_scope("prior"): i=-1 for v in tf.get_collection('ms'): i += 1 if params['prior_path'] == None: tf.add_to_collection('ms_prior',tf.Variable(tf.constant(0, dtype = v.dtype, shape = v.shape),trainable = False)) else: tf.add_to_collection('ms_prior',tf.Variable(tf.convert_to_tensor(prior_np[0][i], dtype = tf.float32),trainable = False)) ms = tf.get_collection('ms') ms_prior = tf.get_collection('ms_prior') i=-1 for v in tf.get_collection('ms'): i += 1 if params['prior_path'] == None: tf.add_to_collection('sigmas_prior',tf.Variable(tf.constant(1, dtype = v.dtype, shape = v.shape),trainable = False)) else: tf.add_to_collection('sigmas_prior',tf.Variable(tf.convert_to_tensor(prior_np[1][i], dtype = tf.float32),trainable = False)) sigmas = tf.get_collection('sigmas') sigmas_prior = tf.get_collection('sigmas_prior') nll_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)) tf.summary.scalar('nll_loss', nll_loss) n_examples = tf.constant(params['n_examples'],dtype=ms[0].dtype) tf.summary.scalar('n_examples', n_examples) l2_loss = tf.add_n([tf.reduce_sum((tf.square(ms[i] - ms_prior[i])) / ((tf.square(sigmas_prior[i]) + 1e-8) * 2.0)) for i in range(len(ms))], name = 'l2_loss') / n_examples tf.summary.scalar('l2_loss', l2_loss) loss = nll_loss + l2_loss assert mode == tf.estimator.ModeKeys.TRAIN global_step = tf.train.get_global_step() update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = params['optimizer'].minimize(loss, global_step=global_step) return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
def model_fn(features, labels, mode, params, config=None): """HighRes3DNet model function. Args: features: 5D float `Tensor`, input tensor. This is the first item returned from the `input_fn` passed to `train`, `evaluate`, and `predict`. Use `NDHWC` format. labels: 4D float `Tensor`, labels tensor. This is the second item returned from the `input_fn` passed to `train`, `evaluate`, and `predict`. Labels should not be one-hot encoded. mode: Optional. Specifies if this training, evaluation or prediction. params: `dict` of parameters. All parameters below are required. - n_classes: (required) number of classes to classify. - optimizer: instance of TensorFlow optimizer. Required if training. config: configuration object. Returns: `tf.estimator.EstimatorSpec` Raises: `ValueError` if required parameters are not in `params`. """ required_keys = {'n_classes'} default_params = {'optimizer': None} check_required_params(params=params, required_keys=required_keys) set_default_params(params=params, defaults=default_params) check_optimizer_for_training(optimizer=params['optimizer'], mode=mode) tf.logging.debug("Parameters for model:") tf.logging.debug(params) training = mode == tf.estimator.ModeKeys.TRAIN with tf.variable_scope('conv_0'): conv = tf.layers.conv3d( features, filters=16, kernel_size=3, padding='SAME') with tf.variable_scope('batchnorm_0'): conv = tf.layers.batch_normalization( conv, training=training, fused=FUSED_BATCH_NORM) with tf.variable_scope('relu_0'): outputs = tf.nn.relu(conv) for ii in range(3): offset = 1 layer_num = ii + offset outputs = _resblock( outputs, mode=mode, layer_num=layer_num, filters=16, kernel_size=3, dilation_rate=1) for ii in range(3): offset = 4 layer_num = ii + offset outputs = _resblock( outputs, mode=mode, layer_num=layer_num, filters=16, kernel_size=3, dilation_rate=2) for ii in range(3): offset = 7 layer_num = ii + offset outputs = _resblock( outputs, mode=mode, layer_num=layer_num, filters=16, kernel_size=3, dilation_rate=4) with tf.variable_scope('logits'): logits = tf.layers.conv3d( outputs, filters=params['n_classes'], kernel_size=1, padding='SAME') predicted_classes = tf.argmax(logits, axis=-1) if mode == tf.estimator.ModeKeys.PREDICT: predictions = { 'class_ids': predicted_classes, 'probabilities': tf.nn.softmax(logits), 'logits': logits, } return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=labels, logits=logits) loss = tf.reduce_mean(cross_entropy) # Add evaluation metrics for class 1. labels = tf.cast(labels, predicted_classes.dtype) labels_onehot = tf.one_hot(labels, params['n_classes']) predictions_onehot = tf.one_hot(predicted_classes, params['n_classes']) eval_metric_ops = { 'accuracy': tf.metrics.accuracy(labels, predicted_classes), 'dice': streaming_dice( labels_onehot[..., 1], predictions_onehot[..., 1]), 'hamming': streaming_hamming( labels_onehot[..., 1], predictions_onehot[..., 1]), } if mode == tf.estimator.ModeKeys.EVAL: return tf.estimator.EstimatorSpec( mode=mode, loss=loss, eval_metric_ops=eval_metric_ops) assert mode == tf.estimator.ModeKeys.TRAIN global_step = tf.train.get_global_step() update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = params['optimizer'].minimize(loss, global_step=global_step) return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
def model_fn(features, labels, mode, params, config=None): """MeshNet model function. Args: features: 5D float `Tensor`, input tensor. This is the first item returned from the `input_fn` passed to `train`, `evaluate`, and `predict`. Use `NDHWC` format. labels: 4D float `Tensor`, labels tensor. This is the second item returned from the `input_fn` passed to `train`, `evaluate`, and `predict`. Labels should not be one-hot encoded. mode: Optional. Specifies if this training, evaluation or prediction. params: `dict` of parameters. - n_classes: (required) number of classes to classify. - optimizer: instance of TensorFlow optimizer. Required if training. - n_filters: number of filters to use in each convolution. The original implementation used 21 filters to classify brainmask and 71 filters for the multi-class problem. - dropout_rate: rate of dropout. For example, 0.1 would drop 10% of input units. config: configuration object. Returns: `tf.estimator.EstimatorSpec` Raises: `ValueError` if required parameters are not in `params`. """ required_keys = {'n_classes'} default_params = {'optimizer': None, 'n_filters': 21, 'dropout_rate': 0.25} check_required_params(params=params, required_keys=required_keys) set_default_params(params=params, defaults=default_params) check_optimizer_for_training(optimizer=params['optimizer'], mode=mode) tf.logging.debug("Parameters for model:") tf.logging.debug(params) # Dilation rate by layer. dilation_rates = ( (1, 1, 1), (1, 1, 1), (1, 1, 1), (2, 2, 2), (4, 4, 4), (8, 8, 8), (1, 1, 1), ) outputs = features for ii, dilation_rate in enumerate(dilation_rates): outputs = _layer( outputs, mode=mode, layer_num=ii + 1, filters=params['n_filters'], kernel_size=3, dilation_rate=dilation_rate, dropout_rate=params['dropout_rate'], ) with tf.variable_scope('logits'): logits = tf.layers.conv3d( inputs=outputs, filters=params['n_classes'], kernel_size=(1, 1, 1), padding='SAME', activation=None, ) predicted_classes = tf.argmax(logits, axis=-1) if mode == tf.estimator.ModeKeys.PREDICT: predictions = { 'class_ids': predicted_classes, 'probabilities': tf.nn.softmax(logits), 'logits': logits, } return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=labels, logits=logits) loss = tf.reduce_mean(cross_entropy) # Add evaluation metrics for class 1. labels = tf.cast(labels, predicted_classes.dtype) labels_onehot = tf.one_hot(labels, params['n_classes']) predictions_onehot = tf.one_hot(predicted_classes, params['n_classes']) eval_metric_ops = { 'accuracy': tf.metrics.accuracy(labels, predicted_classes), 'dice': streaming_dice( labels_onehot[..., 1], predictions_onehot[..., 1]), 'hamming': streaming_hamming( labels_onehot[..., 1], predictions_onehot[..., 1]), } if mode == tf.estimator.ModeKeys.EVAL: return tf.estimator.EstimatorSpec( mode=mode, loss=loss, eval_metric_ops=eval_metric_ops) assert mode == tf.estimator.ModeKeys.TRAIN global_step = tf.train.get_global_step() update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = params['optimizer'].minimize(loss, global_step=global_step) return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
def model_fn(features, labels, mode, params, config=None): """HighRes3DNet model function. Args: features: 5D float `Tensor`, input tensor. This is the first item returned from the `input_fn` passed to `train`, `evaluate`, and `predict`. Use `NDHWC` format. labels: 4D float `Tensor`, labels tensor. This is the second item returned from the `input_fn` passed to `train`, `evaluate`, and `predict`. Labels should not be one-hot encoded. mode: Optional. Specifies if this training, evaluation or prediction. params: `dict` of parameters. All parameters below are required. - n_classes: (required) number of classes to classify. - optimizer: instance of TensorFlow optimizer. Required if training. - one_batchnorm_per_resblock: (default false) if true, only apply first batch normalization layer in each residually connected block. Empirically, only using first batch normalization layer allowed the model to model to be trained on 128**3 float32 inputs. - dropout_rate: (default 0), value between 0 and 1, dropout rate to be applied immediately before last convolution layer. If 0 or false, dropout is not applied. config: configuration object. Returns: `tf.estimator.EstimatorSpec` Raises: `ValueError` if required parameters are not in `params`. """ volume = features if isinstance(volume, dict): volume = features['volume'] required_keys = {'n_classes'} default_params = { 'optimizer': None, 'one_batchnorm_per_resblock': False, 'dropout_rate': 0, } check_required_params(params=params, required_keys=required_keys) set_default_params(params=params, defaults=default_params) check_optimizer_for_training(optimizer=params['optimizer'], mode=mode) tf.logging.debug("Parameters for model:") tf.logging.debug(params) training = mode == tf.estimator.ModeKeys.TRAIN with tf.variable_scope('conv_0'): x = tf.layers.conv3d(volume, filters=16, kernel_size=3, padding='SAME') with tf.variable_scope('batchnorm_0'): x = tf.layers.batch_normalization(x, training=training, fused=FUSED_BATCH_NORM) with tf.variable_scope('relu_0'): x = tf.nn.relu(x) layer_num = 0 one_batchnorm = params['one_batchnorm_per_resblock'] # 16-filter residually connected blocks. for ii in range(3): layer_num += 1 x = _resblock(x, mode=mode, layer_num=layer_num, filters=16, kernel_size=3, dilation_rate=1, one_batchnorm=one_batchnorm) # 32-filter residually connected blocks. Pad inputs immediately before # first elementwise sum to match shape of last dimension. layer_num += 1 paddings = [[0, 0], [0, 0], [0, 0], [0, 0], [8, 8]] x = _resblock(x, mode=mode, layer_num=layer_num, filters=32, kernel_size=3, dilation_rate=2, paddings=paddings, one_batchnorm=one_batchnorm) for ii in range(2): layer_num += 1 x = _resblock(x, mode=mode, layer_num=layer_num, filters=32, kernel_size=3, dilation_rate=2, one_batchnorm=one_batchnorm) # 64-filter residually connected blocks. Pad inputs immediately before # first elementwise sum to match shape of last dimension. layer_num += 1 paddings = [[0, 0], [0, 0], [0, 0], [0, 0], [16, 16]] x = _resblock(x, mode=mode, layer_num=layer_num, filters=64, kernel_size=3, dilation_rate=4, paddings=paddings, one_batchnorm=one_batchnorm) for ii in range(2): layer_num += 1 x = _resblock(x, mode=mode, layer_num=layer_num, filters=64, kernel_size=3, dilation_rate=4, one_batchnorm=one_batchnorm) with tf.variable_scope('conv_1'): x = tf.layers.conv3d(x, filters=80, kernel_size=1, padding='SAME') if params['dropout_rate']: x = tf.layers.dropout(x, rate=params['dropout_rate'], training=training) with tf.variable_scope('logits'): logits = tf.layers.conv3d(x, filters=params['n_classes'], kernel_size=1, padding='SAME') predictions = tf.nn.softmax(logits=logits) predicted_classes = tf.argmax(logits, axis=-1) if mode == tf.estimator.ModeKeys.PREDICT: predictions = { 'class_ids': predicted_classes, 'probabilities': predictions, 'logits': logits } # Outputs for SavedModel. export_outputs = { 'outputs': tf.estimator.export.PredictOutput(predictions) } return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, export_outputs=export_outputs) onehot_labels = tf.one_hot(labels, params['n_classes']) # loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=onehot_labels, logits=logits) # loss = tf.losses.softmax_cross_entropy(onehot_labels=onehot_labels, logits=logits) # loss = losses.dice(labels=labels, predictions=predictions[..., 1], axis=(1, 2, 3)) loss = losses.tversky(labels=onehot_labels, predictions=predictions, axis=(1, 2, 3)) # loss = losses.generalized_dice(labels=onehot_labels, predictions=predictions, axis=(1, 2, 3)) if mode == tf.estimator.ModeKeys.EVAL: return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops={ 'dice': metrics.streaming_dice( labels, predicted_classes, axis=(1, 2, 3)), }) assert mode == tf.estimator.ModeKeys.TRAIN, "unknown mode key {}".format( "mode") global_step = tf.train.get_global_step() update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = params['optimizer'].minimize(loss, global_step=global_step) # Get Dice score of each class. dice_coefficients = tf.reduce_mean(metrics.dice(onehot_labels, tf.one_hot( tf.argmax(predictions, axis=-1), params['n_classes']), axis=(1, 2, 3)), axis=0) logging_hook = tf.train.LoggingTensorHook( { "loss": loss, "dice": dice_coefficients }, every_n_iter=100) return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op, training_hooks=[logging_hook])
def model_fn(features, labels, mode, params, config=None): """HighRes3DNet model function. Args: features: 5D float `Tensor`, input tensor. This is the first item returned from the `input_fn` passed to `train`, `evaluate`, and `predict`. Use `NDHWC` format. labels: 4D float `Tensor`, labels tensor. This is the second item returned from the `input_fn` passed to `train`, `evaluate`, and `predict`. Labels should not be one-hot encoded. mode: Optional. Specifies if this training, evaluation or prediction. params: `dict` of parameters. All parameters below are required. - n_classes: (required) number of classes to classify. - optimizer: instance of TensorFlow optimizer. Required if training. - one_batchnorm_per_resblock: (default false) if true, only apply first batch normalization layer in each residually connected block. Empirically, only using first batch normalization layer allowed the model to model to be trained on 128**3 float32 inputs. - dropout_rate: (default 0), value between 0 and 1, dropout rate to be applied immediately before last convolution layer. If 0 or false, dropout is not applied. config: configuration object. Returns: `tf.estimator.EstimatorSpec` Raises: `ValueError` if required parameters are not in `params`. """ volume = features if isinstance(volume, dict): volume = features['volume'] required_keys = {'n_classes'} default_params = { 'optimizer': None, 'one_batchnorm_per_resblock': False, 'dropout_rate': 0, } check_required_params(params=params, required_keys=required_keys) set_default_params(params=params, defaults=default_params) check_optimizer_for_training(optimizer=params['optimizer'], mode=mode) tf.logging.debug("Parameters for model:") tf.logging.debug(params) training = mode == tf.estimator.ModeKeys.TRAIN with tf.variable_scope('conv_0'): conv = tf.layers.conv3d( volume, filters=16, kernel_size=3, padding='SAME') with tf.variable_scope('batchnorm_0'): conv = tf.layers.batch_normalization( conv, training=training, fused=FUSED_BATCH_NORM) with tf.variable_scope('relu_0'): outputs = tf.nn.relu(conv) layer_num = 0 one_batchnorm = params['one_batchnorm_per_resblock'] for ii in range(3): layer_num += 1 outputs = _resblock( outputs, mode=mode, layer_num=layer_num, filters=16, kernel_size=3, dilation_rate=1, one_batchnorm=one_batchnorm) for ii in range(3): layer_num += 1 outputs = _resblock( outputs, mode=mode, layer_num=layer_num, filters=16, kernel_size=3, dilation_rate=2, one_batchnorm=one_batchnorm) for ii in range(3): layer_num += 1 outputs = _resblock( outputs, mode=mode, layer_num=layer_num, filters=16, kernel_size=3, dilation_rate=4, one_batchnorm=one_batchnorm) if params['dropout_rate']: outputs = tf.layers.dropout( outputs, rate=params['dropout_rate'], training=training) with tf.variable_scope('logits'): logits = tf.layers.conv3d( outputs, filters=params['n_classes'], kernel_size=1, padding='SAME') predicted_classes = tf.argmax(logits, axis=-1) if mode == tf.estimator.ModeKeys.PREDICT: predictions = { 'class_ids': predicted_classes, 'probabilities': tf.nn.softmax(logits), 'logits': logits} # Outputs for SavedModel. export_outputs = { 'outputs': tf.estimator.export.PredictOutput(predictions)} return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, export_outputs=export_outputs) loss = tf.losses.sparse_softmax_cross_entropy( labels=labels, logits=logits) # Add evaluation metrics for class 1. labels = tf.cast(labels, predicted_classes.dtype) labels_onehot = tf.one_hot(labels, params['n_classes']) predictions_onehot = tf.one_hot(predicted_classes, params['n_classes']) eval_metric_ops = { 'accuracy': tf.metrics.accuracy(labels, predicted_classes), 'dice': streaming_dice( labels_onehot[..., 1], predictions_onehot[..., 1]), 'hamming': streaming_hamming( labels_onehot[..., 1], predictions_onehot[..., 1]), } if mode == tf.estimator.ModeKeys.EVAL: return tf.estimator.EstimatorSpec( mode=mode, loss=loss, eval_metric_ops=eval_metric_ops) assert mode == tf.estimator.ModeKeys.TRAIN global_step = tf.train.get_global_step() update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = params['optimizer'].minimize(loss, global_step=global_step) return tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op)
def model_fn(features, labels, mode, params, config=None): """3D U-Net model function. Args: Returns: Raises: """ volume = features if isinstance(volume, dict): volume = features['volume'] required_keys = {'n_classes'} default_params = { 'optimizer': None, 'batchnorm': True, } check_required_params(params=params, required_keys=required_keys) set_default_params(params=params, defaults=default_params) check_optimizer_for_training(optimizer=params['optimizer'], mode=mode) bn = params['batchnorm'] # start encoding shortcut_1 = _conv_block(volume, filters1=32, filters2=64, mode=mode, layer_num=0, batchnorm=bn) with tf.variable_scope('maxpool_1'): x = tf.layers.max_pooling3d(inputs=shortcut_1, pool_size=(2, 2, 2), strides=(2, 2, 2), padding='same') shortcut_2 = _conv_block(x, filters1=64, filters2=128, mode=mode, layer_num=1, batchnorm=bn) with tf.variable_scope('maxpool_2'): x = tf.layers.max_pooling3d(inputs=shortcut_2, pool_size=(2, 2, 2), strides=(2, 2, 2), padding='same') shortcut_3 = _conv_block(x, filters1=128, filters2=256, mode=mode, layer_num=2, batchnorm=bn) with tf.variable_scope('maxpool_3'): x = tf.layers.max_pooling3d(inputs=shortcut_3, pool_size=(2, 2, 2), strides=(2, 2, 2), padding='same') x = _conv_block(x, filters1=256, filters2=512, mode=mode, layer_num=3, batchnorm=bn) # start decoding with tf.variable_scope("upconv_0"): x = tf.layers.conv3d_transpose(inputs=x, filters=512, kernel_size=(2, 2, 2), strides=(2, 2, 2), kernel_regularizer=_regularizer) with tf.variable_scope('concat_1'): x = tf.concat((shortcut_3, x), axis=-1) x = _conv_block(x, filters1=256, filters2=256, mode=mode, layer_num=4, batchnorm=bn) with tf.variable_scope("upconv_1"): x = tf.layers.conv3d_transpose(inputs=x, filters=256, kernel_size=(2, 2, 2), strides=(2, 2, 2), kernel_regularizer=_regularizer) with tf.variable_scope('concat_2'): x = tf.concat((shortcut_2, x), axis=-1) x = _conv_block(x, filters1=128, filters2=128, mode=mode, layer_num=5, batchnorm=bn) with tf.variable_scope("upconv_2"): x = tf.layers.conv3d_transpose(inputs=x, filters=128, kernel_size=(2, 2, 2), strides=(2, 2, 2), kernel_regularizer=_regularizer) with tf.variable_scope('concat_3'): x = tf.concat((shortcut_1, x), axis=-1) x = _conv_block(x, filters1=64, filters2=64, mode=mode, layer_num=6, batchnorm=bn) with tf.variable_scope('logits'): logits = tf.layers.conv3d(inputs=x, filters=params['n_classes'], kernel_size=(1, 1, 1), padding='same', activation=None, kernel_regularizer=_regularizer) # end decoding with tf.variable_scope('predictions'): predictions = tf.nn.softmax(logits=logits) class_ids = tf.argmax(logits, axis=-1) if mode == tf.estimator.ModeKeys.PREDICT: predictions = { 'class_ids': class_ids, 'probabilities': predictions, 'logits': logits } # Outputs for SavedModel. export_outputs = { 'outputs': tf.estimator.export.PredictOutput(predictions) } return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, export_outputs=export_outputs) onehot_labels = tf.one_hot(labels, params['n_classes']) loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=onehot_labels, logits=logits) # loss = tf.losses.softmax_cross_entropy(onehot_labels=onehot_labels, logits=logits) l2_loss = tf.losses.get_regularization_loss() loss += l2_loss if mode == tf.estimator.ModeKeys.EVAL: return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops={ 'dice': metrics.streaming_dice(labels, class_ids, axis=(1, 2, 3)), }) assert mode == tf.estimator.ModeKeys.TRAIN update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = params['optimizer'].minimize( loss, global_step=tf.train.get_global_step()) dice_coefficients = tf.reduce_mean( metrics.dice(onehot_labels, tf.one_hot(class_ids, axis=(1, 2, 3)), axis=0)) logging_hook = tf.train.LoggingTensorHook( { "loss": loss, "dice": dice_coefficients }, every_n_iter=100) return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op, training_hooks=[logging_hook])