def up_transition_3d(self, scope, X, F, K, S, concat_var=None, padding='VALID'): """ Deconvolutions for the DenseUnet :param scope: :param X: :param F: :param K: :param S: :param padding: :param concat_var: :param summary: :return: """ with tf.variable_scope(scope) as scope: # Set channel size based on input depth C = X.get_shape().as_list()[-1] # He init kernel = tf.get_variable( 'Weights', shape=[F, F, F, K, C], initializer=tf.contrib.layers.variance_scaling_initializer()) # Define the biases bias = tf.get_variable('Bias', shape=[K], initializer=tf.constant_initializer(0.0)) # Add to the weights collection tf.add_to_collection('weights', kernel) tf.add_to_collection('biases', bias) # Define the output shape out_shape = X.get_shape().as_list() out_shape[1] *= 2 out_shape[2] *= 2 out_shape[3] *= 2 out_shape[3] = K # Perform the deconvolution. output_shape: A 1-D Tensor representing the output shape of the deconvolution op. conv = tf.nn.conv2d_transpose(X, kernel, output_shape=out_shape, strides=[1, S, S, S, 1], padding=padding) # Add in bias conv = tf.nn.bias_add(conv, bias) # Concatenate conv = tf.concat([concat_var, conv], axis=-1) # Create a histogram summary and summary of sparsity if self.summary: self._activation_summary(conv) return conv
def up_transition(self, scope, X, F, K, S, concat_var=None, padding='SAME', res=True): """ Performs an upsampling procedure :param scope: :param X: Inputs :param F: Filter sizes :param K: Kernel sizes :param S: Stride size :param concat_var: The skip connection :param padding: SAME or VALID. In general, use VALID for 3D skip connections :param res: Whether to concatenate or add the skip connection :return: """ with tf.variable_scope(scope) as scope: # Set channel size based on input depth C = X.get_shape().as_list()[-1] # He init kernel = tf.get_variable( 'Weights', shape=[F, F, K, C], initializer=tf.contrib.layers.variance_scaling_initializer()) # Define the biases bias = tf.get_variable('Bias', shape=[K], initializer=tf.constant_initializer(0.0)) # Add to the weights collection tf.add_to_collection('weights', kernel) tf.add_to_collection('biases', bias) # Define the output shape based on shape of skip connection out_shape = concat_var.get_shape().as_list() out_shape[-1] = K # Perform the deconvolution. output_shape: A 1-D Tensor representing the output shape of the deconvolution op. conv = tf.nn.conv2d_transpose(X, kernel, output_shape=out_shape, strides=[1, S, S, 1], padding=padding) # Concatenate if res: conv = tf.add(conv, concat_var) else: conv = tf.concat([concat_var, conv], axis=-1) # Apply the batch normalization. Updates weights during training phase only conv = self.batch_normalization(conv, self.phase_train, scope) # Add in bias conv = tf.nn.bias_add(conv, bias) # Relu conv = tf.nn.relu(conv, name=scope.name) # Create a histogram summary and summary of sparsity if self.summary: self._activation_summary(conv) return conv