def testMultiplyInverseDense(self): with tf.Graph().as_default(), self.test_session() as sess: tf.set_random_seed(200) block = fb.EmbeddingKFACMultiIndepFB(lc.LayerCollection()) inputs = [tf.constant([[0., 1], [1, 2], [2, 3]]), tf.constant([[0.1], [0.], [0.]])] outputs = [tf.constant([[0.], [1.], [2.]]), tf.constant([[0., 0], [0, 0], [0, 4]])] block.register_additional_tower(inputs, outputs, transpose=[False, True]) grads = [output**2 for output in outputs] damping = tf.constant(0.) block.instantiate_factors(((grads,),), damping) block._input_factor.instantiate_cov_variables() block._output_factor.instantiate_cov_variables() block.register_inverse() block._input_factor.instantiate_inv_variables() block._output_factor.instantiate_inv_variables() # Create a dense update. dense_vector = tf.constant([[0.5], [0.5]]) # Compare Fisher-vector product against explicit result. result = block.multiply_inverse(dense_vector) expected_result = tf.matrix_solve(block.full_fisher_block(), dense_vector) sess.run(tf.global_variables_initializer()) self.assertAlmostEqual( sess.run(expected_result[0]), sess.run(result[0])) self.assertAlmostEqual( sess.run(expected_result[1]), sess.run(result[1]))
def testInstantiateFactorsSingleTensors(self): with tf.Graph().as_default(): tf.set_random_seed(200) vocab_size = 5 block = fb.EmbeddingKFACMultiIndepFB(lc.LayerCollection(), vocab_size, num_uses=2) inputs = tf.constant([[0, 1], [1, 2], [2, 3]]) outputs = tf.constant([[0.], [1.], [2.]]) block.register_additional_tower(inputs, outputs) grads = outputs**2 damping = tf.constant(0.) block.instantiate_factors(((grads,),), damping)
def testInstantiateFactors(self): with tf.Graph().as_default(): tf.set_random_seed(200) vocab_size = 5 block = fb.EmbeddingKFACMultiIndepFB(lc.LayerCollection(), vocab_size) inputs = [tf.constant([[0, 1], [1, 2], [2, 3]]), tf.constant([[0, 0], [0, 0], [0, 4]])] outputs = [tf.constant([[0.], [1.], [2.]]), tf.constant([[0.1], [0.], [0.]])] block.register_additional_tower(inputs, outputs) grads = [output**2 for output in outputs] damping = tf.constant(0.) block.instantiate_factors(((grads,),), damping)
def testInstantiateFactorsTransposeConsistency(self): with tf.Graph().as_default(): tf.set_random_seed(200) vocab_size = 5 block = fb.EmbeddingKFACMultiIndepFB(lc.LayerCollection(), vocab_size) inputs = [tf.constant([[0, 1], [1, 2], [2, 3]]), tf.constant([[0.1], [0.], [0.]])] outputs = [tf.constant([[0.], [1.], [2.]]), tf.constant([[0, 0], [0, 0], [0, 4]])] block.register_additional_tower(inputs, outputs, transpose=[False, True]) block.register_additional_tower(inputs, outputs, transpose=[False, True]) with self.assertRaises(ValueError): block.register_additional_tower(inputs, outputs) with self.assertRaises(ValueError): block.register_additional_tower(inputs, outputs, transpose=[True, False])
def testMultiplyInverse(self): with tf.Graph().as_default(), self.test_session() as sess: tf.set_random_seed(200) vocab_size = 5 block = fb.EmbeddingKFACMultiIndepFB(lc.LayerCollection(), vocab_size) inputs = [tf.constant([[0, 1], [1, 2], [2, 3]]), tf.constant([[0, 0], [0, 0], [0, 4]])] outputs = [tf.constant([[0.], [1.], [2.]]), tf.constant([[0.1], [0.], [0.]])] block.register_additional_tower(inputs, outputs) grads = [output**2 for output in outputs] damping = tf.constant(0.) block.instantiate_factors(((grads,),), damping) block._input_factor.instantiate_cov_variables() block._output_factor.instantiate_cov_variables() block.register_inverse() block._input_factor.instantiate_inv_variables() block._output_factor.instantiate_inv_variables() # Create a sparse update. indices = tf.constant([1, 3, 4]) values = tf.constant([[1.], [1.], [1.]]) sparse_vector = tf.IndexedSlices( values, indices, dense_shape=[vocab_size, 1]) dense_vector = tf.reshape([0., 1., 0., 1., 1.], [vocab_size, 1]) # Compare Fisher-vector product against explicit result. result = block.multiply_inverse(sparse_vector) expected_result = tf.matrix_solve(block.full_fisher_block(), dense_vector) sess.run(tf.global_variables_initializer()) self.assertAlmostEqual( sess.run(expected_result[1]), sess.run(result.values[0])) self.assertAlmostEqual( sess.run(expected_result[3]), sess.run(result.values[1])) self.assertAlmostEqual( sess.run(expected_result[4]), sess.run(result.values[2]))