Exemplo n.º 1
0
    def test_create_acoustic_stem_with_callable(self):
        """
        Test builder `create_acoustic_res_basic_stem` with callable
        inputs.
        """
        for (pool, activation, norm) in itertools.product(
            (nn.AvgPool3d, nn.MaxPool3d, None),
            (nn.ReLU, nn.Softmax, nn.Sigmoid, None),
            (nn.BatchNorm3d, None),
        ):
            model = create_acoustic_res_basic_stem(
                in_channels=3,
                out_channels=64,
                pool=pool,
                activation=activation,
                norm=norm,
            )
            model_gt = ResNetBasicStem(
                conv=ConvReduce3D(
                    in_channels=3,
                    out_channels=64,
                    kernel_size=((3, 1, 1), (1, 7, 7)),
                    stride=((1, 1, 1), (1, 1, 1)),
                    padding=((1, 0, 0), (0, 3, 3)),
                    bias=(False, False),
                ),
                norm=None if norm is None else norm(64),
                activation=None if activation is None else activation(),
                pool=None if pool is None else pool(kernel_size=[1, 3, 3],
                                                    stride=[1, 2, 2],
                                                    padding=[0, 1, 1]),
            )

            model.load_state_dict(model_gt.state_dict(),
                                  strict=True)  # explicitly use strict mode.

            # Test forwarding.
            for input_tensor in TestResNetBasicStem._get_inputs():
                with torch.no_grad():
                    if input_tensor.shape[1] != 3:
                        with self.assertRaises(RuntimeError):
                            output_tensor = model(input_tensor)
                        continue
                    else:
                        output_tensor = model(input_tensor)
                        output_tensor_gt = model_gt(input_tensor)
                self.assertEqual(
                    output_tensor.shape,
                    output_tensor_gt.shape,
                    "Output shape {} is different from expected shape {}".
                    format(output_tensor.shape, output_tensor_gt.shape),
                )
                self.assertTrue(
                    np.allclose(output_tensor.numpy(),
                                output_tensor_gt.numpy()))
Exemplo n.º 2
0
    def test_create_complex_stem(self):
        """
        Test complex ResNetBasicStem.
        """
        for input_dim, output_dim in itertools.product((2, 3), (4, 8, 16)):
            model = ResNetBasicStem(
                conv=nn.Conv3d(
                    input_dim,
                    output_dim,
                    kernel_size=[3, 7, 7],
                    stride=[1, 2, 2],
                    padding=[1, 3, 3],
                    bias=False,
                ),
                norm=nn.BatchNorm3d(output_dim),
                activation=nn.ReLU(),
                pool=nn.MaxPool3d(kernel_size=[1, 3, 3],
                                  stride=[1, 2, 2],
                                  padding=[0, 1, 1]),
            )

            # Test forwarding.
            for input_tensor in TestResNetBasicStem._get_inputs(input_dim):
                if input_tensor.shape[1] != input_dim:
                    with self.assertRaises(Exception):
                        output_tensor = model(input_tensor)
                    continue
                else:
                    output_tensor = model(input_tensor)

                input_shape = input_tensor.shape
                output_shape = output_tensor.shape

                output_shape_gt = (
                    input_shape[0],
                    output_dim,
                    input_shape[2],
                    (((input_shape[3] - 1) // 2 + 1) - 1) // 2 + 1,
                    (((input_shape[4] - 1) // 2 + 1) - 1) // 2 + 1,
                )

                self.assertEqual(
                    output_shape,
                    output_shape_gt,
                    "Output shape {} is different from expected shape {}".
                    format(output_shape, output_shape_gt),
                )
Exemplo n.º 3
0
    def test_create_stem_with_conv_reduced_3d(self):
        """
        Test simple ResNetBasicStem with ConvReduce3D.
        """
        for input_dim, output_dim in itertools.product((2, 3), (4, 8, 16)):
            model = ResNetBasicStem(
                conv=ConvReduce3D(
                    in_channels=input_dim,
                    out_channels=output_dim,
                    kernel_size=(3, 3),
                    stride=(1, 1),
                    padding=(1, 1),
                    bias=(False, False),
                ),
                norm=nn.BatchNorm3d(output_dim),
                activation=nn.ReLU(),
                pool=None,
            )

            # Test forwarding.
            for tensor in TestResNetBasicStem._get_inputs(input_dim):
                if tensor.shape[1] != input_dim:
                    with self.assertRaises(RuntimeError):
                        output_tensor = model(tensor)
                    continue
                else:
                    output_tensor = model(tensor)

                input_shape = tensor.shape
                output_shape = output_tensor.shape
                output_shape_gt = (
                    input_shape[0],
                    output_dim,
                    input_shape[2],
                    input_shape[3],
                    input_shape[4],
                )

                self.assertEqual(
                    output_shape,
                    output_shape_gt,
                    "Output shape {} is different from expected shape {}".
                    format(output_shape, output_shape_gt),
                )
Exemplo n.º 4
0
def create_x3d_stem(
    *,
    # Conv configs.
    in_channels: int,
    out_channels: int,
    conv_kernel_size: Tuple[int] = (5, 3, 3),
    conv_stride: Tuple[int] = (1, 2, 2),
    conv_padding: Tuple[int] = (2, 1, 1),
    # BN configs.
    norm: Callable = nn.BatchNorm3d,
    norm_eps: float = 1e-5,
    norm_momentum: float = 0.1,
    # Activation configs.
    activation: Callable = nn.ReLU,
) -> nn.Module:
    """
    Creates the stem layer for X3D. It performs spatial Conv, temporal Conv, BN, and Relu.

    ::

                                        Conv_xy
                                           ↓
                                        Conv_t
                                           ↓
                                     Normalization
                                           ↓
                                       Activation

    Args:
        in_channels (int): input channel size of the convolution.
        out_channels (int): output channel size of the convolution.
        conv_kernel_size (tuple): convolutional kernel size(s).
        conv_stride (tuple): convolutional stride size(s).
        conv_padding (tuple): convolutional padding size(s).

        norm (callable): a callable that constructs normalization layer, options
            include nn.BatchNorm3d, None (not performing normalization).
        norm_eps (float): normalization epsilon.
        norm_momentum (float): normalization momentum.

        activation (callable): a callable that constructs activation layer, options
            include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing
            activation).

    Returns:
        (nn.Module): X3D stem layer.
    """
    conv_xy_module = nn.Conv3d(
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=(1, conv_kernel_size[1], conv_kernel_size[2]),
        stride=(1, conv_stride[1], conv_stride[2]),
        padding=(0, conv_padding[1], conv_padding[2]),
        bias=False,
    )
    conv_t_module = nn.Conv3d(
        in_channels=out_channels,
        out_channels=out_channels,
        kernel_size=(conv_kernel_size[0], 1, 1),
        stride=(conv_stride[0], 1, 1),
        padding=(conv_padding[0], 0, 0),
        bias=False,
        groups=out_channels,
    )
    stacked_conv_module = Conv2plus1d(
        conv_t=conv_xy_module,
        norm=None,
        activation=None,
        conv_xy=conv_t_module,
    )

    norm_module = (None if norm is None else norm(
        num_features=out_channels, eps=norm_eps, momentum=norm_momentum))
    activation_module = None if activation is None else activation()

    return ResNetBasicStem(
        conv=stacked_conv_module,
        norm=norm_module,
        activation=activation_module,
        pool=None,
    )