コード例 #1
0
ファイル: train.py プロジェクト: Mulugeta/DLTK
def model_fn(features, labels, mode, params):
    """Model function to construct a tf.estimator.EstimatorSpec. It creates a
        network given input features (e.g. from a dltk.io.abstract_reader) and
        training targets (labels). Further, loss, optimiser, evaluation ops and
        custom tensorboard summary ops can be added. For additional information,
        please refer to https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator#model_fn.

    Args:
        features (tf.Tensor): Tensor of input features to train from. Required
            rank and dimensions are determined by the subsequent ops
            (i.e. the network).
        labels (tf.Tensor): Tensor of training targets or labels. Required rank
            and dimensions are determined by the network output.
        mode (str): One of the tf.estimator.ModeKeys: TRAIN, EVAL or PREDICT
        params (dict, optional): A dictionary to parameterise the model_fn
            (e.g. learning_rate)

    Returns:
        tf.estimator.EstimatorSpec: A custom EstimatorSpec for this experiment
    """

    # 1. create a model and its outputs
    net_output_ops = residual_unet_3d(
        inputs=features['x'],
        num_classes=NUM_CLASSES,
        num_res_units=2,
        filters=(16, 32, 64, 128),
        strides=((1, 1, 1), (1, 2, 2), (1, 2, 2), (1, 2, 2)),
        mode=mode,
        kernel_regularizer=tf.contrib.layers.l2_regularizer(1e-4))

    # 1.1 Generate predictions only (for `ModeKeys.PREDICT`)
    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=net_output_ops,
            export_outputs={'out': tf.estimator.export.PredictOutput(net_output_ops)})

    # 2. set up a loss function
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=net_output_ops['logits'],
        labels=labels['y'])
    loss = tf.reduce_mean(ce)

    # 3. define a training op and ops for updating moving averages
    # (i.e. for batch normalisation)
    global_step = tf.train.get_global_step()
    optimiser = tf.train.MomentumOptimizer(
        learning_rate=params["learning_rate"],
        momentum=0.9)

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = optimiser.minimize(loss, global_step=global_step)

    # 4.1 (optional) create custom image summaries for tensorboard
    my_image_summaries = {}
    my_image_summaries['feat_t1'] = features['x'][0, 0, :, :, 0]
    my_image_summaries['feat_t1_ir'] = features['x'][0, 0, :, :, 1]
    my_image_summaries['feat_t2_flair'] = features['x'][0, 0, :, :, 2]
    my_image_summaries['labels'] = tf.cast(labels['y'], tf.float32)[0, 0, :, :]
    my_image_summaries['predictions'] = tf.cast(net_output_ops['y_'], tf.float32)[0, 0, :, :]

    expected_output_size = [1, 128, 128, 1]  # [B, W, H, C]
    [tf.summary.image(name, tf.reshape(image, expected_output_size))
     for name, image in my_image_summaries.items()]

    # 4.2 (optional) create custom metric summaries for tensorboard
    dice_tensor = tf.py_func(dice, [net_output_ops['y_'],
                                    labels['y'],
                                    tf.constant(NUM_CLASSES)], tf.float32)
    [tf.summary.scalar('dsc_l{}'.format(i), dice_tensor[i])
     for i in range(NUM_CLASSES)]

    # 5. Return EstimatorSpec object
    return tf.estimator.EstimatorSpec(mode=mode,
                                      predictions=net_output_ops,
                                      loss=loss,
                                      train_op=train_op,
                                      eval_metric_ops=None)
コード例 #2
0
ファイル: train.py プロジェクト: sambuddinc/DLTK
def model_fn(features, labels, mode, params):
    """Model function to construct a tf.estimator.EstimatorSpec. It creates a
            network given input features (e.g. from a dltk.io.abstract_reader) and
            training targets (labels). Further, loss, optimiser, evaluation ops and
            custom tensorboard summary ops can be added. For additional information,
            please refer to https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator#model_fn.

        Args:
            features (tf.Tensor): Tensor of input features to train from. Required
                rank and dimensions are determined by the subsequent ops
                (i.e. the network).
            labels (tf.Tensor): Tensor of training targets or labels. Required rank
                and dimensions are determined by the network output.
            mode (str): One of the tf.estimator.ModeKeys: TRAIN, EVAL or PREDICT
            params (dict, optional): A dictionary to parameterise the model_fn
                (e.g. learning_rate)

        Returns:
            tf.estimator.EstimatorSpec: A custom EstimatorSpec for this experiment
        """
    print("Setting up U-Net")
    # 1. create a model and its outputs
    net_output_ops = residual_unet_3d(
        inputs=features['x'],
        num_classes=NUM_CLASSES,
        num_res_units=2,
        filters=(16, 32, 64, 128),
        strides=((1, 1, 1), (1, 2, 2), (1, 2, 2), (1, 2, 2)),
        mode=mode,
        kernel_regularizer=tf.contrib.layers.l2_regularizer(1e-4))

    # 1.1 Generate predictions only (for `ModeKeys.PREDICT`)
    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=net_output_ops,
            export_outputs={
                'out': tf.estimator.export.PredictOutput(net_output_ops)
            })

    # 2. set up a loss function
    # print(labels['y'])
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=net_output_ops['logits'], labels=labels['y'])
    loss = tf.reduce_mean(ce)

    # 3. define a training op and ops for updating moving averages
    # (i.e. for batch normalisation)
    global_step = tf.train.get_global_step()
    optimiser = tf.train.MomentumOptimizer(
        learning_rate=params["learning_rate"], momentum=0.9)

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = optimiser.minimize(loss, global_step=global_step)

    # 4.1 (optional) create custom image summaries for tensorboard
    my_image_summaries = {
        'feat_t2': features['x'][0, 0, :, :, 0],
        'labels': tf.cast(labels['y'], tf.float32)[0, 0, :, :],
        'predictions': tf.cast(net_output_ops['y_'], tf.float32)[0, 0, :, :]
    }

    expected_output_size = [1, 64, 64, 1]  # [B, W, H, C]
    [
        tf.summary.image(name, tf.reshape(image, expected_output_size))
        for name, image in my_image_summaries.items()
    ]

    # 4.2 (optional) create custom metric summaries for tensorboard
    dice_tensor = tf.py_func(
        dice, [net_output_ops['y_'], labels['y'],
               tf.constant(NUM_CLASSES)], tf.float32)
    [
        tf.summary.scalar('dsc_l{}'.format(i), dice_tensor[i])
        for i in range(NUM_CLASSES)
    ]

    # 5. Return EstimatorSpec object
    return tf.estimator.EstimatorSpec(mode=mode,
                                      predictions=net_output_ops,
                                      loss=loss,
                                      train_op=train_op,
                                      eval_metric_ops=None)
コード例 #3
0
    def model_fn(self, features, labels, mode, params):
        """Build architecture of network as an instance of tf.estimator.EstimatorSpec according
         to HPs from top-level optimiser.

        Args:
            features (TYPE): Description
            labels (TYPE): Description
            mode (TYPE): Description
            params (TYPE): Description

        Returns:
            TYPE: tf.estimator.EstimatorSpec
        """
        # 1. create a model and its outputs

        from dltk.core.metrics import dice
        from dltk.core.losses import sparse_balanced_crossentropy
        from dltk.networks.segmentation.unet import residual_unet_3d
        from dltk.networks.segmentation.unet import asymmetric_residual_unet_3d
        from dltk.networks.segmentation.fcn import residual_fcn_3d
        from dltk.core.activations import leaky_relu

        filters = params["filters"]
        strides = params["strides"]
        num_residual_units = params["num_residual_units"]
        loss_type = params["loss"]
        net = params["net"]

        def lrelu(x):
            return leaky_relu(x, 0.1)

        if net == 'fcn':
            net_output_ops = residual_fcn_3d(features['x'],
                                             self.NUM_CLASSES,
                                             num_res_units=num_residual_units,
                                             filters=filters,
                                             strides=strides,
                                             activation=lrelu,
                                             mode=mode)
        elif net == 'unet':
            net_output_ops = residual_unet_3d(features['x'],
                                              self.NUM_CLASSES,
                                              num_res_units=num_residual_units,
                                              filters=filters,
                                              strides=strides,
                                              activation=lrelu,
                                              mode=mode)
        elif net == 'asym_unet':
            net_output_ops = asymmetric_residual_unet_3d(
                features['x'],
                self.NUM_CLASSES,
                num_res_units=num_residual_units,
                filters=filters,
                strides=strides,
                activation=lrelu,
                mode=mode)

        # 1.1 Generate predictions only (for `ModeKeys.PREDICT`)
        if mode == tf.estimator.ModeKeys.PREDICT:
            return tf.estimator.EstimatorSpec(
                mode=mode,
                predictions=net_output_ops,
                export_outputs={
                    'out': tf.estimator.export.PredictOutput(net_output_ops)
                })

        # 2. set up a loss function
        if loss_type == 'ce':
            ce = tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=net_output_ops['logits'], labels=labels['y'])
            loss = tf.reduce_mean(ce)
        elif loss_type == 'balce':
            loss = sparse_balanced_crossentropy(net_output_ops['logits'],
                                                labels['y'])

        # 3. define a training op and ops for updating
        # moving averages (i.e. for batch normalisation)
        global_step = tf.train.get_or_create_global_step()
        if params["opt"] == 'adam':
            optimiser = tf.train.AdamOptimizer(
                learning_rate=params["learning_rate"], epsilon=1e-5)
        elif params["opt"] == 'momentum':
            optimiser = tf.train.MomentumOptimizer(
                learning_rate=params["learning_rate"], momentum=0.9)
        elif params["opt"] == 'rmsprop':
            optimiser = tf.train.RMSPropOptimizer(
                learning_rate=params["learning_rate"], momentum=0.9)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimiser.minimize(loss, global_step=global_step)

        # 4.1 (optional) create custom image summaries for tensorboard
        my_image_summaries = {}
        my_image_summaries['feat_t1'] = tf.expand_dims(
            features['x'][:, 0, :, :, 0], 3)
        my_image_summaries['labels'] = tf.expand_dims(
            tf.cast(labels['y'], tf.float32)[:, 0, :, :], 3)
        my_image_summaries['predictions'] = tf.expand_dims(
            tf.cast(net_output_ops['y_'], tf.float32)[:, 0, :, :], 3)

        [
            tf.summary.image(name, image)
            for name, image in my_image_summaries.items()
        ]

        # 4.2 (optional) create custom metric summaries for tensorboard
        dice_tensor = tf.py_func(
            dice,
            [net_output_ops['y_'], labels['y'],
             tf.constant(self.NUM_CLASSES)], tf.float32)

        [
            tf.summary.scalar('dsc_l{}'.format(i), dice_tensor[i])
            for i in range(self.NUM_CLASSES)
        ]

        # average dice over all classes
        dice_metric = tf.metrics.mean(dice_tensor)
        metrics = {'dice': dice_metric}

        # 5. Return EstimatorSpec object
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=net_output_ops,
            loss=loss,
            train_op=train_op,
            training_hooks=[self.training_metrics],
            eval_metric_ops=metrics)