def _weight_jac_t_mat_prod(self,
                               module,
                               g_inp,
                               g_out,
                               mat,
                               sum_batch=True):
        V = mat.shape[0]
        G = module.groups
        C_in = module.input0.shape[1]
        N = module.output.shape[0]
        C_out = module.output.shape[1]

        mat_reshape = mat.reshape(V, N, G, C_out // G,
                                  *module.output.shape[2:])

        u = unfold_by_conv_transpose(module.input0,
                                     module).reshape(N, G, C_in // G,
                                                     *module.weight.shape[2:],
                                                     *module.output.shape[2:])

        dims_kern = "xyz"[:self.conv_dims]
        dims_data = "abc"[:self.conv_dims]
        result_str = ("vgio" if sum_batch else "vngio") + dims_kern
        equation = f"ngi{dims_kern}{dims_data},vngo{dims_data}->{result_str}"

        final_shape = ((V, *module.weight.shape) if sum_batch else
                       (V, N, *module.weight.shape))

        return einsum(equation, u, mat_reshape).reshape(final_shape)
Exemplo n.º 2
0
def conv_transpose_with_unfold(input, module):
    """Perform transpose convolution via matrix multiplication."""
    assert module.bias is None

    def get_output_shape(input, module):
        return module(input).shape

    N, C_in = input.shape[0], input.shape[1]

    output_shape = get_output_shape(input, module)
    C_out = output_shape[1]
    spatial_out_size = output_shape[2:]
    spatial_out_numel = spatial_out_size.numel()

    kernel_size = module.kernel_size
    kernel_size_numel = int(torch.prod(torch.Tensor(kernel_size)))

    G = module.groups

    weight_matrix = module.weight.data.reshape(
        C_in // G, G, C_out // G, kernel_size_numel
    )
    unfolded_input = unfold_by_conv_transpose(input, module).reshape(
        N, C_in // G, G, kernel_size_numel, spatial_out_numel
    )

    result = torch.einsum("cgox,ncgxh->ngoh", weight_matrix, unfolded_input)

    return result.reshape(N, C_out, *spatial_out_size)
Exemplo n.º 3
0
    def _weight_jac_t_mat_prod(self,
                               module,
                               g_inp,
                               g_out,
                               mat,
                               sum_batch=True):
        if module.groups != 1:
            raise NotImplementedError(
                "Groups greater than 1 are not supported yet")

        V = mat.shape[0]
        G = module.groups
        C_in = module.input0.shape[1]
        N = module.output.shape[0]
        C_out = module.output.shape[1]

        mat_reshape = mat.reshape(V, N, G, C_out // G,
                                  *module.output.shape[2:])

        u = unfold_by_conv_transpose(module.input0,
                                     module).reshape(N, C_in // G, G,
                                                     *module.weight.shape[2:],
                                                     *module.output.shape[2:])

        dims_kern = "xyz"[:self.conv_dims]
        dims_data = "abc"[:self.conv_dims]
        result_str = ("vigo" if sum_batch else "vnigo") + dims_kern
        equation = "nig{0}{1},vngo{1}->{2}".format(dims_kern, dims_data,
                                                   result_str)

        final_shape = ((V, *module.weight.shape) if sum_batch else
                       (V, N, *module.weight.shape))

        return einsum(equation, u, mat_reshape).reshape(final_shape)
Exemplo n.º 4
0
    def _weight_jac_t_mat_prod(
        self,
        module: Union[ConvTranspose1d, ConvTranspose2d, ConvTranspose3d],
        g_inp: Tuple[Tensor],
        g_out: Tuple[Tensor],
        mat: Tensor,
        sum_batch: bool = True,
        subsampling: List[int] = None,
    ) -> Tensor:
        V = mat.shape[0]
        G = module.groups
        C_in = module.input0.shape[1]
        N = module.output.shape[0] if subsampling is None else len(subsampling)
        C_out = module.output.shape[1]

        mat_reshape = mat.reshape(V, N, G, C_out // G,
                                  *module.output.shape[2:])

        u = unfold_by_conv_transpose(
            subsample(module.input0, subsampling=subsampling),
            module).reshape(N, G, C_in // G, *module.weight.shape[2:],
                            *module.output.shape[2:])

        dims_kern = "xyz"[:self.conv_dims]
        dims_data = "abc"[:self.conv_dims]
        result_str = ("vgio" if sum_batch else "vngio") + dims_kern
        equation = f"ngi{dims_kern}{dims_data},vngo{dims_data}->{result_str}"

        final_shape = ((V, *module.weight.shape) if sum_batch else
                       (V, N, *module.weight.shape))

        return einsum(equation, u, mat_reshape).reshape(final_shape)
Exemplo n.º 5
0
    def weight(self, ext, module, g_inp, g_out, backproped):
        sqrt_h_outs = backproped["matrices"]
        sqrt_h_outs_signs = backproped["signs"]
        X = convUtils.unfold_by_conv_transpose(module.input0, module)
        h_diag = torch.zeros_like(module.weight)

        for h_sqrt, sign in zip(sqrt_h_outs, sqrt_h_outs_signs):
            h_diag.add_(
                convUtils.extract_weight_diagonal(module,
                                                  X,
                                                  h_sqrt,
                                                  sum_batch=True),
                alpha=sign,
            )

        return h_diag
    def _weight_jac_mat_prod(self, module, g_inp, g_out, mat):
        V = mat.shape[0]
        G = module.groups
        C_in = module.input0.shape[1]
        N = module.output.shape[0]
        C_out = module.output.shape[1]

        mat_reshape = mat.reshape(V, G, C_in // G, C_out // G,
                                  *module.weight.shape[2:])
        u = unfold_by_conv_transpose(module.input0,
                                     module).reshape(N, G, C_in // G,
                                                     *module.weight.shape[2:],
                                                     *module.output.shape[2:])

        dims_kern = "xyz"[:self.conv_dims]
        dims_data = "abc"[:self.conv_dims]
        einstr = "ngi{0}{1},vgio{0}->vngo{1}".format(dims_kern, dims_data)
        jac_mat = einsum(einstr, u, mat_reshape)

        return self.reshape_like_output(jac_mat, module)
Exemplo n.º 7
0
    def weight(self, ext, module, g_inp, g_out, backproped):
        N = module.input0.shape[0]
        sqrt_h_outs = backproped["matrices"]
        sqrt_h_outs_signs = backproped["signs"]
        X = convUtils.unfold_by_conv_transpose(module.input0, module)
        h_diag = torch.zeros(
            N,
            *module.weight.shape,
            device=module.weight.device,
            dtype=module.weight.dtype,
        )

        for h_sqrt, sign in zip(sqrt_h_outs, sqrt_h_outs_signs):
            h_diag.add_(
                convUtils.extract_weight_diagonal(module,
                                                  X,
                                                  h_sqrt,
                                                  sum_batch=False),
                alpha=sign,
            )

        return h_diag
Exemplo n.º 8
0
    def _weight_jac_mat_prod(self, module, g_inp, g_out, mat):
        if module.groups != 1:
            raise NotImplementedError(
                "Groups greater than 1 are not supported yet")

        V = mat.shape[0]
        G = module.groups
        C_in = module.input0.shape[1]
        N = module.output.shape[0]
        C_out = module.output.shape[1]

        mat_reshape = mat.reshape(V, C_in, G, C_out // G,
                                  *module.weight.shape[2:])
        u = unfold_by_conv_transpose(module.input0,
                                     module).reshape(N, C_in // G, G,
                                                     *module.weight.shape[2:],
                                                     *module.output.shape[2:])

        dims_kern = "xyz"[:self.conv_dims]
        dims_data = "abc"[:self.conv_dims]
        einstr = "nig{0}{1},vigo{0}->vngo{1}".format(dims_kern, dims_data)
        jac_mat = einsum(einstr, u, mat_reshape)

        return self.reshape_like_output(jac_mat, module)
Exemplo n.º 9
0
 def weight(self, ext, module, grad_inp, grad_out, backproped):
     X = convUtils.unfold_by_conv_transpose(module.input0, module)
     weight_diag = convUtils.extract_weight_diagonal(
         module, X, backproped, sum_batch=True
     )
     return weight_diag