コード例 #1
0
ファイル: test_conv.py プロジェクト: zueigung1419/backpack
def convolution_with_unfold(input, module):
    """Perform convolution via matrix multiplication."""
    assert module.bias is None

    def get_output_shape(input, module):
        return module(input).shape

    N, C_in = input.shape[0], input.shape[1]

    output_shape = get_output_shape(input, module)
    C_out = output_shape[1]
    spatial_out_size = output_shape[2:]
    spatial_out_numel = spatial_out_size.numel()

    kernel_size = module.kernel_size
    kernel_size_numel = int(torch.prod(torch.Tensor(kernel_size)))

    G = module.groups

    weight_matrix = module.weight.data.reshape(G, C_out // G, C_in // G,
                                               kernel_size_numel)
    unfolded_input = unfold_by_conv(input,
                                    module).reshape(N, G, C_in // G,
                                                    kernel_size_numel,
                                                    spatial_out_numel)

    result = torch.einsum("gocx,ngcxh->ngoh", weight_matrix, unfolded_input)

    return result.reshape(N, C_out, *spatial_out_size)
コード例 #2
0
ファイル: test_conv.py プロジェクト: zueigung1419/backpack
def test_unfold_by_conv(problem):
    """Test the Unfold by convolution for torch.nn.Conv2d.

    Args:
        problem (ConvProblem): Problem for testing unfold operation.
    """
    problem.set_up()
    input = torch.rand(problem.input_shape).to(problem.device)

    result_unfold = unfold_func(problem.module)(input)
    result_unfold_by_conv = unfold_by_conv(input, problem.module)

    check_sizes_and_values(result_unfold, result_unfold_by_conv)
    problem.tear_down()
コード例 #3
0
ファイル: convnd.py プロジェクト: f-dangel/backpack
 def get_unfolded_input(self, module):
     return unfold_by_conv(module.input0, module)