def softmax_index2d(indices, values, reduce=False): indices_shape = shape_list(indices) softmax_indices = tf.reshape( tf.nn.softmax( tf.reshape(indices, [-1, indices_shape[-1] * indices_shape[-2]])), indices_shape) softmax_indices = tf.complex(softmax_indices, tf.zeros_like(softmax_indices)) values = tf.complex(values, tf.zeros_like(values)) fft_of_answer = tf.conj( tf.batch_fft2d(softmax_indices)) * tf.batch_fft2d(values) if reduce: return tf.reduce_mean(tf.real(tf.batch_ifft(fft_of_answer)), -2) else: return tf.real(tf.batch_ifft2d(fft_of_answer))
def _inference(self, x, dropout): with tf.name_scope('conv1'): # Transform to Fourier domain x_2d = tf.reshape(x, [-1, 28, 28]) x_2d = tf.complex(x_2d, 0) xf_2d = tf.batch_fft2d(x_2d) xf = tf.reshape(xf_2d, [-1, NFEATURES]) xf = tf.expand_dims(xf, 1) # NSAMPLES x 1 x NFEATURES xf = tf.transpose(xf) # NFEATURES x 1 x NSAMPLES # Filter Wreal = self._weight_variable([int(NFEATURES / 2), self.F, 1]) Wimg = self._weight_variable([int(NFEATURES / 2), self.F, 1]) W = tf.complex(Wreal, Wimg) xf = xf[:int(NFEATURES / 2), :, :] yf = tf.matmul(W, xf) # for each feature yf = tf.concat(axis=0, values=[yf, tf.conj(yf)]) yf = tf.transpose(yf) # NSAMPLES x NFILTERS x NFEATURES yf_2d = tf.reshape(yf, [-1, 28, 28]) # Transform back to spatial domain y_2d = tf.ifft2d(yf_2d) y_2d = tf.real(y_2d) y = tf.reshape(y_2d, [-1, self.F, NFEATURES]) # Bias and non-linearity b = self._bias_variable([1, self.F, 1]) # b = self._bias_variable([1, self.F, NFEATURES]) y += b # NSAMPLES x NFILTERS x NFEATURES y = tf.nn.relu(y) with tf.name_scope('fc1'): W = self._weight_variable([self.F * NFEATURES, NCLASSES]) b = self._bias_variable([NCLASSES]) y = tf.reshape(y, [-1, self.F * NFEATURES]) y = tf.matmul(y, W) + b return y
def conv2d(self, source, filters, width, height, stride, activation='relu', name='conv2d'): # Normal convolution layer in_channels = source.get_shape().as_list()[3] with tf.variable_scope(name): spatial_filter = tf.get_variable( "weight", [height, width, in_channels, filters], initializer=tf.truncated_normal_initializer(0, stddev=0.01), dtype=tf.float32) b = tf.Variable(tf.constant(0.1, shape=[filters]), name="bias") # Run the filter through ifft(fft(x)) to demonstrate that those functions are inverses of one another spatial_filter_for_fft = tf.transpose(spatial_filter, [2, 3, 0, 1]) # Compute the spectral filter for visualization spectral_filter = tf.batch_fft2d( tf.complex(spatial_filter_for_fft, spatial_filter_for_fft * 0.0)) conv = tf.nn.conv2d(source, spatial_filter, strides=[1, stride, stride, 1], padding='SAME') output = tf.nn.bias_add(conv, b) output = tf.nn.relu(output) if activation is 'relu' else output return output, spatial_filter, spectral_filter
def random_spatial_to_spectral(self, channels, filters, height, width): # Create a truncated random image, then compute the FFT of that image and return it's values # used to initialize spectrally parameterized filters # an alternative to this is to initialize directly in the spectral domain w = tf.truncated_normal([channels, filters, height, width], mean=0, stddev=0.01) fft = tf.batch_fft2d(tf.complex(w, 0.0 * w), name='spectral_initializer') return fft.eval(session=self.sess)
def conv2d(self, source, filters, width, height, stride, activation='relu', name='conv2d'): # Normal convolution layer in_channels = source.get_shape().as_list()[3] with tf.variable_scope(name): spatial_filter = tf.get_variable("weight", [height, width, in_channels, filters], initializer=tf.truncated_normal_initializer(0, stddev=0.01), dtype=tf.float32) b = tf.Variable(tf.constant(0.1, shape=[filters]), name="bias") # Run the filter through ifft(fft(x)) to demonstrate that those functions are inverses of one another spatial_filter_for_fft = tf.transpose(spatial_filter, [2, 3, 0, 1]) # Compute the spectral filter for visualization spectral_filter = tf.batch_fft2d(tf.complex(spatial_filter_for_fft, spatial_filter_for_fft * 0.0)) conv = tf.nn.conv2d(source, spatial_filter, strides=[1, stride, stride, 1], padding='SAME') output = tf.nn.bias_add(conv, b) output = tf.nn.relu(output) if activation is 'relu' else output return output, spatial_filter, spectral_filter
def fft(net): net = tf.transpose(net, [0, 3, 1, 2]) # batch, channel, height, width net_fft = tf.batch_fft2d(tf.complex(net, 0.0 * net)) net_fft = tf.expand_dims(net_fft, 2) # batch, channels, filters, height, width net_fft = tf.tile(net_fft, [1, 1, filters, 1, 1])
def fft_conv_pure(self, source, filters, width, height, stride, activation='relu', name='fft_conv'): # This function applies the convolutional filter, which is stored in the spectral domain, as a element-wise # multiplication between the filter and the image (which has been transformed to the spectral domain) _, input_height, input_width, channels = source.get_shape().as_list() with tf.variable_scope(name): init = self.random_spatial_to_spectral(channels, filters, height, width) if name in self.initialization: init = self.initialization[name] # Option 1: Over-Parameterize fully in the spectral domain # w_real = tf.Variable(init.real, dtype=tf.float32, name='real') # w_imag = tf.Variable(init.imag, dtype=tf.float32, name='imag') # w = tf.cast(tf.complex(w_real, w_imag), tf.complex64) # Option 2: Parameterize only 'free' parameters in the spectral domain to enforce conjugate symmetry # This is very slow. w = self.spectral_to_variable(init) # Option 3: Parameterize in the spatial domain # w = tf.get_variable("weight", [channels, filters, height, width], # initializer=tf.truncated_normal_initializer(0, stddev=0.01), # dtype=tf.float32) # w = tf.batch_fft2d(tf.complex(w, w*0.0)) b = tf.Variable(tf.constant(0.1, shape=[filters])) # Add batch as a dimension for later broadcasting w = tf.expand_dims(w, 0) # batch, channels, filters, height, width # Prepare the source tensor for FFT source = tf.transpose(source, [0, 3, 1, 2]) # batch, channel, height, width source_fft = tf.batch_fft2d(tf.complex(source, 0.0 * source)) # Prepare the FFTd input tensor for element-wise multiplication with filter source_fft = tf.expand_dims( source_fft, 2) # batch, channels, filters, height, width source_fft = tf.tile(source_fft, [1, 1, filters, 1, 1]) # Shift, then pad the filter for element-wise multiplication, then unshift w_shifted = self.batch_fftshift2d(w) height_pad = (input_height - height) // 2 width_pad = (input_width - width) // 2 w_padded = tf.pad(w_shifted, [[0, 0], [0, 0], [0, 0], [height_pad, height_pad], [width_pad, width_pad]], mode='CONSTANT') # Pads with zeros w_padded = self.batch_ifftshift2d(w_padded) # Convolve with the filter in spectral domain conv = source_fft * tf.conj(w_padded) # Sum out the channel dimension, and prepare for bias_add # Note: The decision to sum out the channel dimension seems intuitive, but # not necessarily theoretically sound. conv = tf.real(tf.batch_ifft2d(conv)) conv = tf.reduce_sum( conv, reduction_indices=1) # batch, filters, height, width conv = tf.transpose(conv, [0, 2, 3, 1]) # batch, height, width, filters # Drop the batch dimension to keep things consistent with the other conv_op functions w = tf.squeeze(w, [0]) # channels, filters, height, width # Compute a spatial encoding of the filter for visualization spatial_filter = tf.batch_ifft2d(w) spatial_filter = tf.transpose( spatial_filter, [2, 3, 0, 1]) # height, width, channels, filters # Add the bias (in the spatial domain) output = tf.nn.bias_add(conv, b) output = tf.nn.relu(output) if activation is 'relu' else output return output, spatial_filter, w
def fft_conv_pure(self, source, filters, width, height, stride, activation='relu', name='fft_conv'): # This function applies the convolutional filter, which is stored in the spectral domain, as a element-wise # multiplication between the filter and the image (which has been transformed to the spectral domain) _, input_height, input_width, channels = source.get_shape().as_list() with tf.variable_scope(name): init = self.random_spatial_to_spectral(channels, filters, height, width) if name in self.initialization: init = self.initialization[name] # Option 1: Over-Parameterize fully in the spectral domain # w_real = tf.Variable(init.real, dtype=tf.float32, name='real') # w_imag = tf.Variable(init.imag, dtype=tf.float32, name='imag') # w = tf.cast(tf.complex(w_real, w_imag), tf.complex64) # Option 2: Parameterize only 'free' parameters in the spectral domain to enforce conjugate symmetry # This is very slow. w = self.spectral_to_variable(init) # Option 3: Parameterize in the spatial domain # w = tf.get_variable("weight", [channels, filters, height, width], # initializer=tf.truncated_normal_initializer(0, stddev=0.01), # dtype=tf.float32) # w = tf.batch_fft2d(tf.complex(w, w*0.0)) b = tf.Variable(tf.constant(0.1, shape=[filters])) # Add batch as a dimension for later broadcasting w = tf.expand_dims(w, 0) # batch, channels, filters, height, width # Prepare the source tensor for FFT source = tf.transpose(source, [0, 3, 1, 2]) # batch, channel, height, width source_fft = tf.batch_fft2d(tf.complex(source, 0.0 * source)) # Prepare the FFTd input tensor for element-wise multiplication with filter source_fft = tf.expand_dims(source_fft, 2) # batch, channels, filters, height, width source_fft = tf.tile(source_fft, [1, 1, filters, 1, 1]) # Shift, then pad the filter for element-wise multiplication, then unshift w_shifted = self.batch_fftshift2d(w) height_pad = (input_height - height) // 2 width_pad = (input_width - width) // 2 w_padded = tf.pad(w_shifted, [[0, 0], [0, 0], [0, 0], [height_pad, height_pad], [width_pad, width_pad]], mode='CONSTANT') # Pads with zeros w_padded = self.batch_ifftshift2d(w_padded) # Convolve with the filter in spectral domain conv = source_fft * tf.conj(w_padded) # Sum out the channel dimension, and prepare for bias_add # Note: The decision to sum out the channel dimension seems intuitive, but # not necessarily theoretically sound. conv = tf.real(tf.batch_ifft2d(conv)) conv = tf.reduce_sum(conv, reduction_indices=1) # batch, filters, height, width conv = tf.transpose(conv, [0, 2, 3, 1]) # batch, height, width, filters # Drop the batch dimension to keep things consistent with the other conv_op functions w = tf.squeeze(w, [0]) # channels, filters, height, width # Compute a spatial encoding of the filter for visualization spatial_filter = tf.batch_ifft2d(w) spatial_filter = tf.transpose(spatial_filter, [2, 3, 0, 1]) # height, width, channels, filters # Add the bias (in the spatial domain) output = tf.nn.bias_add(conv, b) output = tf.nn.relu(output) if activation is 'relu' else output return output, spatial_filter, w