def test_create_acoustic_stem_with_callable(self): """ Test builder `create_acoustic_res_basic_stem` with callable inputs. """ for (pool, activation, norm) in itertools.product( (nn.AvgPool3d, nn.MaxPool3d, None), (nn.ReLU, nn.Softmax, nn.Sigmoid, None), (nn.BatchNorm3d, None), ): model = create_acoustic_res_basic_stem( in_channels=3, out_channels=64, pool=pool, activation=activation, norm=norm, ) model_gt = ResNetBasicStem( conv=ConvReduce3D( in_channels=3, out_channels=64, kernel_size=((3, 1, 1), (1, 7, 7)), stride=((1, 1, 1), (1, 1, 1)), padding=((1, 0, 0), (0, 3, 3)), bias=(False, False), ), norm=None if norm is None else norm(64), activation=None if activation is None else activation(), pool=None if pool is None else pool(kernel_size=[1, 3, 3], stride=[1, 2, 2], padding=[0, 1, 1]), ) model.load_state_dict(model_gt.state_dict(), strict=True) # explicitly use strict mode. # Test forwarding. for input_tensor in TestResNetBasicStem._get_inputs(): with torch.no_grad(): if input_tensor.shape[1] != 3: with self.assertRaises(RuntimeError): output_tensor = model(input_tensor) continue else: output_tensor = model(input_tensor) output_tensor_gt = model_gt(input_tensor) self.assertEqual( output_tensor.shape, output_tensor_gt.shape, "Output shape {} is different from expected shape {}". format(output_tensor.shape, output_tensor_gt.shape), ) self.assertTrue( np.allclose(output_tensor.numpy(), output_tensor_gt.numpy()))
def test_create_complex_stem(self): """ Test complex ResNetBasicStem. """ for input_dim, output_dim in itertools.product((2, 3), (4, 8, 16)): model = ResNetBasicStem( conv=nn.Conv3d( input_dim, output_dim, kernel_size=[3, 7, 7], stride=[1, 2, 2], padding=[1, 3, 3], bias=False, ), norm=nn.BatchNorm3d(output_dim), activation=nn.ReLU(), pool=nn.MaxPool3d(kernel_size=[1, 3, 3], stride=[1, 2, 2], padding=[0, 1, 1]), ) # Test forwarding. for input_tensor in TestResNetBasicStem._get_inputs(input_dim): if input_tensor.shape[1] != input_dim: with self.assertRaises(Exception): output_tensor = model(input_tensor) continue else: output_tensor = model(input_tensor) input_shape = input_tensor.shape output_shape = output_tensor.shape output_shape_gt = ( input_shape[0], output_dim, input_shape[2], (((input_shape[3] - 1) // 2 + 1) - 1) // 2 + 1, (((input_shape[4] - 1) // 2 + 1) - 1) // 2 + 1, ) self.assertEqual( output_shape, output_shape_gt, "Output shape {} is different from expected shape {}". format(output_shape, output_shape_gt), )
def test_create_stem_with_conv_reduced_3d(self): """ Test simple ResNetBasicStem with ConvReduce3D. """ for input_dim, output_dim in itertools.product((2, 3), (4, 8, 16)): model = ResNetBasicStem( conv=ConvReduce3D( in_channels=input_dim, out_channels=output_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=(False, False), ), norm=nn.BatchNorm3d(output_dim), activation=nn.ReLU(), pool=None, ) # Test forwarding. for tensor in TestResNetBasicStem._get_inputs(input_dim): if tensor.shape[1] != input_dim: with self.assertRaises(RuntimeError): output_tensor = model(tensor) continue else: output_tensor = model(tensor) input_shape = tensor.shape output_shape = output_tensor.shape output_shape_gt = ( input_shape[0], output_dim, input_shape[2], input_shape[3], input_shape[4], ) self.assertEqual( output_shape, output_shape_gt, "Output shape {} is different from expected shape {}". format(output_shape, output_shape_gt), )
def create_x3d_stem( *, # Conv configs. in_channels: int, out_channels: int, conv_kernel_size: Tuple[int] = (5, 3, 3), conv_stride: Tuple[int] = (1, 2, 2), conv_padding: Tuple[int] = (2, 1, 1), # BN configs. norm: Callable = nn.BatchNorm3d, norm_eps: float = 1e-5, norm_momentum: float = 0.1, # Activation configs. activation: Callable = nn.ReLU, ) -> nn.Module: """ Creates the stem layer for X3D. It performs spatial Conv, temporal Conv, BN, and Relu. :: Conv_xy ↓ Conv_t ↓ Normalization ↓ Activation Args: in_channels (int): input channel size of the convolution. out_channels (int): output channel size of the convolution. conv_kernel_size (tuple): convolutional kernel size(s). conv_stride (tuple): convolutional stride size(s). conv_padding (tuple): convolutional padding size(s). norm (callable): a callable that constructs normalization layer, options include nn.BatchNorm3d, None (not performing normalization). norm_eps (float): normalization epsilon. norm_momentum (float): normalization momentum. activation (callable): a callable that constructs activation layer, options include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing activation). Returns: (nn.Module): X3D stem layer. """ conv_xy_module = nn.Conv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=(1, conv_kernel_size[1], conv_kernel_size[2]), stride=(1, conv_stride[1], conv_stride[2]), padding=(0, conv_padding[1], conv_padding[2]), bias=False, ) conv_t_module = nn.Conv3d( in_channels=out_channels, out_channels=out_channels, kernel_size=(conv_kernel_size[0], 1, 1), stride=(conv_stride[0], 1, 1), padding=(conv_padding[0], 0, 0), bias=False, groups=out_channels, ) stacked_conv_module = Conv2plus1d( conv_t=conv_xy_module, norm=None, activation=None, conv_xy=conv_t_module, ) norm_module = (None if norm is None else norm( num_features=out_channels, eps=norm_eps, momentum=norm_momentum)) activation_module = None if activation is None else activation() return ResNetBasicStem( conv=stacked_conv_module, norm=norm_module, activation=activation_module, pool=None, )