Esempio n. 1
0
def _model_fn(features, labels, mode, params):
    """ Model function for tf.Estimator

    Controls how the training is performed by specifying how the
    total_loss is computed and applied in the backward pass.

    Args:
        features (tf.Tensor): Tensor samples
        labels (tf.Tensor): Tensor labels
        mode (tf.estimator.ModeKeys): Indicates if we train, evaluate or predict
        params (dict): Additional parameters supplied to the estimator

    Returns:
        Appropriate tf.estimator.EstimatorSpec for the current mode

    """
    dtype = params['dtype']
    max_steps = params['max_steps']
    lr_init = params['learning_rate']
    momentum = params['momentum']

    device = '/gpu:0'

    global_step = tf.train.get_global_step()
    learning_rate = tf.train.exponential_decay(lr_init,
                                               global_step,
                                               decay_steps=max_steps,
                                               decay_rate=0.96)

    with tf.device(device):
        features = tf.cast(features, dtype)

        with model_variable_scope('UNet',
                                  reuse=tf.AUTO_REUSE,
                                  dtype=tf.float16,
                                  debug_mode=False):
            output_map = unet_v1(features, mode)

            if mode == tf.estimator.ModeKeys.PREDICT:
                predictions = {'logits': tf.nn.softmax(output_map, axis=-1)}
                return tf.estimator.EstimatorSpec(mode=mode,
                                                  predictions=predictions)

            n_classes = output_map.shape[-1].value

            flat_logits = tf.reshape(tf.cast(output_map, tf.float32),
                                     [tf.shape(output_map)[0], -1, n_classes])
            flat_labels = tf.reshape(labels,
                                     [tf.shape(output_map)[0], -1, n_classes])

            crossentropy_loss = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits_v2(logits=flat_logits,
                                                           labels=flat_labels),
                name='cross_loss_ref')
            dice_loss = tf.reduce_mean(1 - dice_coef(flat_logits, flat_labels),
                                       name='dice_loss_ref')

            total_loss = tf.add(crossentropy_loss,
                                dice_loss,
                                name="total_loss_ref")

            opt = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                             momentum=momentum)

            if is_using_hvd():
                opt = hvd.DistributedOptimizer(opt, device_dense='/gpu:0')

            with tf.control_dependencies(
                    tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                deterministic = True
                gate_gradients = (tf.train.Optimizer.GATE_OP if deterministic
                                  else tf.train.Optimizer.GATE_NONE)

                train_op = opt.minimize(total_loss,
                                        gate_gradients=gate_gradients,
                                        global_step=global_step)

    return tf.estimator.EstimatorSpec(mode,
                                      loss=total_loss,
                                      train_op=train_op,
                                      eval_metric_ops={})
Esempio n. 2
0
def unet_fn(features, labels, mode, params):
    """ Model function for tf.Estimator

    Controls how the training is performed by specifying how the
    total_loss is computed and applied in the backward pass.

    Args:
        features (tf.Tensor): Tensor samples
        labels (tf.Tensor): Tensor labels
        mode (tf.estimator.ModeKeys): Indicates if we train, evaluate or predict
        params (dict): Additional parameters supplied to the estimator

    Returns:
        Appropriate tf.estimator.EstimatorSpec for the current mode

    """
    dtype = tf.float32

    device = '/gpu:0'

    global_step = tf.compat.v1.train.get_global_step()

    if mode == tf.estimator.ModeKeys.TRAIN:
        lr_init = params.learning_rate

    with tf.device(device):
        features = tf.cast(features, dtype)

        output_map = unet_v1(features=features, mode=mode)

        if mode == tf.estimator.ModeKeys.PREDICT:
            predictions = {'logits': tf.nn.softmax(output_map, axis=-1)}
            return tf.estimator.EstimatorSpec(mode=mode,
                                              predictions=predictions)

        n_classes = output_map.shape[-1].value

        flat_logits = tf.reshape(tf.cast(output_map, tf.float32),
                                 [tf.shape(output_map)[0], -1, n_classes])
        flat_labels = tf.reshape(labels,
                                 [tf.shape(output_map)[0], -1, n_classes])

        crossentropy_loss = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits_v2(logits=flat_logits,
                                                       labels=flat_labels),
            name='cross_loss_ref')
        dice_loss = tf.reduce_mean(1 - dice_coef(
            tf.keras.activations.softmax(flat_logits, axis=-1), flat_labels),
                                   name='dice_loss_ref')
        total_loss = tf.add(crossentropy_loss,
                            dice_loss,
                            name="total_loss_ref")

        if mode == tf.estimator.ModeKeys.EVAL:
            eval_metric_ops = {
                "eval_ce_loss": tf.compat.v1.metrics.mean(crossentropy_loss),
                "eval_dice_loss": tf.compat.v1.metrics.mean(dice_loss),
                "eval_total_loss": tf.compat.v1.metrics.mean(total_loss),
                "eval_dice_score": tf.compat.v1.metrics.mean(1.0 - dice_loss)
            }
            return tf.estimator.EstimatorSpec(mode=mode,
                                              loss=dice_loss,
                                              eval_metric_ops=eval_metric_ops)

        opt = tf.compat.v1.train.AdamOptimizer(learning_rate=lr_init)

        if is_using_hvd():
            opt = hvd.DistributedOptimizer(opt, device_dense='/gpu:0')

        with tf.control_dependencies(
                tf.compat.v1.get_collection(
                    tf.compat.v1.GraphKeys.UPDATE_OPS)):
            deterministic = True
            gate_gradients = (tf.compat.v1.train.Optimizer.GATE_OP
                              if deterministic else
                              tf.compat.v1.train.Optimizer.GATE_NONE)

            train_op = opt.minimize(total_loss,
                                    gate_gradients=gate_gradients,
                                    global_step=global_step)

    return tf.estimator.EstimatorSpec(mode,
                                      loss=total_loss,
                                      train_op=train_op,
                                      eval_metric_ops={})