def testInstantiateFactors(self): with ops.Graph().as_default(): random_seed.set_random_seed(200) # Create a Fisher Block. vocab_size = 5 block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size) # Add some examples. inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]]) outputs = array_ops.constant([[0.], [1.], [2.]]) block.register_additional_tower(inputs, outputs) # Instantiate factor's variables. Ensure it doesn't fail. grads = outputs**2. damping = array_ops.constant(0.) block.instantiate_factors(((grads, ), ), damping)
def testMultiplyInverse(self): with ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) # Create a Fisher Block. vocab_size = 5 block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size) # Add some examples. inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]]) outputs = array_ops.constant([[0.], [1.], [2.]]) block.register_additional_tower(inputs, outputs) # Instantiate factor's variables. Ensure it doesn't fail. grads = outputs**2. damping = array_ops.constant(0.) block.instantiate_factors(((grads, ), ), damping) block._input_factor.instantiate_cov_variables() block._output_factor.instantiate_cov_variables() block.register_inverse() block._input_factor.instantiate_inv_variables() block._output_factor.instantiate_inv_variables() # Create a sparse update. indices = array_ops.constant([1, 3, 4]) values = array_ops.constant([[1.], [1.], [1.]]) sparse_vector = ops.IndexedSlices(values, indices, dense_shape=[vocab_size, 1]) dense_vector = array_ops.reshape([0., 1., 0., 1., 1.], [vocab_size, 1]) # Compare Fisher-vector product against explicit result. result = block.multiply_inverse(sparse_vector) expected_result = linalg_ops.matrix_solve( block.full_fisher_block(), dense_vector) sess.run(tf_variables.global_variables_initializer()) self.assertAlmostEqual(sess.run(expected_result[1]), sess.run(result.values[0])) self.assertAlmostEqual(sess.run(expected_result[3]), sess.run(result.values[1])) self.assertAlmostEqual(sess.run(expected_result[4]), sess.run(result.values[2]))
def register_embedding(self, params, inputs, outputs, approx=None, reuse=VARIABLE_SCOPE): """Registers a fully connnected layer. Args: params: Embedding matrix of shape [vocab_size, embedding_size]. inputs: Tensor of shape [batch_size, input_size] and dtype int32. Indices into embedding matrix. outputs: Tensor of shape [batch_size, output_size]. Outputs produced by layer. approx: str. Must be "kron". reuse: bool or str. If True, reuse an existing FisherBlock. If False, create a new FisherBlock. If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. Raises: ValueError: For improper value to 'approx'. KeyError: If reuse == True but no FisherBlock found for 'params'. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ if approx is None: approx = self._get_linked_approx(params) if approx is None: approx = self.default_embedding_approximation if approx != APPROX_KRONECKER_NAME: raise ValueError("Bad value {} for approx.".format(approx)) if isinstance(params, (tuple, list)): raise ValueError("Bias not supported.") vocab_size = int(params.shape[0]) block = self.register_block(params, fb.EmbeddingKFACFB(self, vocab_size), reuse=reuse) block.register_additional_minibatch(inputs, outputs) self._add_uses(params, 1)