Exemplo n.º 1
0
    def test_s3d_stem_cells(self, depth_multiplier, first_temporal_kernel_size,
                            temporal_conv_endpoints):
        batch_size = 1
        num_frames = 64
        height, width = 224, 224

        inputs = tf.keras.layers.Input(shape=(num_frames, height, width, 3),
                                       batch_size=batch_size)

        outputs, output_endpoints = inception_utils.inception_v1_stem_cells(
            inputs,
            depth_multiplier,
            'Mixed_5c',
            temporal_conv_endpoints=temporal_conv_endpoints,
            self_gating_endpoints={'Conv2d_2c_3x3'},
            first_temporal_kernel_size=first_temporal_kernel_size)
        self.assertListEqual(
            outputs.shape.as_list(),
            [batch_size, 32, 28, 28,
             int(192 * depth_multiplier)])

        expected_endpoints = {
            'Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1',
            'Conv2d_2c_3x3', 'MaxPool_3a_3x3'
        }
        self.assertSetEqual(expected_endpoints, set(output_endpoints.keys()))
Exemplo n.º 2
0
    def __init__(self,
                 input_specs: tf.keras.layers.InputSpec,
                 final_endpoint: Text = 'Mixed_5c',
                 first_temporal_kernel_size: int = 3,
                 temporal_conv_start_at: Text = 'Conv2d_2c_3x3',
                 gating_start_at: Text = 'Conv2d_2c_3x3',
                 swap_pool_and_1x1x1: bool = True,
                 gating_style: Text = 'CELL',
                 use_sync_bn: bool = False,
                 norm_momentum: float = 0.999,
                 norm_epsilon: float = 0.001,
                 temporal_conv_initializer: Union[
                     Text,
                     initializers.Initializer] = initializers.TruncatedNormal(
                         mean=0.0, stddev=0.01),
                 temporal_conv_type: Text = '2+1d',
                 kernel_initializer: Union[
                     Text,
                     initializers.Initializer] = initializers.TruncatedNormal(
                         mean=0.0, stddev=0.01),
                 kernel_regularizer: Union[Text,
                                           regularizers.Regularizer] = 'l2',
                 depth_multiplier: float = 1.0,
                 **kwargs):
        """Constructor.

    Args:
      input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
      final_endpoint: Specifies the endpoint to construct the network up to.
      first_temporal_kernel_size: Temporal kernel size of the first convolution
        layer.
      temporal_conv_start_at: Specifies the endpoint where to start performimg
        temporal convolution from.
      gating_start_at: Specifies the endpoint where to start performimg self
        gating from.
      swap_pool_and_1x1x1: A boolean flag indicates that whether to swap the
        order of convolution and max pooling in Branch_3 of inception v1 cell.
      gating_style: A string that specifies self gating to be applied after each
        branch and/or after each cell. It can be one of ['BRANCH', 'CELL',
        'BRANCH_AND_CELL'].
      use_sync_bn: If True, use synchronized batch normalization.
      norm_momentum: A `float` of normalization momentum for the moving average.
      norm_epsilon: A `float` added to variance to avoid dividing by zero.
      temporal_conv_initializer: Weight initializer for temporal convolutional
        layers.
      temporal_conv_type: The type of parameterized convolution. Currently, we
        support '2d', '3d', '2+1d', '1+2d'.
      kernel_initializer: Weight initializer for convolutional layers other than
        temporal convolution.
      kernel_regularizer: Weight regularizer for all convolutional layers.
      depth_multiplier: A float to reduce/increase number of channels.
      **kwargs: keyword arguments to be passed.
    """

        self._input_specs = input_specs
        self._final_endpoint = final_endpoint
        self._first_temporal_kernel_size = first_temporal_kernel_size
        self._temporal_conv_start_at = temporal_conv_start_at
        self._gating_start_at = gating_start_at
        self._swap_pool_and_1x1x1 = swap_pool_and_1x1x1
        self._gating_style = gating_style
        self._use_sync_bn = use_sync_bn
        self._norm_momentum = norm_momentum
        self._norm_epsilon = norm_epsilon
        self._temporal_conv_initializer = temporal_conv_initializer
        self._temporal_conv_type = temporal_conv_type
        self._kernel_initializer = kernel_initializer
        self._kernel_regularizer = kernel_regularizer
        self._depth_multiplier = depth_multiplier

        self._temporal_conv_endpoints = net_utils.make_set_from_start_endpoint(
            temporal_conv_start_at,
            inception_utils.INCEPTION_V1_CONV_ENDPOINTS)
        self._self_gating_endpoints = net_utils.make_set_from_start_endpoint(
            gating_start_at, inception_utils.INCEPTION_V1_CONV_ENDPOINTS)

        inputs = tf.keras.Input(shape=input_specs.shape[1:])
        net, end_points = inception_utils.inception_v1_stem_cells(
            inputs,
            depth_multiplier,
            final_endpoint,
            temporal_conv_endpoints=self._temporal_conv_endpoints,
            self_gating_endpoints=self._self_gating_endpoints,
            temporal_conv_type=self._temporal_conv_type,
            first_temporal_kernel_size=self._first_temporal_kernel_size,
            use_sync_bn=self._use_sync_bn,
            norm_momentum=self._norm_momentum,
            norm_epsilon=self._norm_epsilon,
            temporal_conv_initializer=self._temporal_conv_initializer,
            kernel_initializer=self._kernel_initializer,
            kernel_regularizer=self._kernel_regularizer,
            parameterized_conv_layer=self._get_parameterized_conv_layer_impl(),
            layer_naming_fn=self._get_layer_naming_fn(),
        )

        for end_point, filters in inception_utils.INCEPTION_V1_ARCH_SKELETON:
            net, end_points = self._s3d_cell(net, end_point, end_points,
                                             filters)
            if end_point == final_endpoint:
                break

        if final_endpoint not in end_points:
            raise ValueError(
                'Unrecognized final endpoint %s (available endpoints: %s).' %
                (final_endpoint, end_points.keys()))

        super(S3D, self).__init__(inputs=inputs, outputs=end_points, **kwargs)