def testComputePi(self): with ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) left_factor = array_ops.diag([1., 2., 0., 1.]) right_factor = array_ops.ones([2., 2.]) # pi is the sqrt of the left trace norm divided by the right trace norm pi = utils.compute_pi(left_factor, right_factor) pi_val = sess.run(pi) self.assertEqual(1., pi_val)
def _register_damped_input_and_output_inverses(self, damping): """Registers damped inverses for both the input and output factors. Sets the instance members _input_damping and _output_damping. Requires the instance members _input_factor and _output_factor. Args: damping: The base damping factor (float or Tensor) for the damped inverse. """ pi = utils.compute_pi(self._input_factor.get_cov(), self._output_factor.get_cov()) self._input_damping = (damping**0.5) * pi self._output_damping = (damping**0.5) / pi self._input_factor.register_damped_inverse(self._input_damping) self._output_factor.register_damped_inverse(self._output_damping)
def _register_damped_input_and_output_inverses(self, damping): """Registers damped inverses for both the input and output factors. Sets the instance members _input_damping and _output_damping. Requires the instance members _input_factor and _output_factor. Args: damping: The base damping factor (float or Tensor) for the damped inverse. """ pi = utils.compute_pi(self._input_factor.get_cov(), self._output_factor.get_cov()) self._input_damping = math_ops.sqrt(damping) * pi self._output_damping = math_ops.sqrt(damping) / pi self._input_factor.register_damped_inverse(self._input_damping) self._output_factor.register_damped_inverse(self._output_damping)