Example #1
0
  def test_register_error(self):
    collection = {}

    @registry.register(collection, 'functions/func_0')
    def func_test0():  # pylint: disable=unused-variable
      pass

    with self.assertRaises(KeyError):

      @registry.register(collection, 'functions/func_0/sub_func')
      def func_test1():  # pylint: disable=unused-variable
        pass

    with self.assertRaises(LookupError):
      registry.lookup(collection, 'non-exist')
Example #2
0
def build_backbone(
    input_specs: Union[tf.keras.layers.InputSpec,
                       Sequence[tf.keras.layers.InputSpec]],
    backbone_config: hyperparams.Config,
    norm_activation_config: hyperparams.Config,
    l2_regularizer: tf.keras.regularizers.Regularizer = None,
    **kwargs
) -> tf.keras.Model:  # pytype: disable=annotation-type-mismatch  # typed-keras
    """Builds backbone from a config.

  Args:
    input_specs: A (sequence of) `tf.keras.layers.InputSpec` of input.
    backbone_config: A `OneOfConfig` of backbone config.
    norm_activation_config: A config for normalization/activation layer.
    l2_regularizer: A `tf.keras.regularizers.Regularizer` object. Default to
      None.
    **kwargs: Additional keyword args to be passed to backbone builder.

  Returns:
    A `tf.keras.Model` instance of the backbone.
  """
    backbone_builder = registry.lookup(_REGISTERED_BACKBONE_CLS,
                                       backbone_config.type)

    return backbone_builder(input_specs=input_specs,
                            backbone_config=backbone_config,
                            norm_activation_config=norm_activation_config,
                            l2_regularizer=l2_regularizer,
                            **kwargs)
Example #3
0
def build_decoder(
    input_specs: Mapping[str, tf.TensorShape],
    model_config: hyperparams.Config,
    l2_regularizer: tf.keras.regularizers.Regularizer = None,
    **kwargs
) -> Union[None, tf.keras.Model, tf.keras.layers.Layer]:  # pytype: disable=annotation-type-mismatch  # typed-keras
    """Builds decoder from a config.

  A decoder can be a keras.Model, a keras.layers.Layer, or None. If it is not
  None, the decoder will take features from the backbone as input and generate
  decoded feature maps. If it is None, such as an identity decoder, the decoder
  is skipped and features from the backbone are regarded as model output.

  Args:
    input_specs: A `dict` of input specifications. A dictionary consists of
      {level: TensorShape} from a backbone.
    model_config: A `OneOfConfig` of model config.
    l2_regularizer: A `tf.keras.regularizers.Regularizer` object. Default to
      None.
    **kwargs: Additional keyword args to be passed to decoder builder.

  Returns:
    An instance of the decoder.
  """
    decoder_builder = registry.lookup(_REGISTERED_DECODER_CLS,
                                      model_config.decoder.type)

    return decoder_builder(input_specs=input_specs,
                           model_config=model_config,
                           l2_regularizer=l2_regularizer,
                           **kwargs)
Example #4
0
def build_decoder(
    input_specs: Mapping[str, tf.TensorShape],
    model_config: hyperparams.Config,
    l2_regularizer: tf.keras.regularizers.Regularizer = None,
    **kwargs
) -> Union[None, tf.keras.Model, tf.keras.layers.Layer]:  # pytype: disable=annotation-type-mismatch  # typed-keras
    """Builds decoder from a config.

  Args:
    input_specs: A `dict` of input specifications. A dictionary consists of
      {level: TensorShape} from a backbone.
    model_config: A `OneOfConfig` of model config.
    l2_regularizer: A `tf.keras.regularizers.Regularizer` object. Default to
      None.
    **kwargs: Additional keyword args to be passed to decoder builder.

  Returns:
    An instance of the decoder.
  """
    decoder_builder = registry.lookup(_REGISTERED_DECODER_CLS,
                                      model_config.decoder.type)

    return decoder_builder(input_specs=input_specs,
                           model_config=model_config,
                           l2_regularizer=l2_regularizer,
                           **kwargs)
Example #5
0
  def test_register(self):
    collection = {}

    @registry.register(collection, 'functions/func_0')
    def func_test():
      pass

    self.assertEqual(registry.lookup(collection, 'functions/func_0'), func_test)

    @registry.register(collection, 'classes/cls_0')
    class ClassRegistryKey:
      pass

    self.assertEqual(
        registry.lookup(collection, 'classes/cls_0'), ClassRegistryKey)

    @registry.register(collection, ClassRegistryKey)
    class ClassRegistryValue:
      pass

    self.assertEqual(
        registry.lookup(collection, ClassRegistryKey), ClassRegistryValue)
Example #6
0
def build_backbone(input_specs: tf.keras.layers.InputSpec,
                   model_config,
                   l2_regularizer: tf.keras.regularizers.Regularizer = None):
  """Builds backbone from a config.

  Args:
    input_specs: tf.keras.layers.InputSpec.
    model_config: a OneOfConfig. Model config.
    l2_regularizer: tf.keras.regularizers.Regularizer instance. Default to None.

  Returns:
    tf.keras.Model instance of the backbone.
  """
  backbone_builder = registry.lookup(_REGISTERED_BACKBONE_CLS,
                                     model_config.backbone.type)

  return backbone_builder(input_specs, model_config, l2_regularizer)
Example #7
0
def build_model(model_type: str,
                input_specs: tf.keras.layers.InputSpec,
                model_config: video_classification_cfg.hyperparams.Config,
                num_classes: int,
                l2_regularizer: tf.keras.regularizers.Regularizer = None):
  """Builds backbone from a config.

  Args:
    model_type: string name of model type. It should be consistent with
      ModelConfig.model_type.
    input_specs: tf.keras.layers.InputSpec.
    model_config: a OneOfConfig. Model config.
    num_classes: number of classes.
    l2_regularizer: tf.keras.regularizers.Regularizer instance. Default to None.

  Returns:
    tf.keras.Model instance of the backbone.
  """
  model_builder = registry.lookup(_REGISTERED_MODEL_CLS, model_type)

  return model_builder(input_specs, model_config, num_classes, l2_regularizer)
Example #8
0
def get_data_loader(data_config):
    """Creates a data_loader from data_config."""
    return registry.lookup(_REGISTERED_DATA_LOADER_CLS,
                           data_config.__class__)(data_config)
def get_task_cls(task_config_cls):
    task_cls = registry.lookup(_REGISTERED_TASK_CLS, task_config_cls)
    return task_cls
def get_exp_config(exp_name: str) -> cfg.ExperimentConfig:
  """Looks up the `ExperimentConfig` according to the `exp_name`."""
  exp_creater = registry.lookup(_REGISTERED_CONFIGS, exp_name)
  return exp_creater()
def get_exp_config_creater(exp_name: str):
  """Looks up ExperimentConfig factory methods."""
  exp_creater = registry.lookup(_REGISTERED_CONFIGS, exp_name)
  return exp_creater