Пример #1
0
    def build(self, input_shape):
        if self._add_pos_embed:
            self._pos_embed = AddPositionEmbs(
                posemb_init=tf.keras.initializers.RandomNormal(stddev=0.02),
                name='posembed_input')
        self._dropout = layers.Dropout(rate=self._dropout_rate)

        self._encoder_layers = []
        # Set layer norm epsilons to 1e-6 to be consistent with JAX implementation.
        # https://flax.readthedocs.io/en/latest/_autosummary/flax.deprecated.nn.LayerNorm.html
        for i in range(self._num_layers):
            encoder_layer = nn_blocks.TransformerEncoderBlock(
                inner_activation=activations.gelu,
                num_attention_heads=self._num_heads,
                inner_dim=self._mlp_dim,
                output_dropout=self._dropout_rate,
                attention_dropout=self._attention_dropout_rate,
                kernel_regularizer=self._kernel_regularizer,
                kernel_initializer=self._kernel_initializer,
                norm_first=True,
                stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
                    self._init_stochastic_depth_rate, i + 1, self._num_layers),
                norm_epsilon=1e-6)
            self._encoder_layers.append(encoder_layer)
        self._norm = layers.LayerNormalization(epsilon=1e-6)
        super().build(input_shape)
Пример #2
0
    def _build_scale_permuted_network(self,
                                      net,
                                      input_width,
                                      weighted_fusion=False):
        """Builds scale-permuted network."""
        net_sizes = [int(math.ceil(input_width / 2**2))] * len(net)
        net_block_fns = [self._init_block_fn] * len(net)
        num_outgoing_connections = [0] * len(net)

        endpoints = {}
        for i, block_spec in enumerate(self._block_specs):
            # Find out specs for the target block.
            target_width = int(math.ceil(input_width / 2**block_spec.level))
            target_num_filters = int(FILTER_SIZE_MAP[block_spec.level] *
                                     self._filter_size_scale)
            target_block_fn = block_spec.block_fn

            # Resample then merge input0 and input1.
            parents = []
            input0 = block_spec.input_offsets[0]
            input1 = block_spec.input_offsets[1]

            x0 = self._resample_with_alpha(
                inputs=net[input0],
                input_width=net_sizes[input0],
                input_block_fn=net_block_fns[input0],
                target_width=target_width,
                target_num_filters=target_num_filters,
                target_block_fn=target_block_fn,
                alpha=self._resample_alpha)
            parents.append(x0)
            num_outgoing_connections[input0] += 1

            x1 = self._resample_with_alpha(
                inputs=net[input1],
                input_width=net_sizes[input1],
                input_block_fn=net_block_fns[input1],
                target_width=target_width,
                target_num_filters=target_num_filters,
                target_block_fn=target_block_fn,
                alpha=self._resample_alpha)
            parents.append(x1)
            num_outgoing_connections[input1] += 1

            # Merge 0 outdegree blocks to the output block.
            if block_spec.is_output:
                for j, (j_feat, j_connections) in enumerate(
                        zip(net, num_outgoing_connections)):
                    if j_connections == 0 and (j_feat.shape[2] == target_width
                                               and j_feat.shape[3]
                                               == x0.shape[3]):
                        parents.append(j_feat)
                        num_outgoing_connections[j] += 1

            # pylint: disable=g-direct-tensorflow-import
            if weighted_fusion:
                dtype = parents[0].dtype
                parent_weights = [
                    tf.nn.relu(
                        tf.cast(tf.Variable(1.0,
                                            name='block{}_fusion{}'.format(
                                                i, j)),
                                dtype=dtype)) for j in range(len(parents))
                ]
                weights_sum = tf.add_n(parent_weights)
                parents = [
                    parents[i] * parent_weights[i] / (weights_sum + 0.0001)
                    for i in range(len(parents))
                ]

            # Fuse all parent nodes then build a new block.
            x = tf_utils.get_activation(self._activation_fn)(tf.add_n(parents))
            x = self._block_group(
                inputs=x,
                filters=target_num_filters,
                strides=1,
                block_fn_cand=target_block_fn,
                block_repeats=self._block_repeats,
                stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
                    self._init_stochastic_depth_rate, i + 1,
                    len(self._block_specs)),
                name='scale_permuted_block_{}'.format(i + 1))

            net.append(x)
            net_sizes.append(target_width)
            net_block_fns.append(target_block_fn)
            num_outgoing_connections.append(0)

            # Save output feats.
            if block_spec.is_output:
                if block_spec.level in endpoints:
                    raise ValueError(
                        'Duplicate feats found for output level {}.'.format(
                            block_spec.level))
                if (block_spec.level < self._min_level
                        or block_spec.level > self._max_level):
                    logging.warning(
                        'SpineNet output level out of range [min_level, max_level] = '
                        '[%s, %s] will not be used for further processing.',
                        self._min_level, self._max_level)
                endpoints[str(block_spec.level)] = x

        return endpoints
Пример #3
0
    def __init__(
            self,
            model_id: int,
            temporal_strides: List[int],
            temporal_kernel_sizes: List[Tuple[int]],
            use_self_gating: List[int] = None,
            input_specs=layers.InputSpec(shape=[None, None, None, None, 3]),
            stem_type='v0',
            stem_conv_temporal_kernel_size=5,
            stem_conv_temporal_stride=2,
            stem_pool_temporal_stride=2,
            init_stochastic_depth_rate=0.0,
            activation='relu',
            se_ratio=None,
            use_sync_bn=False,
            norm_momentum=0.99,
            norm_epsilon=0.001,
            kernel_initializer='VarianceScaling',
            kernel_regularizer=None,
            bias_regularizer=None,
            **kwargs):
        """Initializes a 3D ResNet model.

    Args:
      model_id: An `int` of depth of ResNet backbone model.
      temporal_strides: A list of integers that specifies the temporal strides
        for all 3d blocks.
      temporal_kernel_sizes: A list of tuples that specifies the temporal kernel
        sizes for all 3d blocks in different block groups.
      use_self_gating: A list of booleans to specify applying self-gating module
        or not in each block group. If None, self-gating is not applied.
      input_specs: A `tf.keras.layers.InputSpec` of the input tensor.
      stem_type: A `str` of stem type of ResNet. Default to `v0`. If set to
        `v1`, use ResNet-D type stem (https://arxiv.org/abs/1812.01187).
      stem_conv_temporal_kernel_size: An `int` of temporal kernel size for the
        first conv layer.
      stem_conv_temporal_stride: An `int` of temporal stride for the first conv
        layer.
      stem_pool_temporal_stride: An `int` of temporal stride for the first pool
        layer.
      init_stochastic_depth_rate: A `float` of initial stochastic depth rate.
      activation: A `str` of name of the activation function.
      se_ratio: A `float` or None. Ratio of the Squeeze-and-Excitation layer.
      use_sync_bn: If True, use synchronized batch normalization.
      norm_momentum: A `float` of normalization momentum for the moving average.
      norm_epsilon: A `float` added to variance to avoid dividing by zero.
      kernel_initializer: A str for kernel initializer of convolutional layers.
      kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
        Conv2D. Default to None.
      bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D.
        Default to None.
      **kwargs: Additional keyword arguments to be passed.
    """
        self._model_id = model_id
        self._temporal_strides = temporal_strides
        self._temporal_kernel_sizes = temporal_kernel_sizes
        self._input_specs = input_specs
        self._stem_type = stem_type
        self._stem_conv_temporal_kernel_size = stem_conv_temporal_kernel_size
        self._stem_conv_temporal_stride = stem_conv_temporal_stride
        self._stem_pool_temporal_stride = stem_pool_temporal_stride
        self._use_self_gating = use_self_gating
        self._se_ratio = se_ratio
        self._init_stochastic_depth_rate = init_stochastic_depth_rate
        self._use_sync_bn = use_sync_bn
        self._activation = activation
        self._norm_momentum = norm_momentum
        self._norm_epsilon = norm_epsilon
        if use_sync_bn:
            self._norm = layers.experimental.SyncBatchNormalization
        else:
            self._norm = layers.BatchNormalization
        self._kernel_initializer = kernel_initializer
        self._kernel_regularizer = kernel_regularizer
        self._bias_regularizer = bias_regularizer
        if tf.keras.backend.image_data_format() == 'channels_last':
            bn_axis = -1
        else:
            bn_axis = 1

        # Build ResNet3D backbone.
        inputs = tf.keras.Input(shape=input_specs.shape[1:])

        # Build stem.
        if stem_type == 'v0':
            x = layers.Conv3D(
                filters=64,
                kernel_size=[stem_conv_temporal_kernel_size, 7, 7],
                strides=[stem_conv_temporal_stride, 2, 2],
                use_bias=False,
                padding='same',
                kernel_initializer=self._kernel_initializer,
                kernel_regularizer=self._kernel_regularizer,
                bias_regularizer=self._bias_regularizer)(inputs)
            x = self._norm(axis=bn_axis,
                           momentum=norm_momentum,
                           epsilon=norm_epsilon)(x)
            x = tf_utils.get_activation(activation)(x)
        elif stem_type == 'v1':
            x = layers.Conv3D(
                filters=32,
                kernel_size=[stem_conv_temporal_kernel_size, 3, 3],
                strides=[stem_conv_temporal_stride, 2, 2],
                use_bias=False,
                padding='same',
                kernel_initializer=self._kernel_initializer,
                kernel_regularizer=self._kernel_regularizer,
                bias_regularizer=self._bias_regularizer)(inputs)
            x = self._norm(axis=bn_axis,
                           momentum=norm_momentum,
                           epsilon=norm_epsilon)(x)
            x = tf_utils.get_activation(activation)(x)
            x = layers.Conv3D(filters=32,
                              kernel_size=[1, 3, 3],
                              strides=[1, 1, 1],
                              use_bias=False,
                              padding='same',
                              kernel_initializer=self._kernel_initializer,
                              kernel_regularizer=self._kernel_regularizer,
                              bias_regularizer=self._bias_regularizer)(x)
            x = self._norm(axis=bn_axis,
                           momentum=norm_momentum,
                           epsilon=norm_epsilon)(x)
            x = tf_utils.get_activation(activation)(x)
            x = layers.Conv3D(filters=64,
                              kernel_size=[1, 3, 3],
                              strides=[1, 1, 1],
                              use_bias=False,
                              padding='same',
                              kernel_initializer=self._kernel_initializer,
                              kernel_regularizer=self._kernel_regularizer,
                              bias_regularizer=self._bias_regularizer)(x)
            x = self._norm(axis=bn_axis,
                           momentum=norm_momentum,
                           epsilon=norm_epsilon)(x)
            x = tf_utils.get_activation(activation)(x)
        else:
            raise ValueError(f'Stem type {stem_type} not supported.')

        temporal_kernel_size = 1 if stem_pool_temporal_stride == 1 else 3
        x = layers.MaxPool3D(pool_size=[temporal_kernel_size, 3, 3],
                             strides=[stem_pool_temporal_stride, 2, 2],
                             padding='same')(x)

        # Build intermediate blocks and endpoints.
        resnet_specs = RESNET_SPECS[model_id]
        if len(temporal_strides) != len(resnet_specs) or len(
                temporal_kernel_sizes) != len(resnet_specs):
            raise ValueError(
                'Number of blocks in temporal specs should equal to resnet_specs.'
            )

        endpoints = {}
        for i, resnet_spec in enumerate(resnet_specs):
            if resnet_spec[0] == 'bottleneck3d':
                block_fn = nn_blocks_3d.BottleneckBlock3D
            else:
                raise ValueError('Block fn `{}` is not supported.'.format(
                    resnet_spec[0]))

            x = self._block_group(
                inputs=x,
                filters=resnet_spec[1],
                temporal_kernel_sizes=temporal_kernel_sizes[i],
                temporal_strides=temporal_strides[i],
                spatial_strides=(1 if i == 0 else 2),
                block_fn=block_fn,
                block_repeats=resnet_spec[2],
                stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
                    self._init_stochastic_depth_rate, i + 2, 5),
                use_self_gating=use_self_gating[i]
                if use_self_gating else False,
                name='block_group_l{}'.format(i + 2))
            endpoints[str(i + 2)] = x

        self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}

        super(ResNet3D, self).__init__(inputs=inputs,
                                       outputs=endpoints,
                                       **kwargs)
Пример #4
0
    def __init__(self,
                 model_id,
                 input_specs=layers.InputSpec(shape=[None, None, None, 3]),
                 depth_multiplier=1.0,
                 stem_type='v0',
                 se_ratio=None,
                 init_stochastic_depth_rate=0.0,
                 activation='relu',
                 use_sync_bn=False,
                 norm_momentum=0.99,
                 norm_epsilon=0.001,
                 kernel_initializer='VarianceScaling',
                 kernel_regularizer=None,
                 bias_regularizer=None,
                 **kwargs):
        """ResNet initialization function.

    Args:
      model_id: `int` depth of ResNet backbone model.
      input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
      depth_multiplier: `float` a depth multiplier to uniformaly scale up all
        layers in channel size in ResNet.
      stem_type: `str` stem type of ResNet. Default to `v0`. If set to `v1`,
        use ResNet-C type stem (https://arxiv.org/abs/1812.01187).
      se_ratio: `float` or None. Ratio of the Squeeze-and-Excitation layer.
      init_stochastic_depth_rate: `float` initial stochastic depth rate.
      activation: `str` name of the activation function.
      use_sync_bn: if True, use synchronized batch normalization.
      norm_momentum: `float` normalization omentum for the moving average.
      norm_epsilon: `float` small float added to variance to avoid dividing by
        zero.
      kernel_initializer: kernel_initializer for convolutional layers.
      kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
                          Default to None.
      bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
                        Default to None.
      **kwargs: keyword arguments to be passed.
    """
        self._model_id = model_id
        self._input_specs = input_specs
        self._depth_multiplier = depth_multiplier
        self._stem_type = stem_type
        self._se_ratio = se_ratio
        self._init_stochastic_depth_rate = init_stochastic_depth_rate
        self._use_sync_bn = use_sync_bn
        self._activation = activation
        self._norm_momentum = norm_momentum
        self._norm_epsilon = norm_epsilon
        if use_sync_bn:
            self._norm = layers.experimental.SyncBatchNormalization
        else:
            self._norm = layers.BatchNormalization
        self._kernel_initializer = kernel_initializer
        self._kernel_regularizer = kernel_regularizer
        self._bias_regularizer = bias_regularizer

        if tf.keras.backend.image_data_format() == 'channels_last':
            bn_axis = -1
        else:
            bn_axis = 1

        # Build ResNet.
        inputs = tf.keras.Input(shape=input_specs.shape[1:])

        if stem_type == 'v0':
            x = layers.Conv2D(filters=int(64 * self._depth_multiplier),
                              kernel_size=7,
                              strides=2,
                              use_bias=False,
                              padding='same',
                              kernel_initializer=self._kernel_initializer,
                              kernel_regularizer=self._kernel_regularizer,
                              bias_regularizer=self._bias_regularizer)(inputs)
            x = self._norm(axis=bn_axis,
                           momentum=norm_momentum,
                           epsilon=norm_epsilon)(x)
            x = tf_utils.get_activation(activation)(x)
        elif stem_type == 'v1':
            x = layers.Conv2D(filters=int(32 * self._depth_multiplier),
                              kernel_size=3,
                              strides=2,
                              use_bias=False,
                              padding='same',
                              kernel_initializer=self._kernel_initializer,
                              kernel_regularizer=self._kernel_regularizer,
                              bias_regularizer=self._bias_regularizer)(inputs)
            x = self._norm(axis=bn_axis,
                           momentum=norm_momentum,
                           epsilon=norm_epsilon)(x)
            x = tf_utils.get_activation(activation)(x)
            x = layers.Conv2D(filters=int(32 * self._depth_multiplier),
                              kernel_size=3,
                              strides=1,
                              use_bias=False,
                              padding='same',
                              kernel_initializer=self._kernel_initializer,
                              kernel_regularizer=self._kernel_regularizer,
                              bias_regularizer=self._bias_regularizer)(x)
            x = self._norm(axis=bn_axis,
                           momentum=norm_momentum,
                           epsilon=norm_epsilon)(x)
            x = tf_utils.get_activation(activation)(x)
            x = layers.Conv2D(filters=int(64 * self._depth_multiplier),
                              kernel_size=3,
                              strides=1,
                              use_bias=False,
                              padding='same',
                              kernel_initializer=self._kernel_initializer,
                              kernel_regularizer=self._kernel_regularizer,
                              bias_regularizer=self._bias_regularizer)(x)
            x = self._norm(axis=bn_axis,
                           momentum=norm_momentum,
                           epsilon=norm_epsilon)(x)
            x = tf_utils.get_activation(activation)(x)
        else:
            raise ValueError('Stem type {} not supported.'.format(stem_type))

        x = layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x)

        endpoints = {}
        for i, spec in enumerate(RESNET_SPECS[model_id]):
            if spec[0] == 'residual':
                block_fn = nn_blocks.ResidualBlock
            elif spec[0] == 'bottleneck':
                block_fn = nn_blocks.BottleneckBlock
            else:
                raise ValueError('Block fn `{}` is not supported.'.format(
                    spec[0]))
            x = self._block_group(
                inputs=x,
                filters=int(spec[1] * self._depth_multiplier),
                strides=(1 if i == 0 else 2),
                block_fn=block_fn,
                block_repeats=spec[2],
                stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
                    self._init_stochastic_depth_rate, i + 2, 5),
                name='block_group_l{}'.format(i + 2))
            endpoints[str(i + 2)] = x

        self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}

        super(ResNet, self).__init__(inputs=inputs,
                                     outputs=endpoints,
                                     **kwargs)
Пример #5
0
    def __init__(self,
                 model_id,
                 output_stride,
                 input_specs=layers.InputSpec(shape=[None, None, None, 3]),
                 stem_type='v0',
                 se_ratio=None,
                 init_stochastic_depth_rate=0.0,
                 multigrid=None,
                 last_stage_repeats=1,
                 activation='relu',
                 use_sync_bn=False,
                 norm_momentum=0.99,
                 norm_epsilon=0.001,
                 kernel_initializer='VarianceScaling',
                 kernel_regularizer=None,
                 bias_regularizer=None,
                 **kwargs):
        """Initializes a ResNet model with DeepLab modification.

    Args:
<<<<<<< HEAD
      model_id: `int` depth of ResNet backbone model.
      output_stride: `int` output stride, ratio of input to output resolution.
      input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
      stem_type: `standard` or `deeplab`, deeplab replaces 7x7 conv by 3 3x3
        convs.
      se_ratio: `float` or None. Ratio of the Squeeze-and-Excitation layer.
      init_stochastic_depth_rate: `float` initial stochastic depth rate.
      multigrid: `Tuple` of the same length as the number of blocks in the last
        resnet stage.
      last_stage_repeats: `int`, how many times last stage is repeated.
      activation: `str` name of the activation function.
      use_sync_bn: if True, use synchronized batch normalization.
      norm_momentum: `float` normalization omentum for the moving average.
      norm_epsilon: `float` small float added to variance to avoid dividing by
        zero.
      kernel_initializer: kernel_initializer for convolutional layers.
      kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
                          Default to None.
      bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
                        Default to None.
      **kwargs: keyword arguments to be passed.
=======
      model_id: An `int` specifies depth of ResNet backbone model.
      output_stride: An `int` of output stride, ratio of input to output
        resolution.
      input_specs: A `tf.keras.layers.InputSpec` of the input tensor.
      stem_type: A `str` of stem type. Can be `standard` or `deeplab`. `deeplab`
        replaces 7x7 conv by 3 3x3 convs.
      se_ratio: A `float` or None. Ratio of the Squeeze-and-Excitation layer.
      init_stochastic_depth_rate: A `float` of initial stochastic depth rate.
      multigrid: A tuple of the same length as the number of blocks in the last
        resnet stage.
      last_stage_repeats: An `int` that specifies how many times last stage is
        repeated.
      activation: A `str` name of the activation function.
      use_sync_bn: If True, use synchronized batch normalization.
      norm_momentum: A `float` of normalization momentum for the moving average.
      norm_epsilon: A `float` added to variance to avoid dividing by zero.
      kernel_initializer: A str for kernel initializer of convolutional layers.
      kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
        Conv2D. Default to None.
      bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D.
        Default to None.
      **kwargs: Additional keyword arguments to be passed.
>>>>>>> upstream/master
    """
        self._model_id = model_id
        self._output_stride = output_stride
        self._input_specs = input_specs
        self._use_sync_bn = use_sync_bn
        self._activation = activation
        self._norm_momentum = norm_momentum
        self._norm_epsilon = norm_epsilon
        if use_sync_bn:
            self._norm = layers.experimental.SyncBatchNormalization
        else:
            self._norm = layers.BatchNormalization
        self._kernel_initializer = kernel_initializer
        self._kernel_regularizer = kernel_regularizer
        self._bias_regularizer = bias_regularizer
        self._stem_type = stem_type
        self._se_ratio = se_ratio
        self._init_stochastic_depth_rate = init_stochastic_depth_rate

        if tf.keras.backend.image_data_format() == 'channels_last':
            bn_axis = -1
        else:
            bn_axis = 1

        # Build ResNet.
        inputs = tf.keras.Input(shape=input_specs.shape[1:])

        if stem_type == 'v0':
            x = layers.Conv2D(filters=64,
                              kernel_size=7,
                              strides=2,
                              use_bias=False,
                              padding='same',
                              kernel_initializer=self._kernel_initializer,
                              kernel_regularizer=self._kernel_regularizer,
                              bias_regularizer=self._bias_regularizer)(inputs)
            x = self._norm(axis=bn_axis,
                           momentum=norm_momentum,
                           epsilon=norm_epsilon)(x)
            x = tf_utils.get_activation(activation)(x)
        elif stem_type == 'v1':
            x = layers.Conv2D(filters=64,
                              kernel_size=3,
                              strides=2,
                              use_bias=False,
                              padding='same',
                              kernel_initializer=self._kernel_initializer,
                              kernel_regularizer=self._kernel_regularizer,
                              bias_regularizer=self._bias_regularizer)(inputs)
            x = self._norm(axis=bn_axis,
                           momentum=norm_momentum,
                           epsilon=norm_epsilon)(x)
            x = tf_utils.get_activation(activation)(x)
            x = layers.Conv2D(filters=64,
                              kernel_size=3,
                              strides=1,
                              use_bias=False,
                              padding='same',
                              kernel_initializer=self._kernel_initializer,
                              kernel_regularizer=self._kernel_regularizer,
                              bias_regularizer=self._bias_regularizer)(x)
            x = self._norm(axis=bn_axis,
                           momentum=norm_momentum,
                           epsilon=norm_epsilon)(x)
            x = tf_utils.get_activation(activation)(x)
            x = layers.Conv2D(filters=128,
                              kernel_size=3,
                              strides=1,
                              use_bias=False,
                              padding='same',
                              kernel_initializer=self._kernel_initializer,
                              kernel_regularizer=self._kernel_regularizer,
                              bias_regularizer=self._bias_regularizer)(x)
            x = self._norm(axis=bn_axis,
                           momentum=norm_momentum,
                           epsilon=norm_epsilon)(x)
            x = tf_utils.get_activation(activation)(x)
        else:
            raise ValueError('Stem type {} not supported.'.format(stem_type))

        x = layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x)

        normal_resnet_stage = int(np.math.log2(self._output_stride)) - 2

        endpoints = {}
        for i in range(normal_resnet_stage + 1):
            spec = RESNET_SPECS[model_id][i]
            if spec[0] == 'bottleneck':
                block_fn = nn_blocks.BottleneckBlock
            else:
                raise ValueError('Block fn `{}` is not supported.'.format(
                    spec[0]))
            x = self._block_group(
                inputs=x,
                filters=spec[1],
                strides=(1 if i == 0 else 2),
                dilation_rate=1,
                block_fn=block_fn,
                block_repeats=spec[2],
                stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
                    self._init_stochastic_depth_rate, i + 2,
                    4 + last_stage_repeats),
                name='block_group_l{}'.format(i + 2))
            endpoints[str(i + 2)] = x

        dilation_rate = 2
        for i in range(normal_resnet_stage + 1, 3 + last_stage_repeats):
            spec = RESNET_SPECS[model_id][i] if i < 3 else RESNET_SPECS[
                model_id][-1]
            if spec[0] == 'bottleneck':
                block_fn = nn_blocks.BottleneckBlock
            else:
                raise ValueError('Block fn `{}` is not supported.'.format(
                    spec[0]))
            x = self._block_group(
                inputs=x,
                filters=spec[1],
                strides=1,
                dilation_rate=dilation_rate,
                block_fn=block_fn,
                block_repeats=spec[2],
                stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
                    self._init_stochastic_depth_rate, i + 2,
                    4 + last_stage_repeats),
                multigrid=multigrid if i >= 3 else None,
                name='block_group_l{}'.format(i + 2))
            dilation_rate *= 2

        endpoints[str(normal_resnet_stage + 2)] = x

        self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}

        super(DilatedResNet, self).__init__(inputs=inputs,
                                            outputs=endpoints,
                                            **kwargs)
Пример #6
0
  def __init__(
      self,
      model_id: int,
      input_specs: tf.keras.layers.InputSpec = layers.InputSpec(
          shape=[None, None, None, 3]),
      depth_multiplier: float = 1.0,
      stem_type: str = 'v0',
      resnetd_shortcut: bool = False,
      replace_stem_max_pool: bool = False,
      se_ratio: Optional[float] = None,
      init_stochastic_depth_rate: float = 0.0,
      activation: str = 'relu',
      use_sync_bn: bool = False,
      norm_momentum: float = 0.99,
      norm_epsilon: float = 0.001,
      kernel_initializer: str = 'VarianceScaling',
      kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
      bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
      **kwargs):
    """Initializes a ResNet model.

    Args:
      model_id: An `int` of the depth of ResNet backbone model.
      input_specs: A `tf.keras.layers.InputSpec` of the input tensor.
      depth_multiplier: A `float` of the depth multiplier to uniformaly scale up
        all layers in channel size. This argument is also referred to as
        `width_multiplier` in (https://arxiv.org/abs/2103.07579).
      stem_type: A `str` of stem type of ResNet. Default to `v0`. If set to
        `v1`, use ResNet-D type stem (https://arxiv.org/abs/1812.01187).
      resnetd_shortcut: A `bool` of whether to use ResNet-D shortcut in
        downsampling blocks.
      replace_stem_max_pool: A `bool` of whether to replace the max pool in stem
        with a stride-2 conv,
      se_ratio: A `float` or None. Ratio of the Squeeze-and-Excitation layer.
      init_stochastic_depth_rate: A `float` of initial stochastic depth rate.
      activation: A `str` name of the activation function.
      use_sync_bn: If True, use synchronized batch normalization.
      norm_momentum: A `float` of normalization momentum for the moving average.
      norm_epsilon: A small `float` added to variance to avoid dividing by zero.
      kernel_initializer: A str for kernel initializer of convolutional layers.
      kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
        Conv2D. Default to None.
      bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D.
        Default to None.
      **kwargs: Additional keyword arguments to be passed.
    """
    self._model_id = model_id
    self._input_specs = input_specs
    self._depth_multiplier = depth_multiplier
    self._stem_type = stem_type
    self._resnetd_shortcut = resnetd_shortcut
    self._replace_stem_max_pool = replace_stem_max_pool
    self._se_ratio = se_ratio
    self._init_stochastic_depth_rate = init_stochastic_depth_rate
    self._use_sync_bn = use_sync_bn
    self._activation = activation
    self._norm_momentum = norm_momentum
    self._norm_epsilon = norm_epsilon
    if use_sync_bn:
      self._norm = layers.experimental.SyncBatchNormalization
    else:
      self._norm = layers.BatchNormalization
    self._kernel_initializer = kernel_initializer
    self._kernel_regularizer = kernel_regularizer
    self._bias_regularizer = bias_regularizer

    if tf.keras.backend.image_data_format() == 'channels_last':
      bn_axis = -1
    else:
      bn_axis = 1

    # Build ResNet.
    inputs = tf.keras.Input(shape=input_specs.shape[1:])

    if stem_type == 'v0':
      x = layers.Conv2D(
          filters=int(64 * self._depth_multiplier),
          kernel_size=7,
          strides=2,
          use_bias=False,
          padding='same',
          kernel_initializer=self._kernel_initializer,
          kernel_regularizer=self._kernel_regularizer,
          bias_regularizer=self._bias_regularizer)(
              inputs)
      x = self._norm(
          axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
              x)
      x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
    elif stem_type == 'v1':
      x = layers.Conv2D(
          filters=int(32 * self._depth_multiplier),
          kernel_size=3,
          strides=2,
          use_bias=False,
          padding='same',
          kernel_initializer=self._kernel_initializer,
          kernel_regularizer=self._kernel_regularizer,
          bias_regularizer=self._bias_regularizer)(
              inputs)
      x = self._norm(
          axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
              x)
      x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
      x = layers.Conv2D(
          filters=int(32 * self._depth_multiplier),
          kernel_size=3,
          strides=1,
          use_bias=False,
          padding='same',
          kernel_initializer=self._kernel_initializer,
          kernel_regularizer=self._kernel_regularizer,
          bias_regularizer=self._bias_regularizer)(
              x)
      x = self._norm(
          axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
              x)
      x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
      x = layers.Conv2D(
          filters=int(64 * self._depth_multiplier),
          kernel_size=3,
          strides=1,
          use_bias=False,
          padding='same',
          kernel_initializer=self._kernel_initializer,
          kernel_regularizer=self._kernel_regularizer,
          bias_regularizer=self._bias_regularizer)(
              x)
      x = self._norm(
          axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
              x)
      x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
    else:
      raise ValueError('Stem type {} not supported.'.format(stem_type))

    if replace_stem_max_pool:
      x = layers.Conv2D(
          filters=int(64 * self._depth_multiplier),
          kernel_size=3,
          strides=2,
          use_bias=False,
          padding='same',
          kernel_initializer=self._kernel_initializer,
          kernel_regularizer=self._kernel_regularizer,
          bias_regularizer=self._bias_regularizer)(
              x)
      x = self._norm(
          axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
              x)
      x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
    else:
      x = layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x)

    endpoints = {}
    for i, spec in enumerate(RESNET_SPECS[model_id]):
      if spec[0] == 'residual':
        block_fn = nn_blocks.ResidualBlock
      elif spec[0] == 'bottleneck':
        block_fn = nn_blocks.BottleneckBlock
      else:
        raise ValueError('Block fn `{}` is not supported.'.format(spec[0]))
      x = self._block_group(
          inputs=x,
          filters=int(spec[1] * self._depth_multiplier),
          strides=(1 if i == 0 else 2),
          block_fn=block_fn,
          block_repeats=spec[2],
          stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
              self._init_stochastic_depth_rate, i + 2, 5),
          name='block_group_l{}'.format(i + 2))
      endpoints[str(i + 2)] = x

    self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}

    super(ResNet, self).__init__(inputs=inputs, outputs=endpoints, **kwargs)