def reparam(rdft): irdft_matrix = spectral_ops.irdft_matrix(var_shape[:-2], dtype=var_dtype) if not self.dc: irdft_matrix = irdft_matrix[:, 1:] var = tf.linalg.matmul(irdft_matrix, rdft) var = tf.reshape(var, var_shape) return var
def test_irdft3_matrix(self): for shape in [(3, 4, 2), (6, 3, 1)]: size = shape[0] * shape[1] * shape[2] matrix = spectral_ops.irdft_matrix(shape) # Test that the matrix is orthonormal. result = tf.matmul(matrix, tf.transpose(matrix)) with self.test_session() as sess: result, = sess.run([result]) self.assertAllClose(result, np.identity(size))
def test_irdft3_matrix(self): for shape in [(3, 4, 2), (6, 3, 1)]: size = shape[0] * shape[1] * shape[2] matrix = spectral_ops.irdft_matrix(shape) # Test that the matrix is orthonormal. result = tf.matmul(matrix, tf.transpose(matrix)) with self.test_session() as sess: result, = sess.run([result]) self.assertAllClose(result, np.identity(size))
def __call__(self, getter, name, shape, dtype, initializer, regularizer=None): if all(s == 1 for s in shape[:-2]): return getter(name=name, shape=shape, dtype=dtype, initializer=initializer, regularizer=regularizer) var_shape = shape var_dtype = dtype size = var_shape[0] for s in var_shape[1:-2]: size *= s irdft_matrix = spectral_ops.irdft_matrix(var_shape[:-2], dtype=var_dtype) if self.dc: rdft_shape = (size, var_shape[-2] * var_shape[-1]) else: irdft_matrix = irdft_matrix[:, 1:] rdft_shape = (size - 1, var_shape[-2] * var_shape[-1]) rdft_dtype = var_dtype rdft_name = name + "_rdft" def rdft_initializer(shape, dtype=None, partition_info=None): assert tuple(shape) == rdft_shape, shape assert dtype == rdft_dtype, dtype init = initializer(var_shape, dtype=var_dtype, partition_info=partition_info) init = tf.reshape(init, (-1, rdft_shape[-1])) init = tf.linalg.matmul(irdft_matrix, init, transpose_a=True) return init def reparam(rdft): var = tf.linalg.matmul(irdft_matrix, rdft) var = tf.reshape(var, var_shape) return var if regularizer is not None: regularizer = lambda rdft: regularizer(reparam(rdft)) rdft = getter(name=rdft_name, shape=rdft_shape, dtype=rdft_dtype, initializer=rdft_initializer, regularizer=regularizer) return reparam(rdft)
def rdft_initializer(shape, dtype=None, partition_info=None): """Initializer wrapper.""" del partition_info # Ignored for TF 1/2 compatibility. assert tuple(shape) == rdft_shape, shape assert dtype == rdft_dtype, dtype init = initializer(var_shape, dtype=var_dtype) init = tf.reshape(init, (-1, rdft_shape[-1])) irdft_matrix = spectral_ops.irdft_matrix(var_shape[:-2], dtype=var_dtype) if not self.dc: irdft_matrix = irdft_matrix[:, 1:] init = tf.linalg.matmul(irdft_matrix, init, transpose_a=True) return init
def __init__(self, initial_value, name=None, dc=True, shape=None, dtype=None): """Initializer. Args: initial_value: `tf.Tensor` or `None`. The initial value of the kernel. If not provided, its `shape` must be given, and the initial value of the parameter will be undefined. name: String. The name of the kernel. dc: Boolean. If `False`, the DC component of the kernel RDFTs is not represented, forcing the filters to be highpass. Defaults to `True`. shape: `tf.TensorShape` or compatible. Ignored unless `initial_value is None`. dtype: `tf.dtypes.DType` or compatible. DType of this parameter. If not given, inferred from `initial_value`. """ super().__init__(name=name) self._dc = bool(dc) if initial_value is None: if shape is None: raise ValueError( "If initial_value is None, shape must be specified.") initial_value = tf.zeros(shape, dtype=dtype) else: initial_value = tf.convert_to_tensor(initial_value, dtype=dtype) self._shape = initial_value.shape self._matrix = spectral_ops.irdft_matrix(self.shape[:-2], dtype=initial_value.dtype) if not self.dc: self._matrix = self._matrix[:, 1:] initial_value = tf.reshape(initial_value, (-1, self.shape[-2] * self.shape[-1])) initial_value = tf.linalg.matmul(self._matrix, initial_value, transpose_a=True) if name is not None: name = f"{name}_rdft" self.rdft = tf.Variable(initial_value, name=name)
def __call__(self, getter, name, shape, dtype, initializer, regularizer=None): if all(s == 1 for s in shape[:-2]): return getter(name=name, shape=shape, dtype=dtype, initializer=initializer, regularizer=regularizer) var_shape = shape var_dtype = dtype size = var_shape[0] for s in var_shape[1:-2]: size *= s irdft_matrix = spectral_ops.irdft_matrix(var_shape[:-2], dtype=var_dtype) if self.dc: rdft_shape = (size, var_shape[-2] * var_shape[-1]) else: irdft_matrix = irdft_matrix[:, 1:] rdft_shape = (size - 1, var_shape[-2] * var_shape[-1]) rdft_dtype = var_dtype rdft_name = name + "_rdft" def rdft_initializer(shape, dtype=None, partition_info=None): assert tuple(shape) == rdft_shape, shape assert dtype == rdft_dtype, dtype init = initializer( var_shape, dtype=var_dtype, partition_info=partition_info) init = tf.reshape(init, (-1, rdft_shape[-1])) init = tf.linalg.matmul(irdft_matrix, init, transpose_a=True) return init def reparam(rdft): var = tf.linalg.matmul(irdft_matrix, rdft) var = tf.reshape(var, var_shape) return var if regularizer is not None: regularizer = lambda rdft: regularizer(reparam(rdft)) rdft = getter( name=rdft_name, shape=rdft_shape, dtype=rdft_dtype, initializer=rdft_initializer, regularizer=regularizer) return reparam(rdft)
def test_irdft_matrix_is_orthonormal(self, *shape): matrix = spectral_ops.irdft_matrix(shape) result = tf.matmul(matrix, tf.transpose(matrix)) self.assertAllClose(result, tf.eye(tf.TensorShape(shape).num_elements()))