def _model_fn(features, labels, mode, params): """ Model function for tf.Estimator Controls how the training is performed by specifying how the total_loss is computed and applied in the backward pass. Args: features (tf.Tensor): Tensor samples labels (tf.Tensor): Tensor labels mode (tf.estimator.ModeKeys): Indicates if we train, evaluate or predict params (dict): Additional parameters supplied to the estimator Returns: Appropriate tf.estimator.EstimatorSpec for the current mode """ dtype = params['dtype'] max_steps = params['max_steps'] lr_init = params['learning_rate'] momentum = params['momentum'] device = '/gpu:0' global_step = tf.train.get_global_step() learning_rate = tf.train.exponential_decay(lr_init, global_step, decay_steps=max_steps, decay_rate=0.96) with tf.device(device): features = tf.cast(features, dtype) with model_variable_scope('UNet', reuse=tf.AUTO_REUSE, dtype=tf.float16, debug_mode=False): output_map = unet_v1(features, mode) if mode == tf.estimator.ModeKeys.PREDICT: predictions = {'logits': tf.nn.softmax(output_map, axis=-1)} return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) n_classes = output_map.shape[-1].value flat_logits = tf.reshape(tf.cast(output_map, tf.float32), [tf.shape(output_map)[0], -1, n_classes]) flat_labels = tf.reshape(labels, [tf.shape(output_map)[0], -1, n_classes]) crossentropy_loss = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits_v2(logits=flat_logits, labels=flat_labels), name='cross_loss_ref') dice_loss = tf.reduce_mean(1 - dice_coef(flat_logits, flat_labels), name='dice_loss_ref') total_loss = tf.add(crossentropy_loss, dice_loss, name="total_loss_ref") opt = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=momentum) if is_using_hvd(): opt = hvd.DistributedOptimizer(opt, device_dense='/gpu:0') with tf.control_dependencies( tf.get_collection(tf.GraphKeys.UPDATE_OPS)): deterministic = True gate_gradients = (tf.train.Optimizer.GATE_OP if deterministic else tf.train.Optimizer.GATE_NONE) train_op = opt.minimize(total_loss, gate_gradients=gate_gradients, global_step=global_step) return tf.estimator.EstimatorSpec(mode, loss=total_loss, train_op=train_op, eval_metric_ops={})
def unet_fn(features, labels, mode, params): """ Model function for tf.Estimator Controls how the training is performed by specifying how the total_loss is computed and applied in the backward pass. Args: features (tf.Tensor): Tensor samples labels (tf.Tensor): Tensor labels mode (tf.estimator.ModeKeys): Indicates if we train, evaluate or predict params (dict): Additional parameters supplied to the estimator Returns: Appropriate tf.estimator.EstimatorSpec for the current mode """ dtype = tf.float32 device = '/gpu:0' global_step = tf.compat.v1.train.get_global_step() if mode == tf.estimator.ModeKeys.TRAIN: lr_init = params.learning_rate with tf.device(device): features = tf.cast(features, dtype) output_map = unet_v1(features=features, mode=mode) if mode == tf.estimator.ModeKeys.PREDICT: predictions = {'logits': tf.nn.softmax(output_map, axis=-1)} return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) n_classes = output_map.shape[-1].value flat_logits = tf.reshape(tf.cast(output_map, tf.float32), [tf.shape(output_map)[0], -1, n_classes]) flat_labels = tf.reshape(labels, [tf.shape(output_map)[0], -1, n_classes]) crossentropy_loss = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits_v2(logits=flat_logits, labels=flat_labels), name='cross_loss_ref') dice_loss = tf.reduce_mean(1 - dice_coef( tf.keras.activations.softmax(flat_logits, axis=-1), flat_labels), name='dice_loss_ref') total_loss = tf.add(crossentropy_loss, dice_loss, name="total_loss_ref") if mode == tf.estimator.ModeKeys.EVAL: eval_metric_ops = { "eval_ce_loss": tf.compat.v1.metrics.mean(crossentropy_loss), "eval_dice_loss": tf.compat.v1.metrics.mean(dice_loss), "eval_total_loss": tf.compat.v1.metrics.mean(total_loss), "eval_dice_score": tf.compat.v1.metrics.mean(1.0 - dice_loss) } return tf.estimator.EstimatorSpec(mode=mode, loss=dice_loss, eval_metric_ops=eval_metric_ops) opt = tf.compat.v1.train.AdamOptimizer(learning_rate=lr_init) if is_using_hvd(): opt = hvd.DistributedOptimizer(opt, device_dense='/gpu:0') with tf.control_dependencies( tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.UPDATE_OPS)): deterministic = True gate_gradients = (tf.compat.v1.train.Optimizer.GATE_OP if deterministic else tf.compat.v1.train.Optimizer.GATE_NONE) train_op = opt.minimize(total_loss, gate_gradients=gate_gradients, global_step=global_step) return tf.estimator.EstimatorSpec(mode, loss=total_loss, train_op=train_op, eval_metric_ops={})