Exemplo n.º 1
0
    def testInstantiateFactors(self):
        with ops.Graph().as_default():
            random_seed.set_random_seed(200)

            # Create a Fisher Block.
            vocab_size = 5
            block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size)

            # Add some examples.
            inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]])
            outputs = array_ops.constant([[0.], [1.], [2.]])
            block.register_additional_tower(inputs, outputs)

            # Instantiate factor's variables. Ensure it doesn't fail.
            grads = outputs**2.
            damping = array_ops.constant(0.)
            block.instantiate_factors(((grads, ), ), damping)
Exemplo n.º 2
0
    def testMultiplyInverse(self):
        with ops.Graph().as_default(), self.test_session() as sess:
            random_seed.set_random_seed(200)

            # Create a Fisher Block.
            vocab_size = 5
            block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size)

            # Add some examples.
            inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]])
            outputs = array_ops.constant([[0.], [1.], [2.]])
            block.register_additional_tower(inputs, outputs)

            # Instantiate factor's variables. Ensure it doesn't fail.
            grads = outputs**2.
            damping = array_ops.constant(0.)
            block.instantiate_factors(((grads, ), ), damping)
            block._input_factor.instantiate_cov_variables()
            block._output_factor.instantiate_cov_variables()
            block.register_inverse()
            block._input_factor.instantiate_inv_variables()
            block._output_factor.instantiate_inv_variables()

            # Create a sparse update.
            indices = array_ops.constant([1, 3, 4])
            values = array_ops.constant([[1.], [1.], [1.]])
            sparse_vector = ops.IndexedSlices(values,
                                              indices,
                                              dense_shape=[vocab_size, 1])
            dense_vector = array_ops.reshape([0., 1., 0., 1., 1.],
                                             [vocab_size, 1])

            # Compare Fisher-vector product against explicit result.
            result = block.multiply_inverse(sparse_vector)
            expected_result = linalg_ops.matrix_solve(
                block.full_fisher_block(), dense_vector)

            sess.run(tf_variables.global_variables_initializer())
            self.assertAlmostEqual(sess.run(expected_result[1]),
                                   sess.run(result.values[0]))
            self.assertAlmostEqual(sess.run(expected_result[3]),
                                   sess.run(result.values[1]))
            self.assertAlmostEqual(sess.run(expected_result[4]),
                                   sess.run(result.values[2]))
    def register_embedding(self,
                           params,
                           inputs,
                           outputs,
                           approx=None,
                           reuse=VARIABLE_SCOPE):
        """Registers a fully connnected layer.

    Args:
      params: Embedding matrix of shape [vocab_size, embedding_size].
      inputs: Tensor of shape [batch_size, input_size] and dtype int32. Indices
        into embedding matrix.
      outputs: Tensor of shape [batch_size, output_size]. Outputs
        produced by layer.
      approx: str. Must be "kron".
      reuse: bool or str.  If True, reuse an existing FisherBlock. If False,
        create a new FisherBlock.  If "VARIABLE_SCOPE", use
        tf.get_variable_scope().reuse.

    Raises:
      ValueError: For improper value to 'approx'.
      KeyError: If reuse == True but no FisherBlock found for 'params'.
      ValueError: If reuse == True and FisherBlock found but of the wrong type.
    """
        if approx is None:
            approx = self._get_linked_approx(params)
            if approx is None:
                approx = self.default_embedding_approximation

        if approx != APPROX_KRONECKER_NAME:
            raise ValueError("Bad value {} for approx.".format(approx))

        if isinstance(params, (tuple, list)):
            raise ValueError("Bias not supported.")

        vocab_size = int(params.shape[0])
        block = self.register_block(params,
                                    fb.EmbeddingKFACFB(self, vocab_size),
                                    reuse=reuse)
        block.register_additional_minibatch(inputs, outputs)

        self._add_uses(params, 1)