Example #1
0
    def test_hidden_to_logits_computesLogitsCorrectly(self):
        seq_len = 4
        vocab_size = 5
        model_size = 2

        vocab_dim = mtf.Dimension('vocab', vocab_size)
        model_dim = mtf.Dimension('model', model_size)
        length_dim = mtf.Dimension('length', seq_len)

        embeddings = tf.constant([[1, 0], [0, 1], [1, 1], [2, 1]],
                                 dtype=tf.float32)
        mtf_embeddings = mtf.import_tf_tensor(self.mesh,
                                              embeddings,
                                              shape=mtf.Shape(
                                                  [length_dim, model_dim]))

        self.initializer_mock.side_effect = initialize_by_shape({
            (2, 2): [[0, 1], [2, 0]],
            (3, 1): [[1], [2], [3]],
            (1, 2): [[1], [2]],
        })

        vocab_embedding = vocab_embeddings.AdaptiveVocabEmbedding(
            self.mesh,
            vocab_dim,
            output_dim=model_dim,
            variable_dtype=self.variable_dtype,
            name='embedding',
            ensemble_dim=None,
            clusters=[{
                'token_count': 2,
                'embedding_size': 2
            }, {
                'token_count': 3,
                'embedding_size': 1
            }])

        mtf_logits = vocab_embedding.hidden_to_logits(mtf_embeddings,
                                                      context=None)

        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[''])
        lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl})
        actual_logits = lowering.export_to_tf_tensor(mtf_logits)

        self.evaluate(tf.global_variables_initializer())
        self.evaluate(lowering.copy_masters_to_slices())
        actual = self.evaluate([actual_logits])[0]

        self.assertAllClose(
            actual,
            model_size**-0.5 * np.array([[0, 2, 1, 2, 3], [1, 0, 2, 4, 6],
                                         [1, 2, 3, 6, 9], [1, 4, 4, 8, 12]]))
Example #2
0
    def test_ids_to_embedding_correctlyEmbeds(self):
        seq_len = 6
        vocab_size = 5
        model_size = 2

        vocab_dim = mtf.Dimension('vocab', vocab_size)
        model_dim = mtf.Dimension('model', model_size)
        length_dim = mtf.Dimension('length', seq_len)

        ids = tf.constant([0, 1, 2, 3, 4, 0], dtype=tf.int32)
        mtf_ids = mtf.import_tf_tensor(self.mesh,
                                       ids,
                                       shape=mtf.Shape([length_dim]))

        self.initializer_mock.side_effect = initialize_by_shape({
            (2, 2): [[0, 1], [2, 0]],
            (3, 1): [[1], [2], [3]],
            (1, 2): [[1], [2]],
        })

        vocab_embedding = vocab_embeddings.AdaptiveVocabEmbedding(
            self.mesh,
            vocab_dim,
            output_dim=model_dim,
            variable_dtype=self.variable_dtype,
            name='embedding',
            ensemble_dim=None,
            clusters=[{
                'token_count': 2,
                'embedding_size': 2
            }, {
                'token_count': 3,
                'embedding_size': 1
            }])

        mtf_embedding = vocab_embedding.ids_to_embedding(mtf_ids)

        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[''])
        lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl})
        actual_embedding = lowering.export_to_tf_tensor(mtf_embedding)

        self.evaluate(tf.global_variables_initializer())
        self.evaluate(lowering.copy_masters_to_slices())
        actual = self.evaluate([actual_embedding])[0]

        self.assertAllClose(actual,
                            [[0, 1], [2, 0], [1, 2], [2, 4], [3, 6], [0, 1]])
Example #3
0
  def test_constructor_tokenCountsDontSumToVocabSize_raisesValueError(self):
    vocab_dim = mtf.Dimension('vocab', 5)
    model_dim = mtf.Dimension('model', 2)

    with self.assertRaises(ValueError):
      vocab_embeddings.AdaptiveVocabEmbedding(
          self.mesh,
          vocab_dim,
          output_dim=model_dim,
          variable_dtype=self.variable_dtype,
          name='embedding',
          ensemble_dim=None,
          clusters=[{
              'token_count': 3,
              'embedding_size': 2
          }, {
              'token_count': 3,
              'embedding_size': 1
          }])