def network_fn(self, x, is_training): keep_prob = self.keep_prob if is_training else 1.0 # if the input has more than 1 channel it has to be expanded because broadcasting only works for 1 input # channel input_channels = int(x.get_shape()[-1]) with tf.variable_scope('vnet/input_layer'): if input_channels == 1: x = tf.tile(x, [1, 1, 1, 1, self.num_channels]) else: x = self.activation_fn( convolution(x, [5, 5, 5, input_channels, self.num_channels])) features = list() for l in range(self.num_levels): with tf.variable_scope('vnet/encoder/level_' + str(l + 1)): x = convolution_block(x, self.num_convolutions[l], keep_prob, activation_fn=self.activation_fn) features.append(x) with tf.variable_scope('down_convolution'): x = self.activation_fn( down_convolution(x, factor=2, kernel_size=[2, 2, 2])) with tf.variable_scope('vnet/bottom_level'): x = convolution_block(x, self.bottom_convolutions, keep_prob, activation_fn=self.activation_fn) for l in reversed(range(self.num_levels)): with tf.variable_scope('vnet/decoder/level_' + str(l + 1)): f = features[l] with tf.variable_scope('up_convolution'): x = self.activation_fn( up_convolution(x, tf.shape(f), factor=2, kernel_size=[2, 2, 2])) x = convolution_block_2(x, f, self.num_convolutions[l], keep_prob, activation_fn=self.activation_fn) with tf.variable_scope('vnet/output_layer'): logits = convolution( x, [1, 1, 1, self.num_channels, self.num_classes]) return logits
def network_fn(self, x): # keep_prob = self.keep_prob if self.is_training else 1.0 # use 0.0 for tf 1.15 # keep_prob = self.keep_prob if self.is_training else 0.0 keep_prob = self.keep_prob # if the input has more than 1 channel it has to be expanded because broadcasting only works for 1 input # channel input_channels = int(x.get_shape()[-1]) with tf.variable_scope('vnet/input_layer'): if input_channels == 1: x = tf.tile(x, [1, 1, 1, 1, self.num_channels]) x = tf.layers.batch_normalization(x, momentum=0.99, epsilon=0.001,center=True, scale=True,training=self.train_phase) else: x = convolution(x, [5, 5, 5, input_channels, self.num_channels]) x = tf.layers.batch_normalization(x, momentum=0.99, epsilon=0.001,center=True, scale=True,training=self.train_phase) x = self.activation_fn(x) features = list() for l in range(self.num_levels): with tf.variable_scope('vnet/encoder/level_' + str(l + 1)): x = convolution_block(x, self.num_convolutions[l], keep_prob, activation_fn=self.activation_fn, is_training=self.train_phase) features.append(x) with tf.variable_scope('down_convolution'): x = down_convolution(x, factor=2, kernel_size=[2, 2, 2]) x = tf.layers.batch_normalization(x, momentum=0.99, epsilon=0.001,center=True, scale=True,training=self.train_phase) x = self.activation_fn(x) with tf.variable_scope('vnet/bottom_level'): x = convolution_block(x, self.bottom_convolutions, keep_prob, activation_fn=self.activation_fn, is_training=self.train_phase) for l in reversed(range(self.num_levels)): with tf.variable_scope('vnet/decoder/level_' + str(l + 1)): f = features[l] with tf.variable_scope('up_convolution'): x = up_convolution(x, tf.shape(f), factor=2, kernel_size=[2, 2, 2]) x = tf.layers.batch_normalization(x, momentum=0.99, epsilon=0.001,center=True, scale=True,training=self.train_phase) x = self.activation_fn(x) x = convolution_block_2(x, f, self.num_convolutions[l], keep_prob, activation_fn=self.activation_fn, is_training=self.train_phase) with tf.variable_scope('vnet/output_layer'): logits = convolution(x, [1, 1, 1, self.num_channels, self.num_classes]) logits = tf.layers.batch_normalization(logits, momentum=0.99, epsilon=0.001,center=True, scale=True,training=self.train_phase) return logits