def test_undecorated_encoder(self): """Simple test to verify that we can register encoders at runtime.""" # Register the encoder. encoder_registry.register_encoder('undecorated_encoder')(UnDecoratedEncoder) # Retrieve it. encoder_name = encoder_registry.get_registered_encoder( 'undecorated_encoder') self.assertIsNotNone(encoder_name) self.assertEqual(encoder_name.__name__, 'UnDecoratedEncoder') # Verify that we can still access previously registerd decorated layers. encoder_name = encoder_registry.get_registered_encoder('decorated_encoder') self.assertIsNotNone(encoder_name) self.assertEqual(encoder_name.__name__, 'DecoratedEncoder')
def make_output_postprocess_fn( cls, config: ml_collections.ConfigDict # pylint: disable=unused-argument ) -> Callable[[Dict[str, Any], Dict[str, Any]], Dict[str, Any]]: """Postprocess task samples (input and output). See BaseTask.""" base_postprocess_fn = base_task.BaseTask.make_output_postprocess_fn( config) encoder_name = config.model_config.encoder_name encoder_class = encoder_registry.get_registered_encoder(encoder_name) encoder_postprocess_fn = encoder_class.make_output_postprocess_fn( config) def postprocess_fn(batch: Dict[str, Any], auxiliary_output: Dict[str, Any]) -> Dict[str, Any]: """Function that prepares model's input and output for serialization.""" new_auxiliary_output = {} new_auxiliary_output.update(auxiliary_output) encoder_specific_features = encoder_postprocess_fn( batch, new_auxiliary_output) new_auxiliary_output.update(encoder_specific_features) return base_postprocess_fn(batch, new_auxiliary_output) return postprocess_fn
def setup(self): self.encoder = encoder_registry.get_registered_encoder( self.encoder_name)(**self.encoder_config) if self.apply_mlp: self.mlp = nn.Dense(self.encoder_config.hidden_size, self.dtype) self.dropout = nn.Dropout(self.encoder_config.dropout_rate) self.linear_classifier = nn.Dense(self.vocab_size, dtype=self.dtype)
def load_weights(config: ml_collections.ConfigDict) -> Dict[str, Any]: """Load model weights.""" encoder_name = config.model_config.encoder_name encoder_class = encoder_registry.get_registered_encoder(encoder_name) encoder_variables = encoder_class.load_weights(config) model_variables = {} for group_key in encoder_variables: model_variables[group_key] = { 'encoder': encoder_variables[group_key] } return model_variables
def setup(self): self.encoder = encoder_registry.get_registered_encoder( self.encoder_name)(**self.encoder_config) self.classification_mlp_layers = [ mlp.MLPBlock( # pylint: disable=g-complex-comprehension input_dim=self.input_dim, hidden_dim=self.hidden_dim, dropout_rate=self.dropout_rate, dtype=self.dtype, layer_norm_epsilon=self.layer_norm_epsilon, ) for _ in range(self.num_layers) ] self.linear_classifier = nn.Dense(self.num_classes, dtype=self.dtype)
def load_weights(cls, config: ml_collections.ConfigDict) -> Dict[str, Any]: """Load model weights from file. We assume that `encoder_name` is specified in the config. We use corresponding class to load encoder weights. Args: config: experiment config. Returns: Dictionary of model weights. """ encoder_name = config.model_config.encoder_name encoder_class = encoder_registry.get_registered_encoder(encoder_name) encoder_variables = encoder_class.load_weights(config) model_variables = {} for group_key in encoder_variables: model_variables[group_key] = { 'encoder': encoder_variables[group_key] } return model_variables
def load_weights(cls, config: ml_collections.ConfigDict) -> Dict[str, Any]: """Load model weights from file. We assume that MentionEncoderTasks specify an encoder name as a class attribute, which we use to load encoder weights. Args: config: experiment config. Returns: Dictionary of model weights. """ encoder_class = encoder_registry.get_registered_encoder( cls.encoder_name) encoder_variables = encoder_class.load_weights(config) model_variables = {} for group_key in encoder_variables: model_variables[group_key] = { 'encoder': encoder_variables[group_key] } return model_variables
def test_decorated_encoder(self): """Simple test to verify that decorated encoders have been registered.""" encoder_name = encoder_registry.get_registered_encoder('decorated_encoder') self.assertIsNotNone(encoder_name) self.assertEqual(encoder_name.__name__, 'DecoratedEncoder')
def setup(self): self.encoder = encoder_registry.get_registered_encoder( self.encoder_name)(**self.encoder_config) self.linear_classifier = nn.Dense(self.num_classes, dtype=self.dtype)
def setup(self): self.encoder = encoder_registry.get_registered_encoder( self.encoder_name)(**self.encoder_config)