Ejemplo n.º 1
0
def complex_mul_unitary(vec_in, out_size, scope=None):
    shape = vec_in.get_shape().as_list()
    if len(shape) != 2:
        raise ValueError('Argument `vec_in` must be a batch of vectors'
                         ' (2D Tensor)')
    #  in_size = shape[1]
    fft_scale = 1.0 / sqrt(out_size)
    with vs.variable_scope(scope or 'ULinear') as _s:
        diag0 = get_complex_variable('diag0', _s, shape=[out_size])
        diag1 = get_complex_variable('diag1', _s, shape=[out_size])
        diag2 = get_complex_variable('diag2', _s, shape=[out_size])
        refl0 = get_complex_variable('refl0', _s, shape=[out_size])
        refl1 = get_complex_variable('refl1', _s, shape=[out_size])
        perm0 = tf.constant(np.random.permutation(out_size),
                            name='perm0',
                            dtype=tf.int32)
        out_ = vec_in * diag0
        refl0 = normalize_c(refl0)
        refl1 = normalize_c(refl1)
        out_ = refl_c(math_ops.batch_fft(out_) * fft_scale, refl0)
        out_ = diag1 * tf.transpose(tf.gather(tf.transpose(out_), perm0))
        out_ = diag2 * refl_c(math_ops.batch_ifft(out_) * fft_scale, refl1)

        return out_
Ejemplo n.º 2
0
def _BatchIFFTGrad(_, grad):
    rsize = 1. / math_ops.cast(_FFTSizeForGrad(grad, 1), dtypes.float32)
    return math_ops.batch_fft(grad) * math_ops.complex(rsize, 0.)
Ejemplo n.º 3
0
def _BatchIFFTGrad(_, grad):
  rsize = 1. / math_ops.cast(_FFTSizeForGrad(grad, 1), dtypes.float32)
  return math_ops.batch_fft(grad) * math_ops.complex(rsize, 0.)