Beispiel #1
0
def conv3d(
    input,
    pointwise,
    spatial,
    bias=None,
    stride=1,
    padding=0,
    dilation=1,
    groups=1,
):
    stride = triple(stride)
    padding = triple(padding)
    dilation = triple(dilation)
    _ = F.conv3d(input, pointwise, bias, 1, 0, 1, groups)
    for i, weight in enumerate(spatial):
        stri = one_diff_tuple(3, 1, stride[i], i)
        pad = one_diff_tuple(3, 0, padding[i], i)
        dil = one_diff_tuple(3, 1, dilation[i], i)
        _ = F.conv3d(_, weight, None, stri, pad, dil, _.shape[1])
    return _
Beispiel #2
0
def conv_transpose2d(
    input,
    pointwise,
    spatial,
    bias=None,
    stride=1,
    padding=0,
    output_padding=0,
    groups=1,
    dilation=1,
):
    stride = double(stride)
    padding = double(padding)
    output_padding = double(output_padding)
    dilation = double(dilation)
    _ = F.conv2d(input, pointwise, bias, 1, 0, 1, groups)
    for i, weight in enumerate(spatial):
        stri = one_diff_tuple(2, 1, stride[i], i)
        pad = one_diff_tuple(2, 0, padding[i], i)
        out_pad = one_diff_tuple(2, 0, output_padding[i], i)
        dil = one_diff_tuple(2, 1, dilation[i], i)
        _ = F.conv_transpose2d(_, weight, None, stri, pad, out_pad, _.shape[1],
                               dil)
    return _
Beispiel #3
0
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding,
        dilation,
        output_padding,
        groups,
        bias,
        padding_mode,
    ):
        super(_ConvNd, self).__init__()
        if in_channels % groups != 0:
            raise ValueError("in_channels must be divisible by groups")
        if out_channels % groups != 0:
            raise ValueError("out_channels must be divisible by groups")
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = self._tuple(kernel_size)
        self.stride = self._tuple(stride)
        self.padding = self._tuple(padding)
        self.dilation = self._tuple(dilation)
        self.output_padding = self._tuple(output_padding)
        self.groups = groups
        self.padding_mode = padding_mode

        _pw_kernel = (1, ) * self._dim

        self.pointwise = Parameter(
            torch.Tensor(out_channels, in_channels // groups, *_pw_kernel))

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter("bias", None)

        self.spatial = []
        for x in range(self._dim):
            kernel = one_diff_tuple(self._dim, 1, self.kernel_size[x], x)
            weight = Parameter(torch.Tensor(out_channels, 1, *kernel))
            setattr(self, f"spatial_{x}", weight)
            self.spatial.append(getattr(self, f"spatial_{x}"))

        self.reset_parameters()