def testMultiplyInverseTuple(self):
        with ops.Graph().as_default(), self.test_session() as sess:
            random_seed.set_random_seed(200)
            inputs = array_ops.constant([[1., 2., 3.], [3., 4., 5.],
                                         [5., 6., 7.]])
            outputs = array_ops.constant([[3., 4.], [5., 6.]])
            block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(),
                                                 has_bias=False)
            block.register_additional_minibatch(inputs, outputs)
            grads = outputs**2
            block.instantiate_factors(([grads], ), 0.5)

            # Make sure our inverse is something other than the identity.
            sess.run(tf_variables.global_variables_initializer())
            sess.run(block._input_factor.make_inverse_update_ops())
            sess.run(block._output_factor.make_inverse_update_ops())

            vector = (
                np.arange(2, 6).reshape(2, 2).astype(np.float32),  #
                np.arange(1, 3).reshape(2, 1).astype(np.float32))
            output = block.multiply_inverse(
                (array_ops.constant(vector[0]), array_ops.constant(vector[1])))

            output = sess.run(output)
            self.assertAllClose([[0.686291, 1.029437], [1.372583, 1.715729]],
                                output[0])
            self.assertAllClose([0.343146, 0.686291], output[1])
    def testMultiplyInverseAgainstExplicit(self):
        with ops.Graph().as_default(), self.test_session() as sess:
            random_seed.set_random_seed(200)
            input_dim, output_dim = 3, 2
            inputs = array_ops.zeros([32, input_dim])
            outputs = array_ops.zeros([32, output_dim])
            params = array_ops.zeros([input_dim, output_dim])
            block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(),
                                                 has_bias=False)
            block.register_additional_minibatch(inputs, outputs)
            grads = outputs**2
            damping = 0.  # This test is only valid without damping.
            block.instantiate_factors(([grads], ), damping)

            sess.run(state_ops.assign(block._input_factor._cov, _make_psd(3)))
            sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2)))
            sess.run(block._input_factor.make_inverse_update_ops())
            sess.run(block._output_factor.make_inverse_update_ops())

            v_flat = np.arange(6, dtype=np.float32)
            vector = utils.column_to_tensors(params,
                                             array_ops.constant(v_flat))
            output = block.multiply_inverse(vector)
            output_flat = sess.run(utils.tensors_to_column(output)).ravel()

            full = sess.run(block.full_fisher_block())
            explicit = np.dot(np.linalg.inv(full + damping * np.eye(6)),
                              v_flat)

            self.assertAllClose(output_flat, explicit)
示例#3
0
    def testMultiplyInverseNotTuple(self):
        with ops.Graph().as_default(), self.cached_session() as sess:
            random_seed.set_random_seed(200)
            inputs = array_ops.constant([[1., 2.], [3., 4.]])
            outputs = array_ops.constant([[3., 4.], [5., 6.]])
            block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(),
                                                 has_bias=False)
            block.register_additional_tower(inputs, outputs)
            grads = outputs**2
            block.instantiate_factors(((grads, ), ), 0.5)
            block._input_factor.instantiate_cov_variables()
            block._output_factor.instantiate_cov_variables()
            block.register_inverse()
            block._input_factor.instantiate_inv_variables()
            block._output_factor.instantiate_inv_variables()

            # Make sure our inverse is something other than the identity.
            sess.run(tf_variables.global_variables_initializer())
            sess.run(block._input_factor.make_inverse_update_ops())
            sess.run(block._output_factor.make_inverse_update_ops())

            vector = np.arange(2, 6).reshape(2, 2).astype(np.float32)
            output = block.multiply_inverse(array_ops.constant(vector))

            self.assertAllClose([[0.686291, 1.029437], [1.372583, 1.715729]],
                                sess.run(output))
    def testFullyConnectedKFACBasicFBInit(self):
        with ops.Graph().as_default():
            random_seed.set_random_seed(200)
            inputs = array_ops.constant([1., 2.])
            outputs = array_ops.constant([3., 4.])
            block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection())
            block.register_additional_minibatch(inputs, outputs)

            self.assertAllEqual([outputs], block.tensors_to_compute_grads())
  def testInstantiateFactorsNoBias(self):
    with ops.Graph().as_default():
      random_seed.set_random_seed(200)
      inputs = array_ops.constant([[1., 2.], [3., 4.]])
      outputs = array_ops.constant([[3., 4.], [5., 6.]])
      block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
      block.register_additional_minibatch(inputs, outputs)

      grads = outputs**2
      block.instantiate_factors(((grads,),), 0.5)
示例#6
0
  def testInstantiateFactorsHasBias(self):
    with ops.Graph().as_default():
      random_seed.set_random_seed(200)
      inputs = array_ops.constant([[1., 2.], [3., 4.]])
      outputs = array_ops.constant([[3., 4.], [5., 6.]])
      block = fb.FullyConnectedKFACBasicFB(
          lc.LayerCollection(), inputs, outputs, has_bias=True)

      grads = outputs**2
      block.instantiate_factors((grads,), 0.5)
示例#7
0
 def register_fully_connected(self,
                              params,
                              inputs,
                              outputs,
                              approx=APPROX_KRONECKER_NAME):
   has_bias = isinstance(params, (tuple, list))
   if approx == APPROX_KRONECKER_NAME:
     self.register_block(params,
                         fb.FullyConnectedKFACBasicFB(self, inputs, outputs,
                                                      has_bias))
   elif approx == APPROX_DIAGONAL_NAME:
     self.register_block(params,
                         fb.FullyConnectedDiagonalFB(self, inputs, outputs,
                                                     has_bias))
   else:
     raise ValueError("Bad value {} for approx.".format(approx))