def conv3d( input, pointwise, spatial, bias=None, stride=1, padding=0, dilation=1, groups=1, ): stride = triple(stride) padding = triple(padding) dilation = triple(dilation) _ = F.conv3d(input, pointwise, bias, 1, 0, 1, groups) for i, weight in enumerate(spatial): stri = one_diff_tuple(3, 1, stride[i], i) pad = one_diff_tuple(3, 0, padding[i], i) dil = one_diff_tuple(3, 1, dilation[i], i) _ = F.conv3d(_, weight, None, stri, pad, dil, _.shape[1]) return _
def conv_transpose2d( input, pointwise, spatial, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1, ): stride = double(stride) padding = double(padding) output_padding = double(output_padding) dilation = double(dilation) _ = F.conv2d(input, pointwise, bias, 1, 0, 1, groups) for i, weight in enumerate(spatial): stri = one_diff_tuple(2, 1, stride[i], i) pad = one_diff_tuple(2, 0, padding[i], i) out_pad = one_diff_tuple(2, 0, output_padding[i], i) dil = one_diff_tuple(2, 1, dilation[i], i) _ = F.conv_transpose2d(_, weight, None, stri, pad, out_pad, _.shape[1], dil) return _
def __init__( self, in_channels, out_channels, kernel_size, stride, padding, dilation, output_padding, groups, bias, padding_mode, ): super(_ConvNd, self).__init__() if in_channels % groups != 0: raise ValueError("in_channels must be divisible by groups") if out_channels % groups != 0: raise ValueError("out_channels must be divisible by groups") self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = self._tuple(kernel_size) self.stride = self._tuple(stride) self.padding = self._tuple(padding) self.dilation = self._tuple(dilation) self.output_padding = self._tuple(output_padding) self.groups = groups self.padding_mode = padding_mode _pw_kernel = (1, ) * self._dim self.pointwise = Parameter( torch.Tensor(out_channels, in_channels // groups, *_pw_kernel)) if bias: self.bias = Parameter(torch.Tensor(out_channels)) else: self.register_parameter("bias", None) self.spatial = [] for x in range(self._dim): kernel = one_diff_tuple(self._dim, 1, self.kernel_size[x], x) weight = Parameter(torch.Tensor(out_channels, 1, *kernel)) setattr(self, f"spatial_{x}", weight) self.spatial.append(getattr(self, f"spatial_{x}")) self.reset_parameters()