Exemplo n.º 1
0
    def _resample_feature_map(self,
                              inputs,
                              input_level,
                              target_level,
                              target_num_filters=256):
        x = inputs
        _, _, _, input_num_filters = x.get_shape().as_list()
        if input_num_filters != target_num_filters:
            x = self._conv_op(filters=target_num_filters,
                              kernel_size=1,
                              padding='same',
                              **self._conv_kwargs)(x)
            x = self._norm_op(**self._norm_kwargs)(x)

        if input_level < target_level:
            stride = int(2**(target_level - input_level))
            return tf.keras.layers.MaxPool2D(pool_size=stride,
                                             strides=stride,
                                             padding='same')(x)
        if input_level > target_level:
            scale = int(2**(input_level - target_level))
            return spatial_transform_ops.nearest_upsampling(x, scale=scale)

        # Force output x to be the same dtype as mixed precision policy. This avoids
        # dtype mismatch when one input (by default float32 dtype) does not meet all
        # the above conditions and is output unchanged, while other inputs are
        # processed to have different dtype, e.g., using bfloat16 on TPU.
        compute_dtype = tf.keras.layers.Layer().dtype_policy.compute_dtype
        if (compute_dtype is not None) and (x.dtype != compute_dtype):
            return tf.cast(x, dtype=compute_dtype)
        else:
            return x
Exemplo n.º 2
0
    def call(self, inputs: Tuple[Union[tf.Tensor, Mapping[str, tf.Tensor]],
                                 Union[tf.Tensor, Mapping[str, tf.Tensor]]]):
        """Forward pass of the segmentation head.

    It supports both a tuple of 2 tensors or 2 dictionaries. The first is
    backbone endpoints, and the second is decoder endpoints. When inputs are
    tensors, they are from a single level of feature maps. When inputs are
    dictionaries, they contain multiple levels of feature maps, where the key
    is the index of feature map.

    Args:
      inputs: A tuple of 2 feature map tensors of shape
        [batch, height_l, width_l, channels] or 2 dictionaries of tensors:
        - key: A `str` of the level of the multilevel features.
        - values: A `tf.Tensor` of the feature map tensors, whose shape is
            [batch, height_l, width_l, channels].
        The first is backbone endpoints, and the second is decoder endpoints.
    Returns:
      segmentation prediction mask: A `tf.Tensor` of the segmentation mask
        scores predicted from input features.
    """

        backbone_output = inputs[0]
        decoder_output = inputs[1]
        if self._config_dict['feature_fusion'] == 'deeplabv3plus':
            # deeplabv3+ feature fusion
            x = decoder_output[str(self._config_dict['level'])] if isinstance(
                decoder_output, dict) else decoder_output
            y = backbone_output[str(
                self._config_dict['low_level'])] if isinstance(
                    backbone_output, dict) else backbone_output
            y = self._dlv3p_norm(self._dlv3p_conv(y))
            y = self._activation(y)

            x = tf.image.resize(x,
                                tf.shape(y)[1:3],
                                method=tf.image.ResizeMethod.BILINEAR)
            x = tf.cast(x, dtype=y.dtype)
            x = tf.concat([x, y], axis=self._bn_axis)
        elif self._config_dict['feature_fusion'] == 'pyramid_fusion':
            if not isinstance(decoder_output, dict):
                raise ValueError('Only support dictionary decoder_output.')
            x = nn_layers.pyramid_feature_fusion(decoder_output,
                                                 self._config_dict['level'])
        elif self._config_dict['feature_fusion'] == 'panoptic_fpn_fusion':
            x = self._panoptic_fpn_fusion(decoder_output)
        else:
            x = decoder_output[str(self._config_dict['level'])] if isinstance(
                decoder_output, dict) else decoder_output

        for conv, norm in zip(self._convs, self._norms):
            x = conv(x)
            x = norm(x)
            x = self._activation(x)
        if self._config_dict['upsample_factor'] > 1:
            x = spatial_transform_ops.nearest_upsampling(
                x, scale=self._config_dict['upsample_factor'])

        return self._classifier(x)
Exemplo n.º 3
0
  def _resample_with_sepconv(self, inputs, input_width, target_width,
                             target_num_filters):
    """Matches resolution and feature dimension."""
    x = inputs
    # Spatial resampling.
    if input_width > target_width:
      while input_width > target_width:
        x = layers.DepthwiseConv2D(
            kernel_size=3,
            strides=2,
            padding='SAME',
            use_bias=False,
            kernel_initializer=self._kernel_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer)(
                x)
        x = self._norm(
            axis=self._bn_axis,
            momentum=self._norm_momentum,
            epsilon=self._norm_epsilon)(
                x)
        x = tf_utils.get_activation(
            self._activation, use_keras_layer=True)(x)
        input_width /= 2
    elif input_width < target_width:
      scale = target_width // input_width
      x = spatial_transform_ops.nearest_upsampling(
          x, scale=scale, use_keras_layer=self._use_keras_upsampling_2d)

    # Last 1x1 conv to match filter size.
    x = layers.Conv2D(
        filters=target_num_filters,
        kernel_size=1,
        strides=1,
        use_bias=False,
        kernel_initializer=self._kernel_initializer,
        kernel_regularizer=self._kernel_regularizer,
        bias_regularizer=self._bias_regularizer)(
            x)
    x = self._norm(
        axis=self._bn_axis,
        momentum=self._norm_momentum,
        epsilon=self._norm_epsilon)(
            x)
    return x
    def call(self, backbone_output: Mapping[str, tf.Tensor],
             decoder_output: Mapping[str, tf.Tensor]):
        """Forward pass of the segmentation head.

    Args:
      backbone_output: A `dict` of tensors
        - key: A `str` of the level of the multilevel features.
        - values: A `tf.Tensor` of the feature map tensors, whose shape is
            [batch, height_l, width_l, channels].
      decoder_output: A `dict` of tensors
        - key: A `str` of the level of the multilevel features.
        - values: A `tf.Tensor` of the feature map tensors, whose shape is
            [batch, height_l, width_l, channels].
    Returns:
      segmentation prediction mask: A `tf.Tensor` of the segmentation mask
        scores predicted from input features.
    """
        if self._config_dict['feature_fusion'] == 'deeplabv3plus':
            # deeplabv3+ feature fusion
            x = decoder_output[str(self._config_dict['level'])]
            y = backbone_output[str(self._config_dict['low_level'])]
            y = self._dlv3p_norm(self._dlv3p_conv(y))
            y = self._activation(y)

            x = tf.image.resize(x,
                                tf.shape(y)[1:3],
                                method=tf.image.ResizeMethod.BILINEAR)
            x = tf.cast(x, dtype=y.dtype)
            x = tf.concat([x, y], axis=self._bn_axis)
        elif self._config_dict['feature_fusion'] == 'pyramid_fusion':
            x = nn_layers.pyramid_feature_fusion(decoder_output,
                                                 self._config_dict['level'])
        else:
            x = decoder_output[str(self._config_dict['level'])]

        for conv, norm in zip(self._convs, self._norms):
            x = conv(x)
            x = norm(x)
            x = self._activation(x)
        if self._config_dict['upsample_factor'] > 1:
            x = spatial_transform_ops.nearest_upsampling(
                x, scale=self._config_dict['upsample_factor'])

        return self._classifier(x)
    def call(self, features):
        """Forward pass of the segmentation head.

    Args:
      features: a dict of tensors
        - key: `str`, the level of the multilevel features.
        - values: `Tensor`, the feature map tensors, whose shape is
            [batch, height_l, width_l, channels].
    Returns:
      segmentation prediction mask: `Tensor`, the segmentation mask scores
        predicted from input feature.
    """
        x = features[str(self._config_dict['level'])]
        for conv, norm in zip(self._convs, self._norms):
            x = conv(x)
            x = norm(x)
            x = self._activation(x)
        x = spatial_transform_ops.nearest_upsampling(
            x, scale=self._config_dict['upsample_factor'])
        return self._classifier(x)
Exemplo n.º 6
0
    def _resample_feature_map(self,
                              inputs,
                              input_level,
                              target_level,
                              target_num_filters=256):
        x = inputs
        _, _, _, input_num_filters = x.get_shape().as_list()
        if input_num_filters != target_num_filters:
            x = self._conv_op(filters=target_num_filters,
                              kernel_size=1,
                              padding='same',
                              **self._conv_kwargs)(x)
            x = self._norm_op(**self._norm_kwargs)(x)

        if input_level < target_level:
            stride = int(2**(target_level - input_level))
            x = tf.keras.layers.MaxPool2D(pool_size=stride,
                                          strides=stride,
                                          padding='same')(x)
        elif input_level > target_level:
            scale = int(2**(input_level - target_level))
            x = spatial_transform_ops.nearest_upsampling(x, scale=scale)

        return x
Exemplo n.º 7
0
    def _resample_with_alpha(self,
                             inputs,
                             input_width,
                             input_block_fn,
                             target_width,
                             target_num_filters,
                             target_block_fn,
                             alpha=0.5):
        """Matches resolution and feature dimension."""
        _, _, _, input_num_filters = inputs.get_shape().as_list()
        if input_block_fn == 'bottleneck':
            input_num_filters /= 4
        new_num_filters = int(input_num_filters * alpha)

        x = layers.Conv2D(filters=new_num_filters,
                          kernel_size=1,
                          strides=1,
                          use_bias=False,
                          kernel_initializer=self._kernel_initializer,
                          kernel_regularizer=self._kernel_regularizer,
                          bias_regularizer=self._bias_regularizer)(inputs)
        x = self._norm(axis=self._bn_axis,
                       momentum=self._norm_momentum,
                       epsilon=self._norm_epsilon)(x)
        x = tf_utils.get_activation(self._activation_fn)(x)

        # Spatial resampling.
        if input_width > target_width:
            x = layers.Conv2D(filters=new_num_filters,
                              kernel_size=3,
                              strides=2,
                              padding='SAME',
                              use_bias=False,
                              kernel_initializer=self._kernel_initializer,
                              kernel_regularizer=self._kernel_regularizer,
                              bias_regularizer=self._bias_regularizer)(x)
            x = self._norm(axis=self._bn_axis,
                           momentum=self._norm_momentum,
                           epsilon=self._norm_epsilon)(x)
            x = tf_utils.get_activation(self._activation_fn)(x)
            input_width /= 2
            while input_width > target_width:
                x = layers.MaxPool2D(pool_size=3, strides=2, padding='SAME')(x)
                input_width /= 2
        elif input_width < target_width:
            scale = target_width // input_width
            x = spatial_transform_ops.nearest_upsampling(x, scale=scale)

        # Last 1x1 conv to match filter size.
        if target_block_fn == 'bottleneck':
            target_num_filters *= 4
        x = layers.Conv2D(filters=target_num_filters,
                          kernel_size=1,
                          strides=1,
                          use_bias=False,
                          kernel_initializer=self._kernel_initializer,
                          kernel_regularizer=self._kernel_regularizer,
                          bias_regularizer=self._bias_regularizer)(x)
        x = self._norm(axis=self._bn_axis,
                       momentum=self._norm_momentum,
                       epsilon=self._norm_epsilon)(x)
        return x
Exemplo n.º 8
0
  def call(self, backbone_output, decoder_output):
    """Forward pass of the segmentation head.

    Args:
      backbone_output: A `dict` of tensors
        - key: A `str` of the level of the multilevel features.
        - values: A `tf.Tensor` of the feature map tensors, whose shape is
            [batch, height_l, width_l, channels].
      decoder_output: A `dict` of tensors
        - key: A `str` of the level of the multilevel features.
        - values: A `tf.Tensor` of the feature map tensors, whose shape is
            [batch, height_l, width_l, channels].
    Returns:
      segmentation prediction mask: A `tf.Tensor` of the segmentation mask
        scores predicted from input features.
    """
    if self._config_dict['feature_fusion'] == 'deeplabv3plus':
      # deeplabv3+ feature fusion
      x = decoder_output[str(self._config_dict['level'])]
      y = backbone_output[str(
          self._config_dict['low_level'])]
      y = self._dlv3p_norm(self._dlv3p_conv(y))
      y = self._activation(y)

      x = tf.image.resize(
          x, tf.shape(y)[1:3], method=tf.image.ResizeMethod.BILINEAR)
      x = tf.cast(x, dtype=y.dtype)
      x = tf.concat([x, y], axis=self._bn_axis)
    elif self._config_dict['feature_fusion'] == 'pyramid_fusion':
      x = nn_layers.pyramid_feature_fusion(decoder_output,
                                           self._config_dict['level'])
    elif self._config_dict['feature_fusion'] == 'deeplabv2':
      # deeplabv2 feature fusion
      for k in decoder_output.keys():
        decoder_output[k] = self._activation(self._dlv2_norm(self._dlv2_conv(decoder_output[k])))
        decoder_output[k] = self._classifier(decoder_output[k])
    elif self._config_dict['feature_fusion'] == 'deeplabv1_msc':
      # deeplabv1 feature fusion for multi-scale prediction
      x = decoder_output[str(self._config_dict['level'])]
      msc_outputs=[]
      for i in range(self._config_dict['low_level']+1):
        x = backbone_output[str(i)]
        x = self._dlv1_msc_norms[i](self._dlv1_msc_convs33[i](x))
        x = self._activation(x)
        x = self._dlv1_msc_norms[i](self._dlv1_msc_convs11[i](x))
        x = self._activation(x)
        msc_outputs.append(x)
      
      x = decoder_output[str(self._config_dict['level'])]
      x = self._dlv1_msc_norms[i+1](self._dlv1_msc_convs33[i+1](x))
      x = self._activation(x)
      x = self._dlv1_msc_norms[i+1](self._dlv1_msc_convs11[i+1](x))
      x = self._activation(x)
      msc_outputs.append(x)
      
      x = tf.concat(msc_outputs, axis=self._bn_axis)
    else:
      x = decoder_output[str(self._config_dict['level'])]

    for conv, norm in zip(self._convs, self._norms):
      x = conv(x)
      x = norm(x)
      x = self._activation(x)
   
    if self._config_dict['feature_fusion'] == 'deeplabv2':
      out = sum(decoder_output.values())
    else:
      x = spatial_transform_ops.nearest_upsampling(
          x, scale=self._config_dict['upsample_factor'])
      out = self._classifier(x)

    return out
Exemplo n.º 9
0
    def __init__(self,
                 input_specs: Mapping[str, tf.TensorShape],
                 min_level: int = 3,
                 max_level: int = 7,
                 num_filters: int = 256,
                 fusion_type: str = 'sum',
                 use_separable_conv: bool = False,
                 activation: str = 'relu',
                 use_sync_bn: bool = False,
                 norm_momentum: float = 0.99,
                 norm_epsilon: float = 0.001,
                 kernel_initializer: str = 'VarianceScaling',
                 kernel_regularizer: Optional[
                     tf.keras.regularizers.Regularizer] = None,
                 bias_regularizer: Optional[
                     tf.keras.regularizers.Regularizer] = None,
                 **kwargs):
        """Initializes a Feature Pyramid Network (FPN).

    Args:
      input_specs: A `dict` of input specifications. A dictionary consists of
        {level: TensorShape} from a backbone.
      min_level: An `int` of minimum level in FPN output feature maps.
      max_level: An `int` of maximum level in FPN output feature maps.
      num_filters: An `int` number of filters in FPN layers.
      fusion_type: A `str` of `sum` or `concat`. Whether performing sum or
        concat for feature fusion.
      use_separable_conv: A `bool`.  If True use separable convolution for
        convolution in FPN layers.
      activation: A `str` name of the activation function.
      use_sync_bn: A `bool`. If True, use synchronized batch normalization.
      norm_momentum: A `float` of normalization momentum for the moving average.
      norm_epsilon: A `float` added to variance to avoid dividing by zero.
      kernel_initializer: A `str` name of kernel_initializer for convolutional
        layers.
      kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
        Conv2D. Default is None.
      bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D.
      **kwargs: Additional keyword arguments to be passed.
    """
        self._config_dict = {
            'input_specs': input_specs,
            'min_level': min_level,
            'max_level': max_level,
            'num_filters': num_filters,
            'fusion_type': fusion_type,
            'use_separable_conv': use_separable_conv,
            'activation': activation,
            'use_sync_bn': use_sync_bn,
            'norm_momentum': norm_momentum,
            'norm_epsilon': norm_epsilon,
            'kernel_initializer': kernel_initializer,
            'kernel_regularizer': kernel_regularizer,
            'bias_regularizer': bias_regularizer,
        }
        if use_separable_conv:
            conv2d = tf.keras.layers.SeparableConv2D
        else:
            conv2d = tf.keras.layers.Conv2D
        if use_sync_bn:
            norm = tf.keras.layers.experimental.SyncBatchNormalization
        else:
            norm = tf.keras.layers.BatchNormalization
        activation_fn = tf.keras.layers.Activation(
            tf_utils.get_activation(activation))

        # Build input feature pyramid.
        if tf.keras.backend.image_data_format() == 'channels_last':
            bn_axis = -1
        else:
            bn_axis = 1

        # Get input feature pyramid from backbone.
        logging.info('FPN input_specs: %s', input_specs)
        inputs = self._build_input_pyramid(input_specs, min_level)
        backbone_max_level = min(int(max(inputs.keys())), max_level)

        # Build lateral connections.
        feats_lateral = {}
        for level in range(min_level, backbone_max_level + 1):
            feats_lateral[str(level)] = conv2d(
                filters=num_filters,
                kernel_size=1,
                padding='same',
                kernel_initializer=kernel_initializer,
                kernel_regularizer=kernel_regularizer,
                bias_regularizer=bias_regularizer)(inputs[str(level)])

        # Build top-down path.
        feats = {
            str(backbone_max_level): feats_lateral[str(backbone_max_level)]
        }
        for level in range(backbone_max_level - 1, min_level - 1, -1):
            feat_a = spatial_transform_ops.nearest_upsampling(
                feats[str(level + 1)], 2)
            feat_b = feats_lateral[str(level)]

            if fusion_type == 'sum':
                feats[str(level)] = feat_a + feat_b
            elif fusion_type == 'concat':
                feats[str(level)] = tf.concat([feat_a, feat_b], axis=-1)
            else:
                raise ValueError(
                    'Fusion type {} not supported.'.format(fusion_type))

        # TODO(xianzhi): consider to remove bias in conv2d.
        # Build post-hoc 3x3 convolution kernel.
        for level in range(min_level, backbone_max_level + 1):
            feats[str(level)] = conv2d(filters=num_filters,
                                       strides=1,
                                       kernel_size=3,
                                       padding='same',
                                       kernel_initializer=kernel_initializer,
                                       kernel_regularizer=kernel_regularizer,
                                       bias_regularizer=bias_regularizer)(
                                           feats[str(level)])

        # TODO(xianzhi): consider to remove bias in conv2d.
        # Build coarser FPN levels introduced for RetinaNet.
        for level in range(backbone_max_level + 1, max_level + 1):
            feats_in = feats[str(level - 1)]
            if level > backbone_max_level + 1:
                feats_in = activation_fn(feats_in)
            feats[str(level)] = conv2d(
                filters=num_filters,
                strides=2,
                kernel_size=3,
                padding='same',
                kernel_initializer=kernel_initializer,
                kernel_regularizer=kernel_regularizer,
                bias_regularizer=bias_regularizer)(feats_in)

        # Apply batch norm layers.
        for level in range(min_level, max_level + 1):
            feats[str(level)] = norm(axis=bn_axis,
                                     momentum=norm_momentum,
                                     epsilon=norm_epsilon)(feats[str(level)])

        self._output_specs = {
            str(level): feats[str(level)].get_shape()
            for level in range(min_level, max_level + 1)
        }

        super(FPN, self).__init__(inputs=inputs, outputs=feats, **kwargs)
Exemplo n.º 10
0
  def __init__(self,
               input_specs,
               min_level=3,
               max_level=7,
               num_filters=256,
               use_separable_conv=False,
               activation='relu',
               use_sync_bn=False,
               norm_momentum=0.99,
               norm_epsilon=0.001,
               kernel_initializer='VarianceScaling',
               kernel_regularizer=None,
               bias_regularizer=None,
               **kwargs):
    """FPN initialization function.

    Args:
      input_specs: `dict` input specifications. A dictionary consists of
        {level: TensorShape} from a backbone.
      min_level: `int` minimum level in FPN output feature maps.
      max_level: `int` maximum level in FPN output feature maps.
      num_filters: `int` number of filters in FPN layers.
      use_separable_conv: `bool`, if True use separable convolution for
        convolution in FPN layers.
      activation: `str` name of the activation function.
      use_sync_bn: if True, use synchronized batch normalization.
      norm_momentum: `float` normalization omentum for the moving average.
      norm_epsilon: `float` small float added to variance to avoid dividing by
        zero.
      kernel_initializer: kernel_initializer for convolutional layers.
      kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
      bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
      **kwargs: keyword arguments to be passed.
    """
    self._config_dict = {
        'input_specs': input_specs,
        'min_level': min_level,
        'max_level': max_level,
        'num_filters': num_filters,
        'use_separable_conv': use_separable_conv,
        'activation': activation,
        'use_sync_bn': use_sync_bn,
        'norm_momentum': norm_momentum,
        'norm_epsilon': norm_epsilon,
        'kernel_initializer': kernel_initializer,
        'kernel_regularizer': kernel_regularizer,
        'bias_regularizer': bias_regularizer,
    }
    if use_separable_conv:
      conv2d = tf.keras.layers.SeparableConv2D
    else:
      conv2d = tf.keras.layers.Conv2D
    if use_sync_bn:
      norm = tf.keras.layers.experimental.SyncBatchNormalization
    else:
      norm = tf.keras.layers.BatchNormalization
    activation_fn = tf.keras.layers.Activation(
        tf_utils.get_activation(activation))

    # Build input feature pyramid.
    if tf.keras.backend.image_data_format() == 'channels_last':
      bn_axis = -1
    else:
      bn_axis = 1

    # Get input feature pyramid from backbone.
    inputs = self._build_input_pyramid(input_specs, min_level)
    backbone_max_level = min(int(max(inputs.keys())), max_level)

    # Build lateral connections.
    feats_lateral = {}
    for level in range(min_level, backbone_max_level + 1):
      feats_lateral[str(level)] = conv2d(
          filters=num_filters,
          kernel_size=1,
          padding='same',
          kernel_initializer=kernel_initializer,
          kernel_regularizer=kernel_regularizer,
          bias_regularizer=bias_regularizer)(
              inputs[str(level)])

    # Build top-down path.
    feats = {str(backbone_max_level): feats_lateral[str(backbone_max_level)]}
    for level in range(backbone_max_level - 1, min_level - 1, -1):
      feats[str(level)] = spatial_transform_ops.nearest_upsampling(
          feats[str(level + 1)], 2) + feats_lateral[str(level)]

    # TODO(xianzhi): consider to remove bias in conv2d.
    # Build post-hoc 3x3 convolution kernel.
    for level in range(min_level, backbone_max_level + 1):
      feats[str(level)] = conv2d(
          filters=num_filters,
          strides=1,
          kernel_size=3,
          padding='same',
          kernel_initializer=kernel_initializer,
          kernel_regularizer=kernel_regularizer,
          bias_regularizer=bias_regularizer)(
              feats[str(level)])

    # TODO(xianzhi): consider to remove bias in conv2d.
    # Build coarser FPN levels introduced for RetinaNet.
    for level in range(backbone_max_level + 1, max_level + 1):
      feats_in = feats[str(level - 1)]
      if level > backbone_max_level + 1:
        feats_in = activation_fn(feats_in)
      feats[str(level)] = conv2d(
          filters=num_filters,
          strides=2,
          kernel_size=3,
          padding='same',
          kernel_initializer=kernel_initializer,
          kernel_regularizer=kernel_regularizer,
          bias_regularizer=bias_regularizer)(
              feats_in)

    # Apply batch norm layers.
    for level in range(min_level, max_level + 1):
      feats[str(level)] = norm(
          axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
              feats[str(level)])

    self._output_specs = {
        str(level): feats[str(level)].get_shape()
        for level in range(min_level, max_level + 1)
    }

    super(FPN, self).__init__(inputs=inputs, outputs=feats, **kwargs)
Exemplo n.º 11
0
  def __init__(
      self,
      min_level: int = 2,
      max_level: int = 5,
      target_level: int = 2,
      num_filters: int = 128,
      num_fpn_filters: int = 256,
      activation: str = 'relu',
      kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
      bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
      **kwargs):

    """Initializes panoptic FPN feature fusion layer.

    Args:
      min_level: An `int` of minimum level to use in feature fusion.
      max_level: An `int` of maximum level to use in feature fusion.
      target_level: An `int` of the target feature level for feature fusion.
      num_filters: An `int` number of filters in conv2d layers.
      num_fpn_filters: An `int` number of filters in the FPN outputs
      activation: A `str` name of the activation function.
      kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
        Conv2D. Default is None.
      bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D.
      **kwargs: Additional keyword arguments to be passed.
    Returns:
      A `float` `tf.Tensor` of shape [batch_size, feature_height, feature_width,
        feature_channel].
    """
    if target_level > max_level:
      raise ValueError('target_level should be less than max_level')

    self._config_dict = {
        'min_level': min_level,
        'max_level': max_level,
        'target_level': target_level,
        'num_filters': num_filters,
        'num_fpn_filters': num_fpn_filters,
        'activation': activation,
        'kernel_regularizer': kernel_regularizer,
        'bias_regularizer': bias_regularizer,
    }
    norm = tfa.layers.GroupNormalization
    conv2d = tf.keras.layers.Conv2D
    activation_fn = tf_utils.get_activation(activation)
    if tf.keras.backend.image_data_format() == 'channels_last':
      norm_axis = -1
    else:
      norm_axis = 1
    inputs = self._build_inputs(num_fpn_filters, min_level, max_level)

    upscaled_features = []
    for level in range(min_level, max_level + 1):
      num_conv_layers = max(1, level - target_level)
      x = inputs[str(level)]
      for i in range(num_conv_layers):
        x = conv2d(
            filters=num_filters,
            kernel_size=3,
            padding='same',
            kernel_initializer=tf.keras.initializers.VarianceScaling(),
            kernel_regularizer=kernel_regularizer,
            bias_regularizer=bias_regularizer)(x)
        x = norm(groups=32, axis=norm_axis)(x)
        x = activation_fn(x)
        if level != target_level:
          x = spatial_transform_ops.nearest_upsampling(x, scale=2)
      upscaled_features.append(x)

    fused_features = tf.math.add_n(upscaled_features)
    self._output_specs = {str(target_level): fused_features.get_shape()}

    super(PanopticFPNFusion, self).__init__(
        inputs=inputs, outputs=fused_features, **kwargs)