def reparam(rdft):
     irdft_matrix = spectral_ops.irdft_matrix(var_shape[:-2],
                                              dtype=var_dtype)
     if not self.dc:
         irdft_matrix = irdft_matrix[:, 1:]
     var = tf.linalg.matmul(irdft_matrix, rdft)
     var = tf.reshape(var, var_shape)
     return var
示例#2
0
 def test_irdft3_matrix(self):
   for shape in [(3, 4, 2), (6, 3, 1)]:
     size = shape[0] * shape[1] * shape[2]
     matrix = spectral_ops.irdft_matrix(shape)
     # Test that the matrix is orthonormal.
     result = tf.matmul(matrix, tf.transpose(matrix))
     with self.test_session() as sess:
       result, = sess.run([result])
       self.assertAllClose(result, np.identity(size))
 def test_irdft3_matrix(self):
   for shape in [(3, 4, 2), (6, 3, 1)]:
     size = shape[0] * shape[1] * shape[2]
     matrix = spectral_ops.irdft_matrix(shape)
     # Test that the matrix is orthonormal.
     result = tf.matmul(matrix, tf.transpose(matrix))
     with self.test_session() as sess:
       result, = sess.run([result])
       self.assertAllClose(result, np.identity(size))
示例#4
0
    def __call__(self,
                 getter,
                 name,
                 shape,
                 dtype,
                 initializer,
                 regularizer=None):
        if all(s == 1 for s in shape[:-2]):
            return getter(name=name,
                          shape=shape,
                          dtype=dtype,
                          initializer=initializer,
                          regularizer=regularizer)
        var_shape = shape
        var_dtype = dtype
        size = var_shape[0]
        for s in var_shape[1:-2]:
            size *= s
        irdft_matrix = spectral_ops.irdft_matrix(var_shape[:-2],
                                                 dtype=var_dtype)
        if self.dc:
            rdft_shape = (size, var_shape[-2] * var_shape[-1])
        else:
            irdft_matrix = irdft_matrix[:, 1:]
            rdft_shape = (size - 1, var_shape[-2] * var_shape[-1])
        rdft_dtype = var_dtype
        rdft_name = name + "_rdft"

        def rdft_initializer(shape, dtype=None, partition_info=None):
            assert tuple(shape) == rdft_shape, shape
            assert dtype == rdft_dtype, dtype
            init = initializer(var_shape,
                               dtype=var_dtype,
                               partition_info=partition_info)
            init = tf.reshape(init, (-1, rdft_shape[-1]))
            init = tf.linalg.matmul(irdft_matrix, init, transpose_a=True)
            return init

        def reparam(rdft):
            var = tf.linalg.matmul(irdft_matrix, rdft)
            var = tf.reshape(var, var_shape)
            return var

        if regularizer is not None:
            regularizer = lambda rdft: regularizer(reparam(rdft))

        rdft = getter(name=rdft_name,
                      shape=rdft_shape,
                      dtype=rdft_dtype,
                      initializer=rdft_initializer,
                      regularizer=regularizer)
        return reparam(rdft)
 def rdft_initializer(shape, dtype=None, partition_info=None):
     """Initializer wrapper."""
     del partition_info  # Ignored for TF 1/2 compatibility.
     assert tuple(shape) == rdft_shape, shape
     assert dtype == rdft_dtype, dtype
     init = initializer(var_shape, dtype=var_dtype)
     init = tf.reshape(init, (-1, rdft_shape[-1]))
     irdft_matrix = spectral_ops.irdft_matrix(var_shape[:-2],
                                              dtype=var_dtype)
     if not self.dc:
         irdft_matrix = irdft_matrix[:, 1:]
     init = tf.linalg.matmul(irdft_matrix, init, transpose_a=True)
     return init
示例#6
0
    def __init__(self,
                 initial_value,
                 name=None,
                 dc=True,
                 shape=None,
                 dtype=None):
        """Initializer.

    Args:
      initial_value: `tf.Tensor` or `None`. The initial value of the kernel. If
        not provided, its `shape` must be given, and the initial value of the
        parameter will be undefined.
      name: String. The name of the kernel.
      dc: Boolean. If `False`, the DC component of the kernel RDFTs is not
        represented, forcing the filters to be highpass. Defaults to `True`.
      shape: `tf.TensorShape` or compatible. Ignored unless `initial_value is
        None`.
      dtype: `tf.dtypes.DType` or compatible. DType of this parameter. If not
        given, inferred from `initial_value`.
    """
        super().__init__(name=name)
        self._dc = bool(dc)
        if initial_value is None:
            if shape is None:
                raise ValueError(
                    "If initial_value is None, shape must be specified.")
            initial_value = tf.zeros(shape, dtype=dtype)
        else:
            initial_value = tf.convert_to_tensor(initial_value, dtype=dtype)
        self._shape = initial_value.shape
        self._matrix = spectral_ops.irdft_matrix(self.shape[:-2],
                                                 dtype=initial_value.dtype)
        if not self.dc:
            self._matrix = self._matrix[:, 1:]
        initial_value = tf.reshape(initial_value,
                                   (-1, self.shape[-2] * self.shape[-1]))
        initial_value = tf.linalg.matmul(self._matrix,
                                         initial_value,
                                         transpose_a=True)
        if name is not None:
            name = f"{name}_rdft"
        self.rdft = tf.Variable(initial_value, name=name)
  def __call__(self, getter, name, shape, dtype, initializer, regularizer=None):
    if all(s == 1 for s in shape[:-2]):
      return getter(name=name, shape=shape, dtype=dtype,
                    initializer=initializer, regularizer=regularizer)
    var_shape = shape
    var_dtype = dtype
    size = var_shape[0]
    for s in var_shape[1:-2]:
      size *= s
    irdft_matrix = spectral_ops.irdft_matrix(var_shape[:-2], dtype=var_dtype)
    if self.dc:
      rdft_shape = (size, var_shape[-2] * var_shape[-1])
    else:
      irdft_matrix = irdft_matrix[:, 1:]
      rdft_shape = (size - 1, var_shape[-2] * var_shape[-1])
    rdft_dtype = var_dtype
    rdft_name = name + "_rdft"

    def rdft_initializer(shape, dtype=None, partition_info=None):
      assert tuple(shape) == rdft_shape, shape
      assert dtype == rdft_dtype, dtype
      init = initializer(
          var_shape, dtype=var_dtype, partition_info=partition_info)
      init = tf.reshape(init, (-1, rdft_shape[-1]))
      init = tf.linalg.matmul(irdft_matrix, init, transpose_a=True)
      return init

    def reparam(rdft):
      var = tf.linalg.matmul(irdft_matrix, rdft)
      var = tf.reshape(var, var_shape)
      return var

    if regularizer is not None:
      regularizer = lambda rdft: regularizer(reparam(rdft))

    rdft = getter(
        name=rdft_name, shape=rdft_shape, dtype=rdft_dtype,
        initializer=rdft_initializer, regularizer=regularizer)
    return reparam(rdft)
示例#8
0
 def test_irdft_matrix_is_orthonormal(self, *shape):
   matrix = spectral_ops.irdft_matrix(shape)
   result = tf.matmul(matrix, tf.transpose(matrix))
   self.assertAllClose(result, tf.eye(tf.TensorShape(shape).num_elements()))