def test_aspp_decoder_creation(self, level, dilation_rates, num_filters): """Test creation of ASPP decoder.""" input_specs = {'1': tf.TensorShape([1, 128, 128, 3])} network = decoders.ASPP(level=level, dilation_rates=dilation_rates, num_filters=num_filters, use_sync_bn=True) model_config = configs.semantic_segmentation.SemanticSegmentationModel( ) model_config.num_classes = 10 model_config.input_size = [None, None, 3] model_config.decoder = decoders_cfg.Decoder( type='aspp', aspp=decoders_cfg.ASPP(level=level, dilation_rates=dilation_rates, num_filters=num_filters)) factory_network = factory.build_decoder(input_specs=input_specs, model_config=model_config) network_config = network.get_config() factory_network_config = factory_network.get_config() # Due to calling `super().get_config()` in aspp layer, everything but the # the name of two layer instances are the same, so we force equal name so it # will not give false alarm. factory_network_config['name'] = network_config['name'] self.assertEqual(network_config, factory_network_config)
def build_decoder(input_specs, model_config, l2_regularizer: tf.keras.regularizers.Regularizer = None): """Builds decoder from a config. Args: input_specs: `dict` input specifications. A dictionary consists of {level: TensorShape} from a backbone. model_config: A OneOfConfig. Model config. l2_regularizer: tf.keras.regularizers.Regularizer instance. Default to None. Returns: tf.keras.Model instance of the decoder. """ decoder_type = model_config.decoder.type decoder_cfg = model_config.decoder.get() norm_activation_config = model_config.norm_activation if decoder_type == 'identity': decoder = None elif decoder_type == 'fpn': decoder = decoders.FPN( input_specs=input_specs, min_level=model_config.min_level, max_level=model_config.max_level, num_filters=decoder_cfg.num_filters, use_separable_conv=decoder_cfg.use_separable_conv, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer) elif decoder_type == 'nasfpn': decoder = decoders.NASFPN( input_specs=input_specs, min_level=model_config.min_level, max_level=model_config.max_level, num_filters=decoder_cfg.num_filters, num_repeats=decoder_cfg.num_repeats, use_separable_conv=decoder_cfg.use_separable_conv, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer) elif decoder_type == 'aspp': decoder = decoders.ASPP( level=decoder_cfg.level, dilation_rates=decoder_cfg.dilation_rates, num_filters=decoder_cfg.num_filters, dropout_rate=decoder_cfg.dropout_rate, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, activation=norm_activation_config.activation, kernel_regularizer=l2_regularizer) else: raise ValueError('Decoder {!r} not implement'.format(decoder_type)) return decoder
def test_aspp_decoder_creation(self, level, dilation_rates, num_filters): """Test creation of ASPP decoder.""" input_specs = {'1': tf.TensorShape([1, 128, 128, 3])} network = decoders.ASPP( level=level, dilation_rates=dilation_rates, num_filters=num_filters, use_sync_bn=True) model_config = configs.semantic_segmentation.SemanticSegmentationModel() model_config.num_classes = 10 model_config.input_size = [None, None, 3] model_config.decoder = decoders_cfg.Decoder( type='aspp', aspp=decoders_cfg.ASPP( level=level, dilation_rates=dilation_rates, num_filters=num_filters)) factory_network = factory.build_decoder( input_specs=input_specs, model_config=model_config) network_config = network.get_config() factory_network_config = factory_network.get_config() self.assertEqual(network_config, factory_network_config)
def build_decoder( input_specs: Mapping[str, tf.TensorShape], model_config: hyperparams.Config, norm_activation_config: hyperparams.Config = None, l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None ) -> tf.keras.Model: """Builds decoder from a config. Args: input_specs: A `dict` of input specifications. A dictionary consists of {level: TensorShape} from a backbone. model_config: A OneOfConfig. Model config. l2_regularizer: A `tf.keras.regularizers.Regularizer` instance. Default to None. Returns: A `tf.keras.Model` instance of the decoder. """ decoder_type = model_config.decoder.type decoder_cfg = model_config.decoder.get() if norm_activation_config is None: norm_activation_config = model_config.norm_activation if decoder_type == 'identity': decoder = None elif decoder_type == 'fpn': decoder = decoders.FPN( input_specs=input_specs, min_level=model_config.min_level, max_level=model_config.max_level, num_filters=decoder_cfg.num_filters, use_separable_conv=decoder_cfg.use_separable_conv, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer) elif decoder_type == 'nasfpn': decoder = decoders.NASFPN( input_specs=input_specs, min_level=model_config.min_level, max_level=model_config.max_level, num_filters=decoder_cfg.num_filters, num_repeats=decoder_cfg.num_repeats, use_separable_conv=decoder_cfg.use_separable_conv, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer) elif decoder_type == 'aspp': decoder = decoders.ASPP( level=decoder_cfg.level, dilation_rates=decoder_cfg.dilation_rates, num_filters=decoder_cfg.num_filters, pool_kernel_size=decoder_cfg.pool_kernel_size, dropout_rate=decoder_cfg.dropout_rate, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, activation=norm_activation_config.activation, kernel_regularizer=l2_regularizer) elif decoder_type == 'hardnet': decoder = decoders.HardNetDecoder( model_id=decoder_cfg.model_id, input_specs=input_specs, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer) elif decoder_type == 'pan': decoder = decoders.PAN( input_specs=input_specs, routes=decoder_cfg.levels, num_filters=decoder_cfg.num_filters, num_convs=decoder_cfg.num_convs, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer) else: raise ValueError('Decoder {!r} not implement'.format(decoder_type)) return decoder