def testInstantiateFactors(self): with tf.Graph().as_default(): tf.set_random_seed(200) # Create a Fisher Block. vocab_size = 5 block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size) # Add some examples. inputs = tf.constant([[0, 1], [1, 2], [2, 3]]) outputs = tf.constant([[0.], [1.], [2.]]) block.register_additional_tower(inputs, outputs) # Instantiate factor's variables. Ensure it doesn't fail. grads = outputs**2. damping = tf.constant(0.) block.instantiate_factors(((grads,),), damping)
def testMultiplyInverse(self): with tf.Graph().as_default(), self.test_session() as sess: tf.set_random_seed(200) # Create a Fisher Block. vocab_size = 5 block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size) # Add some examples. inputs = tf.constant([[0, 1], [1, 2], [2, 3]]) outputs = tf.constant([[0.], [1.], [2.]]) block.register_additional_tower(inputs, outputs) # Instantiate factor's variables. Ensure it doesn't fail. grads = outputs**2. damping = tf.constant(0.) block.instantiate_factors(((grads, ), ), damping) block._input_factor.instantiate_cov_variables() block._output_factor.instantiate_cov_variables() block.register_inverse() block._input_factor.instantiate_inv_variables() block._output_factor.instantiate_inv_variables() # Create a sparse update. indices = tf.constant([1, 3, 4]) values = tf.constant([[1.], [1.], [1.]]) sparse_vector = tf.IndexedSlices(values, indices, dense_shape=[vocab_size, 1]) dense_vector = tf.reshape([0., 1., 0., 1., 1.], [vocab_size, 1]) # Compare Fisher-vector product against explicit result. result = block.multiply_inverse(sparse_vector) expected_result = tf.matrix_solve(block.full_fisher_block(), dense_vector) sess.run(tf.global_variables_initializer()) self.assertAlmostEqual(sess.run(expected_result[1]), sess.run(result.values[0])) self.assertAlmostEqual(sess.run(expected_result[3]), sess.run(result.values[1])) self.assertAlmostEqual(sess.run(expected_result[4]), sess.run(result.values[2]))