def test_register_error(self): collection = {} @registry.register(collection, 'functions/func_0') def func_test0(): # pylint: disable=unused-variable pass with self.assertRaises(KeyError): @registry.register(collection, 'functions/func_0/sub_func') def func_test1(): # pylint: disable=unused-variable pass with self.assertRaises(LookupError): registry.lookup(collection, 'non-exist')
def build_backbone( input_specs: Union[tf.keras.layers.InputSpec, Sequence[tf.keras.layers.InputSpec]], backbone_config: hyperparams.Config, norm_activation_config: hyperparams.Config, l2_regularizer: tf.keras.regularizers.Regularizer = None, **kwargs ) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras """Builds backbone from a config. Args: input_specs: A (sequence of) `tf.keras.layers.InputSpec` of input. backbone_config: A `OneOfConfig` of backbone config. norm_activation_config: A config for normalization/activation layer. l2_regularizer: A `tf.keras.regularizers.Regularizer` object. Default to None. **kwargs: Additional keyword args to be passed to backbone builder. Returns: A `tf.keras.Model` instance of the backbone. """ backbone_builder = registry.lookup(_REGISTERED_BACKBONE_CLS, backbone_config.type) return backbone_builder(input_specs=input_specs, backbone_config=backbone_config, norm_activation_config=norm_activation_config, l2_regularizer=l2_regularizer, **kwargs)
def build_decoder( input_specs: Mapping[str, tf.TensorShape], model_config: hyperparams.Config, l2_regularizer: tf.keras.regularizers.Regularizer = None, **kwargs ) -> Union[None, tf.keras.Model, tf.keras.layers.Layer]: # pytype: disable=annotation-type-mismatch # typed-keras """Builds decoder from a config. A decoder can be a keras.Model, a keras.layers.Layer, or None. If it is not None, the decoder will take features from the backbone as input and generate decoded feature maps. If it is None, such as an identity decoder, the decoder is skipped and features from the backbone are regarded as model output. Args: input_specs: A `dict` of input specifications. A dictionary consists of {level: TensorShape} from a backbone. model_config: A `OneOfConfig` of model config. l2_regularizer: A `tf.keras.regularizers.Regularizer` object. Default to None. **kwargs: Additional keyword args to be passed to decoder builder. Returns: An instance of the decoder. """ decoder_builder = registry.lookup(_REGISTERED_DECODER_CLS, model_config.decoder.type) return decoder_builder(input_specs=input_specs, model_config=model_config, l2_regularizer=l2_regularizer, **kwargs)
def build_decoder( input_specs: Mapping[str, tf.TensorShape], model_config: hyperparams.Config, l2_regularizer: tf.keras.regularizers.Regularizer = None, **kwargs ) -> Union[None, tf.keras.Model, tf.keras.layers.Layer]: # pytype: disable=annotation-type-mismatch # typed-keras """Builds decoder from a config. Args: input_specs: A `dict` of input specifications. A dictionary consists of {level: TensorShape} from a backbone. model_config: A `OneOfConfig` of model config. l2_regularizer: A `tf.keras.regularizers.Regularizer` object. Default to None. **kwargs: Additional keyword args to be passed to decoder builder. Returns: An instance of the decoder. """ decoder_builder = registry.lookup(_REGISTERED_DECODER_CLS, model_config.decoder.type) return decoder_builder(input_specs=input_specs, model_config=model_config, l2_regularizer=l2_regularizer, **kwargs)
def test_register(self): collection = {} @registry.register(collection, 'functions/func_0') def func_test(): pass self.assertEqual(registry.lookup(collection, 'functions/func_0'), func_test) @registry.register(collection, 'classes/cls_0') class ClassRegistryKey: pass self.assertEqual( registry.lookup(collection, 'classes/cls_0'), ClassRegistryKey) @registry.register(collection, ClassRegistryKey) class ClassRegistryValue: pass self.assertEqual( registry.lookup(collection, ClassRegistryKey), ClassRegistryValue)
def build_backbone(input_specs: tf.keras.layers.InputSpec, model_config, l2_regularizer: tf.keras.regularizers.Regularizer = None): """Builds backbone from a config. Args: input_specs: tf.keras.layers.InputSpec. model_config: a OneOfConfig. Model config. l2_regularizer: tf.keras.regularizers.Regularizer instance. Default to None. Returns: tf.keras.Model instance of the backbone. """ backbone_builder = registry.lookup(_REGISTERED_BACKBONE_CLS, model_config.backbone.type) return backbone_builder(input_specs, model_config, l2_regularizer)
def build_model(model_type: str, input_specs: tf.keras.layers.InputSpec, model_config: video_classification_cfg.hyperparams.Config, num_classes: int, l2_regularizer: tf.keras.regularizers.Regularizer = None): """Builds backbone from a config. Args: model_type: string name of model type. It should be consistent with ModelConfig.model_type. input_specs: tf.keras.layers.InputSpec. model_config: a OneOfConfig. Model config. num_classes: number of classes. l2_regularizer: tf.keras.regularizers.Regularizer instance. Default to None. Returns: tf.keras.Model instance of the backbone. """ model_builder = registry.lookup(_REGISTERED_MODEL_CLS, model_type) return model_builder(input_specs, model_config, num_classes, l2_regularizer)
def get_data_loader(data_config): """Creates a data_loader from data_config.""" return registry.lookup(_REGISTERED_DATA_LOADER_CLS, data_config.__class__)(data_config)
def get_task_cls(task_config_cls): task_cls = registry.lookup(_REGISTERED_TASK_CLS, task_config_cls) return task_cls
def get_exp_config(exp_name: str) -> cfg.ExperimentConfig: """Looks up the `ExperimentConfig` according to the `exp_name`.""" exp_creater = registry.lookup(_REGISTERED_CONFIGS, exp_name) return exp_creater()
def get_exp_config_creater(exp_name: str): """Looks up ExperimentConfig factory methods.""" exp_creater = registry.lookup(_REGISTERED_CONFIGS, exp_name) return exp_creater