Esempio n. 1
0
    def network_fn(self, x, is_training):

        keep_prob = self.keep_prob if is_training else 1.0
        # if the input has more than 1 channel it has to be expanded because broadcasting only works for 1 input
        # channel
        input_channels = int(x.get_shape()[-1])
        with tf.variable_scope('vnet/input_layer'):
            if input_channels == 1:
                x = tf.tile(x, [1, 1, 1, 1, self.num_channels])
            else:
                x = self.activation_fn(
                    convolution(x,
                                [5, 5, 5, input_channels, self.num_channels]))

        features = list()

        for l in range(self.num_levels):
            with tf.variable_scope('vnet/encoder/level_' + str(l + 1)):
                x = convolution_block(x,
                                      self.num_convolutions[l],
                                      keep_prob,
                                      activation_fn=self.activation_fn)
                features.append(x)
                with tf.variable_scope('down_convolution'):
                    x = self.activation_fn(
                        down_convolution(x, factor=2, kernel_size=[2, 2, 2]))

        with tf.variable_scope('vnet/bottom_level'):
            x = convolution_block(x,
                                  self.bottom_convolutions,
                                  keep_prob,
                                  activation_fn=self.activation_fn)

        for l in reversed(range(self.num_levels)):
            with tf.variable_scope('vnet/decoder/level_' + str(l + 1)):
                f = features[l]
                with tf.variable_scope('up_convolution'):
                    x = self.activation_fn(
                        up_convolution(x,
                                       tf.shape(f),
                                       factor=2,
                                       kernel_size=[2, 2, 2]))

                x = convolution_block_2(x,
                                        f,
                                        self.num_convolutions[l],
                                        keep_prob,
                                        activation_fn=self.activation_fn)

        with tf.variable_scope('vnet/output_layer'):
            logits = convolution(
                x, [1, 1, 1, self.num_channels, self.num_classes])

        return logits
Esempio n. 2
0
    def network_fn(self, x):
        # keep_prob = self.keep_prob if self.is_training else 1.0
        # use 0.0 for tf 1.15
        # keep_prob = self.keep_prob if self.is_training else 0.0
        keep_prob = self.keep_prob
        # if the input has more than 1 channel it has to be expanded because broadcasting only works for 1 input
        # channel
        input_channels = int(x.get_shape()[-1])
        with tf.variable_scope('vnet/input_layer'):
            if input_channels == 1:
                x = tf.tile(x, [1, 1, 1, 1, self.num_channels])
                x = tf.layers.batch_normalization(x, momentum=0.99, epsilon=0.001,center=True, scale=True,training=self.train_phase)
            else:
                x = convolution(x, [5, 5, 5, input_channels, self.num_channels])
                x = tf.layers.batch_normalization(x, momentum=0.99, epsilon=0.001,center=True, scale=True,training=self.train_phase)
                x = self.activation_fn(x)

        features = list()

        for l in range(self.num_levels):
            with tf.variable_scope('vnet/encoder/level_' + str(l + 1)):
                x = convolution_block(x, self.num_convolutions[l], keep_prob, activation_fn=self.activation_fn, is_training=self.train_phase)
                features.append(x)
                with tf.variable_scope('down_convolution'):
                    x = down_convolution(x, factor=2, kernel_size=[2, 2, 2])
                    x = tf.layers.batch_normalization(x, momentum=0.99, epsilon=0.001,center=True, scale=True,training=self.train_phase)
                    x = self.activation_fn(x)

        with tf.variable_scope('vnet/bottom_level'):
            x = convolution_block(x, self.bottom_convolutions, keep_prob, activation_fn=self.activation_fn, is_training=self.train_phase)

        for l in reversed(range(self.num_levels)):
            with tf.variable_scope('vnet/decoder/level_' + str(l + 1)):
                f = features[l]
                with tf.variable_scope('up_convolution'):
                    x = up_convolution(x, tf.shape(f), factor=2, kernel_size=[2, 2, 2])
                    x = tf.layers.batch_normalization(x, momentum=0.99, epsilon=0.001,center=True, scale=True,training=self.train_phase)
                    x = self.activation_fn(x)

                x = convolution_block_2(x, f, self.num_convolutions[l], keep_prob, activation_fn=self.activation_fn, is_training=self.train_phase)

        with tf.variable_scope('vnet/output_layer'):
            logits = convolution(x, [1, 1, 1, self.num_channels, self.num_classes])
            logits = tf.layers.batch_normalization(logits, momentum=0.99, epsilon=0.001,center=True, scale=True,training=self.train_phase)

        return logits