Example #1
0
    def up_transition_3d(self,
                         scope,
                         X,
                         F,
                         K,
                         S,
                         concat_var=None,
                         padding='VALID'):
        """
        Deconvolutions for the DenseUnet
        :param scope:
        :param X:
        :param F:
        :param K:
        :param S:
        :param padding:
        :param concat_var:
        :param summary:
        :return:
        """

        with tf.variable_scope(scope) as scope:

            # Set channel size based on input depth
            C = X.get_shape().as_list()[-1]

            # He init
            kernel = tf.get_variable(
                'Weights',
                shape=[F, F, F, K, C],
                initializer=tf.contrib.layers.variance_scaling_initializer())

            # Define the biases
            bias = tf.get_variable('Bias',
                                   shape=[K],
                                   initializer=tf.constant_initializer(0.0))

            # Add to the weights collection
            tf.add_to_collection('weights', kernel)
            tf.add_to_collection('biases', bias)

            # Define the output shape
            out_shape = X.get_shape().as_list()
            out_shape[1] *= 2
            out_shape[2] *= 2
            out_shape[3] *= 2
            out_shape[3] = K

            # Perform the deconvolution. output_shape: A 1-D Tensor representing the output shape of the deconvolution op.
            conv = tf.nn.conv2d_transpose(X,
                                          kernel,
                                          output_shape=out_shape,
                                          strides=[1, S, S, S, 1],
                                          padding=padding)

            # Add in bias
            conv = tf.nn.bias_add(conv, bias)

            # Concatenate
            conv = tf.concat([concat_var, conv], axis=-1)

            # Create a histogram summary and summary of sparsity
            if self.summary: self._activation_summary(conv)

            return conv
Example #2
0
    def up_transition(self,
                      scope,
                      X,
                      F,
                      K,
                      S,
                      concat_var=None,
                      padding='SAME',
                      res=True):
        """
        Performs an upsampling procedure
        :param scope:
        :param X: Inputs
        :param F: Filter sizes
        :param K: Kernel sizes
        :param S: Stride size
        :param concat_var: The skip connection
        :param padding: SAME or VALID. In general, use VALID for 3D skip connections
        :param res: Whether to concatenate or add the skip connection
        :return:
        """

        with tf.variable_scope(scope) as scope:

            # Set channel size based on input depth
            C = X.get_shape().as_list()[-1]

            # He init
            kernel = tf.get_variable(
                'Weights',
                shape=[F, F, K, C],
                initializer=tf.contrib.layers.variance_scaling_initializer())

            # Define the biases
            bias = tf.get_variable('Bias',
                                   shape=[K],
                                   initializer=tf.constant_initializer(0.0))

            # Add to the weights collection
            tf.add_to_collection('weights', kernel)
            tf.add_to_collection('biases', bias)

            # Define the output shape based on shape of skip connection
            out_shape = concat_var.get_shape().as_list()
            out_shape[-1] = K

            # Perform the deconvolution. output_shape: A 1-D Tensor representing the output shape of the deconvolution op.
            conv = tf.nn.conv2d_transpose(X,
                                          kernel,
                                          output_shape=out_shape,
                                          strides=[1, S, S, 1],
                                          padding=padding)

            # Concatenate
            if res: conv = tf.add(conv, concat_var)
            else: conv = tf.concat([concat_var, conv], axis=-1)

            # Apply the batch normalization. Updates weights during training phase only
            conv = self.batch_normalization(conv, self.phase_train, scope)

            # Add in bias
            conv = tf.nn.bias_add(conv, bias)

            # Relu
            conv = tf.nn.relu(conv, name=scope.name)

            # Create a histogram summary and summary of sparsity
            if self.summary: self._activation_summary(conv)

            return conv