def bilinear_pool(x1, x2, x3, output_size): """ Computes approximation of bilinear pooling with respect to x1, x2. For detailed explaination, see the paper (https://arxiv.org/abs/1511.06062) Args: x1: A `Tensor` with shape (batch_size, x1_size). x2: A `Tensor` with shape ((batch_size, x2_size). output_size: Output projection size. (`int`) Returns: A Tensor with shape (batch_size, output_size). """ p1 = count_sketch(x1, output_size) p2 = count_sketch(x2, output_size) p3 = count_sketch(x3, output_size) pc1 = tf.complex(p1, tf.zeros_like(p1)) pc2 = tf.complex(p2, tf.zeros_like(p2)) pc3 = tf.complex(p3, tf.zeros_like(p3)) conved = tf.ifft(tf.fft(pc1) * tf.fft(pc2) * tf.fft(pc3)) conved2 = tf.ifft(tf.fft(pc1) * tf.fft(pc2) * tf.fft(pc3) * 1) return tf.real(conved), tf.real(conved2)
def deconv(comb_loss, loss_a, loss_b): x = tf.cast(comb_loss[:, 0], dtype=tf.complex64) y = tf.cast(comb_loss[:, 1], dtype=tf.complex64) yfft = tf.fft(y) ay = tf.cast(loss_a[:, 1], dtype=tf.complex64) by = tf.cast(loss_b[:, 1], dtype=tf.complex64) ayfft = tf.fft(ay) byfft = tf.fft(by) ayfftest = yfft / byfft ayest = tf.abs(tf.ifft(ayfftest)) aest = tf.cast(tf.abs(tf.ifft(ayfft)), dtype=tf.float64) best = tf.cast(tf.abs(tf.ifft(byfft)), dtype=tf.float64) return best
def istft(spec, overlap=4): assert (spec.shape[0] > 1) S = placeholder(dtype=tf.complex64, shape=spec.shape) X = tf.complex_abs(tf.concat(0, [tf.ifft(frame) \ for frame in tf.unstack(S)])) sess = tf.Session() return sess.run(X, feed_dict={S: spec})
def model(self, sinogram): """ main model for the FDK algorithm Parameters ---------- sinogram : ndarray The projection data used for reconstruction Returns ------- ndarray the reconstructed CT data """ sinogram_cos = tf.multiply(sinogram, self.cosine_weight) weighted_sino_fft = tf.fft(tf.cast(sinogram_cos, dtype=tf.complex64)) filtered_sinogram_fft = tf.multiply( weighted_sino_fft, tf.cast(self.filter, dtype=tf.complex64)) filtered_sinogram = tf.real(tf.ifft(filtered_sinogram_fft)) reconstruction = cone_backprojection3d(filtered_sinogram, self.geometry, hardware_interp=True) return reconstruction
def _get_w(self, last_w, raw_output, memory): with tf.variable_scope('get_w'): k, beta, g, s, gamma = tf.split( raw_output, [self.M, self.N, 1, self.N, self.N], axis=1) # memory_row_norm = tf.reduce_sum(tf.abs(memory), axis=1) memory = tf.nn.l2_normalize(memory, axis=1) k = tf.nn.l2_normalize(k, axis=1) product = tf.squeeze(tf.matmul(tf.expand_dims(k, 1), memory)) # omit division of `k` since it is canceled after softmax anyway beta = tf.nn.softplus(beta, name='beta') # w_c = tf.nn.softmax(beta * product / memory_row_norm, name='w_c') w_c = tf.nn.softmax(beta * product, name='w_c') g = tf.sigmoid(g, name='gate_parameter') w_g = tf.add(g * w_c, (1 - g) * last_w, name='w_g') w_g = tf.cast(w_g, tf.complex64) # s = tf.cast(tf.nn.softmax(s), tf.complex64) s = tf.cast(tf.tanh(s, name='shift'), tf.complex64) w_tild = tf.real(tf.ifft(tf.fft(w_g) * tf.fft(s)), name='w_tild') gamma = tf.add(tf.nn.softplus(gamma), 1, name='gamma') # w = tf.pow(w_tild, gamma) # w = w / tf.reduce_sum(w) # FIXME: tf.log yields lots of NaN here!! # w = tf.nn.softmax(gamma * tf.log(tf.nn.softmax(w_tild))) w = tf.nn.softmax(gamma * tf.log(tf.nn.softplus(w_tild)), name='new_weight') # w = tf.nn.softmax(gamma * tf.log(tf.nn.sigmoid(w_tild))) # w = tf.nn.softmax(gamma * w_tild) return w
def bandPass(signal, low_cut, high_cut, sample_length, sample_rate): ''' band pass filter args: signal = input signal low_cut = filtering bandwidth Hz (lower bound) high_cut = filtering bandwidth Hz (upper bound) sample_length = total signal length (number of samples) sample_rate = sampling rate of the input signal return: filtered.real = filtered signal ''' ratio = int(sample_length / sample_rate) with tf.Graph().as_default(): signal = tf.Variable(signal, dtype=tf.complex64) fft = tf.fft(signal) with tf.Session() as sess: tf.variables_initializer([signal]).run() result = sess.run(fft) for i in range(high_cut * ratio, len(result) - (high_cut * ratio) + 1): result[i] = 0 for i in range(0, (low_cut * ratio) + 1): result[i] = 0 for i in range(len(result) - (low_cut * ratio), len(result)): result[i] = 0 ifft = tf.ifft(result) with tf.Session() as sess: filtered = sess.run(ifft) return filtered.real
def lowPass(signal, cut, sample_length, sample_rate): ''' low pass filter args: signal = input signal cut = filtering bandwidth Hz sample_length = total signal length (number of samples) sample_rate = sampling rate of the input signal return: filtered.real = filtered signal ''' ratio = int(sample_length / sample_rate) with tf.Graph().as_default(): signal = tf.Variable(signal, dtype=tf.complex64) fft = tf.fft(signal) with tf.Session() as sess: tf.variables_initializer([signal]).run() result = sess.run(fft) for i in range(cut * ratio, len(result) - (cut * ratio) + 1): #what is the intuition behind this?? result[i] = 0 ifft = tf.ifft(result) with tf.Session() as sess: filtered = sess.run(ifft) return filtered.real
def istft(inp, n_overlap): inp_sz = inp.get_shape().as_list() if len(inp_sz) > 3: inp = tf.reshape(inp, (np.prod(inp_sz[:-2]), inp_sz[-2], inp_sz[-1])) batch_size, n_frames, n_freqs = inp.get_shape().as_list() n_frames = int(int(float(n_frames) / n_overlap) * n_overlap) inp = inp[:, :n_frames, :] batch_size, n_frames, n_freqs = inp.get_shape().as_list() x = tf.real(tf.ifft(inp)) x = tf.reshape(x, (batch_size, -1, n_overlap, n_freqs)) x = tf.transpose(x, (0, 2, 1, 3)) x = tf.reshape(x, (batch_size, n_overlap, -1)) x_list = tf.unstack(x, axis=1) skip = n_freqs / n_overlap for i in range(n_overlap): # x_sep[i] = tf.manip.roll(x_sep[i], i*wind_size/4, 2) if i == 0: x_list[i] = x_list[i][:, (n_overlap - i - 1) * skip:] else: x_list[i] = x_list[i][:, (n_overlap - i - 1) * skip:-i * skip] x = tf.add_n(x_list) / float(n_overlap) if len(inp_sz) > 3: x_sz = x.get_shape().as_list() x = tf.reshape(x, inp_sz[:-2] + x_sz[-1:]) return x
def forward_proj_domain(self, sinogram): """ the projection domain of the model for processing sinograms Parameters ---------- sinogram : ndarray The projection data used for processing and reconstruction Returns ------- ndarray the sinograms after processing """ # U-Net added in the projection domain #################################################################### sinogram = tf.expand_dims(sinogram, 3) sinogram = self.unet_model(sinogram) sinogram = tf.squeeze(sinogram, axis=3) #################################################################### self.sinogram_cosine = tf.multiply(sinogram, self.cosine_weight) self.weighted_sinogram_fft = tf.fft( tf.cast(self.sinogram_cosine, dtype=tf.complex64)) self.filtered_sinogram_fft = tf.multiply( self.weighted_sinogram_fft, tf.cast(self.recon_filter, dtype=tf.complex64)) self.filtered_sinogram = tf.real(tf.ifft(self.filtered_sinogram_fft)) return self.filtered_sinogram
def tensor_product(self, P, ch1, Q, ch2): P_hat = tf.fft(tf.complex(P, tf.zeros(tf.shape(P), dtype=tf.float32))) Q_hat = tf.fft(tf.complex(Q, tf.zeros(tf.shape(Q), dtype=tf.float32))) p_hat_list = [tf.squeeze(p) for p in tf.split(P_hat, self.k, axis=-1)] q_hat_list = [tf.squeeze(q) for q in tf.split(Q_hat, self.k, axis=-1)] if ch1 == 't' and ch2 == 't': S_hat = tf.concat([ tf.expand_dims(tf.matmul(tf.transpose(p_hat), tf.transpose(q_hat)), axis=-1) for (p_hat, q_hat) in zip(p_hat_list, q_hat_list) ], axis=-1) elif ch1 == 't': S_hat = tf.concat([ tf.expand_dims(tf.matmul(tf.transpose(p_hat), q_hat), axis=-1) for (p_hat, q_hat) in zip(p_hat_list, q_hat_list) ], axis=-1) elif ch2 == 't': S_hat = tf.concat([ tf.expand_dims(tf.matmul(p_hat, tf.transpose(q_hat)), axis=-1) for (p_hat, q_hat) in zip(p_hat_list, q_hat_list) ], axis=-1) else: S_hat = tf.concat([ tf.expand_dims(tf.matmul(p_hat, q_hat), axis=-1) for (p_hat, q_hat) in zip(p_hat_list, q_hat_list) ], axis=-1) return tf.real(tf.ifft(S_hat))
def est_kernel(blurs, deblurs, nstd=2, ksz=27): assert ksz % 2 == 1 hksz = (ksz - 1) // 2 if nstd == 0: nstd = 1e-6 blurs = tf.cast(tf.transpose(blurs, [0, 3, 1, 2]), tf.complex64) deblurs = tf.cast(tf.transpose(deblurs, [0, 3, 1, 2]), tf.complex64) fft_blurs = tf.fft2d(blurs) fft_deblurs = tf.fft2d(deblurs) numerator = fft_deblurs * tf.conj(fft_blurs) denominator = tf.abs(fft_blurs)**2 + (nstd / 255.)**2. out = tf.real(tf.ifft(numerator / denominator)) out = tf.transpose(out, [0, 2, 3, 1]) out1 = tf.concat([out[:, -hksz:, -hksz:], out[:, :hksz + 1, -hksz:]], axis=1) out2 = tf.concat([out[:, -hksz:, :hksz + 1], out[:, :hksz + 1, :hksz + 1]], axis=1) kernels = tf.concat([out1, out2], axis=2) kernels = kernels / tf.reduce_mean(kernels, axis=[1, 2]) return kernels
def fft_to_audio(fft): return_values = [] with tf.variable_scope("fft_to_audio", reuse=fft_to_audio_init): input = tf.placeholder(tf.complex64, [fft.shape[1]]) zeroes = tf.fill([int(fft_size / 2) - int(input.shape[0])], input[-1] * 0) fill = tf.concat([input, zeroes], axis=0) inverse = tf.reverse(tf.conj(fill), [0]) full = tf.concat([fill, inverse[-2:-1], inverse[:-1]], axis=-1) output = tf.cast(tf.ifft(full), tf.float32) sess = tf.get_default_session() steps = fft.shape[0] for i in range(steps): feed_audio = fft[i, :] values = sess.run({"output": output}, feed_dict={input: feed_audio}) output_values = values["output"] return_values.append(output_values) global fft_to_audio_init fft_to_audio_init = True return np.concatenate(return_values, axis=0)
def call(self, x): p1 = self.count_sketch(x[:, :img_dim], self.output_dim, self.h1, self.s1) p2 = self.count_sketch(x[:, img_dim:], self.output_dim, self.h2, self.s2) pc1 = tf.complex(p1, tf.zeros_like(p1)) pc2 = tf.complex(p2, tf.zeros_like(p2)) conved = tf.ifft(tf.fft(pc1) * tf.fft(pc2)) return tf.real(conved)
def ifftc(im, name="ifftc", do_orthonorm=True): """Centered iFFT on second to last dimension.""" with tf.name_scope(name): im_out = im if do_orthonorm: fftscale = tf.sqrt(1.0 * im_out.get_shape().as_list()[-2]) else: fftscale = 1.0 fftscale = tf.cast(fftscale, dtype=tf.complex64) if len(im.get_shape()) == 4: im_out = tf.transpose(im_out, [0, 3, 1, 2]) im_out = fftshift(im_out, axis=3) else: im_out = tf.transpose(im_out, [2, 0, 1]) im_out = fftshift(im_out, axis=2) with tf.device('/gpu:0'): # FFT is only supported on the GPU im_out = tf.ifft(im_out) * fftscale if len(im.get_shape()) == 4: im_out = fftshift(im_out, axis=3) im_out = tf.transpose(im_out, [0, 2, 3, 1]) else: im_out = fftshift(im_out, axis=2) im_out = tf.transpose(im_out, [1, 2, 0]) return im_out
def hilbert(xr): ''' Implements the hilbert transform, a mapping from C to R. Args: xr: The input sequence. Returns: xc: A complex sequence of the same length. ''' with tf.variable_scope('hilbert_transform'): n = tf.Tensor.get_shape(xr).as_list()[0] # Run the fft on the columns no the rows. x = tf.transpose(tf.fft(tf.transpose(xr))) h = np.zeros([n]) if n > 0 and 2*np.fix(n/2) == n: # even and nonempty h[0:int(n/2+1)] = 1 h[1:int(n/2)] = 2 elif n > 0: # odd and nonempty h[0] = 1 h[1:int((n+1)/2)] = 2 tf_h = tf.constant(h, name='h', dtype=tf.float32) if len(x.shape) == 2: hs = np.stack([h]*x.shape[-1], -1) reps = tf.Tensor.get_shape(x).as_list()[-1] hs = tf.stack([tf_h]*reps, -1) elif len(x.shape) == 1: hs = tf_h else: raise NotImplementedError tf_hc = tf.complex(hs, tf.zeros_like(hs)) xc = x*tf_hc return tf.transpose(tf.ifft(tf.transpose(xc)))
def fftc(im, data_format='channels_last', orthonorm=True, transpose=False, name='fftc'): """Centered FFT on last non-channel dimension.""" with tf.name_scope(name): im_out = im if data_format == 'channels_last': permute_orig = np.arange(len(im.shape)) permute = permute_orig.copy() permute[-2] = permute_orig[-1] permute[-1] = permute_orig[-2] im_out = tf.transpose(im_out, permute) if orthonorm: fftscale = tf.sqrt(tf.cast(im_out.shape[-1], tf.float32)) else: fftscale = 1.0 fftscale = tf.cast(fftscale, dtype=tf.complex64) im_out = fftshift(im_out, axis=-1) if transpose: im_out = tf.ifft(im_out) * fftscale else: im_out = tf.fft(im_out) / fftscale im_out = fftshift(im_out, axis=-1) if data_format == 'channels_last': im_out = tf.transpose(im_out, permute) return im_out
def forward_proj_domain(self, sinogram): """ the projection domain of the model for processing sinograms Parameters ---------- sinogram : ndarray The projection data used for processing and reconstruction Returns ------- ndarray the sinograms after processing """ self.sinogram_cosine = tf.multiply(sinogram, self.cosine_weight) self.weighted_sinogram_fft = tf.fft( tf.cast(self.sinogram_cosine, dtype=tf.complex64)) self.filtered_sinogram_fft = tf.multiply( self.weighted_sinogram_fft, tf.cast(self.recon_filter, dtype=tf.complex64)) self.filtered_sinogram = tf.real(tf.ifft(self.filtered_sinogram_fft)) return self.filtered_sinogram
def _matmul(self, X_fft, w): w = tf.cast(w, tf.complex64) fft_w = tf.fft(w[..., ::-1]) fft_mul = tf.multiply(X_fft, fft_w) ifft_val = tf.ifft(fft_mul) mat = tf.cast(tf.real(ifft_val), tf.float32) mat = tf.manip.roll(mat, shift=1, axis=1) return mat
def circular_cross_correlation(x, y): """Periodic correlation, implemented using the FFT. x and y must be of the same length. """ return tf.real( tf.ifft( tf.multiply(tf.conj(tf.fft(tf.cast(x, tf.complex64))), tf.fft(tf.cast(y, tf.complex64)))))
def one_layer(X_real, X_image, W1_real, W1_image, W2_real, W2_image, alph, ii): ## High memory consuming all_zeros = tf.zeros([k + 1], tf.float32) ones = tf.ones([k + 1], tf.bool) zeros = tf.zeros([k + 1], tf.bool) sliced_zeros = tf.slice(zeros, [ii], [k + 1 - ii]) sliced_ones = tf.slice(ones, [0], [ii]) mask = tf.concat((sliced_ones, sliced_zeros), axis=0) W1_real = tf.where(mask, all_zeros, W1_real) W1_image = tf.where(mask, all_zeros, W1_image) W1_real = Weight_Transform(W1_real, k=k, n=N) W1_image = Weight_Transform(W1_image, k=k, n=N) W2_real = tf.where(mask, all_zeros, W2_real) W2_image = tf.where(mask, all_zeros, W2_image) W2_real = Weight_Transform(W2_real, k=k, n=N) W2_image = Weight_Transform(W2_image, k=k, n=N) X_complex_fft = tf.fft(tf.complex(X_real, X_image)) W1_complex_fft = tf.fft(tf.complex(W1_real, W1_image)) X_complex = tf.ifft(X_complex_fft * W1_complex_fft) X_real = tf.math.real(X_complex) X_image = tf.math.imag(X_complex) # None-linear S_power = tf.math.add(tf.math.square(X_real), tf.math.square(X_image)) S_power = tf.math.scalar_mul(alph, S_power) sin = tf.math.sin(S_power) cos = tf.math.cos(S_power) X_real = tf.math.subtract(tf.math.multiply(X_real, cos), tf.math.multiply(X_image, sin)) X_image = tf.math.add(tf.math.multiply(X_image, cos), tf.math.multiply(X_real, sin)) # W2 X_complex_fft = tf.fft(tf.complex(X_real, X_image)) W2_complex_fft = tf.fft(tf.complex(W2_real, W2_image)) X_complex = tf.ifft(X_complex_fft * W2_complex_fft) out_real = tf.math.real(X_complex) out_image = tf.math.imag(X_complex) return out_real, out_image
def cconv(x, y): x_fft_ = tf.fft(tf.complex(x,0.0)) #e2_fft_ = tf.fft(tf.complex(tf.nn.l2_normalize(self.e2, axis=2),0.0)) y_fft_ = tf.fft(tf.complex(y,0.0)) x_fft = x_fft_ #+ tf.complex(tf.to_float(tf.equal(x_fft_, 0.)),0.)*no_zeros y_fft = y_fft_ #+ tf.complex(tf.to_float(tf.equal(y_fft_, 0.)),0.)*no_zeros return tf.cast(tf.real(tf.ifft(tf.multiply(tf.conj(x_fft),\ y_fft))),dtype=tf.float32)
def circular_corr(a, b, name=''): name = get_name(name, 'circular_corr') with tf.name_scope(name): a_fft = tf.conj(tf.fft(tf.complex(a, 0.0))) b_fft = tf.fft(tf.complex(b, 0.0)) ifft = tf.ifft(a_fft * b_fft) res = tf.cast(tf.real(ifft), 'float32') return res
def triple_linear_pool(x1, x2, x3, output_size): p1 = count_sketch(x1, output_size) p2 = count_sketch(x2, output_size) p3 = count_sketch(x3, output_size) pc1 = tf.complex(p1, tf.zeros_like(p1)) pc2 = tf.complex(p2, tf.zeros_like(p2)) pc3 = tf.complex(p3, tf.zeros_like(p3)) conved = tf.ifft(tf.fft(pc1) * tf.fft(pc2) * tf.fft(pc3)) return tf.real(conved)
def extract_cochlear_subbands(nets, SIGNAL_SIZE, SR, LOW_LIM, HIGH_LIM, N, SAMPLE_FACTOR, pad_factor, rFFT, custom_filts, erb_filter_kwargs, include_all_keys, compression_function): """ Computes the cochlear subbands from the fft of the input signal Parameters ---------- nets : dictionary dictionary containing parts of the cochleagram graph. 'fft_input' is multiplied by the cochlear filters SIGNAL_SIZE : int the length of the audio signal used for the cochleagram graph SR : int raw sampling rate in Hz for the audio. LOW_LIM : int Lower frequency limits for the filters. HIGH_LIM : int Higher frequency limits for the filters. N : int Number of filters to uniquely span the frequency space SAMPLE_FACTOR : int number of times to overcomplete the filters. N : int Number of filters to uniquely span the frequency space SAMPLE_FACTOR : int number of times to overcomplete the filters. pad_factor : int how much padding to add to the signal. Follows conventions of pycochleagram (ie pad of 2 doubles the signal length) rFFT : Boolean If true, cochleagram graph is constructed using rFFT wherever possible custom_filts : None, or numpy array if not None, a numpy array containing the filters to use for the cochleagram generation. If none, uses erb.make_erb_cos_filters from pycochleagram to construct the filterbank. If using rFFT, should contain th full filters, shape [SIGNAL_SIZE, NUMBER_OF_FILTERS] erb_filter_kwargs : dictionary contains additional arguments with filter parameters to use with erb.make_erb_cos_filters include_all_keys : Boolean If True, includes the time subbands and the cochleagram in the dictionary keys compression_function : function A partial function that takes in nets and the input and output names to apply compression Returns ------- nets : dictionary updated dictionary containing parts of the cochleagram graph. """ # make the erb filters tensor nets['filts_tensor'] = make_filts_tensor(SIGNAL_SIZE, SR, LOW_LIM, HIGH_LIM, N, SAMPLE_FACTOR, use_rFFT=rFFT, pad_factor=pad_factor, custom_filts=custom_filts, erb_filter_kwargs=erb_filter_kwargs) # make subbands by multiplying filts with fft of input nets['subbands'] = tf.multiply(nets['filts_tensor'],nets['fft_input'],name='mul_subbands') # make the time the keys in the graph if we are returning all keys (otherwise, only return the subbands in fourier domain) if include_all_keys: if not rFFT: nets['subbands_ifft'] = tf.real(tf.ifft(nets['subbands'],name='ifft_subbands'),name='ifft_subbands_r') else: nets['subbands_ifft'] = tf.spectral.irfft(nets['subbands'],name='ifft_subbands') nets['subbands_time'] = nets['subbands_ifft'] return nets
def _fft_solver(self, grad, varname): grad_shape = grad.shape.as_list() N, coef = self._var_info_dict[varname] grad = tf.reshape(grad, shape=[N]) grad = tf.cast(grad, tf.complex64) # import ipdb; ipdb.set_trace() grad = tf.real(ifft(fft(grad) * coef)) grad = tf.reshape(grad, shape=grad_shape) return grad
def update_x_step(z): result_x_step = kt_y * u + z result_x_step = tf.cast(result_x_step, tf.complex128) result_x_step = tf.fft(result_x_step) result_x_step = result_x_step / fre_k result_x_step = tf.ifft(result_x_step) result_x_step = tf.abs(result_x_step) return result_x_step
def backward_fft(x): print(x.shape) real, imag = tf.split(x, num_or_size_splits=2, axis=1) x_complex = tf.complex(real, imag) x_complex = tf.ifft(x_complex) x_real = tf.cast(tf.abs(x_complex), dtype=tf.float32) return x_real
def combine_and_inverse_fft(args): abs_branch, angle_branch = args real = abs_branch * tf.cos(angle_branch) imag = abs_branch * tf.sin(angle_branch) x_complex = tf.complex(real, imag) x_reverse = tf.real(tf.ifft(x_complex)) return x_reverse
def _ifft(bottom, sequential, compute_size): ''' Return ------ iFFT(bottom), has the same shape of bottom ''' if sequential: return sequential_batch_ifft(bottom, compute_size) else: return tf.ifft(bottom)
def tf_ifft(tensor, shift, axis=0): shifted = tf.manip.roll(tensor, shift=shift, axis=axis) # fft time_domain_not_shifted = tf.ifft(shifted) # shift again time_domain = tf.manip.roll(time_domain_not_shifted, shift=shift, axis=axis) return time_domain
def op(self): xf = tensorflow.fft(self.x) x2 = xf * tensorflow.conj(xf) xt = tensorflow.ifft(x2) xr = 10*tensorflow.log( tensorflow.abs( xt[:,0:self.aclen] ) ) if self.avg: N = tensorflow.shape(xr)[0] idx = tensorflow.cast(tensorflow.range(0,N), tensorflow.float32) s = tensorflow.reshape( self.alpha * tensorflow.pow( (1-self.alpha), idx ), [N,1] ) self.u = tensorflow.pow( (1-self.alpha), tensorflow.cast(N,tensorflow.float32) )*self.u + tensorflow.reduce_sum(s*xr, 0) return self.u else: return xr
def auto_correlation( x, axis=-1, max_lags=None, center=True, normalize=True, name="auto_correlation"): """Auto correlation along one axis. Given a `1-D` wide sense stationary (WSS) sequence `X`, the auto correlation `RXX` may be defined as (with `E` expectation and `Conj` complex conjugate) ``` RXX[m] := E{ W[m] Conj(W[0]) } = E{ W[0] Conj(W[-m]) }, W[n] := (X[n] - MU) / S, MU := E{ X[0] }, S**2 := E{ (X[0] - MU) Conj(X[0] - MU) }. ``` This function takes the viewpoint that `x` is (along one axis) a finite sub-sequence of a realization of (WSS) `X`, and then uses `x` to produce an estimate of `RXX[m]` as follows: After extending `x` from length `L` to `inf` by zero padding, the auto correlation estimate `rxx[m]` is computed for `m = 0, 1, ..., max_lags` as ``` rxx[m] := (L - m)**-1 sum_n w[n + m] Conj(w[n]), w[n] := (x[n] - mu) / s, mu := L**-1 sum_n x[n], s**2 := L**-1 sum_n (x[n] - mu) Conj(x[n] - mu) ``` The error in this estimate is proportional to `1 / sqrt(len(x) - m)`, so users often set `max_lags` small enough so that the entire output is meaningful. Note that since `mu` is an imperfect estimate of `E{ X[0] }`, and we divide by `len(x) - m` rather than `len(x) - m - 1`, our estimate of auto correlation contains a slight bias, which goes to zero as `len(x) - m --> infinity`. Args: x: `float32` or `complex64` `Tensor`. axis: Python `int`. The axis number along which to compute correlation. Other dimensions index different batch members. max_lags: Positive `int` tensor. The maximum value of `m` to consider (in equation above). If `max_lags >= x.shape[axis]`, we effectively re-set `max_lags` to `x.shape[axis] - 1`. center: Python `bool`. If `False`, do not subtract the mean estimate `mu` from `x[n]` when forming `w[n]`. normalize: Python `bool`. If `False`, do not divide by the variance estimate `s**2` when forming `w[n]`. name: `String` name to prepend to created ops. Returns: `rxx`: `Tensor` of same `dtype` as `x`. `rxx.shape[i] = x.shape[i]` for `i != axis`, and `rxx.shape[axis] = max_lags + 1`. Raises: TypeError: If `x` is not a supported type. """ # Implementation details: # Extend length N / 2 1-D array x to length N by zero padding onto the end. # Then, set # F[x]_k := sum_n x_n exp{-i 2 pi k n / N }. # It is not hard to see that # F[x]_k Conj(F[x]_k) = F[R]_k, where # R_m := sum_n x_n Conj(x_{(n - m) mod N}). # One can also check that R_m / (N / 2 - m) is an unbiased estimate of RXX[m]. # Since F[x] is the DFT of x, this leads us to a zero-padding and FFT/IFFT # based version of estimating RXX. # Note that this is a special case of the Wiener-Khinchin Theorem. with tf.name_scope(name, values=[x]): x = tf.convert_to_tensor(x, name="x") # Rotate dimensions of x in order to put axis at the rightmost dim. # FFT op requires this. rank = util.prefer_static_rank(x) if axis < 0: axis = rank + axis shift = rank - 1 - axis # Suppose x.shape[axis] = T, so there are T "time" steps. # ==> x_rotated.shape = B + [T], # where B is x_rotated's batch shape. x_rotated = util.rotate_transpose(x, shift) if center: x_rotated -= tf.reduce_mean(x_rotated, axis=-1, keepdims=True) # x_len = N / 2 from above explanation. The length of x along axis. # Get a value for x_len that works in all cases. x_len = util.prefer_static_shape(x_rotated)[-1] # TODO(langmore) Investigate whether this zero padding helps or hurts. At # the moment is is necessary so that all FFT implementations work. # Zero pad to the next power of 2 greater than 2 * x_len, which equals # 2**(ceil(Log_2(2 * x_len))). Note: Log_2(X) = Log_e(X) / Log_e(2). x_len_float64 = tf.cast(x_len, np.float64) target_length = tf.pow( np.float64(2.), tf.ceil(tf.log(x_len_float64 * 2) / np.log(2.))) pad_length = tf.cast(target_length - x_len_float64, np.int32) # We should have: # x_rotated_pad.shape = x_rotated.shape[:-1] + [T + pad_length] # = B + [T + pad_length] x_rotated_pad = util.pad(x_rotated, axis=-1, back=True, count=pad_length) dtype = x.dtype if not dtype.is_complex: if not dtype.is_floating: raise TypeError("Argument x must have either float or complex dtype" " found: {}".format(dtype)) x_rotated_pad = tf.complex(x_rotated_pad, dtype.real_dtype.as_numpy_dtype(0.)) # Autocorrelation is IFFT of power-spectral density (up to some scaling). fft_x_rotated_pad = tf.fft(x_rotated_pad) spectral_density = fft_x_rotated_pad * tf.conj(fft_x_rotated_pad) # shifted_product is R[m] from above detailed explanation. # It is the inner product sum_n X[n] * Conj(X[n - m]). shifted_product = tf.ifft(spectral_density) # Cast back to real-valued if x was real to begin with. shifted_product = tf.cast(shifted_product, dtype) # Figure out if we can deduce the final static shape, and set max_lags. # Use x_rotated as a reference, because it has the time dimension in the far # right, and was created before we performed all sorts of crazy shape # manipulations. know_static_shape = True if not x_rotated.shape.is_fully_defined(): know_static_shape = False if max_lags is None: max_lags = x_len - 1 else: max_lags = tf.convert_to_tensor(max_lags, name="max_lags") max_lags_ = tensor_util.constant_value(max_lags) if max_lags_ is None or not know_static_shape: know_static_shape = False max_lags = tf.minimum(x_len - 1, max_lags) else: max_lags = min(x_len - 1, max_lags_) # Chop off the padding. # We allow users to provide a huge max_lags, but cut it off here. # shifted_product_chopped.shape = x_rotated.shape[:-1] + [max_lags] shifted_product_chopped = shifted_product[..., :max_lags + 1] # If possible, set shape. if know_static_shape: chopped_shape = x_rotated.shape.as_list() chopped_shape[-1] = min(x_len, max_lags + 1) shifted_product_chopped.set_shape(chopped_shape) # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]). The # other terms were zeros arising only due to zero padding. # `denominator = (N / 2 - m)` (defined below) is the proper term to # divide by by to make this an unbiased estimate of the expectation # E[X[n] Conj(X[n - m])]. x_len = tf.cast(x_len, dtype.real_dtype) max_lags = tf.cast(max_lags, dtype.real_dtype) denominator = x_len - tf.range(0., max_lags + 1.) denominator = tf.cast(denominator, dtype) shifted_product_rotated = shifted_product_chopped / denominator if normalize: shifted_product_rotated /= shifted_product_rotated[..., :1] # Transpose dimensions back to those of x. return util.rotate_transpose(shifted_product_rotated, -shift)
def _cconv(self, a, b): return tf.ifft(tf.fft(a) * tf.fft(b)).real
def _ccorr(self, a, b): a = tf.cast(a, tf.complex64) b = tf.cast(b, tf.complex64) return tf.real(tf.ifft(tf.conj(tf.fft(a)) * tf.fft(b)))