Ejemplo n.º 1
0
def build_decoder(input_specs,
                  model_config,
                  l2_regularizer: tf.keras.regularizers.Regularizer = None):
    """Builds decoder from a config.

  Args:
    input_specs: `dict` input specifications. A dictionary consists of
      {level: TensorShape} from a backbone.
    model_config: A OneOfConfig. Model config.
    l2_regularizer: tf.keras.regularizers.Regularizer instance. Default to None.

  Returns:
    tf.keras.Model instance of the decoder.
  """
    decoder_type = model_config.decoder.type
    decoder_cfg = model_config.decoder.get()
    norm_activation_config = model_config.norm_activation

    if decoder_type == 'identity':
        decoder = None
    elif decoder_type == 'fpn':
        decoder = decoders.FPN(
            input_specs=input_specs,
            min_level=model_config.min_level,
            max_level=model_config.max_level,
            num_filters=decoder_cfg.num_filters,
            use_separable_conv=decoder_cfg.use_separable_conv,
            activation=norm_activation_config.activation,
            use_sync_bn=norm_activation_config.use_sync_bn,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon,
            kernel_regularizer=l2_regularizer)
    elif decoder_type == 'nasfpn':
        decoder = decoders.NASFPN(
            input_specs=input_specs,
            min_level=model_config.min_level,
            max_level=model_config.max_level,
            num_filters=decoder_cfg.num_filters,
            num_repeats=decoder_cfg.num_repeats,
            use_separable_conv=decoder_cfg.use_separable_conv,
            activation=norm_activation_config.activation,
            use_sync_bn=norm_activation_config.use_sync_bn,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon,
            kernel_regularizer=l2_regularizer)
    elif decoder_type == 'aspp':
        decoder = decoders.ASPP(
            level=decoder_cfg.level,
            dilation_rates=decoder_cfg.dilation_rates,
            num_filters=decoder_cfg.num_filters,
            dropout_rate=decoder_cfg.dropout_rate,
            use_sync_bn=norm_activation_config.use_sync_bn,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon,
            activation=norm_activation_config.activation,
            kernel_regularizer=l2_regularizer)
    else:
        raise ValueError('Decoder {!r} not implement'.format(decoder_type))

    return decoder
Ejemplo n.º 2
0
    def test_fpn_decoder_creation(self, num_filters, use_separable_conv):
        """Test creation of FPN decoder."""
        min_level = 3
        max_level = 7
        input_specs = {}
        for level in range(min_level, max_level):
            input_specs[str(level)] = tf.TensorShape(
                [1, 128 // (2**level), 128 // (2**level), 3])

        network = decoders.FPN(input_specs=input_specs,
                               num_filters=num_filters,
                               use_separable_conv=use_separable_conv,
                               use_sync_bn=True)

        model_config = configs.retinanet.RetinaNet()
        model_config.min_level = min_level
        model_config.max_level = max_level
        model_config.num_classes = 10
        model_config.input_size = [None, None, 3]
        model_config.decoder = decoders_cfg.Decoder(
            type='fpn',
            fpn=decoders_cfg.FPN(num_filters=num_filters,
                                 use_separable_conv=use_separable_conv))

        factory_network = factory.build_decoder(input_specs=input_specs,
                                                model_config=model_config)

        network_config = network.get_config()
        factory_network_config = factory_network.get_config()

        self.assertEqual(network_config, factory_network_config)
Ejemplo n.º 3
0
def build_decoder(
    input_specs: Mapping[str, tf.TensorShape],
    model_config: hyperparams.Config,
    l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None
) -> tf.keras.Model:
    """Builds decoder from a config.

  Args:
    input_specs: A `dict` of input specifications. A dictionary consists of
      {level: TensorShape} from a backbone.
    model_config: A OneOfConfig. Model config.
    l2_regularizer: A `tf.keras.regularizers.Regularizer` instance. Default to
      None.

  Returns:
    A `tf.keras.Model` instance of the decoder.
  """
    decoder_type = model_config.decoder.type
    decoder_cfg = model_config.decoder.get()
    norm_activation_config = model_config.norm_activation

    if decoder_type == 'identity':
        decoder = None
    elif decoder_type == 'fpn':
        decoder = decoders.FPN(
            input_specs=input_specs,
            min_level=model_config.min_level,
            max_level=model_config.max_level,
            num_filters=decoder_cfg.num_filters,
            use_separable_conv=decoder_cfg.use_separable_conv,
            activation=norm_activation_config.activation,
            use_sync_bn=norm_activation_config.use_sync_bn,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon,
            kernel_regularizer=l2_regularizer)
    else:
        raise ValueError('Decoder {!r} not implement'.format(decoder_type))

    return decoder
Ejemplo n.º 4
0
def build_decoder(
    input_specs: Mapping[str, tf.TensorShape],
    model_config: hyperparams.Config,
    norm_activation_config: hyperparams.Config = None,
    l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None
) -> tf.keras.Model:
    """Builds decoder from a config.

  Args:
    input_specs: A `dict` of input specifications. A dictionary consists of
      {level: TensorShape} from a backbone.
    model_config: A OneOfConfig. Model config.
    l2_regularizer: A `tf.keras.regularizers.Regularizer` instance. Default to
      None.

  Returns:
    A `tf.keras.Model` instance of the decoder.
  """
    decoder_type = model_config.decoder.type
    decoder_cfg = model_config.decoder.get()
    if norm_activation_config is None:
        norm_activation_config = model_config.norm_activation

    if decoder_type == 'identity':
        decoder = None
    elif decoder_type == 'fpn':
        decoder = decoders.FPN(
            input_specs=input_specs,
            min_level=model_config.min_level,
            max_level=model_config.max_level,
            num_filters=decoder_cfg.num_filters,
            use_separable_conv=decoder_cfg.use_separable_conv,
            activation=norm_activation_config.activation,
            use_sync_bn=norm_activation_config.use_sync_bn,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon,
            kernel_regularizer=l2_regularizer)
    elif decoder_type == 'nasfpn':
        decoder = decoders.NASFPN(
            input_specs=input_specs,
            min_level=model_config.min_level,
            max_level=model_config.max_level,
            num_filters=decoder_cfg.num_filters,
            num_repeats=decoder_cfg.num_repeats,
            use_separable_conv=decoder_cfg.use_separable_conv,
            activation=norm_activation_config.activation,
            use_sync_bn=norm_activation_config.use_sync_bn,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon,
            kernel_regularizer=l2_regularizer)
    elif decoder_type == 'aspp':
        decoder = decoders.ASPP(
            level=decoder_cfg.level,
            dilation_rates=decoder_cfg.dilation_rates,
            num_filters=decoder_cfg.num_filters,
            pool_kernel_size=decoder_cfg.pool_kernel_size,
            dropout_rate=decoder_cfg.dropout_rate,
            use_sync_bn=norm_activation_config.use_sync_bn,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon,
            activation=norm_activation_config.activation,
            kernel_regularizer=l2_regularizer)
    elif decoder_type == 'hardnet':
        decoder = decoders.HardNetDecoder(
            model_id=decoder_cfg.model_id,
            input_specs=input_specs,
            activation=norm_activation_config.activation,
            use_sync_bn=norm_activation_config.use_sync_bn,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon,
            kernel_regularizer=l2_regularizer)
    elif decoder_type == 'pan':
        decoder = decoders.PAN(
            input_specs=input_specs,
            routes=decoder_cfg.levels,
            num_filters=decoder_cfg.num_filters,
            num_convs=decoder_cfg.num_convs,
            activation=norm_activation_config.activation,
            use_sync_bn=norm_activation_config.use_sync_bn,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon,
            kernel_regularizer=l2_regularizer)
    else:
        raise ValueError('Decoder {!r} not implement'.format(decoder_type))

    return decoder