示例#1
0
 def make_layer_revr(self, k, inp_dim, out_dim, modules, **kwargs):
     layers = []
     for _ in range(modules - 1):
         layers.append(nn_blocks.ResidualBlock(inp_dim, 1, **kwargs))
     layers.append(
         nn_blocks.ResidualBlock(out_dim, 1, use_projection=True, **kwargs))
     return tf.keras.Sequential(layers)
示例#2
0
    def __init__(
            self,
            model_id: int,
            input_channel_dims: int,
            input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 3]),
            num_hourglasses: int = 1,
            initial_downsample: bool = True,
            activation: str = 'relu',
            use_sync_bn: bool = True,
            norm_momentum=0.1,
            norm_epsilon=1e-5,
            kernel_initializer: str = 'VarianceScaling',
            kernel_regularizer: Optional[
                tf.keras.regularizers.Regularizer] = None,
            bias_regularizer: Optional[
                tf.keras.regularizers.Regularizer] = None,
            **kwargs):
        """Initialize Hourglass backbone.

    Args:
      model_id: An `int` of the scale of Hourglass backbone model.
      input_channel_dims: `int`, number of filters used to downsample the
        input image.
      input_specs: A `tf.keras.layers.InputSpec` of specs of the input tensor.
      num_hourglasses: `int``, number of hourglass blocks in backbone. For
        example, hourglass-104 has two hourglass-52 modules.
      initial_downsample: `bool`, whether or not to downsample the input.
      activation: A `str` name of the activation function.
      use_sync_bn: If True, use synchronized batch normalization.
      norm_momentum: `float`, momentum for the batch normalization layers.
      norm_epsilon: `float`, epsilon for the batch normalization layers.
      kernel_initializer: A `str` for kernel initializer of conv layers.
      kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
        Conv2D. Default to None.
      bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D.
        Default to None.
      **kwargs: Additional keyword arguments to be passed.
    """
        self._input_channel_dims = input_channel_dims
        self._model_id = model_id
        self._num_hourglasses = num_hourglasses
        self._initial_downsample = initial_downsample
        self._activation = activation
        self._kernel_initializer = kernel_initializer
        self._kernel_regularizer = kernel_regularizer
        self._bias_regularizer = bias_regularizer
        self._use_sync_bn = use_sync_bn
        self._norm_momentum = norm_momentum
        self._norm_epsilon = norm_epsilon

        specs = HOURGLASS_SPECS[model_id]
        self._blocks_per_stage = specs['blocks_per_stage']
        self._channel_dims_per_stage = [
            item * self._input_channel_dims
            for item in specs['channel_dims_per_stage']
        ]

        inputs = tf.keras.layers.Input(shape=input_specs.shape[1:])

        inp_filters = self._channel_dims_per_stage[0]

        # Downsample the input
        if initial_downsample:
            prelayer_kernel_size = 7
            prelayer_strides = 2
        else:
            prelayer_kernel_size = 3
            prelayer_strides = 1

        x_downsampled = mobilenet.Conv2DBNBlock(
            filters=self._input_channel_dims,
            kernel_size=prelayer_kernel_size,
            strides=prelayer_strides,
            use_explicit_padding=True,
            activation=self._activation,
            bias_regularizer=self._bias_regularizer,
            kernel_initializer=self._kernel_initializer,
            kernel_regularizer=self._kernel_regularizer,
            use_sync_bn=self._use_sync_bn,
            norm_momentum=self._norm_momentum,
            norm_epsilon=self._norm_epsilon)(inputs)

        x_downsampled = nn_blocks.ResidualBlock(
            filters=inp_filters,
            use_projection=True,
            use_explicit_padding=True,
            strides=prelayer_strides,
            bias_regularizer=self._bias_regularizer,
            kernel_initializer=self._kernel_initializer,
            kernel_regularizer=self._kernel_regularizer,
            use_sync_bn=self._use_sync_bn,
            norm_momentum=self._norm_momentum,
            norm_epsilon=self._norm_epsilon)(x_downsampled)

        all_heatmaps = {}
        for i in range(num_hourglasses):
            # Create an hourglass stack
            x_hg = cn_nn_blocks.HourglassBlock(
                channel_dims_per_stage=self._channel_dims_per_stage,
                blocks_per_stage=self._blocks_per_stage,
            )(x_downsampled)

            x_hg = mobilenet.Conv2DBNBlock(
                filters=inp_filters,
                kernel_size=3,
                strides=1,
                use_explicit_padding=True,
                activation=self._activation,
                bias_regularizer=self._bias_regularizer,
                kernel_initializer=self._kernel_initializer,
                kernel_regularizer=self._kernel_regularizer,
                use_sync_bn=self._use_sync_bn,
                norm_momentum=self._norm_momentum,
                norm_epsilon=self._norm_epsilon)(x_hg)

            # Given two down-sampling blocks above, the starting level is set to 2
            # To make it compatible with implementation of remaining backbones, the
            # output of hourglass backbones is organized as
            # '2' -> the last layer of output
            # '2_0' -> the first layer of output
            # ......
            # '2_{num_hourglasses-2}' -> the second to last layer of output
            if i < num_hourglasses - 1:
                all_heatmaps['2_{}'.format(i)] = x_hg
            else:
                all_heatmaps['2'] = x_hg

            # Intermediate conv and residual layers between hourglasses
            if i < num_hourglasses - 1:
                inter_hg_conv1 = mobilenet.Conv2DBNBlock(
                    filters=inp_filters,
                    kernel_size=1,
                    strides=1,
                    activation='identity',
                    bias_regularizer=self._bias_regularizer,
                    kernel_initializer=self._kernel_initializer,
                    kernel_regularizer=self._kernel_regularizer,
                    use_sync_bn=self._use_sync_bn,
                    norm_momentum=self._norm_momentum,
                    norm_epsilon=self._norm_epsilon)(x_downsampled)

                inter_hg_conv2 = mobilenet.Conv2DBNBlock(
                    filters=inp_filters,
                    kernel_size=1,
                    strides=1,
                    activation='identity',
                    bias_regularizer=self._bias_regularizer,
                    kernel_initializer=self._kernel_initializer,
                    kernel_regularizer=self._kernel_regularizer,
                    use_sync_bn=self._use_sync_bn,
                    norm_momentum=self._norm_momentum,
                    norm_epsilon=self._norm_epsilon)(x_hg)

                x_downsampled = tf.keras.layers.Add()(
                    [inter_hg_conv1, inter_hg_conv2])
                x_downsampled = tf.keras.layers.ReLU()(x_downsampled)

                x_downsampled = nn_blocks.ResidualBlock(
                    filters=inp_filters,
                    use_projection=False,
                    use_explicit_padding=True,
                    strides=1,
                    bias_regularizer=self._bias_regularizer,
                    kernel_initializer=self._kernel_initializer,
                    kernel_regularizer=self._kernel_regularizer,
                    use_sync_bn=self._use_sync_bn,
                    norm_momentum=self._norm_momentum,
                    norm_epsilon=self._norm_epsilon)(x_downsampled)

        self._output_specs = {
            l: all_heatmaps[l].get_shape()
            for l in all_heatmaps
        }

        super().__init__(inputs=inputs, outputs=all_heatmaps, **kwargs)
示例#3
0
def _make_repeated_residual_blocks(
    reps: int,
    out_channels: int,
    use_sync_bn: bool = True,
    norm_momentum: float = 0.1,
    norm_epsilon: float = 1e-5,
    residual_channels: Optional[int] = None,
    initial_stride: int = 1,
    initial_skip_conv: bool = False,
    kernel_initializer: str = 'VarianceScaling',
    kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
    bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
):
    """Stack Residual blocks one after the other.

  Args:
    reps: `int` for desired number of residual blocks
    out_channels: `int`, filter depth of the final residual block
    use_sync_bn: A `bool`, if True, use synchronized batch normalization.
    norm_momentum: `float`, momentum for the batch normalization layers
    norm_epsilon: `float`, epsilon for the batch normalization layers
    residual_channels: `int`, filter depth for the first reps - 1 residual
      blocks. If None, defaults to the same value as out_channels. If not
      equal to out_channels, then uses a projection shortcut in the final
      residual block
    initial_stride: `int`, stride for the first residual block
    initial_skip_conv: `bool`, if set, the first residual block uses a skip
      convolution. This is useful when the number of channels in the input
      are not the same as residual_channels.
    kernel_initializer: A `str` for kernel initializer of convolutional layers.
    kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
      Conv2D. Default to None.
    bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D.
      Default to None.

  Returns:
    blocks: A list of residual blocks to be applied in sequence.
  """
    blocks = []

    if residual_channels is None:
        residual_channels = out_channels

    for i in range(reps - 1):
        # Only use the stride at the first block so we don't repeatedly downsample
        # the input
        stride = initial_stride if i == 0 else 1

        # If the stride is more than 1, we cannot use an identity layer for the
        # skip connection and are forced to use a conv for the skip connection.
        skip_conv = stride > 1

        if i == 0 and initial_skip_conv:
            skip_conv = True

        blocks.append(
            nn_blocks.ResidualBlock(filters=residual_channels,
                                    strides=stride,
                                    use_explicit_padding=True,
                                    use_projection=skip_conv,
                                    use_sync_bn=use_sync_bn,
                                    norm_momentum=norm_momentum,
                                    norm_epsilon=norm_epsilon,
                                    kernel_initializer=kernel_initializer,
                                    kernel_regularizer=kernel_regularizer,
                                    bias_regularizer=bias_regularizer))

    if reps == 1:
        # If there is only 1 block, the `for` loop above is not run,
        # therefore we honor the requested stride in the last residual block
        stride = initial_stride
        # We are forced to use a conv in the skip connection if stride > 1
        skip_conv = stride > 1
    else:
        stride = 1
        skip_conv = residual_channels != out_channels

    blocks.append(
        nn_blocks.ResidualBlock(filters=out_channels,
                                strides=stride,
                                use_explicit_padding=True,
                                use_projection=skip_conv,
                                use_sync_bn=use_sync_bn,
                                norm_momentum=norm_momentum,
                                norm_epsilon=norm_epsilon,
                                kernel_initializer=kernel_initializer,
                                kernel_regularizer=kernel_regularizer,
                                bias_regularizer=bias_regularizer))

    return tf.keras.Sequential(blocks)