def build_decoder(input_specs, model_config, l2_regularizer: tf.keras.regularizers.Regularizer = None): """Builds decoder from a config. Args: input_specs: `dict` input specifications. A dictionary consists of {level: TensorShape} from a backbone. model_config: A OneOfConfig. Model config. l2_regularizer: tf.keras.regularizers.Regularizer instance. Default to None. Returns: tf.keras.Model instance of the decoder. """ decoder_type = model_config.decoder.type decoder_cfg = model_config.decoder.get() norm_activation_config = model_config.norm_activation if decoder_type == 'identity': decoder = None elif decoder_type == 'fpn': decoder = decoders.FPN( input_specs=input_specs, min_level=model_config.min_level, max_level=model_config.max_level, num_filters=decoder_cfg.num_filters, use_separable_conv=decoder_cfg.use_separable_conv, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer) elif decoder_type == 'nasfpn': decoder = decoders.NASFPN( input_specs=input_specs, min_level=model_config.min_level, max_level=model_config.max_level, num_filters=decoder_cfg.num_filters, num_repeats=decoder_cfg.num_repeats, use_separable_conv=decoder_cfg.use_separable_conv, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer) elif decoder_type == 'aspp': decoder = decoders.ASPP( level=decoder_cfg.level, dilation_rates=decoder_cfg.dilation_rates, num_filters=decoder_cfg.num_filters, dropout_rate=decoder_cfg.dropout_rate, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, activation=norm_activation_config.activation, kernel_regularizer=l2_regularizer) else: raise ValueError('Decoder {!r} not implement'.format(decoder_type)) return decoder
def test_fpn_decoder_creation(self, num_filters, use_separable_conv): """Test creation of FPN decoder.""" min_level = 3 max_level = 7 input_specs = {} for level in range(min_level, max_level): input_specs[str(level)] = tf.TensorShape( [1, 128 // (2**level), 128 // (2**level), 3]) network = decoders.FPN(input_specs=input_specs, num_filters=num_filters, use_separable_conv=use_separable_conv, use_sync_bn=True) model_config = configs.retinanet.RetinaNet() model_config.min_level = min_level model_config.max_level = max_level model_config.num_classes = 10 model_config.input_size = [None, None, 3] model_config.decoder = decoders_cfg.Decoder( type='fpn', fpn=decoders_cfg.FPN(num_filters=num_filters, use_separable_conv=use_separable_conv)) factory_network = factory.build_decoder(input_specs=input_specs, model_config=model_config) network_config = network.get_config() factory_network_config = factory_network.get_config() self.assertEqual(network_config, factory_network_config)
def build_decoder( input_specs: Mapping[str, tf.TensorShape], model_config: hyperparams.Config, l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None ) -> tf.keras.Model: """Builds decoder from a config. Args: input_specs: A `dict` of input specifications. A dictionary consists of {level: TensorShape} from a backbone. model_config: A OneOfConfig. Model config. l2_regularizer: A `tf.keras.regularizers.Regularizer` instance. Default to None. Returns: A `tf.keras.Model` instance of the decoder. """ decoder_type = model_config.decoder.type decoder_cfg = model_config.decoder.get() norm_activation_config = model_config.norm_activation if decoder_type == 'identity': decoder = None elif decoder_type == 'fpn': decoder = decoders.FPN( input_specs=input_specs, min_level=model_config.min_level, max_level=model_config.max_level, num_filters=decoder_cfg.num_filters, use_separable_conv=decoder_cfg.use_separable_conv, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer) else: raise ValueError('Decoder {!r} not implement'.format(decoder_type)) return decoder
def build_decoder( input_specs: Mapping[str, tf.TensorShape], model_config: hyperparams.Config, norm_activation_config: hyperparams.Config = None, l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None ) -> tf.keras.Model: """Builds decoder from a config. Args: input_specs: A `dict` of input specifications. A dictionary consists of {level: TensorShape} from a backbone. model_config: A OneOfConfig. Model config. l2_regularizer: A `tf.keras.regularizers.Regularizer` instance. Default to None. Returns: A `tf.keras.Model` instance of the decoder. """ decoder_type = model_config.decoder.type decoder_cfg = model_config.decoder.get() if norm_activation_config is None: norm_activation_config = model_config.norm_activation if decoder_type == 'identity': decoder = None elif decoder_type == 'fpn': decoder = decoders.FPN( input_specs=input_specs, min_level=model_config.min_level, max_level=model_config.max_level, num_filters=decoder_cfg.num_filters, use_separable_conv=decoder_cfg.use_separable_conv, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer) elif decoder_type == 'nasfpn': decoder = decoders.NASFPN( input_specs=input_specs, min_level=model_config.min_level, max_level=model_config.max_level, num_filters=decoder_cfg.num_filters, num_repeats=decoder_cfg.num_repeats, use_separable_conv=decoder_cfg.use_separable_conv, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer) elif decoder_type == 'aspp': decoder = decoders.ASPP( level=decoder_cfg.level, dilation_rates=decoder_cfg.dilation_rates, num_filters=decoder_cfg.num_filters, pool_kernel_size=decoder_cfg.pool_kernel_size, dropout_rate=decoder_cfg.dropout_rate, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, activation=norm_activation_config.activation, kernel_regularizer=l2_regularizer) elif decoder_type == 'hardnet': decoder = decoders.HardNetDecoder( model_id=decoder_cfg.model_id, input_specs=input_specs, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer) elif decoder_type == 'pan': decoder = decoders.PAN( input_specs=input_specs, routes=decoder_cfg.levels, num_filters=decoder_cfg.num_filters, num_convs=decoder_cfg.num_convs, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer) else: raise ValueError('Decoder {!r} not implement'.format(decoder_type)) return decoder