Exemplo n.º 1
0
def _build_pretrainer(
        config: teams.TeamsPretrainerConfig
) -> teams_pretrainer.TeamsPretrainer:
    """Instantiates TeamsPretrainer from the config."""
    generator_encoder_cfg = config.generator
    discriminator_encoder_cfg = config.discriminator
    discriminator_network = teams.get_encoder(discriminator_encoder_cfg)
    # Copy discriminator's embeddings to generator for easier model serialization.
    hidden_layers = _get_generator_hidden_layers(
        discriminator_network, generator_encoder_cfg.num_layers,
        config.num_shared_generator_hidden_layers)
    if config.tie_embeddings:
        generator_network = teams.get_encoder(
            generator_encoder_cfg,
            embedding_network=discriminator_network.embedding_network,
            hidden_layers=hidden_layers)
    else:
        generator_network = teams.get_encoder(generator_encoder_cfg,
                                              hidden_layers=hidden_layers)

    return teams_pretrainer.TeamsPretrainer(
        generator_network=generator_network,
        discriminator_mws_network=discriminator_network,
        num_discriminator_task_agnostic_layers=config.
        num_discriminator_task_agnostic_layers,
        vocab_size=generator_encoder_cfg.vocab_size,
        candidate_size=config.candidate_size,
        mlm_activation=tf_utils.get_activation(
            generator_encoder_cfg.hidden_activation),
        mlm_initializer=tf.keras.initializers.TruncatedNormal(
            stddev=generator_encoder_cfg.initializer_range))
Exemplo n.º 2
0
  def test_serialize_deserialize(self):
    """Validate that the TEAMS trainer can be serialized and deserialized."""
    vocab_size = 100
    test_generator_network = self._get_network(vocab_size)
    test_discriminator_network = self._get_network(vocab_size)

    # Create a TEAMS trainer with the created network. (Note that all the args
    # are different, so we can catch any serialization mismatches.)
    teams_trainer_model = teams_pretrainer.TeamsPretrainer(
        generator_network=test_generator_network,
        discriminator_mws_network=test_discriminator_network,
        num_discriminator_task_agnostic_layers=2,
        vocab_size=vocab_size,
        candidate_size=2)

    # Create another TEAMS trainer via serialization and deserialization.
    config = teams_trainer_model.get_config()
    new_teams_trainer_model = teams_pretrainer.TeamsPretrainer.from_config(
        config)

    # Validate that the config can be forced to JSON.
    _ = new_teams_trainer_model.to_json()

    # If the serialization was successful, the new config should match the old.
    self.assertAllEqual(teams_trainer_model.get_config(),
                        new_teams_trainer_model.get_config())
Exemplo n.º 3
0
  def test_teams_trainer_tensor_call(self):
    """Validate that the Keras object can be invoked."""
    vocab_size = 100
    test_generator_network = self._get_network(vocab_size)
    test_discriminator_network = self._get_network(vocab_size)

    # Create a TEAMS trainer with the created network.
    teams_trainer_model = teams_pretrainer.TeamsPretrainer(
        generator_network=test_generator_network,
        discriminator_mws_network=test_discriminator_network,
        num_discriminator_task_agnostic_layers=2,
        vocab_size=vocab_size,
        candidate_size=2)

    # Create a set of 2-dimensional data tensors to feed into the model.
    word_ids = tf.constant([[1, 1, 1], [2, 2, 2]], dtype=tf.int32)
    mask = tf.constant([[1, 1, 1], [1, 0, 0]], dtype=tf.int32)
    type_ids = tf.constant([[1, 1, 1], [2, 2, 2]], dtype=tf.int32)
    lm_positions = tf.constant([[0, 1], [0, 2]], dtype=tf.int32)
    lm_ids = tf.constant([[10, 20], [20, 30]], dtype=tf.int32)
    inputs = {
        'input_word_ids': word_ids,
        'input_mask': mask,
        'input_type_ids': type_ids,
        'masked_lm_positions': lm_positions,
        'masked_lm_ids': lm_ids
    }

    # Invoke the trainer model on the tensors. In Eager mode, this does the
    # actual calculation. (We can't validate the outputs, since the network is
    # too complex: this simply ensures we're not hitting runtime errors.)
    _ = teams_trainer_model(inputs)
Exemplo n.º 4
0
  def test_teams_pretrainer(self):
    """Validate that the Keras object can be created."""
    vocab_size = 100
    test_generator_network = self._get_network(vocab_size)
    test_discriminator_network = self._get_network(vocab_size)

    # Create a TEAMS trainer with the created network.
    candidate_size = 3
    teams_trainer_model = teams_pretrainer.TeamsPretrainer(
        generator_network=test_generator_network,
        discriminator_mws_network=test_discriminator_network,
        num_discriminator_task_agnostic_layers=1,
        vocab_size=vocab_size,
        candidate_size=candidate_size)

    # Create a set of 2-dimensional inputs (the first dimension is implicit).
    num_token_predictions = 2
    sequence_length = 128
    word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
    mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
    type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
    lm_positions = tf.keras.Input(
        shape=(num_token_predictions,), dtype=tf.int32)
    lm_ids = tf.keras.Input(shape=(num_token_predictions,), dtype=tf.int32)
    inputs = {
        'input_word_ids': word_ids,
        'input_mask': mask,
        'input_type_ids': type_ids,
        'masked_lm_positions': lm_positions,
        'masked_lm_ids': lm_ids
    }

    # Invoke the trainer model on the inputs. This causes the layer to be built.
    outputs = teams_trainer_model(inputs)
    lm_outs = outputs['lm_outputs']
    disc_rtd_logits = outputs['disc_rtd_logits']
    disc_rtd_label = outputs['disc_rtd_label']
    disc_mws_logits = outputs['disc_mws_logits']
    disc_mws_label = outputs['disc_mws_label']

    # Validate that the outputs are of the expected shape.
    expected_lm_shape = [None, num_token_predictions, vocab_size]
    expected_disc_rtd_logits_shape = [None, sequence_length]
    expected_disc_rtd_label_shape = [None, sequence_length]
    expected_disc_disc_mws_logits_shape = [
        None, num_token_predictions, candidate_size
    ]
    expected_disc_disc_mws_label_shape = [None, num_token_predictions]
    self.assertAllEqual(expected_lm_shape, lm_outs.shape.as_list())
    self.assertAllEqual(expected_disc_rtd_logits_shape,
                        disc_rtd_logits.shape.as_list())
    self.assertAllEqual(expected_disc_rtd_label_shape,
                        disc_rtd_label.shape.as_list())
    self.assertAllEqual(expected_disc_disc_mws_logits_shape,
                        disc_mws_logits.shape.as_list())
    self.assertAllEqual(expected_disc_disc_mws_label_shape,
                        disc_mws_label.shape.as_list())