Ejemplo n.º 1
0
    def test_aspp_decoder_creation(self, level, dilation_rates, num_filters):
        """Test creation of ASPP decoder."""
        input_specs = {'1': tf.TensorShape([1, 128, 128, 3])}

        network = decoders.ASPP(level=level,
                                dilation_rates=dilation_rates,
                                num_filters=num_filters,
                                use_sync_bn=True)

        model_config = configs.semantic_segmentation.SemanticSegmentationModel(
        )
        model_config.num_classes = 10
        model_config.input_size = [None, None, 3]
        model_config.decoder = decoders_cfg.Decoder(
            type='aspp',
            aspp=decoders_cfg.ASPP(level=level,
                                   dilation_rates=dilation_rates,
                                   num_filters=num_filters))

        factory_network = factory.build_decoder(input_specs=input_specs,
                                                model_config=model_config)

        network_config = network.get_config()
        factory_network_config = factory_network.get_config()
        # Due to calling `super().get_config()` in aspp layer, everything but the
        # the name of two layer instances are the same, so we force equal name so it
        # will not give false alarm.
        factory_network_config['name'] = network_config['name']

        self.assertEqual(network_config, factory_network_config)
Ejemplo n.º 2
0
def build_decoder(input_specs,
                  model_config,
                  l2_regularizer: tf.keras.regularizers.Regularizer = None):
    """Builds decoder from a config.

  Args:
    input_specs: `dict` input specifications. A dictionary consists of
      {level: TensorShape} from a backbone.
    model_config: A OneOfConfig. Model config.
    l2_regularizer: tf.keras.regularizers.Regularizer instance. Default to None.

  Returns:
    tf.keras.Model instance of the decoder.
  """
    decoder_type = model_config.decoder.type
    decoder_cfg = model_config.decoder.get()
    norm_activation_config = model_config.norm_activation

    if decoder_type == 'identity':
        decoder = None
    elif decoder_type == 'fpn':
        decoder = decoders.FPN(
            input_specs=input_specs,
            min_level=model_config.min_level,
            max_level=model_config.max_level,
            num_filters=decoder_cfg.num_filters,
            use_separable_conv=decoder_cfg.use_separable_conv,
            activation=norm_activation_config.activation,
            use_sync_bn=norm_activation_config.use_sync_bn,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon,
            kernel_regularizer=l2_regularizer)
    elif decoder_type == 'nasfpn':
        decoder = decoders.NASFPN(
            input_specs=input_specs,
            min_level=model_config.min_level,
            max_level=model_config.max_level,
            num_filters=decoder_cfg.num_filters,
            num_repeats=decoder_cfg.num_repeats,
            use_separable_conv=decoder_cfg.use_separable_conv,
            activation=norm_activation_config.activation,
            use_sync_bn=norm_activation_config.use_sync_bn,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon,
            kernel_regularizer=l2_regularizer)
    elif decoder_type == 'aspp':
        decoder = decoders.ASPP(
            level=decoder_cfg.level,
            dilation_rates=decoder_cfg.dilation_rates,
            num_filters=decoder_cfg.num_filters,
            dropout_rate=decoder_cfg.dropout_rate,
            use_sync_bn=norm_activation_config.use_sync_bn,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon,
            activation=norm_activation_config.activation,
            kernel_regularizer=l2_regularizer)
    else:
        raise ValueError('Decoder {!r} not implement'.format(decoder_type))

    return decoder
Ejemplo n.º 3
0
  def test_aspp_decoder_creation(self, level, dilation_rates, num_filters):
    """Test creation of ASPP decoder."""
    input_specs = {'1': tf.TensorShape([1, 128, 128, 3])}

    network = decoders.ASPP(
        level=level,
        dilation_rates=dilation_rates,
        num_filters=num_filters,
        use_sync_bn=True)

    model_config = configs.semantic_segmentation.SemanticSegmentationModel()
    model_config.num_classes = 10
    model_config.input_size = [None, None, 3]
    model_config.decoder = decoders_cfg.Decoder(
        type='aspp',
        aspp=decoders_cfg.ASPP(
            level=level, dilation_rates=dilation_rates,
            num_filters=num_filters))

    factory_network = factory.build_decoder(
        input_specs=input_specs, model_config=model_config)

    network_config = network.get_config()
    factory_network_config = factory_network.get_config()

    self.assertEqual(network_config, factory_network_config)
Ejemplo n.º 4
0
def build_decoder(
    input_specs: Mapping[str, tf.TensorShape],
    model_config: hyperparams.Config,
    norm_activation_config: hyperparams.Config = None,
    l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None
) -> tf.keras.Model:
    """Builds decoder from a config.

  Args:
    input_specs: A `dict` of input specifications. A dictionary consists of
      {level: TensorShape} from a backbone.
    model_config: A OneOfConfig. Model config.
    l2_regularizer: A `tf.keras.regularizers.Regularizer` instance. Default to
      None.

  Returns:
    A `tf.keras.Model` instance of the decoder.
  """
    decoder_type = model_config.decoder.type
    decoder_cfg = model_config.decoder.get()
    if norm_activation_config is None:
        norm_activation_config = model_config.norm_activation

    if decoder_type == 'identity':
        decoder = None
    elif decoder_type == 'fpn':
        decoder = decoders.FPN(
            input_specs=input_specs,
            min_level=model_config.min_level,
            max_level=model_config.max_level,
            num_filters=decoder_cfg.num_filters,
            use_separable_conv=decoder_cfg.use_separable_conv,
            activation=norm_activation_config.activation,
            use_sync_bn=norm_activation_config.use_sync_bn,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon,
            kernel_regularizer=l2_regularizer)
    elif decoder_type == 'nasfpn':
        decoder = decoders.NASFPN(
            input_specs=input_specs,
            min_level=model_config.min_level,
            max_level=model_config.max_level,
            num_filters=decoder_cfg.num_filters,
            num_repeats=decoder_cfg.num_repeats,
            use_separable_conv=decoder_cfg.use_separable_conv,
            activation=norm_activation_config.activation,
            use_sync_bn=norm_activation_config.use_sync_bn,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon,
            kernel_regularizer=l2_regularizer)
    elif decoder_type == 'aspp':
        decoder = decoders.ASPP(
            level=decoder_cfg.level,
            dilation_rates=decoder_cfg.dilation_rates,
            num_filters=decoder_cfg.num_filters,
            pool_kernel_size=decoder_cfg.pool_kernel_size,
            dropout_rate=decoder_cfg.dropout_rate,
            use_sync_bn=norm_activation_config.use_sync_bn,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon,
            activation=norm_activation_config.activation,
            kernel_regularizer=l2_regularizer)
    elif decoder_type == 'hardnet':
        decoder = decoders.HardNetDecoder(
            model_id=decoder_cfg.model_id,
            input_specs=input_specs,
            activation=norm_activation_config.activation,
            use_sync_bn=norm_activation_config.use_sync_bn,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon,
            kernel_regularizer=l2_regularizer)
    elif decoder_type == 'pan':
        decoder = decoders.PAN(
            input_specs=input_specs,
            routes=decoder_cfg.levels,
            num_filters=decoder_cfg.num_filters,
            num_convs=decoder_cfg.num_convs,
            activation=norm_activation_config.activation,
            use_sync_bn=norm_activation_config.use_sync_bn,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon,
            kernel_regularizer=l2_regularizer)
    else:
        raise ValueError('Decoder {!r} not implement'.format(decoder_type))

    return decoder