def metric_fn(labels_r2, logits_r2): """Compute evaluation metrics.""" if labels_r2.dtype == tf.bfloat16: labels_r2 = tf.cast(labels_r2, tf.float32) if logits_r2.dtype == tf.bfloat16: logits_r2 = tf.cast(logits_r2, tf.float32) labels = tf.reshape(labels_r2, original_shape) logits = tf.reshape(logits_r2, original_shape) predictions = tf.nn.softmax(logits) categorical_crossentropy = tf.keras.losses.categorical_crossentropy( labels, predictions, from_logits=False) adaptive_dice32_val = metrics.adaptive_dice32(labels, predictions) return { 'accuracy': tf.metrics.accuracy( labels=tf.argmax(labels, -1), predictions=tf.argmax(predictions, -1)), 'adaptice_dice32': tf.metrics.mean(adaptive_dice32_val, name='adaptive_dice32'), 'categorical_crossentropy': tf.metrics.mean( categorical_crossentropy, name='categorical_crossentropy'), }
def _unet_model_fn(image, labels, mode, params): """Builds the UNet model graph, train op and eval metrics. Args: image: input image Tensor. Shape [x, y, z, num_channels]. labels: input label Tensor. Shape [x, y, z, num_classes]. mode: TRAIN, EVAL or PREDICT. params: model parameters dictionary. Returns: EstimatorSpec or TPUEstimatorSpec. """ with tf.variable_scope('base', reuse=tf.AUTO_REUSE): if params['use_bfloat16']: with tf.contrib.tpu.bfloat16_scope(): logits = unet3d_base( image, pool_size=(2, 2, 2), n_labels=params['num_classes'], deconvolution=params['deconvolution'], depth=params['depth'], n_base_filters=params['num_base_filters'], batch_normalization=params['use_batch_norm'], data_format=params['data_format']) else: with tf.variable_scope(''): logits = unet3d_base( image, pool_size=(2, 2, 2), n_labels=params['num_classes'], deconvolution=params['deconvolution'], depth=params['depth'], n_base_filters=params['num_base_filters'], batch_normalization=params['use_batch_norm'], data_format=params['data_format']) loss = None if mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL: with tf.variable_scope('loss', reuse=tf.AUTO_REUSE): if params['loss'] == 'adaptive_dice32': predictions = tf.nn.softmax(logits) assert ( labels.get_shape().as_list() == predictions.get_shape().as_list() ), 'predictions shape {} is not equal to label shape {}'.format( predictions.get_shape().as_list(), labels.get_shape().as_list()) loss = metrics.adaptive_dice32(labels, predictions) else: if mode == tf.estimator.ModeKeys.TRAIN and params[ 'use_index_label_in_train']: assert ( len(labels.get_shape().as_list()) + 1 == len( logits.get_shape().as_list()) ), 'logits shape {} is not equal to label shape {} plus one'.format( logits.get_shape().as_list(), labels.get_shape().as_list()) labels_idx = tf.cast(labels, dtype=tf.int32) else: assert ( labels.get_shape().as_list() == logits.get_shape().as_list() ), 'logits shape {} is not equal to label shape {}'.format( logits.get_shape().as_list(), labels.get_shape().as_list()) # Convert the one-hot encoding to label index. channel_dim = -1 labels_idx = tf.argmax(labels, axis=channel_dim, output_type=tf.int32) logits = tf.cast(logits, dtype=tf.float32) loss = tf.losses.sparse_softmax_cross_entropy( labels=labels_idx, logits=logits) if mode == tf.estimator.ModeKeys.TRAIN: learning_rate = tf.compat.v1.train.exponential_decay( float(params['init_learning_rate']), tf.compat.v1.train.get_or_create_global_step(), decay_steps=params['lr_decay_steps'], decay_rate=params['lr_decay_rate']) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) optimizer = create_optimizer(learning_rate, params) if params['use_tpu']: optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) minimize_op = optimizer.minimize(loss, tf.train.get_global_step()) with tf.control_dependencies(update_ops): train_op = minimize_op def host_call_fn(gs, lr): """Training host call. Creates scalar summaries for training metrics. Args: gs: `Tensor with shape `[batch]` for the global_step lr: `Tensor` with shape `[batch]` for the learning_rate. Returns: List of summary ops to run on the CPU host. """ gs = gs[0] with summary.create_file_writer( params['model_dir']).as_default(): with summary.always_record_summaries(): summary.scalar('learning_rate', lr[0], step=gs) return summary.all_summary_ops() # To log the loss, current learning rate, and epoch for Tensorboard, the # summary op needs to be run on the host CPU via host_call. host_call # expects [batch_size, ...] Tensors, thus reshape to introduce a batch # dimension. These Tensors are implicitly concatenated to # [params['batch_size']]. gs_t = tf.reshape(tf.train.get_global_step(), [1]) lr_t = tf.reshape(learning_rate, [1]) host_call = (host_call_fn, [gs_t, lr_t]) if params['use_tpu']: return tf.contrib.tpu.TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, host_call=host_call) # Note: hook cannot accesss tensors defined in model_fn in TPUEstimator. logging_hook = tf.train.LoggingTensorHook({'loss': loss}, every_n_iter=10) return tf.estimator.EstimatorSpec(mode=mode, loss=loss, training_hooks=[logging_hook], train_op=train_op) if mode == tf.estimator.ModeKeys.EVAL: # Reshape labels/logits to R2 tensor to avoid TPU padding issue. # TPU tends to pad the last dimension to 128x, # and the second to last dimension to 8x. labels_r2 = tf.reshape(labels, [params['eval_batch_size'], -1]) logits_r2 = tf.reshape(logits, [params['eval_batch_size'], -1]) original_shape = [params['eval_batch_size'] ] + (params['input_image_size'] + [-1]) if params['use_tpu']: return tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, eval_metrics=(get_metric_fn(original_shape), [labels_r2, logits_r2])) return tf.estimator.EstimatorSpec( mode=mode, loss=loss, eval_metrics=(get_metric_fn(original_shape), [labels_r2, logits_r2])) if mode == tf.estimator.ModeKeys.PREDICT: predictions = { 'classes': tf.identity(tf.math.argmax(logits, axis=-1), 'Classes'), 'scores': tf.identity(tf.nn.softmax(logits, axis=-1), 'Scores'), } if params['use_tpu']: return tf.contrib.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, export_outputs={ 'classify': tf.estimator.export.PredictOutput(predictions) }) return tf.estimator.EstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, export_outputs={ 'classify': tf.estimator.export.PredictOutput(predictions) })