コード例 #1
0
    def build_model(self, inputs, training=True, reuse=False):
        
        with var_storage.model_variable_scope(
            self.model_hparams.model_name,
            reuse=reuse,
            dtype=self.model_hparams.dtype):

            with tf.variable_scope("input_reshape"):
                if self.model_hparams.input_format == 'NHWC' and self.model_hparams.compute_format == 'NCHW':
                    # Reshape inputs: NHWC => NCHW
                    inputs = tf.transpose(inputs, [0, 3, 1, 2])

                elif self.model_hparams.input_format == 'NCHW' and self.model_hparams.compute_format == 'NHWC':
                    # Reshape inputs: NCHW => NHWC
                    inputs = tf.transpose(inputs, [0, 2, 3, 1])

            if self.model_hparams.dtype != inputs.dtype:
                inputs = tf.cast(inputs, self.model_hparams.dtype)

            net = blocks.conv2d_block(
                inputs,
                n_channels=64,
                kernel_size=(7, 7),
                strides=(2, 2),
                mode='SAME',
                use_batch_norm=True,
                activation='relu',
                is_training=training,
                data_format=self.model_hparams.compute_format,
                conv2d_hparams=self.conv2d_hparams,
                batch_norm_hparams=self.batch_norm_hparams,
                name='conv2d'
            )

            net = layers.max_pooling2d(
                net,
                pool_size=(3, 3),
                strides=(2, 2),
                padding='SAME',
                data_format=self.model_hparams.compute_format,
                name="max_pooling2d",
            )

            model_bottlenecks = self.model_hparams.layers_depth
            for block_id, block_bottleneck in enumerate(model_bottlenecks):
                for layer_id in range(self.model_hparams.layers_count[block_id]):
                    stride = 2 if (layer_id == 0 and block_id != 0) else 1

                    net = blocks.bottleneck_block(
                        inputs=net,
                        depth=block_bottleneck * self.model_hparams.expansions,
                        depth_bottleneck=block_bottleneck,
                        cardinality=self.model_hparams.cardinality,
                        stride=stride,
                        training=training,
                        data_format=self.model_hparams.compute_format,
                        conv2d_hparams=self.conv2d_hparams,
                        batch_norm_hparams=self.batch_norm_hparams,
                        block_name="btlnck_block_%d_%d" % (block_id, layer_id),
                        use_se=self.model_hparams.use_se,
                        ratio=self.model_hparams.se_ratio)

            with tf.variable_scope("output"):
                net = layers.reduce_mean(
                    net, keepdims=False, data_format=self.model_hparams.compute_format, name='spatial_mean')

                logits = layers.dense(
                    inputs=net,
                    units=self.model_hparams.n_classes,
                    use_bias=True,
                    trainable=training,
                    kernel_initializer=self.dense_hparams.kernel_initializer,
                    bias_initializer=self.dense_hparams.bias_initializer)

                if logits.dtype != tf.float32:
                    logits = tf.cast(logits, tf.float32)

                probs = layers.softmax(logits, name="softmax", axis=1)

            return probs, logits
コード例 #2
0
    def build_model(self, inputs, training=True, reuse=False):

        with var_storage.model_variable_scope(self.model_hparams.model_name,
                                              reuse=reuse,
                                              dtype=self.model_hparams.dtype):

            with tf.variable_scope("input_reshape"):

                if self.model_hparams.input_format == 'NHWC' and self.model_hparams.compute_format == 'NCHW':
                    # Reshape inputs: NHWC => NCHW
                    inputs = tf.transpose(inputs, [0, 3, 1, 2])

                elif self.model_hparams.input_format == 'NCHW' and self.model_hparams.compute_format == 'NHWC':

                    # Reshape inputs: NCHW => NHWC
                    inputs = tf.transpose(inputs, [0, 2, 3, 1])

            if self.model_hparams.dtype != inputs.dtype:
                inputs = tf.cast(inputs, self.model_hparams.dtype)

            net = blocks.conv2d_block(
                inputs,
                n_channels=64,
                # n_channels=16,
                kernel_size=(7, 7),
                strides=(2, 2),
                mode='SAME_RESNET',
                use_batch_norm=True,
                activation='relu',
                is_training=training,
                data_format=self.model_hparams.compute_format,
                conv2d_hparams=self.conv2d_hparams,
                batch_norm_hparams=self.batch_norm_hparams,
                name='conv2d')

            net = layers.max_pooling2d(
                net,
                pool_size=(3, 3),
                strides=(2, 2),
                padding='SAME',
                data_format=self.model_hparams.compute_format,
                name="max_pooling2d",
            )

            for block_id, _ in enumerate(
                    range(self.model_hparams.layer_counts[0])):
                net = blocks.bottleneck_block(
                    inputs=net,
                    depth=256,
                    depth_bottleneck=64,
                    stride=1,
                    training=training,
                    data_format=self.model_hparams.compute_format,
                    conv2d_hparams=self.conv2d_hparams,
                    batch_norm_hparams=self.batch_norm_hparams,
                    block_name="btlnck_block_1_%d" % (block_id + 1))

            for block_id, i in enumerate(
                    range(self.model_hparams.layer_counts[1])):
                stride = 2 if i == 0 else 1

                net = blocks.bottleneck_block(
                    inputs=net,
                    depth=512,
                    depth_bottleneck=128,
                    stride=stride,
                    training=training,
                    data_format=self.model_hparams.compute_format,
                    conv2d_hparams=self.conv2d_hparams,
                    batch_norm_hparams=self.batch_norm_hparams,
                    block_name="btlnck_block_2_%d" % (block_id + 1))

            for block_id, i in enumerate(
                    range(self.model_hparams.layer_counts[2])):
                block_id += 1
                stride = 2 if i == 0 else 1

                net = blocks.bottleneck_block(
                    inputs=net,
                    depth=1024,
                    depth_bottleneck=256,
                    stride=stride,
                    training=training,
                    data_format=self.model_hparams.compute_format,
                    conv2d_hparams=self.conv2d_hparams,
                    batch_norm_hparams=self.batch_norm_hparams,
                    block_name="btlnck_block_3_%d" % (block_id + 1))

            for block_id, i in enumerate(
                    range(self.model_hparams.layer_counts[3])):
                stride = 2 if i == 0 else 1

                net = blocks.bottleneck_block(
                    inputs=net,
                    depth=2048,
                    depth_bottleneck=512,
                    stride=stride,
                    training=training,
                    data_format=self.model_hparams.compute_format,
                    conv2d_hparams=self.conv2d_hparams,
                    batch_norm_hparams=self.batch_norm_hparams,
                    block_name="btlnck_block_4_%d" % (block_id + 1))

            with tf.variable_scope("output"):

                net = layers.reduce_mean(
                    net,
                    keepdims=False,
                    data_format=self.model_hparams.compute_format,
                    name='spatial_mean')

                logits = layers.dense(
                    inputs=net,
                    units=self.model_hparams.n_classes,
                    use_bias=True,
                    trainable=training,
                    kernel_initializer=self.dense_hparams.kernel_initializer,
                    bias_initializer=self.dense_hparams.bias_initializer)

                if logits.dtype != tf.float32:
                    logits = tf.cast(logits, tf.float32, name="logits")

                probs = layers.softmax(logits, name="softmax", axis=1)

            return probs, logits
コード例 #3
0
def _model_fn(features, labels, mode, params):
    """ Model function for tf.Estimator

    Controls how the training is performed by specifying how the
    total_loss is computed and applied in the backward pass.

    Args:
        features (tf.Tensor): Tensor samples
        labels (tf.Tensor): Tensor labels
        mode (tf.estimator.ModeKeys): Indicates if we train, evaluate or predict
        params (dict): Additional parameters supplied to the estimator

    Returns:
        Appropriate tf.estimator.EstimatorSpec for the current mode

    """
    dtype = params['dtype']
    max_steps = params['max_steps']
    lr_init = params['learning_rate']
    momentum = params['momentum']

    device = '/gpu:0'

    global_step = tf.train.get_global_step()
    learning_rate = tf.train.exponential_decay(lr_init, global_step,
                                               decay_steps=max_steps,
                                               decay_rate=0.96)

    with tf.device(device):
        features = tf.cast(features, dtype)

        with model_variable_scope(
                'UNet',
                reuse=tf.AUTO_REUSE,
                dtype=tf.float16,
                debug_mode=False
        ):
            output_map = unet_v1(features, mode)

            if mode == tf.estimator.ModeKeys.PREDICT:
                predictions = {'logits': tf.nn.softmax(output_map, axis=-1)}
                return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

            n_classes = output_map.shape[-1].value

            flat_logits = tf.reshape(tf.cast(output_map, tf.float32),
                                     [tf.shape(output_map)[0], -1, n_classes])
            flat_labels = tf.reshape(labels,
                                     [tf.shape(output_map)[0], -1, n_classes])

            crossentropy_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=flat_logits,
                                                                                          labels=flat_labels),
                                               name='cross_loss_ref')
            dice_loss = tf.reduce_mean(1 - dice_coef(flat_logits, flat_labels), name='dice_loss_ref')

            total_loss = tf.add(crossentropy_loss, dice_loss, name="total_loss_ref")

            opt = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=momentum)

            if is_using_hvd():
                opt = hvd.DistributedOptimizer(opt, device_dense='/gpu:0')

            with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                deterministic = True
                gate_gradients = (
                    tf.train.Optimizer.GATE_OP
                    if deterministic
                    else tf.train.Optimizer.GATE_NONE)

                train_op = opt.minimize(total_loss, gate_gradients=gate_gradients, global_step=global_step)

    return tf.estimator.EstimatorSpec(mode, loss=total_loss, train_op=train_op,
                                      eval_metric_ops={})
コード例 #4
0
def vnet_v2(features, labels, mode, params):
    is_training = (mode == tf.estimator.ModeKeys.TRAIN)
    is_eval = (mode == tf.estimator.ModeKeys.EVAL)
    is_predict = (mode == tf.estimator.ModeKeys.PREDICT)
    num_classes = len(params.labels)
    channel_axis = -1

    with model_variable_scope('vnet',
                              reuse=tf.AUTO_REUSE,
                              dtype=tf.float16,
                              debug_mode=False):
        features = tf.reshape(features,
                              [params.batch_size] + params.input_shape + [1])
        if labels is not None:
            labels = tf.reshape(labels,
                                [params.batch_size] + params.input_shape + [1])

        logits = Builder(kernel_size=params.convolution_size,
                         n_classes=num_classes,
                         downscale_blocks=params.downscale_blocks,
                         upscale_blocks=params.upscale_blocks,
                         upsampling=params.upsampling,
                         pooling=params.pooling,
                         normalization=params.normalization_layer,
                         activation=params.activation,
                         mode=mode)(features)

        softmax = tf.nn.softmax(logits=logits, axis=channel_axis)

        if is_predict:
            prediction = tf.argmax(input=softmax, axis=channel_axis)
            predictions = {'prediction': prediction}
            return tf.estimator.EstimatorSpec(mode=mode,
                                              predictions=predictions)

        # Flattened logits and softmax - in FP32
        flattened_softmax = tf.reshape(softmax,
                                       [tf.shape(logits)[0], -1, num_classes])
        flattened_softmax = tf.cast(flattened_softmax, tf.float32)

        # One hot encoding
        flattened_labels = tf.layers.flatten(labels)
        one_hot_labels = tf.one_hot(indices=flattened_labels,
                                    depth=num_classes,
                                    dtype=tf.float32)

        with tf.name_scope("loss"):
            if params.loss == 'dice':
                loss = dice_coef(predict=tf.cast(flattened_softmax,
                                                 tf.float32),
                                 target=one_hot_labels,
                                 dice_type='sorensen')
                total_loss = tf.identity(tf.reduce_sum(1. - loss),
                                         name='total_loss_ref')
            else:
                raise NotImplementedError

        train_op = None
        if is_training:
            global_step = tf.train.get_or_create_global_step()

            with tf.name_scope("optimizer"):
                if params.optimizer == 'rmsprop':
                    optimizer = tf.train.RMSPropOptimizer(
                        learning_rate=params.base_lr,
                        momentum=params.momentum,
                        centered=True)
                else:
                    raise NotImplementedError

            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                gradients, variables = zip(
                    *optimizer.compute_gradients(total_loss))
                if params.gradient_clipping == 'global_norm':
                    gradients, _ = tf.clip_by_global_norm(gradients, 1.0)
                    tf.logging.info('clipping: global_norm')
                else:
                    return NotImplementedError

                optimizer = hvd.DistributedOptimizer(optimizer)

                try:
                    amp_envar_enabled = (int(
                        os.environ['TF_ENABLE_AUTO_MIXED_PRECISION']) == 1)
                except KeyError:
                    amp_envar_enabled = False

                if params.use_amp and not amp_envar_enabled:
                    optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
                        optimizer, loss_scale='dynamic')

                train_op = optimizer.minimize(total_loss,
                                              global_step=global_step)

        eval_metric_ops = None
        if is_eval:
            dice_loss = dice_coef(predict=tf.cast(flattened_softmax,
                                                  tf.float32),
                                  target=one_hot_labels,
                                  dice_type='sorensen')
            eval_loss = tf.identity(dice_loss, name='eval_loss_ref')
            eval_metric_ops = {}
            for i in range(num_classes):
                eval_metric_ops['%s dice' %
                                params.labels[str(i)]] = tf.metrics.mean(
                                    eval_loss[i])

    return tf.estimator.EstimatorSpec(mode=mode,
                                      loss=total_loss,
                                      train_op=train_op,
                                      eval_metric_ops=eval_metric_ops)