Пример #1
0
def matmul_real_with_complex(real_input, complex_matrix):
    real_part = tf.matmul(real_input, tf.math.real(complex_matrix))
    imag_part = tf.matmul(real_input, tf.math.imag(complex_matrix))
    return tf.complex(real_part, imag_part)
Пример #2
0
  def _testMVN(self,
               base_distribution_class,
               base_distribution_kwargs,
               batch_shape=(),
               event_shape=(),
               not_implemented_message=None):
    # Overriding shapes must be compatible w/bijector; most bijectors are
    # batch_shape agnostic and only care about event_ndims.
    # In the case of `Affine`, if we got it wrong then it would fire an
    # exception due to incompatible dimensions.
    batch_shape_pl = tf1.placeholder_with_default(
        input=np.int32(batch_shape), shape=None, name='dynamic_batch_shape')
    event_shape_pl = tf1.placeholder_with_default(
        input=np.int32(event_shape), shape=None, name='dynamic_event_shape')
    fake_mvn_dynamic = self._cls()(
        distribution=base_distribution_class(
            validate_args=True, **base_distribution_kwargs),
        bijector=tfb.Affine(shift=self._shift, scale_tril=self._tril),
        batch_shape=batch_shape_pl,
        event_shape=event_shape_pl,
        validate_args=True)

    fake_mvn_static = self._cls()(
        distribution=base_distribution_class(
            validate_args=True, **base_distribution_kwargs),
        bijector=tfb.Affine(shift=self._shift, scale_tril=self._tril),
        batch_shape=batch_shape,
        event_shape=event_shape,
        validate_args=True)

    actual_mean = np.tile(self._shift, [2, 1])  # Affine elided this tile.
    actual_cov = np.matmul(self._tril, np.transpose(self._tril, [0, 2, 1]))

    def actual_mvn_log_prob(x):
      return np.concatenate([[  # pylint: disable=g-complex-comprehension
          stats.multivariate_normal(actual_mean[i],
                                    actual_cov[i]).logpdf(x[:, i, :])
      ] for i in range(len(actual_cov))]).T

    actual_mvn_entropy = np.concatenate(
        [[stats.multivariate_normal(actual_mean[i], actual_cov[i]).entropy()]
         for i in range(len(actual_cov))])

    self.assertAllEqual([3], fake_mvn_static.event_shape)
    self.assertAllEqual([2], fake_mvn_static.batch_shape)

    if not tf.executing_eagerly():
      self.assertAllEqual(tf.TensorShape(None), fake_mvn_dynamic.event_shape)
      self.assertAllEqual(tf.TensorShape(None), fake_mvn_dynamic.batch_shape)

    x = self.evaluate(fake_mvn_static.sample(5, seed=tfp_test_util.test_seed()))
    for unsupported_fn in (fake_mvn_static.log_cdf, fake_mvn_static.cdf,
                           fake_mvn_static.survival_function,
                           fake_mvn_static.log_survival_function):
      with self.assertRaisesRegexp(NotImplementedError,
                                   not_implemented_message):
        unsupported_fn(x)

    num_samples = 7e3
    for fake_mvn in [fake_mvn_static, fake_mvn_dynamic]:
      # Ensure sample works by checking first, second moments.
      y = fake_mvn.sample(int(num_samples), seed=tfp_test_util.test_seed())
      x = y[0:5, ...]
      sample_mean = tf.reduce_mean(input_tensor=y, axis=0)
      centered_y = tf.transpose(a=y - sample_mean, perm=[1, 2, 0])
      sample_cov = tf.matmul(
          centered_y, centered_y, transpose_b=True) / num_samples
      [
          sample_mean_,
          sample_cov_,
          x_,
          fake_event_shape_,
          fake_batch_shape_,
          fake_log_prob_,
          fake_prob_,
          fake_mean_,
          fake_entropy_,
      ] = self.evaluate([
          sample_mean,
          sample_cov,
          x,
          fake_mvn.event_shape_tensor(),
          fake_mvn.batch_shape_tensor(),
          fake_mvn.log_prob(x),
          fake_mvn.prob(x),
          fake_mvn.mean(),
          fake_mvn.entropy(),
      ])

      self.assertAllClose(actual_mean, sample_mean_, atol=0.1, rtol=0.1)
      self.assertAllClose(actual_cov, sample_cov_, atol=0., rtol=0.1)

      # Ensure all other functions work as intended.
      self.assertAllEqual([5, 2, 3], x_.shape)
      self.assertAllEqual([3], fake_event_shape_)
      self.assertAllEqual([2], fake_batch_shape_)
      self.assertAllClose(
          actual_mvn_log_prob(x_), fake_log_prob_, atol=0., rtol=1e-6)
      self.assertAllClose(
          np.exp(actual_mvn_log_prob(x_)), fake_prob_, atol=0., rtol=1e-5)
      self.assertAllClose(actual_mean, fake_mean_, atol=0., rtol=1e-6)
      self.assertAllClose(actual_mvn_entropy, fake_entropy_, atol=0., rtol=1e-6)
Пример #3
0
def sample_lkj(
    num_samples,
    dimension,
    concentration,
    cholesky_space=False,
    seed=None,
    name=None):
  """Returns a Tensor of samples from an LKJ distribution.

  Args:
    num_samples: Python `int`. The number of samples to draw.
    dimension: Python `int`. The dimension of correlation matrices.
    concentration: `Tensor` representing the concentration of the LKJ
      distribution.
    cholesky_space: Python `bool`. Whether to take samples from LKJ or
      Chol(LKJ).
    seed: Python integer seed for RNG
    name: Python `str` name prefixed to Ops created by this function.

  Returns:
    samples: A Tensor of correlation matrices (or Cholesky factors of
      correlation matrices if `cholesky_space = True`) with shape
      `[n] + B + [D, D]`, where `B` is the shape of the `concentration`
      parameter, and `D` is the `dimension`.

  Raises:
    ValueError: If `dimension` is negative.
  """
  if dimension < 0:
    raise ValueError(
        'Cannot sample negative-dimension correlation matrices.')
  # Notation below: B is the batch shape, i.e., tf.shape(concentration)

  # We need 1 seed for beta corr12, and 2 per loop iter.
  num_seeds = 1 + 2 * max(0, dimension - 2)
  seeds = list(samplers.split_seed(seed, n=num_seeds, salt='sample_lkj'))
  with tf.name_scope('sample_lkj' or name):
    concentration = tf.convert_to_tensor(concentration)
    if not dtype_util.is_floating(concentration.dtype):
      raise TypeError(
          'The concentration argument should have floating type, not '
          '{}'.format(dtype_util.name(concentration.dtype)))

    concentration = _replicate(num_samples, concentration)
    concentration_shape = tf.shape(concentration)
    if dimension <= 1:
      # For any dimension <= 1, there is only one possible correlation matrix.
      shape = tf.concat([
          concentration_shape, [dimension, dimension]], axis=0)
      return tf.ones(shape=shape, dtype=concentration.dtype)
    beta_conc = concentration + (dimension - 2.) / 2.
    beta_dist = beta.Beta(concentration1=beta_conc, concentration0=beta_conc)

    # Note that the sampler below deviates from [1], by doing the sampling in
    # cholesky space. This does not change the fundamental logic of the
    # sampler, but does speed up the sampling.

    # This is the correlation coefficient between the first two dimensions.
    # This is also `r` in reference [1].
    corr12 = 2. * beta_dist.sample(seed=seeds.pop()) - 1.

    # Below we construct the Cholesky of the initial 2x2 correlation matrix,
    # which is of the form:
    # [[1, 0], [r, sqrt(1 - r**2)]], where r is the correlation between the
    # first two dimensions.
    # This is the top-left corner of the cholesky of the final sample.
    first_row = tf.concat([
        tf.ones_like(corr12)[..., tf.newaxis],
        tf.zeros_like(corr12)[..., tf.newaxis]], axis=-1)
    second_row = tf.concat([
        corr12[..., tf.newaxis],
        tf.sqrt(1 - corr12**2)[..., tf.newaxis]], axis=-1)

    chol_result = tf.concat([
        first_row[..., tf.newaxis, :],
        second_row[..., tf.newaxis, :]], axis=-2)

    for n in range(2, dimension):
      # Loop invariant: on entry, result has shape B + [n, n]
      beta_conc = beta_conc - 0.5
      # norm is y in reference [1].
      norm = beta.Beta(
          concentration1=n/2.,
          concentration0=beta_conc
      ).sample(seed=seeds.pop())
      # distance shape: B + [1] for broadcast
      distance = tf.sqrt(norm)[..., tf.newaxis]
      # direction is u in reference [1].
      # direction shape: B + [n]
      direction = _uniform_unit_norm(
          n, concentration_shape, concentration.dtype,
          seed=seeds.pop())
      # raw_correlation is w in reference [1].
      raw_correlation = distance * direction  # shape: B + [n]

      # This is the next row in the cholesky of the result,
      # which differs from the construction in reference [1].
      # In the reference, the new row `z` = chol_result @ raw_correlation^T
      # = C @ raw_correlation^T (where as short hand we use C = chol_result).
      # We prove that the below equation is the right row to add to the
      # cholesky, by showing equality with reference [1].
      # Let S be the sample constructed so far, and let `z` be as in
      # reference [1]. Then at this iteration, the new sample S' will be
      # [[S z^T]
      #  [z 1]]
      # In our case we have the cholesky decomposition factor C, so
      # we want our new row x (same size as z) to satisfy:
      #  [[S z^T]  [[C 0]    [[C^T  x^T]         [[CC^T  Cx^T]
      #   [z 1]] =  [x k]]    [0     k]]  =       [xC^t   xx^T + k**2]]
      # Since C @ raw_correlation^T = z = C @ x^T, and C is invertible,
      # we have that x = raw_correlation. Also 1 = xx^T + k**2, so k
      # = sqrt(1 - xx^T) = sqrt(1 - |raw_correlation|**2) = sqrt(1 -
      # distance**2).
      new_row = tf.concat(
          [raw_correlation, tf.sqrt(1. - norm[..., tf.newaxis])], axis=-1)

      # Finally add this new row, by growing the cholesky of the result.
      chol_result = tf.concat([
          chol_result,
          tf.zeros_like(chol_result[..., 0][..., tf.newaxis])], axis=-1)

      chol_result = tf.concat(
          [chol_result, new_row[..., tf.newaxis, :]], axis=-2)

    assert not seeds, 'Did not use all seeds: ' + len(seeds)
    if cholesky_space:
      return chol_result

    result = tf.matmul(chol_result, chol_result, transpose_b=True)
    # The diagonal for a correlation matrix should always be ones. Due to
    # numerical instability the matmul might not achieve that, so manually set
    # these to ones.
    result = tf.linalg.set_diag(
        result, tf.ones(shape=tf.shape(result)[:-1], dtype=result.dtype))
    # This sampling algorithm can produce near-PSD matrices on which standard
    # algorithms such as `tf.cholesky` or `tf.linalg.self_adjoint_eigvals`
    # fail. Specifically, as documented in b/116828694, around 2% of trials
    # of 900,000 5x5 matrices (distributed according to 9 different
    # concentration parameter values) contained at least one matrix on which
    # the Cholesky decomposition failed.
    return result
    def testDenseLocalReparameterization(self):
        batch_size, in_size, out_size = 2, 3, 4
        with self.cached_session() as sess:
            tf1.set_random_seed(9068)
            (kernel_posterior, kernel_prior, kernel_divergence, bias_posterior,
             bias_prior, bias_divergence, layer, inputs, outputs,
             kl_penalty) = self._testDenseSetUp(
                 tfp.layers.DenseLocalReparameterization, batch_size, in_size,
                 out_size)

            tf1.set_random_seed(9068)
            expected_kernel_posterior_affine = tfd.Normal(
                loc=tf.matmul(inputs, kernel_posterior.result_loc),
                scale=tf.matmul(inputs**2.,
                                kernel_posterior.result_scale**2)**0.5)
            expected_kernel_posterior_affine_tensor = (
                expected_kernel_posterior_affine.sample(seed=42))
            expected_outputs = (expected_kernel_posterior_affine_tensor +
                                bias_posterior.result_sample)

            [
                expected_outputs_,
                actual_outputs_,
                expected_kernel_divergence_,
                actual_kernel_divergence_,
                expected_bias_,
                actual_bias_,
                expected_bias_divergence_,
                actual_bias_divergence_,
            ] = sess.run([
                expected_outputs,
                outputs,
                kernel_divergence.result,
                kl_penalty[0],
                bias_posterior.result_sample,
                layer.bias_posterior_tensor,
                bias_divergence.result,
                kl_penalty[1],
            ])

            self.assertAllClose(expected_bias_,
                                actual_bias_,
                                rtol=1e-6,
                                atol=0.)
            self.assertAllClose(expected_outputs_,
                                actual_outputs_,
                                rtol=1e-6,
                                atol=0.)
            self.assertAllClose(expected_kernel_divergence_,
                                actual_kernel_divergence_,
                                rtol=1e-6,
                                atol=0.)
            self.assertAllClose(expected_bias_divergence_,
                                actual_bias_divergence_,
                                rtol=1e-6,
                                atol=0.)

            expected_args = [kernel_posterior, kernel_prior, None]
            # We expect that there was one call to kernel_divergence, with the above
            # args; MockKLDivergence appends the list of args to a list, so the above
            # args should be in the 0th position of that list.
            actual_args = kernel_divergence.args[0]
            # Test for identity with 'is'. TensorFlowTestCase.assertAllEqual actually
            # coerces the inputs to numpy arrays, so we can't use that to assert that
            # the arguments (which are a mixture of Distributions and Tensors) are
            # equal.
            for a, b in zip(expected_args, actual_args):
                self.assertIs(a, b)

            # Same story as above.
            expected_args = [
                bias_posterior, bias_prior, bias_posterior.result_sample
            ]
            actual_args = bias_divergence.args[0]
            for a, b in zip(expected_args, actual_args):
                self.assertIs(a, b)
Пример #5
0
 def _matmul(self, inputs, kernel):
     if inputs.shape.ndims <= 2:
         return tf.matmul(inputs, kernel)
     # To handle broadcasting, we must use `tensordot`.
     return tf.tensordot(inputs, kernel, axes=[[-1], [0]])
Пример #6
0
 def _covariance(self):
   p = self._probs_parameter_no_checks()
   ret = -tf.matmul(p[..., None], p[..., None, :])
   return tf.linalg.set_diag(ret, self._variance(p))
Пример #7
0
    def _sample_n(self, n, seed):
        batch_shape = self.batch_shape_tensor()
        event_shape = self.event_shape_tensor()
        batch_ndims = tf.shape(input=batch_shape)[0]

        ndims = batch_ndims + 3  # sample_ndims=1, event_ndims=2
        shape = tf.concat([[n], batch_shape, event_shape], 0)
        stream = seed_stream.SeedStream(seed, salt="Wishart")

        # Complexity: O(nbk**2)
        x = tf.random.normal(shape=shape,
                             mean=0.,
                             stddev=1.,
                             dtype=self.dtype,
                             seed=stream())

        # Complexity: O(nbk)
        # This parametrization is equivalent to Chi2, i.e.,
        # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2)
        expanded_df = self.df * tf.ones(
            self.scale_operator.batch_shape_tensor(),
            dtype=dtype_util.base_dtype(self.df.dtype))

        g = tf.random.gamma(shape=[n],
                            alpha=self._multi_gamma_sequence(
                                0.5 * expanded_df, self.dimension),
                            beta=0.5,
                            dtype=self.dtype,
                            seed=stream())

        # Complexity: O(nbk**2)
        x = tf.linalg.band_part(x, -1, 0)  # Tri-lower.

        # Complexity: O(nbk)
        x = tf.linalg.set_diag(x, tf.sqrt(g))

        # Make batch-op ready.
        # Complexity: O(nbk**2)
        perm = tf.concat([tf.range(1, ndims), [0]], 0)
        x = tf.transpose(a=x, perm=perm)
        shape = tf.concat(
            [batch_shape, [event_shape[0]], [event_shape[1] * n]], 0)
        x = tf.reshape(x, shape)

        # Complexity: O(nbM) where M is the complexity of the operator solving a
        # vector system. For LinearOperatorLowerTriangular, each matmul is O(k^3) so
        # this step has complexity O(nbk^3).
        x = self.scale_operator.matmul(x)

        # Undo make batch-op ready.
        # Complexity: O(nbk**2)
        shape = tf.concat([batch_shape, event_shape, [n]], 0)
        x = tf.reshape(x, shape)
        perm = tf.concat([[ndims - 1], tf.range(0, ndims - 1)], 0)
        x = tf.transpose(a=x, perm=perm)

        if not self.input_output_cholesky:
            # Complexity: O(nbk**3)
            x = tf.matmul(x, x, adjoint_b=True)

        return x
Пример #8
0
 def matmul_broadcast_singleton_dimension(self, lhs, rhs):
     return tf.matmul(lhs, rhs)
Пример #9
0
 def matmul_high_rank_batch(self, lhs, rhs):
     return tf.matmul(lhs, rhs)
Пример #10
0
 def basic_matmul(self, lhs, rhs):
     return tf.matmul(lhs, rhs)
Пример #11
0
 def matmul_rhs_batch(self, lhs, rhs):
     return tf.matmul(lhs, rhs)
Пример #12
0
 def _random_chol(self, *shape):
     mat = self._rng.rand(*shape)
     chol = tfd.matrix_diag_transform(mat, transform=tf.math.softplus)
     chol = tf.linalg.band_part(chol, -1, 0)
     sigma = tf.matmul(chol, chol, adjoint_b=True)
     return self.evaluate(chol), self.evaluate(sigma)
Пример #13
0
    def testSampleLarge(self):
        mu = np.array([-1., 1], dtype=np.float32)
        scale_tril = np.array([[3., 0], [1, -2]], dtype=np.float32) / 3.

        true_mean = mu
        true_scale = scale_tril
        true_covariance = np.matmul(true_scale, true_scale.T)
        true_variance = np.diag(true_covariance)
        true_stddev = np.sqrt(true_variance)

        dist = tfd.MultivariateNormalTriL(loc=mu,
                                          scale_tril=scale_tril,
                                          validate_args=True)

        # The following distributions will test the KL divergence calculation.
        mvn_chol = tfd.MultivariateNormalTriL(loc=np.array([0.5, 1.2],
                                                           dtype=np.float32),
                                              scale_tril=np.array(
                                                  [[3., 0], [1, 2]],
                                                  dtype=np.float32),
                                              validate_args=True)

        n = int(10e3)
        samps = dist.sample(n, seed=tfp_test_util.test_seed())
        sample_mean = tf.reduce_mean(input_tensor=samps, axis=0)
        x = samps - sample_mean
        sample_covariance = tf.matmul(x, x, transpose_a=True) / n

        sample_kl_chol = tf.reduce_mean(input_tensor=dist.log_prob(samps) -
                                        mvn_chol.log_prob(samps),
                                        axis=0)
        analytical_kl_chol = tfd.kl_divergence(dist, mvn_chol)

        scale = dist.scale.to_dense()

        [
            sample_mean_,
            analytical_mean_,
            sample_covariance_,
            analytical_covariance_,
            analytical_variance_,
            analytical_stddev_,
            sample_kl_chol_,
            analytical_kl_chol_,
            scale_,
        ] = self.evaluate([
            sample_mean,
            dist.mean(),
            sample_covariance,
            dist.covariance(),
            dist.variance(),
            dist.stddev(),
            sample_kl_chol,
            analytical_kl_chol,
            scale,
        ])

        sample_variance_ = np.diag(sample_covariance_)
        sample_stddev_ = np.sqrt(sample_variance_)

        tf1.logging.vlog(2, "true_mean:\n{}  ".format(true_mean))
        tf1.logging.vlog(2, "sample_mean:\n{}".format(sample_mean_))
        tf1.logging.vlog(2, "analytical_mean:\n{}".format(analytical_mean_))

        tf1.logging.vlog(2, "true_covariance:\n{}".format(true_covariance))
        tf1.logging.vlog(2,
                         "sample_covariance:\n{}".format(sample_covariance_))
        tf1.logging.vlog(
            2, "analytical_covariance:\n{}".format(analytical_covariance_))

        tf1.logging.vlog(2, "true_variance:\n{}".format(true_variance))
        tf1.logging.vlog(2, "sample_variance:\n{}".format(sample_variance_))
        tf1.logging.vlog(
            2, "analytical_variance:\n{}".format(analytical_variance_))

        tf1.logging.vlog(2, "true_stddev:\n{}".format(true_stddev))
        tf1.logging.vlog(2, "sample_stddev:\n{}".format(sample_stddev_))
        tf1.logging.vlog(2,
                         "analytical_stddev:\n{}".format(analytical_stddev_))

        tf1.logging.vlog(2, "true_scale:\n{}".format(true_scale))
        tf1.logging.vlog(2, "scale:\n{}".format(scale_))

        tf1.logging.vlog(
            2, "kl_chol:      analytical:{}  sample:{}".format(
                analytical_kl_chol_, sample_kl_chol_))

        self.assertAllClose(true_mean, sample_mean_, atol=0., rtol=0.03)
        self.assertAllClose(true_mean, analytical_mean_, atol=0., rtol=1e-6)

        self.assertAllClose(true_covariance,
                            sample_covariance_,
                            atol=0.,
                            rtol=0.03)
        self.assertAllClose(true_covariance,
                            analytical_covariance_,
                            atol=0.,
                            rtol=1e-6)

        self.assertAllClose(true_variance,
                            sample_variance_,
                            atol=0.,
                            rtol=0.02)
        self.assertAllClose(true_variance,
                            analytical_variance_,
                            atol=0.,
                            rtol=1e-6)

        self.assertAllClose(true_stddev, sample_stddev_, atol=0., rtol=0.01)
        self.assertAllClose(true_stddev,
                            analytical_stddev_,
                            atol=0.,
                            rtol=1e-6)

        self.assertAllClose(true_scale, scale_, atol=0., rtol=1e-6)

        self.assertAllClose(sample_kl_chol_,
                            analytical_kl_chol_,
                            atol=0.,
                            rtol=0.02)
Пример #14
0
 def _random_pd_matrix(self, *shape):
     mat = rng.rand(*shape)
     chol = tfb.TransformDiagonal(tfb.Softplus())(mat)
     chol = tf.linalg.band_part(chol, -1, 0)
     return self.evaluate(tf.matmul(chol, chol, adjoint_b=True))
Пример #15
0
 def call(self, inputs):
   return tf.matmul(inputs, self.kernel)
Пример #16
0
 def matmul_dynamic(self, lhs, rhs):
     return tf.matmul(lhs, rhs)
Пример #17
0
 def exported_function(x):
     root.x = constant_op.constant([[37.0, -23.0], [1.0, 4.0]])
     root.y = tf.matmul(root.x, root.w)
     # unsupported op: linalg.diag
     root.z = tf.linalg.diag(root.y)
     return root.z * x
Пример #18
0
 def matmul_dynamic_lhs_batch(self, lhs, rhs):
     return tf.matmul(lhs, rhs)
Пример #19
0
def positive_definite(x):
    shp = tensorshape_util.as_list(x.shape)
    psd = (tf.matmul(x, x, transpose_b=True) +
           .1 * tf.linalg.eye(shp[-1], batch_shape=shp[:-2]))
    return symmetric(psd)
Пример #20
0
 def loss():
     pred = tf.matmul(
         tf.compat.v1.nn.embedding_lookup([var0], [0]), x)  # pylint: disable=cell-var-from-loop
     pred += var1  # pylint: disable=cell-var-from-loop
     return pred * pred
Пример #21
0
    def testLangevin3DNormalDynamicVolatility(self):
        """Sampling from a 3-D Multivariate Normal distribution."""
        dtype = np.float32
        true_mean = dtype([1, 2, 7])
        true_cov = dtype([[1, 0.25, 0.25], [0.25, 1, 0.25], [0.25, 0.25, 1]])
        num_results = 500
        num_chains = 500

        # Targeg distribution is defined through the Cholesky decomposition
        chol = tf.linalg.cholesky(true_cov)
        target = tfd.MultivariateNormalTriL(loc=true_mean, scale_tril=chol)

        # Assume that the state is passed as a list of 1-d tensors `x` and `y`.
        # Then the target log-density is defined as follows:
        def target_log_prob(x, y):
            # Stack the input tensors together
            z = tf.concat([x, y], axis=-1)
            return target.log_prob(z)

        # Here we define the volatility function to be non-caonstant
        def volatility_fn(x, y):
            # Stack the input tensors together
            return [
                1. / (0.5 + 0.1 * tf.abs(x + y)), 1. / (0.5 + 0.1 * tf.abs(y))
            ]

        # Initial state of the chain
        init_state = [
            np.ones([num_chains, 2], dtype=dtype),
            np.ones([num_chains, 1], dtype=dtype)
        ]

        # Run Random Walk Metropolis with normal proposal for `num_results`
        # iterations for `num_chains` independent chains:
        states, _ = tfp.mcmc.sample_chain(
            num_results=num_results,
            current_state=init_state,
            kernel=tfp.mcmc.MetropolisAdjustedLangevinAlgorithm(
                target_log_prob_fn=target_log_prob,
                volatility_fn=volatility_fn,
                step_size=.1,
                seed=42),
            num_burnin_steps=200,
            num_steps_between_results=1,
            parallel_iterations=1)

        states = tf.concat(states, axis=-1)
        sample_mean = tf.reduce_mean(states, axis=[0, 1])
        x = (states - sample_mean)[..., tf.newaxis]
        sample_cov = tf.reduce_mean(tf.matmul(x, x, transpose_b=True),
                                    axis=[0, 1])

        sample_mean_, sample_cov_ = self.evaluate([sample_mean, sample_cov])

        self.assertAllClose(np.squeeze(sample_mean_),
                            true_mean,
                            atol=0.1,
                            rtol=0.1)
        self.assertAllClose(np.squeeze(sample_cov_),
                            true_cov,
                            atol=0.1,
                            rtol=0.1)
Пример #22
0
 def loss():
     x = tf.constant([[4.0], [5.0]], dtype=dtype)
     pred = tf.matmul(
         tf.compat.v1.nn.embedding_lookup([var0], [0]), x)
     return pred * pred
Пример #23
0
    def testDenseFlipout(self):
        batch_size, in_size, out_size = 2, 3, 4
        with self.cached_session() as sess:
            tf1.set_random_seed(9069)
            (kernel_posterior, kernel_prior, kernel_divergence, bias_posterior,
             bias_prior, bias_divergence, layer, inputs, outputs,
             kl_penalty) = self._testDenseSetUp(tfp.layers.DenseFlipout,
                                                batch_size,
                                                in_size,
                                                out_size,
                                                seed=44)

            tf1.set_random_seed(9069)
            expected_kernel_posterior_affine = tfd.Normal(
                loc=tf.zeros_like(kernel_posterior.result_loc),
                scale=kernel_posterior.result_scale)
            expected_kernel_posterior_affine_tensor = (
                expected_kernel_posterior_affine.sample(seed=42))

            stream = tfp.util.SeedStream(layer.seed, salt='DenseFlipout')

            sign_input = tf.random.uniform([batch_size, in_size],
                                           minval=0,
                                           maxval=2,
                                           dtype=tf.int64,
                                           seed=stream())
            sign_input = tf.cast(2 * sign_input - 1, inputs.dtype)
            sign_output = tf.random.uniform([batch_size, out_size],
                                            minval=0,
                                            maxval=2,
                                            dtype=tf.int64,
                                            seed=stream())
            sign_output = tf.cast(2 * sign_output - 1, inputs.dtype)
            perturbed_inputs = tf.matmul(
                inputs * sign_input, expected_kernel_posterior_affine_tensor)
            perturbed_inputs *= sign_output

            expected_outputs = tf.matmul(inputs, kernel_posterior.result_loc)
            expected_outputs += perturbed_inputs
            expected_outputs += bias_posterior.result_sample

            [
                expected_outputs_,
                actual_outputs_,
                expected_kernel_divergence_,
                actual_kernel_divergence_,
                expected_bias_,
                actual_bias_,
                expected_bias_divergence_,
                actual_bias_divergence_,
            ] = sess.run([
                expected_outputs,
                outputs,
                kernel_divergence.result,
                kl_penalty[0],
                bias_posterior.result_sample,
                layer.bias_posterior_tensor,
                bias_divergence.result,
                kl_penalty[1],
            ])

            self.assertAllClose(expected_bias_,
                                actual_bias_,
                                rtol=1e-6,
                                atol=0.)
            self.assertAllClose(expected_outputs_,
                                actual_outputs_,
                                rtol=1e-6,
                                atol=0.)
            self.assertAllClose(expected_kernel_divergence_,
                                actual_kernel_divergence_,
                                rtol=1e-6,
                                atol=0.)
            self.assertAllClose(expected_bias_divergence_,
                                actual_bias_divergence_,
                                rtol=1e-6,
                                atol=0.)

            expected_args = [kernel_posterior, kernel_prior, None]
            # We expect that there was one call to kernel_divergence, with the above
            # args; MockKLDivergence appends the list of args to a list, so the above
            # args should be in the 0th position of that list.
            actual_args = kernel_divergence.args[0]
            # Test for identity with 'is'. TensorFlowTestCase.assertAllEqual actually
            # coerces the inputs to numpy arrays, so we can't use that to assert that
            # the arguments (which are a mixture of Distributions and Tensors) are
            # equal.
            for a, b in zip(expected_args, actual_args):
                self.assertIs(a, b)

            # Same story as above.
            expected_args = [
                bias_posterior, bias_prior, bias_posterior.result_sample
            ]
            actual_args = bias_divergence.args[0]
            for a, b in zip(expected_args, actual_args):
                self.assertIs(a, b)
Пример #24
0
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images, labels = inputs
            if FLAGS.version2 and FLAGS.ensemble_size > 1:
                images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1])
                if not (FLAGS.member_sampling or FLAGS.expected_probs):
                    labels = tf.tile(labels, [FLAGS.ensemble_size])

            if FLAGS.num_train_samples > 1:
                images = tf.tile(images, [FLAGS.num_train_samples, 1, 1, 1])

            with tf.GradientTape() as tape:
                logits = model(images, training=True)
                probs = tf.nn.softmax(logits)
                # Diversity evaluation.
                if FLAGS.version2 and FLAGS.ensemble_size > 1:
                    per_probs = tf.reshape(
                        probs,
                        tf.concat([[FLAGS.ensemble_size, -1], probs.shape[1:]],
                                  0))

                    diversity_results = ed.metrics.average_pairwise_diversity(
                        per_probs, FLAGS.ensemble_size)

                if FLAGS.num_train_samples > 1:
                    probs = tf.reshape(
                        probs,
                        tf.concat(
                            [[FLAGS.num_train_samples, -1], probs.shape[1:]],
                            0))
                    probs = tf.reduce_mean(probs, 0)

                if FLAGS.member_sampling and FLAGS.version2 and FLAGS.ensemble_size > 1:
                    idx = tf.random.uniform([],
                                            maxval=FLAGS.ensemble_size,
                                            dtype=tf.int64)
                    idx_one_hot = tf.expand_dims(
                        tf.one_hot(idx, FLAGS.ensemble_size,
                                   dtype=probs.dtype), 0)
                    probs_shape = probs.shape
                    probs = tf.reshape(probs, [FLAGS.ensemble_size, -1])
                    probs = tf.matmul(idx_one_hot, probs)
                    probs = tf.reshape(probs,
                                       tf.concat([[-1], probs_shape[1:]], 0))

                elif FLAGS.expected_probs and FLAGS.version2 and FLAGS.ensemble_size > 1:
                    probs = tf.reshape(
                        probs,
                        tf.concat([[FLAGS.ensemble_size, -1], probs.shape[1:]],
                                  0))
                    probs = tf.reduce_mean(probs, 0)

                negative_log_likelihood = tf.reduce_mean(
                    tf.keras.losses.sparse_categorical_crossentropy(
                        labels, probs))

                filtered_variables = []
                for var in model.trainable_variables:
                    # Apply l2 on the slow weights and bias terms. This excludes BN
                    # parameters and fast weight approximate posterior/prior parameters,
                    # but pay caution to their naming scheme.
                    if 'kernel' in var.name or 'bias' in var.name:
                        filtered_variables.append(tf.reshape(var, (-1, )))

                l2_loss = FLAGS.l2 * 2 * tf.nn.l2_loss(
                    tf.concat(filtered_variables, axis=0))
                kl = sum(model.losses) / train_dataset_size
                kl_scale = tf.cast(optimizer.iterations + 1, kl.dtype)
                kl_scale /= FLAGS.kl_annealing_steps
                kl_scale = tf.minimum(1., kl_scale)
                kl_loss = kl_scale * kl

                # Scale the loss given the TPUStrategy will reduce sum all gradients.
                loss = negative_log_likelihood + l2_loss + kl_loss
                scaled_loss = loss / strategy.num_replicas_in_sync

            grads = tape.gradient(scaled_loss, model.trainable_variables)

            # Separate learning rate implementation.
            grad_list = []
            if FLAGS.fast_weight_lr_multiplier != 1.0:
                grads_and_vars = list(zip(grads, model.trainable_variables))
                for vec, var in grads_and_vars:
                    # Apply different learning rate on the fast weight approximate
                    # posterior/prior parameters. This is excludes BN and slow weights,
                    # but pay caution to the naming scheme.
                    if ('batch_norm' not in var.name
                            and 'kernel' not in var.name):
                        grad_list.append(
                            (vec * FLAGS.fast_weight_lr_multiplier, var))
                    else:
                        grad_list.append((vec, var))
                optimizer.apply_gradients(grad_list)
            else:
                optimizer.apply_gradients(zip(grads,
                                              model.trainable_variables))

            metrics['train/ece'].update_state(labels, probs)
            metrics['train/loss'].update_state(loss)
            metrics['train/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics['train/accuracy'].update_state(labels, probs)
            if FLAGS.version2 and FLAGS.ensemble_size > 1:
                for k, v in diversity_results.items():
                    training_diversity['train/' + k].update_state(v)
Пример #25
0
    def testSampleMarginals(self):
        # Verify that the marginals of the LKJ distribution are distributed
        # according to a (scaled) Beta distribution. The LKJ distributed samples are
        # obtained by sampling a CholeskyLKJ distribution using HMC and the
        # CorrelationCholesky bijector.
        dim = 4
        concentration = np.array(2.5, dtype=np.float64)
        beta_concentration = np.array(.5 * dim + concentration - 1, np.float64)
        beta_dist = beta.Beta(concentration0=beta_concentration,
                              concentration1=beta_concentration)

        inner_kernel = hmc.HamiltonianMonteCarlo(
            target_log_prob_fn=cholesky_lkj.CholeskyLKJ(
                dimension=dim, concentration=concentration).log_prob,
            num_leapfrog_steps=3,
            step_size=0.3)

        kernel = transformed_kernel.TransformedTransitionKernel(
            inner_kernel=inner_kernel, bijector=tfb.CorrelationCholesky())

        num_chains = 10
        num_total_samples = 30000

        # Make sure that we have enough samples to catch a wrong sampler to within
        # a small enough discrepancy.
        self.assertLess(
            self.evaluate(
                st.min_num_samples_for_dkwm_cdf_test(discrepancy=0.04,
                                                     false_fail_rate=1e-9,
                                                     false_pass_rate=1e-9)),
            num_total_samples)

        @tf.function  # Ensure that MCMC sampling is done efficiently.
        def sample_mcmc_chain():
            return sample.sample_chain(
                num_results=num_total_samples // num_chains,
                num_burnin_steps=1000,
                current_state=tf.eye(dim,
                                     batch_shape=[num_chains],
                                     dtype=tf.float64),
                trace_fn=lambda _, pkr: pkr.inner_results.is_accepted,
                kernel=kernel,
                seed=test_util.test_seed())

        # Draw samples from the HMC chains.
        chol_lkj_samples, is_accepted = self.evaluate(sample_mcmc_chain())

        # Ensure that the per-chain acceptance rate is high enough.
        self.assertAllGreater(np.mean(is_accepted, axis=0), 0.8)

        # Transform from Cholesky LKJ samples to LKJ samples.
        lkj_samples = tf.matmul(chol_lkj_samples,
                                chol_lkj_samples,
                                adjoint_b=True)
        lkj_samples = tf.reshape(lkj_samples,
                                 shape=[num_total_samples, dim, dim])

        # Only look at the entries strictly below the diagonal which is achieved by
        # the OutputToUnconstrained bijector. Also scale the marginals from the
        # range [-1,1] to [0,1].
        scaled_lkj_samples = .5 * (
            OutputToUnconstrained().forward(lkj_samples) + 1)

        # Each of the off-diagonal marginals should be distributed according to a
        # Beta distribution.
        for i in range(dim * (dim - 1) // 2):
            self.evaluate(
                st.assert_true_cdf_equal_by_dkwm(scaled_lkj_samples[..., i],
                                                 cdf=beta_dist.cdf,
                                                 false_fail_rate=1e-9))
Пример #26
0
def lu_reconstruct(lower_upper, perm, validate_args=False, name=None):
    """The inverse LU decomposition, `X == lu_reconstruct(*tf.linalg.lu(X))`.

  Args:
    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`.
    perm: `p` as returned by `tf.linag.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`.
    validate_args: Python `bool` indicating whether arguments should be checked
      for correctness.
      Default value: `False` (i.e., don't validate arguments).
    name: Python `str` name given to ops managed by this object.
      Default value: `None` (i.e., 'lu_reconstruct').

  Returns:
    x: The original input to `tf.linalg.lu`, i.e., `x` as in,
      `lu_reconstruct(*tf.linalg.lu(x))`.

  #### Examples

  ```python
  import numpy as np
  import tensorflow as tf
  import tensorflow_probability as tfp

  x = [[[3., 4], [1, 2]],
       [[7., 8], [3, 4]]]
  x_reconstructed = tfp.math.lu_reconstruct(*tf.linalg.lu(x))
  tf.assert_near(x, x_reconstructed)
  # ==> True
  ```

  """
    with tf.name_scope(name or 'lu_reconstruct'):
        lower_upper = tf.convert_to_tensor(lower_upper,
                                           dtype_hint=tf.float32,
                                           name='lower_upper')
        perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm')

        assertions = lu_reconstruct_assertions(lower_upper, perm,
                                               validate_args)
        if assertions:
            with tf.control_dependencies(assertions):
                lower_upper = tf.identity(lower_upper)
                perm = tf.identity(perm)

        shape = tf.shape(lower_upper)

        lower = tf.linalg.set_diag(
            tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0),
            tf.ones(shape[:-1], dtype=lower_upper.dtype))
        upper = tf.linalg.band_part(lower_upper, num_lower=0, num_upper=-1)
        x = tf.matmul(lower, upper)

        if (tensorshape_util.rank(lower_upper.shape) is None
                or tensorshape_util.rank(lower_upper.shape) != 2):
            # We either don't know the batch rank or there are >0 batch dims.
            batch_size = tf.reduce_prod(shape[:-2])
            d = shape[-1]
            x = tf.reshape(x, [batch_size, d, d])
            perm = tf.reshape(perm, [batch_size, d])
            perm = tf.map_fn(tf.math.invert_permutation, perm)
            batch_indices = tf.broadcast_to(
                tf.range(batch_size)[:, tf.newaxis], [batch_size, d])
            x = tf.gather_nd(x, tf.stack([batch_indices, perm], axis=-1))
            x = tf.reshape(x, shape)
        else:
            x = tf.gather(x, tf.math.invert_permutation(perm))

        tensorshape_util.set_shape(x, lower_upper.shape)
        return x
Пример #27
0
def positive_definite(x):
    shp = x.shape.as_list()
    return (tf.matmul(x, x, transpose_b=True) +
            .1 * tf.linalg.eye(shp[-1], batch_shape=shp[:-2]))
Пример #28
0
 def simple_matmul(self, a, b):
     return tf.matmul(a, b)
Пример #29
0
  def solve_nu_zeta(self,
                    dataset: dataset_lib.OffpolicyDataset,
                    target_policy: tf_policy.TFPolicy,
                    regularizer: float = 1e-6):
    """Solves for density ratios and then approximates target policy value.

    Args:
      dataset: The dataset to sample experience from.
      target_policy: The policy whose value we want to estimate.
      regularizer: A small constant to add to matrices before inverting them or
        to floats before taking square root.

    Returns:
      Estimated average per-step reward of the target policy.
    """

    if not hasattr(self, '_td_mat'):
      # Set up env_steps.
      episodes, valid_steps = dataset.get_all_episodes(
          limit=self._limit_episodes)
      total_num_steps_per_episode = tf.shape(valid_steps)[1] - 1
      num_episodes = tf.shape(valid_steps)[0]
      num_samples = num_episodes * total_num_steps_per_episode
      valid_and_not_last = tf.logical_and(valid_steps, episodes.discount > 0)
      valid_indices = tf.squeeze(
          tf.where(tf.reshape(valid_and_not_last[:, :-1], [-1])))

      initial_env_step = tf.nest.map_structure(
          lambda t: tf.squeeze(
              tf.reshape(
                  tf.repeat(
                      t[:, 0:1, ...],
                      axis=1,
                      repeats=total_num_steps_per_episode), [num_samples, -1])),
          episodes)
      initial_env_step = tf.nest.map_structure(
          lambda t: tf.gather(t, valid_indices), initial_env_step)
      tfagents_initial_env_step = dataset_lib.convert_to_tfagents_timestep(
          initial_env_step)

      env_step = tf.nest.map_structure(
          lambda t: tf.squeeze(
              tf.reshape(t[:, 0:total_num_steps_per_episode, ...],
                         [num_samples, -1])), episodes)
      env_step = tf.nest.map_structure(lambda t: tf.gather(t, valid_indices),
                                       env_step)
      tfagents_env_step = dataset_lib.convert_to_tfagents_timestep(env_step)

      next_env_step = tf.nest.map_structure(
          lambda t: tf.squeeze(
              tf.reshape(t[:, 1:total_num_steps_per_episode + 1, ...],
                         [num_samples, -1])), episodes)
      next_env_step = tf.nest.map_structure(
          lambda t: tf.gather(t, valid_indices), next_env_step)
      tfagents_next_env_step = dataset_lib.convert_to_tfagents_timestep(
          next_env_step)

      # get probabilities
      initial_target_probs = target_policy.distribution(
          tfagents_initial_env_step).action.probs_parameter()
      next_target_probs = target_policy.distribution(
          tfagents_next_env_step).action.probs_parameter()

      # First, get the nu_loss and data weights
      #current_nu_loss = self._get_nu_loss(initial_env_step, env_step,
      #                                    next_env_step, target_policy)
      #data_weight, _ = self._get_weights(current_nu_loss)

      # # debug only and to reproduce dual dice result, DELETE
      # data_weight = tf.ones_like(data_weight)

      state_action_count = self._get_state_action_counts(env_step)
      counts = tf.reduce_sum(tf.one_hot(state_action_count, self._dimension), 0)
      gamma_sample = tf.pow(self._gamma, tf.cast(env_step.step_num, tf.float32))

      # # debug only and to reproduce dual dice result, DELETE
      # gamma_sample = tf.ones_like(gamma_sample)

      # now we need to expand_dims to include action space in extra dimensions
      #data_weights = tf.reshape(data_weight, [-1, self._num_limits])
      # both are data sample weights for L2 problem, needs to be normalized later
      #gamma_data_weights = tf.reshape(gamma_sample, [-1, 1]) * data_weights

      initial_states = tf.tile(
          tf.reshape(initial_env_step.observation, [-1, 1]),
          [1, self._num_actions])
      initial_actions = tf.tile(
          tf.reshape(tf.range(self._num_actions), [1, -1]),
          [initial_env_step.observation.shape[0], 1])
      initial_nu_indices = self._get_index(initial_states, initial_actions)

      # linear term w.r.t. initial distribution
      #b_vec_2 = tf.stack([
      #    tf.reduce_sum(
      #        tf.reshape(
      #            data_weights[:, itr] / tf.reduce_sum(data_weights[:, itr]),
      #            [-1, 1]) * tf.reduce_sum(
      #                tf.one_hot(initial_nu_indices, self._dimension) *
      #                (1 - self._gamma) *
      #                tf.expand_dims(initial_target_probs, axis=-1),
      #                axis=1),
      #        axis=0) for itr in range(self._num_limits)
      #],
      #                   axis=0)

      next_states = tf.tile(
          tf.reshape(next_env_step.observation, [-1, 1]),
          [1, self._num_actions])
      next_actions = tf.tile(
          tf.reshape(tf.range(self._num_actions), [1, -1]),
          [next_env_step.observation.shape[0], 1])
      next_nu_indices = self._get_index(next_states, next_actions)
      next_nu_indices = tf.where(
          tf.expand_dims(next_env_step.is_absorbing(), -1),
          -1 * tf.ones_like(next_nu_indices), next_nu_indices)

      nu_indices = self._get_index(env_step.observation, env_step.action)

      target_log_probabilities = target_policy.distribution(
          tfagents_env_step).action.log_prob(env_step.action)
      if not self._solve_for_state_action_ratio:
        policy_ratio = tf.exp(target_log_probabilities -
                              env_step.get_log_probability())
      else:
        policy_ratio = tf.ones([
            target_log_probabilities.shape[0],
        ])
      policy_ratios = tf.tile(
          tf.reshape(policy_ratio, [-1, 1]), [1, self._num_actions])

      # the tabular feature vector
      a_vec = tf.one_hot(nu_indices, self._dimension) - tf.reduce_sum(
          self._gamma *
          tf.expand_dims(next_target_probs * policy_ratios, axis=-1) *
          tf.one_hot(next_nu_indices, self._dimension),
          axis=1)

      # linear term w.r.t. reward
      #b_vec_1 = tf.stack([
      #    tf.reduce_sum(
      #        tf.reshape(
      #            (gamma_data_weights[:, itr] /
      #             tf.reduce_sum(gamma_data_weights[:, itr])) * self._reward_fn(env_step), #/
      #            #tf.cast(state_action_count, tf.float32),
      #            [-1, 1]) * a_vec,
      #        axis=0) for itr in range(self._num_limits)
      #],
      #                   axis=0)
      # quadratic term of feature
      # Get weighted outer product by using einsum to save computing resource!
      #a_mat = tf.stack([
      #    tf.einsum(
      #        'ai, a, aj -> ij', a_vec,
      #        #1.0 / tf.cast(state_action_count, tf.float32),
      #        gamma_data_weights[:, itr] /
      #        tf.reduce_sum(gamma_data_weights[:, itr]),
      #        a_vec)
      #    for itr in range(self._num_limits)
      #],
      #                 axis=0)

      td_mat = tf.einsum('ai, a, aj -> ij',
                         tf.one_hot(nu_indices, self._dimension),
                         1.0 / tf.cast(state_action_count, tf.float32), a_vec)

      weighted_rewards = policy_ratio * self._reward_fn(env_step)

      bias = tf.reduce_sum(
          tf.one_hot(nu_indices, self._dimension) *
          tf.reshape(weighted_rewards, [-1, 1]) * 1.0 /
          tf.cast(state_action_count, tf.float32)[:, None],
          axis=0)

      # Initialize
      self._nu = np.ones_like(self._nu) * bias[:, None]
      self._nu2 = np.ones_like(self._nu2) * bias[:, None]

      self._a_vec = a_vec
      self._td_mat = td_mat
      self._bias = bias
      self._weighted_rewards = weighted_rewards
      self._state_action_count = state_action_count
      self._nu_indices = nu_indices
      self._initial_nu_indices = initial_nu_indices
      self._initial_target_probs = initial_target_probs
      self._gamma_sample = gamma_sample
      self._gamma_sample = tf.ones_like(gamma_sample)

    saddle_bellman_residuals = (
        tf.matmul(self._a_vec, self._nu) - self._weighted_rewards[:, None])
    saddle_bellman_residuals *= -1 * self._algae_alpha_sign
    saddle_zetas = tf.gather(self._zeta, self._nu_indices)
    saddle_initial_nu_values = tf.reduce_sum(  # Average over actions.
        self._initial_target_probs[:, :, None] *
        tf.gather(self._nu, self._initial_nu_indices),
        axis=1)
    saddle_init_nu_loss = ((1 - self._gamma) * saddle_initial_nu_values *
                           self._algae_alpha_sign)

    saddle_bellman_residuals2 = (
        tf.matmul(self._a_vec, self._nu2) - self._weighted_rewards[:, None])
    saddle_bellman_residuals2 *= 1 * self._algae_alpha_sign
    saddle_zetas2 = tf.gather(self._zeta2, self._nu_indices)
    saddle_initial_nu_values2 = tf.reduce_sum(  # Average over actions.
        self._initial_target_probs[:, :, None] *
        tf.gather(self._nu2, self._initial_nu_indices),
        axis=1)
    saddle_init_nu_loss2 = ((1 - self._gamma) * saddle_initial_nu_values2 * -1 *
                            self._algae_alpha_sign)

    saddle_loss = 0.5 * (
        saddle_init_nu_loss + saddle_bellman_residuals * saddle_zetas +
        -tf.math.abs(self._algae_alpha) * 0.5 * tf.square(saddle_zetas) +
        -saddle_init_nu_loss2 + -saddle_bellman_residuals2 * saddle_zetas2 +
        tf.math.abs(self._algae_alpha) * 0.5 * tf.square(saddle_zetas2))
    # Binary search to find best alpha.
    left = tf.constant([-8., -8.])
    right = tf.constant([32., 32.])
    for _ in range(16):
      mid = 0.5 * (left + right)
      self._alpha.assign(mid)
      weights, log_weights = self._get_weights(saddle_loss *
                                               self._gamma_sample[:, None])

      divergence = self._compute_divergence(weights, log_weights)
      divergence_violation = divergence - self._two_sided_limit
      left = tf.where(divergence_violation > 0., mid, left)
      right = tf.where(divergence_violation > 0., right, mid)
    self._alpha.assign(0.5 * (left + right))
    weights, log_weights = self._get_weights(saddle_loss *
                                             self._gamma_sample[:, None])

    gamma_data_weights = tf.stop_gradient(weights * self._gamma_sample[:, None])
    #print(tf.concat([gamma_data_weights, saddle_loss], axis=-1))
    avg_saddle_loss = (
        tf.reduce_sum(gamma_data_weights * saddle_loss, axis=0) /
        tf.reduce_sum(gamma_data_weights, axis=0))

    weighted_state_action_count = tf.reduce_sum(
        tf.one_hot(self._nu_indices, self._dimension)[:, :, None] *
        weights[:, None, :],
        axis=0)
    weighted_state_action_count = tf.gather(weighted_state_action_count,
                                            self._nu_indices)
    my_td_mat = tf.einsum(
        'ai, ab, ab, aj -> bij',
        tf.one_hot(self._nu_indices, self._dimension),
        #1.0 / tf.cast(self._state_action_count, tf.float32),
        1.0 / weighted_state_action_count,
        weights,
        self._a_vec)
    my_bias = tf.reduce_sum(
        tf.transpose(weights)[:, :, None] *
        tf.one_hot(self._nu_indices, self._dimension)[None, :, :] *
        tf.reshape(self._weighted_rewards, [1, -1, 1]) *
        #1.0 / tf.cast(self._state_action_count, tf.float32)[None, :, None],
        1.0 / tf.transpose(weighted_state_action_count)[:, :, None],
        axis=1)

    #print('hello', saddle_initial_nu_values[:1], saddle_zetas[:3],
    #      self._nu[:2], my_bias[:, :2], saddle_loss[:4])

    with tf.GradientTape(
        watch_accessed_variables=False, persistent=True) as tape:
      tape.watch([self._nu, self._nu2, self._alpha])
      bellman_residuals = tf.matmul(
          my_td_mat,
          tf.transpose(self._nu)[:, :, None]) - my_bias[:, :, None]
      bellman_residuals = tf.transpose(tf.squeeze(bellman_residuals, -1))
      bellman_residuals = tf.gather(bellman_residuals, self._nu_indices)
      initial_nu_values = tf.reduce_sum(  # Average over actions.
          self._initial_target_probs[:, :, None] *
          tf.gather(self._nu, self._initial_nu_indices),
          axis=1)

      bellman_residuals *= self._algae_alpha_sign

      init_nu_loss = ((1 - self._gamma) * initial_nu_values *
                      self._algae_alpha_sign)

      nu_loss = (
          tf.math.square(bellman_residuals) / 2.0 +
          tf.math.abs(self._algae_alpha) * init_nu_loss)

      loss = (
          gamma_data_weights * nu_loss /
          tf.reduce_sum(gamma_data_weights, axis=0, keepdims=True))

      bellman_residuals2 = tf.matmul(
          my_td_mat,
          tf.transpose(self._nu2)[:, :, None]) - my_bias[:, :, None]
      bellman_residuals2 = tf.transpose(tf.squeeze(bellman_residuals2, -1))
      bellman_residuals2 = tf.gather(bellman_residuals2, self._nu_indices)
      initial_nu_values2 = tf.reduce_sum(  # Average over actions.
          self._initial_target_probs[:, :, None] *
          tf.gather(self._nu2, self._initial_nu_indices),
          axis=1)

      bellman_residuals2 *= -1 * self._algae_alpha_sign

      init_nu_loss2 = ((1 - self._gamma) * initial_nu_values2 * -1 *
                       self._algae_alpha_sign)

      nu_loss2 = (
          tf.math.square(bellman_residuals2) / 2.0 +
          tf.math.abs(self._algae_alpha) * init_nu_loss2)

      loss2 = (
          gamma_data_weights * nu_loss2 /
          tf.reduce_sum(gamma_data_weights, axis=0, keepdims=True))

      divergence = self._compute_divergence(weights, log_weights)
      divergence_violation = divergence - self._two_sided_limit

      alpha_loss = (-tf.exp(self._alpha) *
                    tf.stop_gradient(divergence_violation))

      extra_loss = tf.reduce_sum(tf.math.square(self._nu[-1, :]))
      extra_loss2 = tf.reduce_sum(tf.math.square(self._nu2[-1, :]))
      nu_grad = tape.gradient(loss + extra_loss, [self._nu])[0]
      nu_grad2 = tape.gradient(loss2 + extra_loss2, [self._nu2])[0]
    avg_loss = tf.reduce_sum(
        0.5 * (loss - loss2) / tf.math.abs(self._algae_alpha), axis=0)
    nu_jacob = tape.jacobian(nu_grad, [self._nu])[0]
    nu_hess = tf.stack([nu_jacob[:, i, :, i] for i in range(self._num_limits)],
                       axis=0)

    nu_jacob2 = tape.jacobian(nu_grad2, [self._nu2])[0]
    nu_hess2 = tf.stack(
        [nu_jacob2[:, i, :, i] for i in range(self._num_limits)], axis=0)

    for idx, div in enumerate(divergence):
      tf.summary.scalar('divergence%d' % idx, div)

    #alpha_grads = tape.gradient(alpha_loss, [self._alpha])
    #alpha_grad_op = self._alpha_optimizer.apply_gradients(
    #    zip(alpha_grads, [self._alpha]))
    #self._alpha.assign(tf.minimum(8., tf.maximum(-8., self._alpha)))

    #print(self._alpha, tf.concat([weights, nu_loss], -1))
    #regularizer = 0.1
    nu_transformed = tf.transpose(
        tf.squeeze(
            tf.linalg.solve(nu_hess + regularizer * tf.eye(self._dimension),
                            tf.expand_dims(-tf.transpose(nu_grad), axis=-1))))
    self._nu = self._nu + 0.1 * nu_transformed
    nu_transformed2 = tf.transpose(
        tf.squeeze(
            tf.linalg.solve(nu_hess2 + regularizer * tf.eye(self._dimension),
                            tf.expand_dims(-tf.transpose(nu_grad2), axis=-1))))
    self._nu2 = self._nu2 + 0.1 * nu_transformed2

    print(avg_loss * self._algae_alpha_sign,
          avg_saddle_loss * self._algae_alpha_sign, self._nu[:2], divergence)
    #print(init_nu_loss[:8], init_nu_loss[-8:])
    #print(bellman_residuals[:8])
    #print(self._nu[:3], self._zeta[:3])

    zetas = tf.matmul(my_td_mat,
                      tf.transpose(self._nu)[:, :, None]) - my_bias[:, :, None]
    zetas = tf.transpose(tf.squeeze(zetas, -1))
    zetas *= -self._algae_alpha_sign
    zetas /= tf.math.abs(self._algae_alpha)
    self._zeta = self._zeta + 0.1 * (zetas - self._zeta)

    zetas2 = tf.matmul(my_td_mat,
                       tf.transpose(self._nu2)[:, :, None]) - my_bias[:, :,
                                                                      None]
    zetas2 = tf.transpose(tf.squeeze(zetas2, -1))
    zetas2 *= 1 * self._algae_alpha_sign
    zetas2 /= tf.math.abs(self._algae_alpha)
    self._zeta2 = self._zeta2 + 0.1 * (zetas2 - self._zeta2)

    #self._zeta = (
    #    tf.einsum('ij,ja-> ia', self._td_mat, self._nu) -
    #    tf.transpose(my_bias))
    #self._zeta *= -tf.reshape(self._algae_alpha_sign, [1, self._num_limits])
    #self._zeta /= tf.math.abs(self._algae_alpha)
    return [
        avg_saddle_loss * self._algae_alpha_sign,
        avg_loss * self._algae_alpha_sign, divergence
    ]
Пример #30
0
def calc_spectrograms(waves,
                      window_lengths,
                      spectral_diffs=(0, 1),
                      window_name='hann',
                      use_mel_scale=True,
                      proj_method='matmul',
                      num_spec_bins=256,
                      random_crop=True):
    """Calculate spectrograms with multiple window sizes for list of input waves.

  Args:
    waves: List of float tensors of shape [batch, length] or [batch, length, 1].
    window_lengths: List of Int. Window sizes (frame lengths) to use for
      computing the spectrograms.
    spectral_diffs: Int. order of finite diff. to take before computing specs.
    window_name: Str. Name of the window to use when computing the spectrograms.
      Supports 'hann' and None.
    use_mel_scale: Bool. Whether or not to project to mel-scale frequencies.
    proj_method: Str. Spectral projection method implementation to use.
      Supported are 'fft' and 'matmul'.
    num_spec_bins: Int. Number of bins in the spectrogram.
    random_crop: Bool. Take random crop or not.

  Returns:
    Tuple of lists of magnitude spectrograms, with output[i][j] being the
      spectrogram for input wave i, computed for window length j.
  """
    waves = [tf.squeeze(w, axis=-1) for w in waves]

    if window_name == 'hann':
        windows = [
            tf.reshape(tf.signal.hann_window(wl, periodic=False), [1, 1, -1])
            for wl in window_lengths
        ]
    elif window_name is None:
        windows = [None] * len(window_lengths)
    else:
        raise ValueError('Unknown window function (%s).' % window_name)

    spec_len_wave = []
    for d in spectral_diffs:
        for length, window in zip(window_lengths, windows):

            wave_crops = waves
            for _ in range(d):
                wave_crops = [w[:, 1:] - w[:, :-1] for w in wave_crops]

            if random_crop:
                wave_crops = aligned_random_crop(wave_crops, length)

            frames = [
                tf.signal.frame(wc, length, length // 2) for wc in wave_crops
            ]
            if window is not None:
                frames = [f * window for f in frames]

            if proj_method == 'fft':
                ffts = [tf.signal.rfft(f)[:, :, 1:] for f in frames]

            elif proj_method == 'matmul':
                mat = get_spectral_matrix(length,
                                          num_spec_bins=num_spec_bins,
                                          use_mel_scale=use_mel_scale)
                ffts = [matmul_real_with_complex(f, mat) for f in frames]

            sq_mag = lambda x: tf.square(tf.math.real(x)) + tf.square(
                tf.math.imag(x))
            specs_sq = [sq_mag(f) for f in ffts]

            if use_mel_scale and proj_method == 'fft':
                sample_rate = 24000
                upper_edge_hertz = sample_rate / 2.
                lower_edge_hertz = sample_rate / length
                lin_to_mel = tf.signal.linear_to_mel_weight_matrix(
                    num_mel_bins=num_spec_bins,
                    num_spectrogram_bins=length // 2 + 1,
                    sample_rate=sample_rate,
                    lower_edge_hertz=lower_edge_hertz,
                    upper_edge_hertz=upper_edge_hertz,
                    dtype=tf.dtypes.float32)[1:]
                specs_sq = [tf.matmul(s, lin_to_mel) for s in specs_sq]

            specs = [tf.sqrt(s + EPSILON) for s in specs_sq]
            spec_len_wave.append(specs)

    spec_wave_len = zip(*spec_len_wave)
    return spec_wave_len