def complex_mul_unitary(vec_in, out_size, scope=None): shape = vec_in.get_shape().as_list() if len(shape) != 2: raise ValueError('Argument `vec_in` must be a batch of vectors' ' (2D Tensor)') # in_size = shape[1] fft_scale = 1.0 / sqrt(out_size) with vs.variable_scope(scope or 'ULinear') as _s: diag0 = get_complex_variable('diag0', _s, shape=[out_size]) diag1 = get_complex_variable('diag1', _s, shape=[out_size]) diag2 = get_complex_variable('diag2', _s, shape=[out_size]) refl0 = get_complex_variable('refl0', _s, shape=[out_size]) refl1 = get_complex_variable('refl1', _s, shape=[out_size]) perm0 = tf.constant(np.random.permutation(out_size), name='perm0', dtype=tf.int32) out_ = vec_in * diag0 refl0 = normalize_c(refl0) refl1 = normalize_c(refl1) out_ = refl_c(math_ops.batch_fft(out_) * fft_scale, refl0) out_ = diag1 * tf.transpose(tf.gather(tf.transpose(out_), perm0)) out_ = diag2 * refl_c(math_ops.batch_ifft(out_) * fft_scale, refl1) return out_
def _BatchFFTGrad(_, grad): size = math_ops.cast(_FFTSizeForGrad(grad, 1), dtypes.float32) return math_ops.batch_ifft(grad) * math_ops.complex(size, 0.)