Esempio n. 1
0
    def build(self, input_shape):
        if self._add_pos_embed:
            self._pos_embed = AddPositionEmbs(
                posemb_init=tf.keras.initializers.RandomNormal(stddev=0.02),
                posemb_origin_shape=self._pos_embed_origin_shape,
                posemb_target_shape=self._pos_embed_target_shape,
                name='posembed_input')
        self._dropout = layers.Dropout(rate=self._dropout_rate)

        self._encoder_layers = []
        # Set layer norm epsilons to 1e-6 to be consistent with JAX implementation.
        # https://flax.readthedocs.io/en/latest/_autosummary/flax.deprecated.nn.LayerNorm.html
        for i in range(self._num_layers):
            encoder_layer = nn_blocks.TransformerEncoderBlock(
                inner_activation=activations.gelu,
                num_attention_heads=self._num_heads,
                inner_dim=self._mlp_dim,
                output_dropout=self._dropout_rate,
                attention_dropout=self._attention_dropout_rate,
                kernel_regularizer=self._kernel_regularizer,
                kernel_initializer=self._kernel_initializer,
                norm_first=True,
                stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
                    self._init_stochastic_depth_rate, i + 1, self._num_layers),
                norm_epsilon=1e-6)
            self._encoder_layers.append(encoder_layer)
        self._norm = layers.LayerNormalization(epsilon=1e-6)
        super().build(input_shape)
Esempio n. 2
0
    def _build_model(self, inputs):
        """Builds model architecture.

    Args:
      inputs: the keras input spec.

    Returns:
      endpoints: A dictionary of backbone endpoint features.
    """
        # Build stem.
        x = self._build_stem(inputs, stem_type=self._stem_type)

        temporal_kernel_size = 1 if self._stem_pool_temporal_stride == 1 else 3
        x = layers.MaxPool3D(pool_size=[temporal_kernel_size, 3, 3],
                             strides=[self._stem_pool_temporal_stride, 2, 2],
                             padding='same')(x)

        # Build intermediate blocks and endpoints.
        resnet_specs = RESNET_SPECS[self._model_id]
        if len(self._temporal_strides) != len(resnet_specs) or len(
                self._temporal_kernel_sizes) != len(resnet_specs):
            raise ValueError(
                'Number of blocks in temporal specs should equal to resnet_specs.'
            )

        endpoints = {}
        for i, resnet_spec in enumerate(resnet_specs):
            if resnet_spec[0] == 'bottleneck3d':
                block_fn = nn_blocks_3d.BottleneckBlock3D
            else:
                raise ValueError('Block fn `{}` is not supported.'.format(
                    resnet_spec[0]))

            use_self_gating = (self._use_self_gating[i]
                               if self._use_self_gating else False)
            x = self._block_group(
                inputs=x,
                filters=resnet_spec[1],
                temporal_kernel_sizes=self._temporal_kernel_sizes[i],
                temporal_strides=self._temporal_strides[i],
                spatial_strides=(1 if i == 0 else 2),
                block_fn=block_fn,
                block_repeats=resnet_spec[2],
                stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
                    self._init_stochastic_depth_rate, i + 2, 5),
                use_self_gating=use_self_gating,
                name='block_group_l{}'.format(i + 2))
            endpoints[str(i + 2)] = x

        return endpoints
Esempio n. 3
0
    def __init__(self,
                 model_id: int,
                 input_specs: tf.keras.layers.InputSpec = layers.InputSpec(
                     shape=[None, None, None, 3]),
                 depth_multiplier: float = 1.0,
                 stem_type: str = 'v0',
                 resnetd_shortcut: bool = False,
                 replace_stem_max_pool: bool = False,
                 se_ratio: Optional[float] = None,
                 init_stochastic_depth_rate: float = 0.0,
                 scale_stem: bool = True,
                 activation: str = 'relu',
                 use_sync_bn: bool = False,
                 norm_momentum: float = 0.99,
                 norm_epsilon: float = 0.001,
                 kernel_initializer: str = 'VarianceScaling',
                 kernel_regularizer: Optional[
                     tf.keras.regularizers.Regularizer] = None,
                 bias_regularizer: Optional[
                     tf.keras.regularizers.Regularizer] = None,
                 bn_trainable: bool = True,
                 **kwargs):
        """Initializes a ResNet model.

    Args:
      model_id: An `int` of the depth of ResNet backbone model.
      input_specs: A `tf.keras.layers.InputSpec` of the input tensor.
      depth_multiplier: A `float` of the depth multiplier to uniformaly scale up
        all layers in channel size. This argument is also referred to as
        `width_multiplier` in (https://arxiv.org/abs/2103.07579).
      stem_type: A `str` of stem type of ResNet. Default to `v0`. If set to
        `v1`, use ResNet-D type stem (https://arxiv.org/abs/1812.01187).
      resnetd_shortcut: A `bool` of whether to use ResNet-D shortcut in
        downsampling blocks.
      replace_stem_max_pool: A `bool` of whether to replace the max pool in stem
        with a stride-2 conv,
      se_ratio: A `float` or None. Ratio of the Squeeze-and-Excitation layer.
      init_stochastic_depth_rate: A `float` of initial stochastic depth rate.
      scale_stem: A `bool` of whether to scale stem layers.
      activation: A `str` name of the activation function.
      use_sync_bn: If True, use synchronized batch normalization.
      norm_momentum: A `float` of normalization momentum for the moving average.
      norm_epsilon: A small `float` added to variance to avoid dividing by zero.
      kernel_initializer: A str for kernel initializer of convolutional layers.
      kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
        Conv2D. Default to None.
      bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D.
        Default to None.
      bn_trainable: A `bool` that indicates whether batch norm layers should be
        trainable. Default to True.
      **kwargs: Additional keyword arguments to be passed.
    """
        self._model_id = model_id
        self._input_specs = input_specs
        self._depth_multiplier = depth_multiplier
        self._stem_type = stem_type
        self._resnetd_shortcut = resnetd_shortcut
        self._replace_stem_max_pool = replace_stem_max_pool
        self._se_ratio = se_ratio
        self._init_stochastic_depth_rate = init_stochastic_depth_rate
        self._scale_stem = scale_stem
        self._use_sync_bn = use_sync_bn
        self._activation = activation
        self._norm_momentum = norm_momentum
        self._norm_epsilon = norm_epsilon
        if use_sync_bn:
            self._norm = layers.experimental.SyncBatchNormalization
        else:
            self._norm = layers.BatchNormalization
        self._kernel_initializer = kernel_initializer
        self._kernel_regularizer = kernel_regularizer
        self._bias_regularizer = bias_regularizer
        self._bn_trainable = bn_trainable

        if tf.keras.backend.image_data_format() == 'channels_last':
            bn_axis = -1
        else:
            bn_axis = 1

        # Build ResNet.
        inputs = tf.keras.Input(shape=input_specs.shape[1:])

        stem_depth_multiplier = self._depth_multiplier if scale_stem else 1.0
        if stem_type == 'v0':
            x = layers.Conv2D(filters=int(64 * stem_depth_multiplier),
                              kernel_size=7,
                              strides=2,
                              use_bias=False,
                              padding='same',
                              kernel_initializer=self._kernel_initializer,
                              kernel_regularizer=self._kernel_regularizer,
                              bias_regularizer=self._bias_regularizer)(inputs)
            x = self._norm(axis=bn_axis,
                           momentum=norm_momentum,
                           epsilon=norm_epsilon,
                           trainable=bn_trainable)(x)
            x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
        elif stem_type == 'v1':
            x = layers.Conv2D(filters=int(32 * stem_depth_multiplier),
                              kernel_size=3,
                              strides=2,
                              use_bias=False,
                              padding='same',
                              kernel_initializer=self._kernel_initializer,
                              kernel_regularizer=self._kernel_regularizer,
                              bias_regularizer=self._bias_regularizer)(inputs)
            x = self._norm(axis=bn_axis,
                           momentum=norm_momentum,
                           epsilon=norm_epsilon,
                           trainable=bn_trainable)(x)
            x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
            x = layers.Conv2D(filters=int(32 * stem_depth_multiplier),
                              kernel_size=3,
                              strides=1,
                              use_bias=False,
                              padding='same',
                              kernel_initializer=self._kernel_initializer,
                              kernel_regularizer=self._kernel_regularizer,
                              bias_regularizer=self._bias_regularizer)(x)
            x = self._norm(axis=bn_axis,
                           momentum=norm_momentum,
                           epsilon=norm_epsilon,
                           trainable=bn_trainable)(x)
            x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
            x = layers.Conv2D(filters=int(64 * stem_depth_multiplier),
                              kernel_size=3,
                              strides=1,
                              use_bias=False,
                              padding='same',
                              kernel_initializer=self._kernel_initializer,
                              kernel_regularizer=self._kernel_regularizer,
                              bias_regularizer=self._bias_regularizer)(x)
            x = self._norm(axis=bn_axis,
                           momentum=norm_momentum,
                           epsilon=norm_epsilon,
                           trainable=bn_trainable)(x)
            x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
        else:
            raise ValueError('Stem type {} not supported.'.format(stem_type))

        if replace_stem_max_pool:
            x = layers.Conv2D(filters=int(64 * self._depth_multiplier),
                              kernel_size=3,
                              strides=2,
                              use_bias=False,
                              padding='same',
                              kernel_initializer=self._kernel_initializer,
                              kernel_regularizer=self._kernel_regularizer,
                              bias_regularizer=self._bias_regularizer)(x)
            x = self._norm(axis=bn_axis,
                           momentum=norm_momentum,
                           epsilon=norm_epsilon,
                           trainable=bn_trainable)(x)
            x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
        else:
            x = layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x)

        endpoints = {}
        for i, spec in enumerate(RESNET_SPECS[model_id]):
            if spec[0] == 'residual':
                block_fn = nn_blocks.ResidualBlock
            elif spec[0] == 'bottleneck':
                block_fn = nn_blocks.BottleneckBlock
            else:
                raise ValueError('Block fn `{}` is not supported.'.format(
                    spec[0]))
            x = self._block_group(
                inputs=x,
                filters=int(spec[1] * self._depth_multiplier),
                strides=(1 if i == 0 else 2),
                block_fn=block_fn,
                block_repeats=spec[2],
                stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
                    self._init_stochastic_depth_rate, i + 2, 5),
                name='block_group_l{}'.format(i + 2))
            endpoints[str(i + 2)] = x

        self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}

        super(ResNet, self).__init__(inputs=inputs,
                                     outputs=endpoints,
                                     **kwargs)
Esempio n. 4
0
    def __init__(self,
                 model_id: int,
                 temporal_strides: List[int],
                 temporal_kernel_sizes: List[Tuple[int]],
                 use_self_gating: Optional[List[int]] = None,
                 input_specs: tf.keras.layers.InputSpec = layers.InputSpec(
                     shape=[None, None, None, None, 3]),
                 stem_type: str = 'v0',
                 stem_conv_temporal_kernel_size: int = 5,
                 stem_conv_temporal_stride: int = 2,
                 stem_pool_temporal_stride: int = 2,
                 init_stochastic_depth_rate: float = 0.0,
                 activation: str = 'relu',
                 se_ratio: Optional[float] = None,
                 use_sync_bn: bool = False,
                 norm_momentum: float = 0.99,
                 norm_epsilon: float = 0.001,
                 kernel_initializer: str = 'VarianceScaling',
                 kernel_regularizer: Optional[
                     tf.keras.regularizers.Regularizer] = None,
                 bias_regularizer: Optional[
                     tf.keras.regularizers.Regularizer] = None,
                 **kwargs):
        """Initializes a 3D ResNet model.

    Args:
      model_id: An `int` of depth of ResNet backbone model.
      temporal_strides: A list of integers that specifies the temporal strides
        for all 3d blocks.
      temporal_kernel_sizes: A list of tuples that specifies the temporal kernel
        sizes for all 3d blocks in different block groups.
      use_self_gating: A list of booleans to specify applying self-gating module
        or not in each block group. If None, self-gating is not applied.
      input_specs: A `tf.keras.layers.InputSpec` of the input tensor.
      stem_type: A `str` of stem type of ResNet. Default to `v0`. If set to
        `v1`, use ResNet-D type stem (https://arxiv.org/abs/1812.01187).
      stem_conv_temporal_kernel_size: An `int` of temporal kernel size for the
        first conv layer.
      stem_conv_temporal_stride: An `int` of temporal stride for the first conv
        layer.
      stem_pool_temporal_stride: An `int` of temporal stride for the first pool
        layer.
      init_stochastic_depth_rate: A `float` of initial stochastic depth rate.
      activation: A `str` of name of the activation function.
      se_ratio: A `float` or None. Ratio of the Squeeze-and-Excitation layer.
      use_sync_bn: If True, use synchronized batch normalization.
      norm_momentum: A `float` of normalization momentum for the moving average.
      norm_epsilon: A `float` added to variance to avoid dividing by zero.
      kernel_initializer: A str for kernel initializer of convolutional layers.
      kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
        Conv2D. Default to None.
      bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D.
        Default to None.
      **kwargs: Additional keyword arguments to be passed.
    """
        self._model_id = model_id
        self._temporal_strides = temporal_strides
        self._temporal_kernel_sizes = temporal_kernel_sizes
        self._input_specs = input_specs
        self._stem_type = stem_type
        self._stem_conv_temporal_kernel_size = stem_conv_temporal_kernel_size
        self._stem_conv_temporal_stride = stem_conv_temporal_stride
        self._stem_pool_temporal_stride = stem_pool_temporal_stride
        self._use_self_gating = use_self_gating
        self._se_ratio = se_ratio
        self._init_stochastic_depth_rate = init_stochastic_depth_rate
        self._use_sync_bn = use_sync_bn
        self._activation = activation
        self._norm_momentum = norm_momentum
        self._norm_epsilon = norm_epsilon
        if use_sync_bn:
            self._norm = layers.experimental.SyncBatchNormalization
        else:
            self._norm = layers.BatchNormalization
        self._kernel_initializer = kernel_initializer
        self._kernel_regularizer = kernel_regularizer
        self._bias_regularizer = bias_regularizer
        if tf.keras.backend.image_data_format() == 'channels_last':
            bn_axis = -1
        else:
            bn_axis = 1

        # Build ResNet3D backbone.
        inputs = tf.keras.Input(shape=input_specs.shape[1:])

        # Build stem.
        if stem_type == 'v0':
            x = layers.Conv3D(
                filters=64,
                kernel_size=[stem_conv_temporal_kernel_size, 7, 7],
                strides=[stem_conv_temporal_stride, 2, 2],
                use_bias=False,
                padding='same',
                kernel_initializer=self._kernel_initializer,
                kernel_regularizer=self._kernel_regularizer,
                bias_regularizer=self._bias_regularizer)(inputs)
            x = self._norm(axis=bn_axis,
                           momentum=norm_momentum,
                           epsilon=norm_epsilon)(x)
            x = tf_utils.get_activation(activation)(x)
        elif stem_type == 'v1':
            x = layers.Conv3D(
                filters=32,
                kernel_size=[stem_conv_temporal_kernel_size, 3, 3],
                strides=[stem_conv_temporal_stride, 2, 2],
                use_bias=False,
                padding='same',
                kernel_initializer=self._kernel_initializer,
                kernel_regularizer=self._kernel_regularizer,
                bias_regularizer=self._bias_regularizer)(inputs)
            x = self._norm(axis=bn_axis,
                           momentum=norm_momentum,
                           epsilon=norm_epsilon)(x)
            x = tf_utils.get_activation(activation)(x)
            x = layers.Conv3D(filters=32,
                              kernel_size=[1, 3, 3],
                              strides=[1, 1, 1],
                              use_bias=False,
                              padding='same',
                              kernel_initializer=self._kernel_initializer,
                              kernel_regularizer=self._kernel_regularizer,
                              bias_regularizer=self._bias_regularizer)(x)
            x = self._norm(axis=bn_axis,
                           momentum=norm_momentum,
                           epsilon=norm_epsilon)(x)
            x = tf_utils.get_activation(activation)(x)
            x = layers.Conv3D(filters=64,
                              kernel_size=[1, 3, 3],
                              strides=[1, 1, 1],
                              use_bias=False,
                              padding='same',
                              kernel_initializer=self._kernel_initializer,
                              kernel_regularizer=self._kernel_regularizer,
                              bias_regularizer=self._bias_regularizer)(x)
            x = self._norm(axis=bn_axis,
                           momentum=norm_momentum,
                           epsilon=norm_epsilon)(x)
            x = tf_utils.get_activation(activation)(x)
        else:
            raise ValueError(f'Stem type {stem_type} not supported.')

        temporal_kernel_size = 1 if stem_pool_temporal_stride == 1 else 3
        x = layers.MaxPool3D(pool_size=[temporal_kernel_size, 3, 3],
                             strides=[stem_pool_temporal_stride, 2, 2],
                             padding='same')(x)

        # Build intermediate blocks and endpoints.
        resnet_specs = RESNET_SPECS[model_id]
        if len(temporal_strides) != len(resnet_specs) or len(
                temporal_kernel_sizes) != len(resnet_specs):
            raise ValueError(
                'Number of blocks in temporal specs should equal to resnet_specs.'
            )

        endpoints = {}
        for i, resnet_spec in enumerate(resnet_specs):
            if resnet_spec[0] == 'bottleneck3d':
                block_fn = nn_blocks_3d.BottleneckBlock3D
            else:
                raise ValueError('Block fn `{}` is not supported.'.format(
                    resnet_spec[0]))

            x = self._block_group(
                inputs=x,
                filters=resnet_spec[1],
                temporal_kernel_sizes=temporal_kernel_sizes[i],
                temporal_strides=temporal_strides[i],
                spatial_strides=(1 if i == 0 else 2),
                block_fn=block_fn,
                block_repeats=resnet_spec[2],
                stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
                    self._init_stochastic_depth_rate, i + 2, 5),
                use_self_gating=use_self_gating[i]
                if use_self_gating else False,
                name='block_group_l{}'.format(i + 2))
            endpoints[str(i + 2)] = x

        self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}

        super(ResNet3D, self).__init__(inputs=inputs,
                                       outputs=endpoints,
                                       **kwargs)
Esempio n. 5
0
  def _build_scale_permuted_network(self,
                                    net,
                                    input_width,
                                    weighted_fusion=False):
    """Builds scale-permuted network."""
    net_sizes = [int(math.ceil(input_width / 2**2))] * len(net)
    net_block_fns = [self._init_block_fn] * len(net)
    num_outgoing_connections = [0] * len(net)

    endpoints = {}
    for i, block_spec in enumerate(self._block_specs):
      # Find out specs for the target block.
      target_width = int(math.ceil(input_width / 2**block_spec.level))
      target_num_filters = int(FILTER_SIZE_MAP[block_spec.level] *
                               self._filter_size_scale)
      target_block_fn = block_spec.block_fn

      # Resample then merge input0 and input1.
      parents = []
      input0 = block_spec.input_offsets[0]
      input1 = block_spec.input_offsets[1]

      x0 = self._resample_with_alpha(
          inputs=net[input0],
          input_width=net_sizes[input0],
          input_block_fn=net_block_fns[input0],
          target_width=target_width,
          target_num_filters=target_num_filters,
          target_block_fn=target_block_fn,
          alpha=self._resample_alpha)
      parents.append(x0)
      num_outgoing_connections[input0] += 1

      x1 = self._resample_with_alpha(
          inputs=net[input1],
          input_width=net_sizes[input1],
          input_block_fn=net_block_fns[input1],
          target_width=target_width,
          target_num_filters=target_num_filters,
          target_block_fn=target_block_fn,
          alpha=self._resample_alpha)
      parents.append(x1)
      num_outgoing_connections[input1] += 1

      # Merge 0 outdegree blocks to the output block.
      if block_spec.is_output:
        for j, (j_feat,
                j_connections) in enumerate(zip(net, num_outgoing_connections)):
          if j_connections == 0 and (j_feat.shape[2] == target_width and
                                     j_feat.shape[3] == x0.shape[3]):
            parents.append(j_feat)
            num_outgoing_connections[j] += 1

      # pylint: disable=g-direct-tensorflow-import
      if weighted_fusion:
        dtype = parents[0].dtype
        parent_weights = [
            tf.nn.relu(tf.cast(tf.Variable(1.0, name='block{}_fusion{}'.format(
                i, j)), dtype=dtype)) for j in range(len(parents))]
        weights_sum = tf.add_n(parent_weights)
        parents = [
            parents[i] * parent_weights[i] / (weights_sum + 0.0001)
            for i in range(len(parents))
        ]

      # Fuse all parent nodes then build a new block.
      x = tf_utils.get_activation(self._activation_fn)(tf.add_n(parents))
      x = self._block_group(
          inputs=x,
          filters=target_num_filters,
          strides=1,
          block_fn_cand=target_block_fn,
          block_repeats=self._block_repeats,
          stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
              self._init_stochastic_depth_rate, i + 1, len(self._block_specs)),
          name='scale_permuted_block_{}'.format(i + 1))

      net.append(x)
      net_sizes.append(target_width)
      net_block_fns.append(target_block_fn)
      num_outgoing_connections.append(0)

      # Save output feats.
      if block_spec.is_output:
        if block_spec.level in endpoints:
          raise ValueError('Duplicate feats found for output level {}.'.format(
              block_spec.level))
        if (block_spec.level < self._min_level or
            block_spec.level > self._max_level):
          logging.warning(
              'SpineNet output level out of range [min_level, max_level] = '
              '[%s, %s] will not be used for further processing.',
              self._min_level, self._max_level)
        endpoints[str(block_spec.level)] = x

    return endpoints
    def __init__(self,
                 model_id: int,
                 output_stride: int,
                 input_specs: tf.keras.layers.InputSpec = layers.InputSpec(
                     shape=[None, None, None, 3]),
                 stem_type: str = 'v0',
                 se_ratio: Optional[float] = None,
                 init_stochastic_depth_rate: float = 0.0,
                 multigrid: Optional[Tuple[int]] = None,
                 last_stage_repeats: int = 1,
                 activation: str = 'relu',
                 use_sync_bn: bool = False,
                 norm_momentum: float = 0.99,
                 norm_epsilon: float = 0.001,
                 kernel_initializer: str = 'VarianceScaling',
                 kernel_regularizer: Optional[
                     tf.keras.regularizers.Regularizer] = None,
                 bias_regularizer: Optional[
                     tf.keras.regularizers.Regularizer] = None,
                 **kwargs):
        """Initializes a ResNet model with DeepLab modification.

    Args:
      model_id: An `int` specifies depth of ResNet backbone model.
      output_stride: An `int` of output stride, ratio of input to output
        resolution.
      input_specs: A `tf.keras.layers.InputSpec` of the input tensor.
      stem_type: A `str` of stem type. Can be `v0` or `v1`. `v1` replaces 7x7
        conv by 3 3x3 convs.
      se_ratio: A `float` or None. Ratio of the Squeeze-and-Excitation layer.
      init_stochastic_depth_rate: A `float` of initial stochastic depth rate.
      multigrid: A tuple of the same length as the number of blocks in the last
        resnet stage.
      last_stage_repeats: An `int` that specifies how many times last stage is
        repeated.
      activation: A `str` name of the activation function.
      use_sync_bn: If True, use synchronized batch normalization.
      norm_momentum: A `float` of normalization momentum for the moving average.
      norm_epsilon: A `float` added to variance to avoid dividing by zero.
      kernel_initializer: A str for kernel initializer of convolutional layers.
      kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
        Conv2D. Default to None.
      bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D.
        Default to None.
      **kwargs: Additional keyword arguments to be passed.
    """
        self._model_id = model_id
        self._output_stride = output_stride
        self._input_specs = input_specs
        self._use_sync_bn = use_sync_bn
        self._activation = activation
        self._norm_momentum = norm_momentum
        self._norm_epsilon = norm_epsilon
        if use_sync_bn:
            self._norm = layers.experimental.SyncBatchNormalization
        else:
            self._norm = layers.BatchNormalization
        self._kernel_initializer = kernel_initializer
        self._kernel_regularizer = kernel_regularizer
        self._bias_regularizer = bias_regularizer
        self._stem_type = stem_type
        self._se_ratio = se_ratio
        self._init_stochastic_depth_rate = init_stochastic_depth_rate

        if tf.keras.backend.image_data_format() == 'channels_last':
            bn_axis = -1
        else:
            bn_axis = 1

        # Build ResNet.
        inputs = tf.keras.Input(shape=input_specs.shape[1:])

        if stem_type == 'v0':
            x = layers.Conv2D(filters=64,
                              kernel_size=7,
                              strides=2,
                              use_bias=False,
                              padding='same',
                              kernel_initializer=self._kernel_initializer,
                              kernel_regularizer=self._kernel_regularizer,
                              bias_regularizer=self._bias_regularizer)(inputs)
            x = self._norm(axis=bn_axis,
                           momentum=norm_momentum,
                           epsilon=norm_epsilon)(x)
            x = tf_utils.get_activation(activation)(x)
        elif stem_type == 'v1':
            x = layers.Conv2D(filters=64,
                              kernel_size=3,
                              strides=2,
                              use_bias=False,
                              padding='same',
                              kernel_initializer=self._kernel_initializer,
                              kernel_regularizer=self._kernel_regularizer,
                              bias_regularizer=self._bias_regularizer)(inputs)
            x = self._norm(axis=bn_axis,
                           momentum=norm_momentum,
                           epsilon=norm_epsilon)(x)
            x = tf_utils.get_activation(activation)(x)
            x = layers.Conv2D(filters=64,
                              kernel_size=3,
                              strides=1,
                              use_bias=False,
                              padding='same',
                              kernel_initializer=self._kernel_initializer,
                              kernel_regularizer=self._kernel_regularizer,
                              bias_regularizer=self._bias_regularizer)(x)
            x = self._norm(axis=bn_axis,
                           momentum=norm_momentum,
                           epsilon=norm_epsilon)(x)
            x = tf_utils.get_activation(activation)(x)
            x = layers.Conv2D(filters=128,
                              kernel_size=3,
                              strides=1,
                              use_bias=False,
                              padding='same',
                              kernel_initializer=self._kernel_initializer,
                              kernel_regularizer=self._kernel_regularizer,
                              bias_regularizer=self._bias_regularizer)(x)
            x = self._norm(axis=bn_axis,
                           momentum=norm_momentum,
                           epsilon=norm_epsilon)(x)
            x = tf_utils.get_activation(activation)(x)
        else:
            raise ValueError('Stem type {} not supported.'.format(stem_type))

        x = layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x)

        normal_resnet_stage = int(np.math.log2(self._output_stride)) - 2

        endpoints = {}
        for i in range(normal_resnet_stage + 1):
            spec = RESNET_SPECS[model_id][i]
            if spec[0] == 'bottleneck':
                block_fn = nn_blocks.BottleneckBlock
            else:
                raise ValueError('Block fn `{}` is not supported.'.format(
                    spec[0]))
            x = self._block_group(
                inputs=x,
                filters=spec[1],
                strides=(1 if i == 0 else 2),
                dilation_rate=1,
                block_fn=block_fn,
                block_repeats=spec[2],
                stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
                    self._init_stochastic_depth_rate, i + 2,
                    4 + last_stage_repeats),
                name='block_group_l{}'.format(i + 2))
            endpoints[str(i + 2)] = x

        dilation_rate = 2
        for i in range(normal_resnet_stage + 1, 3 + last_stage_repeats):
            spec = RESNET_SPECS[model_id][i] if i < 3 else RESNET_SPECS[
                model_id][-1]
            if spec[0] == 'bottleneck':
                block_fn = nn_blocks.BottleneckBlock
            else:
                raise ValueError('Block fn `{}` is not supported.'.format(
                    spec[0]))
            x = self._block_group(
                inputs=x,
                filters=spec[1],
                strides=1,
                dilation_rate=dilation_rate,
                block_fn=block_fn,
                block_repeats=spec[2],
                stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
                    self._init_stochastic_depth_rate, i + 2,
                    4 + last_stage_repeats),
                multigrid=multigrid if i >= 3 else None,
                name='block_group_l{}'.format(i + 2))
            dilation_rate *= 2

        endpoints[str(normal_resnet_stage + 2)] = x

        self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}

        super(DilatedResNet, self).__init__(inputs=inputs,
                                            outputs=endpoints,
                                            **kwargs)