Example #1
0
    def test_hidden_to_logits_computesLogitsCorrectly(self):
        seq_len = 1
        vocab_size = 4
        model_size = 3
        num_softmaxes = 2

        vocab_dim = mtf.Dimension('vocab', vocab_size)
        model_dim = mtf.Dimension('model', model_size)
        length_dim = mtf.Dimension('length', seq_len)

        embeddings = tf.constant(np.array([[1.0, 1.0, 2.0]]) /
                                 model_size**-0.5,
                                 dtype=tf.float32)
        mtf_embeddings = mtf.import_tf_tensor(self.mesh,
                                              embeddings,
                                              shape=mtf.Shape(
                                                  [length_dim, model_dim]))

        self.initializer_mock.side_effect = initialize_by_shape({
            # Embedding weights.
            (4, 3): [[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1]],
            # Mixture weights.
            (2, 3): [[1, 0, 0], [0, 1, 1]],
            # Context weights
            (2, 3, 3): [
                [[1, 0, 0], [0, 1, 0], [0, 0, 1]],
                [[0, 0, 1], [0, 1, 0], [1, 0, 0]],
            ],
        })

        vocab_embedding = vocab_embeddings.MixtureOfSoftmaxes(
            self.mesh,
            vocab_dim,
            output_dim=model_dim,
            variable_dtype=self.variable_dtype,
            name='embedding',
            ensemble_dim=None,
            num_softmaxes=num_softmaxes)

        mtf_logits = vocab_embedding.hidden_to_logits(mtf_embeddings,
                                                      context=None)

        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[''])
        lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl})
        actual_logits = lowering.export_to_tf_tensor(mtf_logits)

        self.evaluate(tf.global_variables_initializer())
        self.evaluate(lowering.copy_masters_to_slices())
        actual, = self.evaluate([actual_logits])

        expected_priors = scipy.special.softmax([1, 3])
        expected_probs_1 = scipy.special.softmax(np.tanh([1, 1, 2, 2]))
        expected_probs_2 = scipy.special.softmax(np.tanh([2, 1, 1, 1]))
        expected_probs = (expected_priors[0] * expected_probs_1 +
                          expected_priors[1] * expected_probs_2)
        expected_logits = np.log(expected_probs)

        self.assertAllClose(actual, [expected_logits])
Example #2
0
    def test_ids_to_embedding_correctlyEmbeds(self):
        seq_len = 4
        vocab_size = 4
        model_size = 3
        num_softmaxes = 1

        vocab_dim = mtf.Dimension('vocab', vocab_size)
        model_dim = mtf.Dimension('model', model_size)
        length_dim = mtf.Dimension('length', seq_len)

        ids = tf.constant([0, 1, 2, 3], dtype=tf.int32)
        mtf_ids = mtf.import_tf_tensor(self.mesh,
                                       ids,
                                       shape=mtf.Shape([length_dim]))

        self.initializer_mock.side_effect = initialize_by_shape({
            # Embedding weights.
            (4, 3): [[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 2]],
            # Mixture weights.
            (1, 3): [[1, 0, 0]],
            # Context weights
            (1, 3, 3): [
                [[1, 0, 0], [0, 1, 0], [0, 0, 1]],
            ],
        })

        vocab_embedding = vocab_embeddings.MixtureOfSoftmaxes(
            self.mesh,
            vocab_dim,
            output_dim=model_dim,
            variable_dtype=self.variable_dtype,
            name='embedding',
            ensemble_dim=None,
            num_softmaxes=num_softmaxes)

        mtf_embedding = vocab_embedding.ids_to_embedding(mtf_ids)

        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[''])
        lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl})
        actual_embedding = lowering.export_to_tf_tensor(mtf_embedding)

        self.evaluate(tf.global_variables_initializer())
        self.evaluate(lowering.copy_masters_to_slices())
        actual = self.evaluate([actual_embedding])[0]

        self.assertAllClose(actual,
                            [[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 2]])