def call(self, inputs):

        # expected input: [N, Time, Frequency, Channels]
        if inputs.shape.rank != 4:
            raise ValueError('input_shape.rank:%d must be 4' %
                             inputs.shape.rank)

        net = inputs
        net = self.conv1x1_1(net)
        net = self.batch_norm1(net)
        net = tf.keras.activations.relu(net)
        net = self.frequency_dw_conv(net)
        net = sub_spectral_normalization.SubSpectralNormalization(
            self.sub_groups)(net)

        residual = net
        net = tf.keras.backend.mean(net, axis=2, keepdims=True)
        net = self.temporal_dw_conv(net)
        net = self.batch_norm2(net)
        net = tf.keras.activations.swish(net)
        net = self.conv1x1_2(net)
        net = tf.keras.layers.SpatialDropout2D(rate=self.dropout)(net)

        net = net + residual
        net = tf.keras.activations.relu(net)
        return net
  def __init__(self,
               filters=8,
               dilation=1,
               stride=1,
               padding='same',
               dropout=0.5,
               use_one_step=True,
               sub_groups=5,
               **kwargs):
    super(TransitionBlock, self).__init__(**kwargs)
    self.filters = filters
    self.dilation = dilation
    self.stride = stride
    self.padding = padding
    self.dropout = dropout
    self.use_one_step = use_one_step
    self.sub_groups = sub_groups

    self.frequency_dw_conv = tf.keras.layers.DepthwiseConv2D(
        kernel_size=(1, 3),
        strides=self.stride,
        dilation_rate=self.dilation,
        padding='same',
        use_bias=False)
    if self.padding == 'same':
      self.temporal_dw_conv = tf.keras.layers.DepthwiseConv2D(
          kernel_size=(3, 1),
          strides=self.stride,
          dilation_rate=self.dilation,
          padding='same',
          use_bias=False)
    else:
      self.temporal_dw_conv = stream.Stream(
          cell=tf.keras.layers.DepthwiseConv2D(
              kernel_size=(3, 1),
              strides=self.stride,
              dilation_rate=self.dilation,
              padding='valid',
              use_bias=False),
          use_one_step=use_one_step,
          pad_time_dim=self.padding,
          pad_freq_dim='same')
    self.batch_norm1 = tf.keras.layers.BatchNormalization()
    self.batch_norm2 = tf.keras.layers.BatchNormalization()
    self.conv1x1_1 = tf.keras.layers.Conv2D(
        filters=self.filters,
        kernel_size=1,
        strides=1,
        padding='valid',
        use_bias=False)
    self.conv1x1_2 = tf.keras.layers.Conv2D(
        filters=self.filters,
        kernel_size=1,
        strides=1,
        padding='valid',
        use_bias=False)
    self.spatial_drop = tf.keras.layers.SpatialDropout2D(rate=self.dropout)
    self.spectral_norm = sub_spectral_normalization.SubSpectralNormalization(
        self.sub_groups)