def build_model(self, inputs, training=True, reuse=False): with var_storage.model_variable_scope( self.model_hparams.model_name, reuse=reuse, dtype=self.model_hparams.dtype): with tf.variable_scope("input_reshape"): if self.model_hparams.input_format == 'NHWC' and self.model_hparams.compute_format == 'NCHW': # Reshape inputs: NHWC => NCHW inputs = tf.transpose(inputs, [0, 3, 1, 2]) elif self.model_hparams.input_format == 'NCHW' and self.model_hparams.compute_format == 'NHWC': # Reshape inputs: NCHW => NHWC inputs = tf.transpose(inputs, [0, 2, 3, 1]) if self.model_hparams.dtype != inputs.dtype: inputs = tf.cast(inputs, self.model_hparams.dtype) net = blocks.conv2d_block( inputs, n_channels=64, kernel_size=(7, 7), strides=(2, 2), mode='SAME', use_batch_norm=True, activation='relu', is_training=training, data_format=self.model_hparams.compute_format, conv2d_hparams=self.conv2d_hparams, batch_norm_hparams=self.batch_norm_hparams, name='conv2d' ) net = layers.max_pooling2d( net, pool_size=(3, 3), strides=(2, 2), padding='SAME', data_format=self.model_hparams.compute_format, name="max_pooling2d", ) model_bottlenecks = self.model_hparams.layers_depth for block_id, block_bottleneck in enumerate(model_bottlenecks): for layer_id in range(self.model_hparams.layers_count[block_id]): stride = 2 if (layer_id == 0 and block_id != 0) else 1 net = blocks.bottleneck_block( inputs=net, depth=block_bottleneck * self.model_hparams.expansions, depth_bottleneck=block_bottleneck, cardinality=self.model_hparams.cardinality, stride=stride, training=training, data_format=self.model_hparams.compute_format, conv2d_hparams=self.conv2d_hparams, batch_norm_hparams=self.batch_norm_hparams, block_name="btlnck_block_%d_%d" % (block_id, layer_id), use_se=self.model_hparams.use_se, ratio=self.model_hparams.se_ratio) with tf.variable_scope("output"): net = layers.reduce_mean( net, keepdims=False, data_format=self.model_hparams.compute_format, name='spatial_mean') logits = layers.dense( inputs=net, units=self.model_hparams.n_classes, use_bias=True, trainable=training, kernel_initializer=self.dense_hparams.kernel_initializer, bias_initializer=self.dense_hparams.bias_initializer) if logits.dtype != tf.float32: logits = tf.cast(logits, tf.float32) probs = layers.softmax(logits, name="softmax", axis=1) return probs, logits
def build_model(self, inputs, training=True, reuse=False): with var_storage.model_variable_scope(self.model_hparams.model_name, reuse=reuse, dtype=self.model_hparams.dtype): with tf.variable_scope("input_reshape"): if self.model_hparams.input_format == 'NHWC' and self.model_hparams.compute_format == 'NCHW': # Reshape inputs: NHWC => NCHW inputs = tf.transpose(inputs, [0, 3, 1, 2]) elif self.model_hparams.input_format == 'NCHW' and self.model_hparams.compute_format == 'NHWC': # Reshape inputs: NCHW => NHWC inputs = tf.transpose(inputs, [0, 2, 3, 1]) if self.model_hparams.dtype != inputs.dtype: inputs = tf.cast(inputs, self.model_hparams.dtype) net = blocks.conv2d_block( inputs, n_channels=64, # n_channels=16, kernel_size=(7, 7), strides=(2, 2), mode='SAME_RESNET', use_batch_norm=True, activation='relu', is_training=training, data_format=self.model_hparams.compute_format, conv2d_hparams=self.conv2d_hparams, batch_norm_hparams=self.batch_norm_hparams, name='conv2d') net = layers.max_pooling2d( net, pool_size=(3, 3), strides=(2, 2), padding='SAME', data_format=self.model_hparams.compute_format, name="max_pooling2d", ) for block_id, _ in enumerate( range(self.model_hparams.layer_counts[0])): net = blocks.bottleneck_block( inputs=net, depth=256, depth_bottleneck=64, stride=1, training=training, data_format=self.model_hparams.compute_format, conv2d_hparams=self.conv2d_hparams, batch_norm_hparams=self.batch_norm_hparams, block_name="btlnck_block_1_%d" % (block_id + 1)) for block_id, i in enumerate( range(self.model_hparams.layer_counts[1])): stride = 2 if i == 0 else 1 net = blocks.bottleneck_block( inputs=net, depth=512, depth_bottleneck=128, stride=stride, training=training, data_format=self.model_hparams.compute_format, conv2d_hparams=self.conv2d_hparams, batch_norm_hparams=self.batch_norm_hparams, block_name="btlnck_block_2_%d" % (block_id + 1)) for block_id, i in enumerate( range(self.model_hparams.layer_counts[2])): block_id += 1 stride = 2 if i == 0 else 1 net = blocks.bottleneck_block( inputs=net, depth=1024, depth_bottleneck=256, stride=stride, training=training, data_format=self.model_hparams.compute_format, conv2d_hparams=self.conv2d_hparams, batch_norm_hparams=self.batch_norm_hparams, block_name="btlnck_block_3_%d" % (block_id + 1)) for block_id, i in enumerate( range(self.model_hparams.layer_counts[3])): stride = 2 if i == 0 else 1 net = blocks.bottleneck_block( inputs=net, depth=2048, depth_bottleneck=512, stride=stride, training=training, data_format=self.model_hparams.compute_format, conv2d_hparams=self.conv2d_hparams, batch_norm_hparams=self.batch_norm_hparams, block_name="btlnck_block_4_%d" % (block_id + 1)) with tf.variable_scope("output"): net = layers.reduce_mean( net, keepdims=False, data_format=self.model_hparams.compute_format, name='spatial_mean') logits = layers.dense( inputs=net, units=self.model_hparams.n_classes, use_bias=True, trainable=training, kernel_initializer=self.dense_hparams.kernel_initializer, bias_initializer=self.dense_hparams.bias_initializer) if logits.dtype != tf.float32: logits = tf.cast(logits, tf.float32, name="logits") probs = layers.softmax(logits, name="softmax", axis=1) return probs, logits
def _model_fn(features, labels, mode, params): """ Model function for tf.Estimator Controls how the training is performed by specifying how the total_loss is computed and applied in the backward pass. Args: features (tf.Tensor): Tensor samples labels (tf.Tensor): Tensor labels mode (tf.estimator.ModeKeys): Indicates if we train, evaluate or predict params (dict): Additional parameters supplied to the estimator Returns: Appropriate tf.estimator.EstimatorSpec for the current mode """ dtype = params['dtype'] max_steps = params['max_steps'] lr_init = params['learning_rate'] momentum = params['momentum'] device = '/gpu:0' global_step = tf.train.get_global_step() learning_rate = tf.train.exponential_decay(lr_init, global_step, decay_steps=max_steps, decay_rate=0.96) with tf.device(device): features = tf.cast(features, dtype) with model_variable_scope( 'UNet', reuse=tf.AUTO_REUSE, dtype=tf.float16, debug_mode=False ): output_map = unet_v1(features, mode) if mode == tf.estimator.ModeKeys.PREDICT: predictions = {'logits': tf.nn.softmax(output_map, axis=-1)} return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) n_classes = output_map.shape[-1].value flat_logits = tf.reshape(tf.cast(output_map, tf.float32), [tf.shape(output_map)[0], -1, n_classes]) flat_labels = tf.reshape(labels, [tf.shape(output_map)[0], -1, n_classes]) crossentropy_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=flat_logits, labels=flat_labels), name='cross_loss_ref') dice_loss = tf.reduce_mean(1 - dice_coef(flat_logits, flat_labels), name='dice_loss_ref') total_loss = tf.add(crossentropy_loss, dice_loss, name="total_loss_ref") opt = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=momentum) if is_using_hvd(): opt = hvd.DistributedOptimizer(opt, device_dense='/gpu:0') with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): deterministic = True gate_gradients = ( tf.train.Optimizer.GATE_OP if deterministic else tf.train.Optimizer.GATE_NONE) train_op = opt.minimize(total_loss, gate_gradients=gate_gradients, global_step=global_step) return tf.estimator.EstimatorSpec(mode, loss=total_loss, train_op=train_op, eval_metric_ops={})
def vnet_v2(features, labels, mode, params): is_training = (mode == tf.estimator.ModeKeys.TRAIN) is_eval = (mode == tf.estimator.ModeKeys.EVAL) is_predict = (mode == tf.estimator.ModeKeys.PREDICT) num_classes = len(params.labels) channel_axis = -1 with model_variable_scope('vnet', reuse=tf.AUTO_REUSE, dtype=tf.float16, debug_mode=False): features = tf.reshape(features, [params.batch_size] + params.input_shape + [1]) if labels is not None: labels = tf.reshape(labels, [params.batch_size] + params.input_shape + [1]) logits = Builder(kernel_size=params.convolution_size, n_classes=num_classes, downscale_blocks=params.downscale_blocks, upscale_blocks=params.upscale_blocks, upsampling=params.upsampling, pooling=params.pooling, normalization=params.normalization_layer, activation=params.activation, mode=mode)(features) softmax = tf.nn.softmax(logits=logits, axis=channel_axis) if is_predict: prediction = tf.argmax(input=softmax, axis=channel_axis) predictions = {'prediction': prediction} return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) # Flattened logits and softmax - in FP32 flattened_softmax = tf.reshape(softmax, [tf.shape(logits)[0], -1, num_classes]) flattened_softmax = tf.cast(flattened_softmax, tf.float32) # One hot encoding flattened_labels = tf.layers.flatten(labels) one_hot_labels = tf.one_hot(indices=flattened_labels, depth=num_classes, dtype=tf.float32) with tf.name_scope("loss"): if params.loss == 'dice': loss = dice_coef(predict=tf.cast(flattened_softmax, tf.float32), target=one_hot_labels, dice_type='sorensen') total_loss = tf.identity(tf.reduce_sum(1. - loss), name='total_loss_ref') else: raise NotImplementedError train_op = None if is_training: global_step = tf.train.get_or_create_global_step() with tf.name_scope("optimizer"): if params.optimizer == 'rmsprop': optimizer = tf.train.RMSPropOptimizer( learning_rate=params.base_lr, momentum=params.momentum, centered=True) else: raise NotImplementedError update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): gradients, variables = zip( *optimizer.compute_gradients(total_loss)) if params.gradient_clipping == 'global_norm': gradients, _ = tf.clip_by_global_norm(gradients, 1.0) tf.logging.info('clipping: global_norm') else: return NotImplementedError optimizer = hvd.DistributedOptimizer(optimizer) try: amp_envar_enabled = (int( os.environ['TF_ENABLE_AUTO_MIXED_PRECISION']) == 1) except KeyError: amp_envar_enabled = False if params.use_amp and not amp_envar_enabled: optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( optimizer, loss_scale='dynamic') train_op = optimizer.minimize(total_loss, global_step=global_step) eval_metric_ops = None if is_eval: dice_loss = dice_coef(predict=tf.cast(flattened_softmax, tf.float32), target=one_hot_labels, dice_type='sorensen') eval_loss = tf.identity(dice_loss, name='eval_loss_ref') eval_metric_ops = {} for i in range(num_classes): eval_metric_ops['%s dice' % params.labels[str(i)]] = tf.metrics.mean( eval_loss[i]) return tf.estimator.EstimatorSpec(mode=mode, loss=total_loss, train_op=train_op, eval_metric_ops=eval_metric_ops)