def test_mask_conv_output_size_and_seq_lens( mask_conv1d_input: Tuple[MaskConv1d, Tuple[torch.Tensor, torch.Tensor]] ) -> None: """Ensures the MaskConv1d output size and seq_len values are correct.""" mask_conv1d, inputs = mask_conv1d_input batch, channels, seq_len = inputs[0].size() out_tensor, out_seq_lens = mask_conv1d(inputs) out_batch, out_channels, out_seq_len = out_tensor.size() assert out_batch == len(out_seq_lens) == batch assert out_channels == mask_conv1d.out_channels # test out_seq_len dimension of tensor output and set padding value ready # for testing out_seq_lens if mask_conv1d._padding_mode == PaddingMode.NONE: padding = (0, 0) # out_lens function returns correct expected length if the tests pass exp_len = out_lens( seq_lens=torch.tensor(seq_len), kernel_size=mask_conv1d.kernel_size[0], stride=mask_conv1d.stride[0], dilation=mask_conv1d.dilation[0], padding=0, ).item() assert out_seq_len == exp_len elif mask_conv1d._padding_mode == PaddingMode.SAME: padding = pad_same( length=seq_len, kernel_size=mask_conv1d.kernel_size[0], stride=mask_conv1d.stride[0], dilation=mask_conv1d.dilation[0], ) # by definition of SAME padding assert out_seq_len == math.ceil(float(seq_len) / mask_conv1d.stride[0]) else: raise ValueError(f"unknown PaddingMode {mask_conv1d._padding_mode}") # test out_seq_lens exp_out_seq_lens = out_lens( seq_lens=inputs[1], kernel_size=mask_conv1d.kernel_size[0], stride=mask_conv1d.stride[0], dilation=mask_conv1d.dilation[0], padding=sum(padding), ) assert torch.all(out_seq_lens == exp_out_seq_lens)
def test_pad_same_output_has_correct_size_after_conv( conv1d_valid_input: Tuple[torch.nn.Conv1d, torch.Tensor]) -> None: """Ensures output size after applying [pad_same, Conv1d](x) is correct.""" conv1d, tensor = conv1d_valid_input batch_size, in_channels, seq_len = tensor.size() pad = pad_same( length=seq_len, kernel_size=conv1d.kernel_size[0], stride=conv1d.stride[0], dilation=conv1d.dilation[0], ) tensor = torch.nn.functional.pad(tensor, pad) out = conv1d(tensor) assert out.size(2) == math.ceil(float(seq_len) / conv1d.stride[0])
def test_pad_same_raises_value_error_invalid_parameters( length: int, kernel_size: int, stride: int, dilation: int) -> None: """Ensures pad_same raises ValueError when called with invalid params.""" assume(length <= 0 or kernel_size <= 0 or stride <= 0 or dilation <= 0) with pytest.raises(ValueError): pad_same(length, kernel_size, stride, dilation)