def bilinear_pool(x1, x2, x3, output_size):
    """ Computes approximation of bilinear pooling with respect to x1, x2.
    For detailed explaination, see the paper (https://arxiv.org/abs/1511.06062)

    Args:
      x1: A `Tensor` with shape (batch_size, x1_size).
      x2: A `Tensor` with shape ((batch_size, x2_size).
      output_size: Output projection size. (`int`)

    Returns:
       A Tensor with shape (batch_size, output_size).
    """

    p1 = count_sketch(x1, output_size)
    p2 = count_sketch(x2, output_size)
    p3 = count_sketch(x3, output_size)

    pc1 = tf.complex(p1, tf.zeros_like(p1))
    pc2 = tf.complex(p2, tf.zeros_like(p2))
    pc3 = tf.complex(p3, tf.zeros_like(p3))


    conved = tf.ifft(tf.fft(pc1) * tf.fft(pc2) * tf.fft(pc3))
    conved2 = tf.ifft(tf.fft(pc1) * tf.fft(pc2) * tf.fft(pc3) * 1)

    return tf.real(conved), tf.real(conved2)
Beispiel #2
0
def deconv(comb_loss, loss_a, loss_b):
    x = tf.cast(comb_loss[:, 0], dtype=tf.complex64)
    y = tf.cast(comb_loss[:, 1], dtype=tf.complex64)
    yfft = tf.fft(y)

    ay = tf.cast(loss_a[:, 1], dtype=tf.complex64)
    by = tf.cast(loss_b[:, 1], dtype=tf.complex64)
    ayfft = tf.fft(ay)
    byfft = tf.fft(by)

    ayfftest = yfft / byfft
    ayest = tf.abs(tf.ifft(ayfftest))
    aest = tf.cast(tf.abs(tf.ifft(ayfft)), dtype=tf.float64)
    best = tf.cast(tf.abs(tf.ifft(byfft)), dtype=tf.float64)

    return best
Beispiel #3
0
def istft(spec, overlap=4):
    assert (spec.shape[0] > 1)
    S = placeholder(dtype=tf.complex64, shape=spec.shape)
    X = tf.complex_abs(tf.concat(0, [tf.ifft(frame) \
            for frame in tf.unstack(S)]))
    sess = tf.Session()
    return sess.run(X, feed_dict={S: spec})
    def model(self, sinogram):
        """
        main model for the FDK algorithm

        Parameters
        ----------
        sinogram : ndarray
            The projection data used for reconstruction

        Returns
        -------
        ndarray
            the reconstructed CT data
        """

        sinogram_cos = tf.multiply(sinogram, self.cosine_weight)

        weighted_sino_fft = tf.fft(tf.cast(sinogram_cos, dtype=tf.complex64))
        filtered_sinogram_fft = tf.multiply(
            weighted_sino_fft, tf.cast(self.filter, dtype=tf.complex64))
        filtered_sinogram = tf.real(tf.ifft(filtered_sinogram_fft))

        reconstruction = cone_backprojection3d(filtered_sinogram,
                                               self.geometry,
                                               hardware_interp=True)

        return reconstruction
Beispiel #5
0
 def _get_w(self, last_w, raw_output, memory):
     with tf.variable_scope('get_w'):
         k, beta, g, s, gamma = tf.split(
             raw_output, [self.M, self.N, 1, self.N, self.N], axis=1)
         # memory_row_norm = tf.reduce_sum(tf.abs(memory), axis=1)
         memory = tf.nn.l2_normalize(memory, axis=1)
         k = tf.nn.l2_normalize(k, axis=1)
         product = tf.squeeze(tf.matmul(tf.expand_dims(k, 1), memory))
         # omit division of `k` since it is canceled after softmax anyway
         beta = tf.nn.softplus(beta, name='beta')
         # w_c = tf.nn.softmax(beta * product / memory_row_norm, name='w_c')
         w_c = tf.nn.softmax(beta * product, name='w_c')
         g = tf.sigmoid(g, name='gate_parameter')
         w_g = tf.add(g * w_c, (1 - g) * last_w, name='w_g')
         w_g = tf.cast(w_g, tf.complex64)
         # s = tf.cast(tf.nn.softmax(s), tf.complex64)
         s = tf.cast(tf.tanh(s, name='shift'), tf.complex64)
         w_tild = tf.real(tf.ifft(tf.fft(w_g) * tf.fft(s)), name='w_tild')
         gamma = tf.add(tf.nn.softplus(gamma), 1, name='gamma')
         # w = tf.pow(w_tild, gamma)
         # w = w / tf.reduce_sum(w)
         # FIXME: tf.log yields lots of NaN here!!
         # w = tf.nn.softmax(gamma * tf.log(tf.nn.softmax(w_tild)))
         w = tf.nn.softmax(gamma * tf.log(tf.nn.softplus(w_tild)),
                           name='new_weight')
         # w = tf.nn.softmax(gamma * tf.log(tf.nn.sigmoid(w_tild)))
         # w = tf.nn.softmax(gamma * w_tild)
         return w
def bandPass(signal, low_cut, high_cut, sample_length, sample_rate):
    ''' 
  band pass filter
  args:
    signal = input signal
    low_cut = filtering bandwidth Hz (lower bound)
    high_cut = filtering bandwidth Hz (upper bound)
    sample_length = total signal length (number of samples)
    sample_rate = sampling rate of the input signal
  return:
    filtered.real = filtered signal
  '''
    ratio = int(sample_length / sample_rate)
    with tf.Graph().as_default():
        signal = tf.Variable(signal, dtype=tf.complex64)
        fft = tf.fft(signal)
        with tf.Session() as sess:
            tf.variables_initializer([signal]).run()
            result = sess.run(fft)
            for i in range(high_cut * ratio,
                           len(result) - (high_cut * ratio) + 1):
                result[i] = 0
            for i in range(0, (low_cut * ratio) + 1):
                result[i] = 0
            for i in range(len(result) - (low_cut * ratio), len(result)):
                result[i] = 0
        ifft = tf.ifft(result)
        with tf.Session() as sess:
            filtered = sess.run(ifft)
    return filtered.real
def lowPass(signal, cut, sample_length, sample_rate):
    ''' 
  low pass filter
  args:
    signal = input signal
    cut = filtering bandwidth Hz
    sample_length = total signal length (number of samples)
    sample_rate = sampling rate of the input signal
  return:
    filtered.real = filtered signal
  '''
    ratio = int(sample_length / sample_rate)
    with tf.Graph().as_default():
        signal = tf.Variable(signal, dtype=tf.complex64)
        fft = tf.fft(signal)
        with tf.Session() as sess:
            tf.variables_initializer([signal]).run()
            result = sess.run(fft)
            for i in range(cut * ratio,
                           len(result) - (cut * ratio) +
                           1):  #what is the intuition behind this??
                result[i] = 0
        ifft = tf.ifft(result)
        with tf.Session() as sess:
            filtered = sess.run(ifft)
    return filtered.real
Beispiel #8
0
def istft(inp, n_overlap):
    inp_sz = inp.get_shape().as_list()
    if len(inp_sz) > 3:
        inp = tf.reshape(inp, (np.prod(inp_sz[:-2]), inp_sz[-2], inp_sz[-1]))

    batch_size, n_frames, n_freqs = inp.get_shape().as_list()
    n_frames = int(int(float(n_frames) / n_overlap) * n_overlap)
    inp = inp[:, :n_frames, :]
    batch_size, n_frames, n_freqs = inp.get_shape().as_list()

    x = tf.real(tf.ifft(inp))
    x = tf.reshape(x, (batch_size, -1, n_overlap, n_freqs))
    x = tf.transpose(x, (0, 2, 1, 3))
    x = tf.reshape(x, (batch_size, n_overlap, -1))

    x_list = tf.unstack(x, axis=1)
    skip = n_freqs / n_overlap
    for i in range(n_overlap):
        # x_sep[i] = tf.manip.roll(x_sep[i], i*wind_size/4, 2)
        if i == 0:
            x_list[i] = x_list[i][:, (n_overlap - i - 1) * skip:]
        else:
            x_list[i] = x_list[i][:, (n_overlap - i - 1) * skip:-i * skip]

    x = tf.add_n(x_list) / float(n_overlap)

    if len(inp_sz) > 3:
        x_sz = x.get_shape().as_list()
        x = tf.reshape(x, inp_sz[:-2] + x_sz[-1:])

    return x
    def forward_proj_domain(self, sinogram):
        """
        the projection domain of the model for processing sinograms

        Parameters
        ----------
        sinogram : ndarray
            The projection data used for processing and reconstruction

        Returns
        -------
        ndarray
            the sinograms after processing
        """

        # U-Net added in the projection domain
        ####################################################################
        sinogram = tf.expand_dims(sinogram, 3)
        sinogram = self.unet_model(sinogram)
        sinogram = tf.squeeze(sinogram, axis=3)
        ####################################################################

        self.sinogram_cosine = tf.multiply(sinogram, self.cosine_weight)

        self.weighted_sinogram_fft = tf.fft(
            tf.cast(self.sinogram_cosine, dtype=tf.complex64))
        self.filtered_sinogram_fft = tf.multiply(
            self.weighted_sinogram_fft,
            tf.cast(self.recon_filter, dtype=tf.complex64))
        self.filtered_sinogram = tf.real(tf.ifft(self.filtered_sinogram_fft))

        return self.filtered_sinogram
Beispiel #10
0
    def tensor_product(self, P, ch1, Q, ch2):
        P_hat = tf.fft(tf.complex(P, tf.zeros(tf.shape(P), dtype=tf.float32)))
        Q_hat = tf.fft(tf.complex(Q, tf.zeros(tf.shape(Q), dtype=tf.float32)))
        p_hat_list = [tf.squeeze(p) for p in tf.split(P_hat, self.k, axis=-1)]
        q_hat_list = [tf.squeeze(q) for q in tf.split(Q_hat, self.k, axis=-1)]

        if ch1 == 't' and ch2 == 't':
            S_hat = tf.concat([
                tf.expand_dims(tf.matmul(tf.transpose(p_hat),
                                         tf.transpose(q_hat)),
                               axis=-1)
                for (p_hat, q_hat) in zip(p_hat_list, q_hat_list)
            ],
                              axis=-1)
        elif ch1 == 't':
            S_hat = tf.concat([
                tf.expand_dims(tf.matmul(tf.transpose(p_hat), q_hat), axis=-1)
                for (p_hat, q_hat) in zip(p_hat_list, q_hat_list)
            ],
                              axis=-1)
        elif ch2 == 't':
            S_hat = tf.concat([
                tf.expand_dims(tf.matmul(p_hat, tf.transpose(q_hat)), axis=-1)
                for (p_hat, q_hat) in zip(p_hat_list, q_hat_list)
            ],
                              axis=-1)
        else:
            S_hat = tf.concat([
                tf.expand_dims(tf.matmul(p_hat, q_hat), axis=-1)
                for (p_hat, q_hat) in zip(p_hat_list, q_hat_list)
            ],
                              axis=-1)

        return tf.real(tf.ifft(S_hat))
Beispiel #11
0
def est_kernel(blurs, deblurs, nstd=2, ksz=27):
    assert ksz % 2 == 1
    hksz = (ksz - 1) // 2
    if nstd == 0:
        nstd = 1e-6

    blurs = tf.cast(tf.transpose(blurs, [0, 3, 1, 2]), tf.complex64)
    deblurs = tf.cast(tf.transpose(deblurs, [0, 3, 1, 2]), tf.complex64)

    fft_blurs = tf.fft2d(blurs)
    fft_deblurs = tf.fft2d(deblurs)

    numerator = fft_deblurs * tf.conj(fft_blurs)
    denominator = tf.abs(fft_blurs)**2 + (nstd / 255.)**2.
    out = tf.real(tf.ifft(numerator / denominator))
    out = tf.transpose(out, [0, 2, 3, 1])

    out1 = tf.concat([out[:, -hksz:, -hksz:], out[:, :hksz + 1, -hksz:]],
                     axis=1)
    out2 = tf.concat([out[:, -hksz:, :hksz + 1], out[:, :hksz + 1, :hksz + 1]],
                     axis=1)
    kernels = tf.concat([out1, out2], axis=2)
    kernels = kernels / tf.reduce_mean(kernels, axis=[1, 2])

    return kernels
Beispiel #12
0
def fft_to_audio(fft):

    return_values = []
    with tf.variable_scope("fft_to_audio", reuse=fft_to_audio_init):
        input = tf.placeholder(tf.complex64, [fft.shape[1]])

        zeroes = tf.fill([int(fft_size / 2) - int(input.shape[0])],
                         input[-1] * 0)
        fill = tf.concat([input, zeroes], axis=0)

        inverse = tf.reverse(tf.conj(fill), [0])
        full = tf.concat([fill, inverse[-2:-1], inverse[:-1]], axis=-1)
        output = tf.cast(tf.ifft(full), tf.float32)

        sess = tf.get_default_session()

        steps = fft.shape[0]
        for i in range(steps):
            feed_audio = fft[i, :]

            values = sess.run({"output": output},
                              feed_dict={input: feed_audio})
            output_values = values["output"]

            return_values.append(output_values)

    global fft_to_audio_init
    fft_to_audio_init = True

    return np.concatenate(return_values, axis=0)
Beispiel #13
0
 def call(self, x):
     p1 = self.count_sketch(x[:, :img_dim], self.output_dim, self.h1, self.s1)
     p2 = self.count_sketch(x[:, img_dim:], self.output_dim, self.h2, self.s2)
     pc1 = tf.complex(p1, tf.zeros_like(p1))
     pc2 = tf.complex(p2, tf.zeros_like(p2))
     conved = tf.ifft(tf.fft(pc1) * tf.fft(pc2))
     return tf.real(conved)
Beispiel #14
0
def ifftc(im, name="ifftc", do_orthonorm=True):
    """Centered iFFT on second to last dimension."""
    with tf.name_scope(name):
        im_out = im
        if do_orthonorm:
            fftscale = tf.sqrt(1.0 * im_out.get_shape().as_list()[-2])
        else:
            fftscale = 1.0
        fftscale = tf.cast(fftscale, dtype=tf.complex64)
        if len(im.get_shape()) == 4:
            im_out = tf.transpose(im_out, [0, 3, 1, 2])
            im_out = fftshift(im_out, axis=3)
        else:
            im_out = tf.transpose(im_out, [2, 0, 1])
            im_out = fftshift(im_out, axis=2)
        with tf.device('/gpu:0'):
            # FFT is only supported on the GPU
            im_out = tf.ifft(im_out) * fftscale
        if len(im.get_shape()) == 4:
            im_out = fftshift(im_out, axis=3)
            im_out = tf.transpose(im_out, [0, 2, 3, 1])
        else:
            im_out = fftshift(im_out, axis=2)
            im_out = tf.transpose(im_out, [1, 2, 0])

    return im_out
Beispiel #15
0
def hilbert(xr):
    '''
    Implements the hilbert transform, a mapping from C to R.
    Args:
        xr: The input sequence.
    Returns:
        xc: A complex sequence of the same length.
    '''
    with tf.variable_scope('hilbert_transform'):
        n = tf.Tensor.get_shape(xr).as_list()[0]
        # Run the fft on the columns no the rows.
        x = tf.transpose(tf.fft(tf.transpose(xr)))
        h = np.zeros([n])
        if n > 0 and 2*np.fix(n/2) == n:
            # even and nonempty
            h[0:int(n/2+1)] = 1
            h[1:int(n/2)] = 2
        elif n > 0:
            # odd and nonempty
            h[0] = 1
            h[1:int((n+1)/2)] = 2
        tf_h = tf.constant(h, name='h', dtype=tf.float32)
        if len(x.shape) == 2:
            hs = np.stack([h]*x.shape[-1], -1)
            reps = tf.Tensor.get_shape(x).as_list()[-1]
            hs = tf.stack([tf_h]*reps, -1)
        elif len(x.shape) == 1:
            hs = tf_h
        else:
            raise NotImplementedError
        tf_hc = tf.complex(hs, tf.zeros_like(hs))
        xc = x*tf_hc
        return tf.transpose(tf.ifft(tf.transpose(xc)))
Beispiel #16
0
def fftc(im,
         data_format='channels_last',
         orthonorm=True,
         transpose=False,
         name='fftc'):
    """Centered FFT on last non-channel dimension."""
    with tf.name_scope(name):
        im_out = im
        if data_format == 'channels_last':
            permute_orig = np.arange(len(im.shape))
            permute = permute_orig.copy()
            permute[-2] = permute_orig[-1]
            permute[-1] = permute_orig[-2]
            im_out = tf.transpose(im_out, permute)

        if orthonorm:
            fftscale = tf.sqrt(tf.cast(im_out.shape[-1], tf.float32))
        else:
            fftscale = 1.0
        fftscale = tf.cast(fftscale, dtype=tf.complex64)

        im_out = fftshift(im_out, axis=-1)
        if transpose:
            im_out = tf.ifft(im_out) * fftscale
        else:
            im_out = tf.fft(im_out) / fftscale
        im_out = fftshift(im_out, axis=-1)

        if data_format == 'channels_last':
            im_out = tf.transpose(im_out, permute)

    return im_out
    def forward_proj_domain(self, sinogram):
        """
        the projection domain of the model for processing sinograms

        Parameters
        ----------
        sinogram : ndarray
            The projection data used for processing and reconstruction

        Returns
        -------
        ndarray
            the sinograms after processing
        """

        self.sinogram_cosine = tf.multiply(sinogram, self.cosine_weight)

        self.weighted_sinogram_fft = tf.fft(
            tf.cast(self.sinogram_cosine, dtype=tf.complex64))
        self.filtered_sinogram_fft = tf.multiply(
            self.weighted_sinogram_fft,
            tf.cast(self.recon_filter, dtype=tf.complex64))
        self.filtered_sinogram = tf.real(tf.ifft(self.filtered_sinogram_fft))

        return self.filtered_sinogram
Beispiel #18
0
 def _matmul(self, X_fft, w):
     w = tf.cast(w, tf.complex64)
     fft_w = tf.fft(w[..., ::-1])
     fft_mul = tf.multiply(X_fft, fft_w)
     ifft_val = tf.ifft(fft_mul)
     mat = tf.cast(tf.real(ifft_val), tf.float32)
     mat = tf.manip.roll(mat, shift=1, axis=1)
     return mat
Beispiel #19
0
def circular_cross_correlation(x, y):
    """Periodic correlation, implemented using the FFT.
    x and y must be of the same length.
    """
    return tf.real(
        tf.ifft(
            tf.multiply(tf.conj(tf.fft(tf.cast(x, tf.complex64))),
                        tf.fft(tf.cast(y, tf.complex64)))))
Beispiel #20
0
def one_layer(X_real, X_image, W1_real, W1_image, W2_real, W2_image, alph, ii):
    ## High memory consuming
    all_zeros = tf.zeros([k + 1], tf.float32)
    ones = tf.ones([k + 1], tf.bool)
    zeros = tf.zeros([k + 1], tf.bool)
    sliced_zeros = tf.slice(zeros, [ii], [k + 1 - ii])
    sliced_ones = tf.slice(ones, [0], [ii])

    mask = tf.concat((sliced_ones, sliced_zeros), axis=0)
    W1_real = tf.where(mask, all_zeros, W1_real)
    W1_image = tf.where(mask, all_zeros, W1_image)

    W1_real = Weight_Transform(W1_real, k=k, n=N)
    W1_image = Weight_Transform(W1_image, k=k, n=N)

    W2_real = tf.where(mask, all_zeros, W2_real)
    W2_image = tf.where(mask, all_zeros, W2_image)

    W2_real = Weight_Transform(W2_real, k=k, n=N)
    W2_image = Weight_Transform(W2_image, k=k, n=N)

    X_complex_fft = tf.fft(tf.complex(X_real, X_image))
    W1_complex_fft = tf.fft(tf.complex(W1_real, W1_image))
    X_complex = tf.ifft(X_complex_fft * W1_complex_fft)
    X_real = tf.math.real(X_complex)
    X_image = tf.math.imag(X_complex)

    # None-linear
    S_power = tf.math.add(tf.math.square(X_real), tf.math.square(X_image))
    S_power = tf.math.scalar_mul(alph, S_power)
    sin = tf.math.sin(S_power)
    cos = tf.math.cos(S_power)

    X_real = tf.math.subtract(tf.math.multiply(X_real, cos),
                              tf.math.multiply(X_image, sin))
    X_image = tf.math.add(tf.math.multiply(X_image, cos),
                          tf.math.multiply(X_real, sin))

    # W2
    X_complex_fft = tf.fft(tf.complex(X_real, X_image))
    W2_complex_fft = tf.fft(tf.complex(W2_real, W2_image))
    X_complex = tf.ifft(X_complex_fft * W2_complex_fft)
    out_real = tf.math.real(X_complex)
    out_image = tf.math.imag(X_complex)

    return out_real, out_image
Beispiel #21
0
def cconv(x, y):
	x_fft_ = tf.fft(tf.complex(x,0.0))
	#e2_fft_ = tf.fft(tf.complex(tf.nn.l2_normalize(self.e2, axis=2),0.0))
	y_fft_ = tf.fft(tf.complex(y,0.0))
	x_fft = x_fft_ #+ tf.complex(tf.to_float(tf.equal(x_fft_, 0.)),0.)*no_zeros
	y_fft = y_fft_ #+ tf.complex(tf.to_float(tf.equal(y_fft_, 0.)),0.)*no_zeros
	return tf.cast(tf.real(tf.ifft(tf.multiply(tf.conj(x_fft),\
                                             y_fft))),dtype=tf.float32)
Beispiel #22
0
def circular_corr(a, b, name=''):
    name = get_name(name, 'circular_corr')
    with tf.name_scope(name):
        a_fft = tf.conj(tf.fft(tf.complex(a, 0.0)))
        b_fft = tf.fft(tf.complex(b, 0.0))
        ifft = tf.ifft(a_fft * b_fft)
        res = tf.cast(tf.real(ifft), 'float32')
    return res
Beispiel #23
0
def triple_linear_pool(x1, x2, x3, output_size):
    p1 = count_sketch(x1, output_size)
    p2 = count_sketch(x2, output_size)
    p3 = count_sketch(x3, output_size)
    pc1 = tf.complex(p1, tf.zeros_like(p1))
    pc2 = tf.complex(p2, tf.zeros_like(p2))
    pc3 = tf.complex(p3, tf.zeros_like(p3))
    conved = tf.ifft(tf.fft(pc1) * tf.fft(pc2) * tf.fft(pc3))
    return tf.real(conved)
def extract_cochlear_subbands(nets, SIGNAL_SIZE, SR, LOW_LIM, HIGH_LIM, N, SAMPLE_FACTOR, pad_factor,  rFFT, custom_filts, erb_filter_kwargs, include_all_keys, compression_function):
    """
    Computes the cochlear subbands from the fft of the input signal
    Parameters
    ----------
    nets : dictionary
        dictionary containing parts of the cochleagram graph. 'fft_input' is multiplied by the cochlear filters
    SIGNAL_SIZE : int
        the length of the audio signal used for the cochleagram graph
    SR : int
        raw sampling rate in Hz for the audio.
    LOW_LIM : int
        Lower frequency limits for the filters.
    HIGH_LIM : int
        Higher frequency limits for the filters.
    N : int
        Number of filters to uniquely span the frequency space
    SAMPLE_FACTOR : int
        number of times to overcomplete the filters.
    N : int
        Number of filters to uniquely span the frequency space
    SAMPLE_FACTOR : int
        number of times to overcomplete the filters.
    pad_factor : int
        how much padding to add to the signal. Follows conventions of pycochleagram (ie pad of 2 doubles the signal length)
    rFFT : Boolean
        If true, cochleagram graph is constructed using rFFT wherever possible
    custom_filts : None, or numpy array
        if not None, a numpy array containing the filters to use for the cochleagram generation. If none, uses erb.make_erb_cos_filters from pycochleagram to construct the filterbank. If using rFFT, should contain th full filters, shape [SIGNAL_SIZE, NUMBER_OF_FILTERS]
    erb_filter_kwargs : dictionary
        contains additional arguments with filter parameters to use with erb.make_erb_cos_filters
    include_all_keys : Boolean
        If True, includes the time subbands and the cochleagram in the dictionary keys
    compression_function : function
        A partial function that takes in nets and the input and output names to apply compression 

    Returns
    -------
    nets : dictionary
        updated dictionary containing parts of the cochleagram graph.
    """

    # make the erb filters tensor
    nets['filts_tensor'] = make_filts_tensor(SIGNAL_SIZE, SR, LOW_LIM, HIGH_LIM, N, SAMPLE_FACTOR, use_rFFT=rFFT, pad_factor=pad_factor, custom_filts=custom_filts, erb_filter_kwargs=erb_filter_kwargs)

    # make subbands by multiplying filts with fft of input
    nets['subbands'] = tf.multiply(nets['filts_tensor'],nets['fft_input'],name='mul_subbands')

    # make the time the keys in the graph if we are returning all keys (otherwise, only return the subbands in fourier domain)
    if include_all_keys:
        if not rFFT:
            nets['subbands_ifft'] = tf.real(tf.ifft(nets['subbands'],name='ifft_subbands'),name='ifft_subbands_r')
        else:
            nets['subbands_ifft'] = tf.spectral.irfft(nets['subbands'],name='ifft_subbands')
        nets['subbands_time'] = nets['subbands_ifft']

    return nets
Beispiel #25
0
 def _fft_solver(self, grad, varname):
     grad_shape = grad.shape.as_list()
     N, coef = self._var_info_dict[varname]
     grad = tf.reshape(grad, shape=[N])
     grad = tf.cast(grad, tf.complex64)
     # import ipdb; ipdb.set_trace()
     grad = tf.real(ifft(fft(grad) * coef))
     grad = tf.reshape(grad, shape=grad_shape)
     return grad
Beispiel #26
0
        def update_x_step(z):
            result_x_step = kt_y * u + z
            result_x_step = tf.cast(result_x_step, tf.complex128)
            result_x_step = tf.fft(result_x_step)
            result_x_step = result_x_step / fre_k
            result_x_step = tf.ifft(result_x_step)
            result_x_step = tf.abs(result_x_step)

            return result_x_step
Beispiel #27
0
        def backward_fft(x):
            print(x.shape)

            real, imag = tf.split(x, num_or_size_splits=2, axis=1)
            x_complex = tf.complex(real, imag)

            x_complex = tf.ifft(x_complex)
            x_real = tf.cast(tf.abs(x_complex), dtype=tf.float32)
            return x_real
Beispiel #28
0
        def combine_and_inverse_fft(args):
            abs_branch, angle_branch = args

            real = abs_branch * tf.cos(angle_branch)
            imag = abs_branch * tf.sin(angle_branch)

            x_complex = tf.complex(real, imag)
            x_reverse = tf.real(tf.ifft(x_complex))
            return x_reverse
Beispiel #29
0
def _ifft(bottom, sequential, compute_size):
    '''
    Return
    ------
    iFFT(bottom), has the same shape of bottom
    '''
    if sequential:
        return sequential_batch_ifft(bottom, compute_size)
    else:
        return tf.ifft(bottom)
def tf_ifft(tensor, shift, axis=0):

    shifted = tf.manip.roll(tensor, shift=shift, axis=axis)
    # fft
    time_domain_not_shifted = tf.ifft(shifted)
    # shift again
    time_domain = tf.manip.roll(time_domain_not_shifted,
                                shift=shift,
                                axis=axis)

    return time_domain
Beispiel #31
0
    def op(self):
        xf = tensorflow.fft(self.x)
        x2 = xf * tensorflow.conj(xf)
        xt = tensorflow.ifft(x2)
        xr = 10*tensorflow.log( tensorflow.abs( xt[:,0:self.aclen] ) )
 
        if self.avg:
            N = tensorflow.shape(xr)[0]
            idx = tensorflow.cast(tensorflow.range(0,N), tensorflow.float32)
            s = tensorflow.reshape( self.alpha * tensorflow.pow( (1-self.alpha), idx ), [N,1] )
            self.u = tensorflow.pow( (1-self.alpha), tensorflow.cast(N,tensorflow.float32) )*self.u  +  tensorflow.reduce_sum(s*xr, 0)
            return self.u
        else:
            return xr
Beispiel #32
0
def auto_correlation(
    x,
    axis=-1,
    max_lags=None,
    center=True,
    normalize=True,
    name="auto_correlation"):
  """Auto correlation along one axis.

  Given a `1-D` wide sense stationary (WSS) sequence `X`, the auto correlation
  `RXX` may be defined as  (with `E` expectation and `Conj` complex conjugate)

  ```
  RXX[m] := E{ W[m] Conj(W[0]) } = E{ W[0] Conj(W[-m]) },
  W[n]   := (X[n] - MU) / S,
  MU     := E{ X[0] },
  S**2   := E{ (X[0] - MU) Conj(X[0] - MU) }.
  ```

  This function takes the viewpoint that `x` is (along one axis) a finite
  sub-sequence of a realization of (WSS) `X`, and then uses `x` to produce an
  estimate of `RXX[m]` as follows:

  After extending `x` from length `L` to `inf` by zero padding, the auto
  correlation estimate `rxx[m]` is computed for `m = 0, 1, ..., max_lags` as

  ```
  rxx[m] := (L - m)**-1 sum_n w[n + m] Conj(w[n]),
  w[n]   := (x[n] - mu) / s,
  mu     := L**-1 sum_n x[n],
  s**2   := L**-1 sum_n (x[n] - mu) Conj(x[n] - mu)
  ```

  The error in this estimate is proportional to `1 / sqrt(len(x) - m)`, so users
  often set `max_lags` small enough so that the entire output is meaningful.

  Note that since `mu` is an imperfect estimate of `E{ X[0] }`, and we divide by
  `len(x) - m` rather than `len(x) - m - 1`, our estimate of auto correlation
  contains a slight bias, which goes to zero as `len(x) - m --> infinity`.

  Args:
    x:  `float32` or `complex64` `Tensor`.
    axis:  Python `int`. The axis number along which to compute correlation.
      Other dimensions index different batch members.
    max_lags:  Positive `int` tensor.  The maximum value of `m` to consider
      (in equation above).  If `max_lags >= x.shape[axis]`, we effectively
      re-set `max_lags` to `x.shape[axis] - 1`.
    center:  Python `bool`.  If `False`, do not subtract the mean estimate `mu`
      from `x[n]` when forming `w[n]`.
    normalize:  Python `bool`.  If `False`, do not divide by the variance
      estimate `s**2` when forming `w[n]`.
    name:  `String` name to prepend to created ops.

  Returns:
    `rxx`: `Tensor` of same `dtype` as `x`.  `rxx.shape[i] = x.shape[i]` for
      `i != axis`, and `rxx.shape[axis] = max_lags + 1`.

  Raises:
    TypeError:  If `x` is not a supported type.
  """
  # Implementation details:
  # Extend length N / 2 1-D array x to length N by zero padding onto the end.
  # Then, set
  #   F[x]_k := sum_n x_n exp{-i 2 pi k n / N }.
  # It is not hard to see that
  #   F[x]_k Conj(F[x]_k) = F[R]_k, where
  #   R_m := sum_n x_n Conj(x_{(n - m) mod N}).
  # One can also check that R_m / (N / 2 - m) is an unbiased estimate of RXX[m].

  # Since F[x] is the DFT of x, this leads us to a zero-padding and FFT/IFFT
  # based version of estimating RXX.
  # Note that this is a special case of the Wiener-Khinchin Theorem.
  with tf.name_scope(name, values=[x]):
    x = tf.convert_to_tensor(x, name="x")

    # Rotate dimensions of x in order to put axis at the rightmost dim.
    # FFT op requires this.
    rank = util.prefer_static_rank(x)
    if axis < 0:
      axis = rank + axis
    shift = rank - 1 - axis
    # Suppose x.shape[axis] = T, so there are T "time" steps.
    #   ==> x_rotated.shape = B + [T],
    # where B is x_rotated's batch shape.
    x_rotated = util.rotate_transpose(x, shift)

    if center:
      x_rotated -= tf.reduce_mean(x_rotated, axis=-1, keepdims=True)

    # x_len = N / 2 from above explanation.  The length of x along axis.
    # Get a value for x_len that works in all cases.
    x_len = util.prefer_static_shape(x_rotated)[-1]

    # TODO(langmore) Investigate whether this zero padding helps or hurts.  At
    # the moment is is necessary so that all FFT implementations work.
    # Zero pad to the next power of 2 greater than 2 * x_len, which equals
    # 2**(ceil(Log_2(2 * x_len))).  Note: Log_2(X) = Log_e(X) / Log_e(2).
    x_len_float64 = tf.cast(x_len, np.float64)
    target_length = tf.pow(
        np.float64(2.), tf.ceil(tf.log(x_len_float64 * 2) / np.log(2.)))
    pad_length = tf.cast(target_length - x_len_float64, np.int32)

    # We should have:
    # x_rotated_pad.shape = x_rotated.shape[:-1] + [T + pad_length]
    #                     = B + [T + pad_length]
    x_rotated_pad = util.pad(x_rotated, axis=-1, back=True, count=pad_length)

    dtype = x.dtype
    if not dtype.is_complex:
      if not dtype.is_floating:
        raise TypeError("Argument x must have either float or complex dtype"
                        " found: {}".format(dtype))
      x_rotated_pad = tf.complex(x_rotated_pad,
                                 dtype.real_dtype.as_numpy_dtype(0.))

    # Autocorrelation is IFFT of power-spectral density (up to some scaling).
    fft_x_rotated_pad = tf.fft(x_rotated_pad)
    spectral_density = fft_x_rotated_pad * tf.conj(fft_x_rotated_pad)
    # shifted_product is R[m] from above detailed explanation.
    # It is the inner product sum_n X[n] * Conj(X[n - m]).
    shifted_product = tf.ifft(spectral_density)

    # Cast back to real-valued if x was real to begin with.
    shifted_product = tf.cast(shifted_product, dtype)

    # Figure out if we can deduce the final static shape, and set max_lags.
    # Use x_rotated as a reference, because it has the time dimension in the far
    # right, and was created before we performed all sorts of crazy shape
    # manipulations.
    know_static_shape = True
    if not x_rotated.shape.is_fully_defined():
      know_static_shape = False
    if max_lags is None:
      max_lags = x_len - 1
    else:
      max_lags = tf.convert_to_tensor(max_lags, name="max_lags")
      max_lags_ = tensor_util.constant_value(max_lags)
      if max_lags_ is None or not know_static_shape:
        know_static_shape = False
        max_lags = tf.minimum(x_len - 1, max_lags)
      else:
        max_lags = min(x_len - 1, max_lags_)

    # Chop off the padding.
    # We allow users to provide a huge max_lags, but cut it off here.
    # shifted_product_chopped.shape = x_rotated.shape[:-1] + [max_lags]
    shifted_product_chopped = shifted_product[..., :max_lags + 1]

    # If possible, set shape.
    if know_static_shape:
      chopped_shape = x_rotated.shape.as_list()
      chopped_shape[-1] = min(x_len, max_lags + 1)
      shifted_product_chopped.set_shape(chopped_shape)

    # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]).  The
    # other terms were zeros arising only due to zero padding.
    # `denominator = (N / 2 - m)` (defined below) is the proper term to
    # divide by by to make this an unbiased estimate of the expectation
    # E[X[n] Conj(X[n - m])].
    x_len = tf.cast(x_len, dtype.real_dtype)
    max_lags = tf.cast(max_lags, dtype.real_dtype)
    denominator = x_len - tf.range(0., max_lags + 1.)
    denominator = tf.cast(denominator, dtype)
    shifted_product_rotated = shifted_product_chopped / denominator

    if normalize:
      shifted_product_rotated /= shifted_product_rotated[..., :1]

    # Transpose dimensions back to those of x.
    return util.rotate_transpose(shifted_product_rotated, -shift)
Beispiel #33
0
	def _cconv(self, a, b):
		return tf.ifft(tf.fft(a) * tf.fft(b)).real
Beispiel #34
0
	def _ccorr(self, a, b):
		a = tf.cast(a, tf.complex64)
		b = tf.cast(b, tf.complex64)
		return tf.real(tf.ifft(tf.conj(tf.fft(a)) * tf.fft(b)))