예제 #1
0
  def test_undecorated_encoder(self):
    """Simple test to verify that we can register encoders at runtime."""
    # Register the encoder.
    encoder_registry.register_encoder('undecorated_encoder')(UnDecoratedEncoder)

    # Retrieve it.
    encoder_name = encoder_registry.get_registered_encoder(
        'undecorated_encoder')
    self.assertIsNotNone(encoder_name)
    self.assertEqual(encoder_name.__name__, 'UnDecoratedEncoder')

    # Verify that we can still access previously registerd decorated layers.
    encoder_name = encoder_registry.get_registered_encoder('decorated_encoder')
    self.assertIsNotNone(encoder_name)
    self.assertEqual(encoder_name.__name__, 'DecoratedEncoder')
예제 #2
0
    def make_output_postprocess_fn(
        cls,
        config: ml_collections.ConfigDict  # pylint: disable=unused-argument
    ) -> Callable[[Dict[str, Any], Dict[str, Any]], Dict[str, Any]]:
        """Postprocess task samples (input and output). See BaseTask."""

        base_postprocess_fn = base_task.BaseTask.make_output_postprocess_fn(
            config)

        encoder_name = config.model_config.encoder_name
        encoder_class = encoder_registry.get_registered_encoder(encoder_name)
        encoder_postprocess_fn = encoder_class.make_output_postprocess_fn(
            config)

        def postprocess_fn(batch: Dict[str, Any],
                           auxiliary_output: Dict[str, Any]) -> Dict[str, Any]:
            """Function that prepares model's input and output for serialization."""

            new_auxiliary_output = {}
            new_auxiliary_output.update(auxiliary_output)
            encoder_specific_features = encoder_postprocess_fn(
                batch, new_auxiliary_output)
            new_auxiliary_output.update(encoder_specific_features)
            return base_postprocess_fn(batch, new_auxiliary_output)

        return postprocess_fn
예제 #3
0
    def setup(self):
        self.encoder = encoder_registry.get_registered_encoder(
            self.encoder_name)(**self.encoder_config)

        if self.apply_mlp:
            self.mlp = nn.Dense(self.encoder_config.hidden_size, self.dtype)
            self.dropout = nn.Dropout(self.encoder_config.dropout_rate)
        self.linear_classifier = nn.Dense(self.vocab_size, dtype=self.dtype)
예제 #4
0
    def load_weights(config: ml_collections.ConfigDict) -> Dict[str, Any]:
        """Load model weights."""
        encoder_name = config.model_config.encoder_name
        encoder_class = encoder_registry.get_registered_encoder(encoder_name)
        encoder_variables = encoder_class.load_weights(config)
        model_variables = {}
        for group_key in encoder_variables:
            model_variables[group_key] = {
                'encoder': encoder_variables[group_key]
            }

        return model_variables
예제 #5
0
    def setup(self):
        self.encoder = encoder_registry.get_registered_encoder(
            self.encoder_name)(**self.encoder_config)

        self.classification_mlp_layers = [
            mlp.MLPBlock(  # pylint: disable=g-complex-comprehension
                input_dim=self.input_dim,
                hidden_dim=self.hidden_dim,
                dropout_rate=self.dropout_rate,
                dtype=self.dtype,
                layer_norm_epsilon=self.layer_norm_epsilon,
            ) for _ in range(self.num_layers)
        ]

        self.linear_classifier = nn.Dense(self.num_classes, dtype=self.dtype)
예제 #6
0
    def load_weights(cls, config: ml_collections.ConfigDict) -> Dict[str, Any]:
        """Load model weights from file.

    We assume that `encoder_name` is specified in the config.
    We use corresponding class to load encoder weights.

    Args:
      config: experiment config.

    Returns:
      Dictionary of model weights.
    """
        encoder_name = config.model_config.encoder_name
        encoder_class = encoder_registry.get_registered_encoder(encoder_name)
        encoder_variables = encoder_class.load_weights(config)
        model_variables = {}
        for group_key in encoder_variables:
            model_variables[group_key] = {
                'encoder': encoder_variables[group_key]
            }

        return model_variables
예제 #7
0
    def load_weights(cls, config: ml_collections.ConfigDict) -> Dict[str, Any]:
        """Load model weights from file.

    We assume that MentionEncoderTasks specify an encoder name as a class
    attribute, which we use to load encoder weights.

    Args:
      config: experiment config.

    Returns:
      Dictionary of model weights.
    """

        encoder_class = encoder_registry.get_registered_encoder(
            cls.encoder_name)
        encoder_variables = encoder_class.load_weights(config)
        model_variables = {}
        for group_key in encoder_variables:
            model_variables[group_key] = {
                'encoder': encoder_variables[group_key]
            }
        return model_variables
예제 #8
0
 def test_decorated_encoder(self):
   """Simple test to verify that decorated encoders have been registered."""
   encoder_name = encoder_registry.get_registered_encoder('decorated_encoder')
   self.assertIsNotNone(encoder_name)
   self.assertEqual(encoder_name.__name__, 'DecoratedEncoder')
예제 #9
0
 def setup(self):
     self.encoder = encoder_registry.get_registered_encoder(
         self.encoder_name)(**self.encoder_config)
     self.linear_classifier = nn.Dense(self.num_classes, dtype=self.dtype)
예제 #10
0
 def setup(self):
     self.encoder = encoder_registry.get_registered_encoder(
         self.encoder_name)(**self.encoder_config)