def testMultiplyInverseAgainstExplicit(self):
        with ops.Graph().as_default(), self.test_session() as sess:
            random_seed.set_random_seed(200)
            params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
            block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
            block.register_additional_minibatch(32)
            grads = (params[0]**2, math_ops.sqrt(params[1]))
            damping = 0.5
            block.instantiate_factors((grads, ), damping)

            cov = array_ops.reshape(array_ops.constant([2., 3., 4.]), [-1, 1])
            sess.run(state_ops.assign(block._factor._cov, cov))
            sess.run(block._factor.make_inverse_update_ops())

            v_flat = np.array([4., 5., 6.], dtype=np.float32)
            vector = utils.column_to_tensors(params,
                                             array_ops.constant(v_flat))
            output = block.multiply_inverse(vector)
            output_flat = sess.run(utils.tensors_to_column(output)).ravel()

            full = sess.run(block.full_fisher_block())
            explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)),
                              v_flat)

            self.assertAllClose(output_flat, explicit)
    def testNaiveDiagonalFBInitTensorTuple(self):
        with ops.Graph().as_default():
            random_seed.set_random_seed(200)
            params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
            block = fb.NaiveDiagonalFB(lc.LayerCollection(), params, 32)

            self.assertAllEqual(params, block.tensors_to_compute_grads())
    def testInstantiateFactors(self):
        with ops.Graph().as_default():
            random_seed.set_random_seed(200)
            params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
            block = fb.NaiveDiagonalFB(lc.LayerCollection(), params, 32)

            grads = (params[0]**2, math_ops.sqrt(params[1]))
            block.instantiate_factors(grads, 0.5)
Exemplo n.º 4
0
  def register_generic(self, params, batch_size, approx=APPROX_DIAGONAL_NAME):
    params = params if isinstance(params, (tuple, list)) else (params,)
    self._generic_registrations |= set(params)

    # Generic registrations do not need special registration rules because we do
    # not care about multiple generic registrations. Add them to the
    # fisher_block dictionary manually rather than going through the logic in
    # self.register_block.
    if approx == APPROX_FULL_NAME:
      self.fisher_blocks[params] = fb.FullFB(self, params, batch_size)
    elif approx == APPROX_DIAGONAL_NAME:
      self.fisher_blocks[params] = fb.NaiveDiagonalFB(self, params, batch_size)
    else:
      raise ValueError("Bad value {} for approx.".format(approx))
    def testMultiplyInverseNotTuple(self):
        with ops.Graph().as_default(), self.test_session() as sess:
            random_seed.set_random_seed(200)
            params = array_ops.constant([[1.], [2.]])
            block = fb.NaiveDiagonalFB(lc.LayerCollection(), params, 32)
            grads = params**2
            block.instantiate_factors((grads, ), 0.5)

            # Make sure our inverse is something other than the identity.
            sess.run(tf_variables.global_variables_initializer())
            sess.run(block._factor.make_inverse_update_ops())
            vector = array_ops.ones(2, ) * 2
            output = block.multiply_inverse(vector)

            self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output))
    def testMultiplyInverseTuple(self):
        with ops.Graph().as_default(), self.test_session() as sess:
            random_seed.set_random_seed(200)
            params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
            block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
            block.register_additional_minibatch(32)
            grads = (params[0]**2, math_ops.sqrt(params[1]))
            block.instantiate_factors((grads, ), 0.5)

            # Make sure our inverse is something other than the identity.
            sess.run(tf_variables.global_variables_initializer())
            sess.run(block._factor.make_inverse_update_ops())

            vector = array_ops.ones(3, ) * 2
            output = block.multiply_inverse(vector)

            self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output))