Example #1
0
def conv2d_backward(dout, input_requires_grad, input_shape,
                    weight_requires_grad, weight, x_cols, stride, padding):
    """
    Backward pass for a convolutional layer.
    Inputs:
    - dout: Upstream derivatives.
    - cache: A tuple of (x, w, bi, conv_param) as in conv_forward
    Returns a tuple of:
    - dx: Gradient with respect to x
    - dw: Gradient with respect to w
    - db: Gradient with respect to b
    """
    dx, dw = None, None
    """
    """

    num_filters, _, kH, kW = weight.shape
    N, C, H, W = dout.shape
    dout_reshaped = dout.transpose(1, 2, 3, 0).reshape(num_filters, -1)

    if weight_requires_grad:
        dw = dout_reshaped.dot(x_cols.T).reshape(weight.shape)

    if input_requires_grad:
        # 1.32  s
        # dx_cols = weight.reshape(num_filters, -1).T.dot(dout_reshaped)
        # dx = col2im_indices(dx_cols, input_shape, kH, kW, padding, stride)

        dx = torch.conv_transpose2d(torch.tensor(dout),
                                    torch.tensor(weight),
                                    bias=None,
                                    stride=stride,
                                    padding=padding).numpy()

    return dx, dw
Example #2
0
 def forward(self, x):
     """
     forward pass of the layer
     :param x: input
     :return: y => output
     """
     return torch.conv_transpose2d(input=x,
                             weight=self.weight * self.scale,  # scale the weight on runtime
                             bias=self.bias if self.use_bias else None,
                             stride=self.stride,
                             padding=self.pad)
Example #3
0
    def forward(self, x, y):
        key = self.conv(x)
        key = torch.sigmoid(key)
        key = F.interpolate(key, (32, 32))
        key = key.flatten()
        # key = torch.tanh(key)
        attention = torch.matmul(self.keys, key)
        attention = torch.softmax(attention, 0)
        attention = torch.reshape(attention, (-1, 1))

        kernel = self.values * attention
        kernel = torch.sum(kernel, 0)
        kernel = torch.reshape(kernel, (3, 3, 5, 5))
        out = torch.conv_transpose2d(x, weight=kernel, stride=2, padding=2)
        # out = torch.sigmoid(out)
        # out = F.interpolate(out, size=(64, 64))
        it = torch.tensor([0])
        # while 1/kornia.psnr_loss(out, y, max_val=1.0) > 0.0:
        while F.mse_loss(out, y) > 0.0:
            # st.write(f'ITERATION: {it} ol: {ol}')
            it += 1
            key = self.conv(out)
            key = torch.sigmoid(key)
            key = F.interpolate(key, (32, 32))
            key = key.flatten()
            # key = torch.tanh(key)
            attention = torch.matmul(self.keys, key)
            attention = torch.softmax(attention, 0)
            attention = torch.reshape(attention, (-1, 1))

            kernel = self.values * attention
            kernel = torch.sum(kernel, 0)
            kernel = torch.reshape(kernel, (3, 3, 5, 5))
            out = torch.conv_transpose2d(out,
                                         weight=kernel,
                                         stride=2,
                                         padding=2)
            if it >= 5:
                break
        # st.write(f'ITERATIONS: {it}')
        return out
Example #4
0
 def forward(self, x: Tensor, output_size: Any = None) -> Tensor:
     output_padding = self._output_padding(input, output_size, self.stride,
                                           self.padding, self.kernel_size)
     return torch.conv_transpose2d(
         input=x,
         weight=self.weight * self.scale,  # scale the weight on runtime
         bias=self.bias,
         stride=self.stride,
         padding=self.padding,
         output_padding=output_padding,
         groups=self.groups,
         dilation=self.dilation,
     )
Example #5
0
    def forward(self, x):
        key = self.conv1(x).flatten()
        attention = torch.matmul(self.keys, key)
        attention = torch.softmax(attention, 0)
        attention = torch.reshape(attention, (-1, 1))

        kernel = self.values * attention
        kernel = torch.sum(kernel, 0)
        kernel = torch.reshape(kernel, (3, 3, 5, 5))

        out = torch.conv_transpose2d(x, weight=kernel, stride=2, padding=2)

        return out
Example #6
0
    def forward(self, x):
        key = self.conv(x)
        key = torch.sigmoid(key)
        key = F.interpolate(key, (32, 32))
        key = key.flatten()
        # key = torch.tanh(key)
        attention = torch.matmul(self.keys, key)
        attention = torch.softmax(attention, 0)
        ol = torch.max(attention)
        attention = torch.reshape(attention, (-1, 1))

        kernel = self.values * attention
        kernel = torch.sum(kernel, 0)
        kernel = torch.reshape(kernel, (3, 3, 5, 5))
        out = torch.conv_transpose2d(x, weight=kernel, stride=2, padding=2)
        # st.write(ol)
        it = 0
        while ol < 1.0:
            key = self.conv(out)
            key = torch.sigmoid(key)
            key = F.interpolate(key, (32, 32))
            key = key.flatten()
            # key = torch.tanh(key)
            attention = torch.matmul(self.keys, key)
            attention = torch.softmax(attention, 0)
            ol += torch.max(attention)
            attention = torch.reshape(attention, (-1, 1))

            kernel = self.values * attention
            kernel = torch.sum(kernel, 0)
            kernel = torch.reshape(kernel, (3, 3, 5, 5))
            out = torch.conv_transpose2d(x, weight=kernel, stride=2, padding=2)

            if it >= 10:
                break
        st.write(f'ITERATIONS: {it}')
        return out
Example #7
0
    def forward(self, x):
        key = F.interpolate(x, (32, 32))
        # key = torch.sigmoid(key)
        key = key.flatten()
        # key = torch.tanh(key)
        attention = torch.matmul(self.keys, key)
        attention = torch.softmax(attention, 0)
        attention = torch.reshape(attention, (-1, 1))

        kernel = self.values * attention
        kernel = torch.sum(kernel, 0)
        kernel = torch.reshape(kernel, (3, 3, 5, 5))

        # out = torch.conv_transpose2d(x, weight=kernel, stride=2, padding=2)
        out = torch.conv_transpose2d(x, weight=kernel, stride=1)
        return out
Example #8
0
 def _impl(input, weight, stride, padding, output_padding, groups, dilation):
     stride = tuple(_x.item() for _x in stride)
     padding = tuple(_x.item() for _x in padding)
     output_padding = tuple(_x.item() for _x in output_padding)
     dilation = tuple(_x.item() for _x in dilation)
     groups = groups.item()
     return torch.conv_transpose2d(
         input,
         weight,
         None,
         stride,
         padding,
         output_padding,
         groups,
         dilation,
     )
Example #9
0
def conv2d_input(input_size,
                 weight,
                 grad_output,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=None):
    r"""
    Computes the gradient of conv2d with respect to the input of the convolution.
    This is same as the 2D transposed convolution operator under the hood but requires
    the shape of the gradient w.r.t. input to be specified explicitly.

    Args:
        input_size : Shape of the input gradient tensor
        weight: weight tensor (out_channels x in_channels/groups x kH x kW)
        grad_output : output gradient tensor (minibatch x out_channels x oH x oW)
        stride (int or tuple, optional): Stride of the convolution. Default: 1
        padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
        bias: optional bias tensor (out_channels). Default: None

    Examples::

        >>> input = torch.randn(1,1,3,3, requires_grad=True)
        >>> weight = torch.randn(1,1,1,2, requires_grad=True)
        >>> output = F.conv2d(input, weight)
        >>> grad_output = torch.randn(output.shape)
        >>> grad_input = torch.autograd.grad(output, input, grad_output)
        >>> F.grad.conv2d_input(input.shape, weight, grad_output)

    """
    stride = _pair(stride)
    padding = _pair(padding)
    dilation = _pair(dilation)
    kernel_size = (weight.shape[2], weight.shape[3])

    if input_size is None:
        raise ValueError("grad.conv2d_input requires specifying an input_size")

    grad_input_padding = _grad_input_padding(grad_output, input_size, stride,
                                             padding, kernel_size)

    return torch.conv_transpose2d(grad_output, weight, bias, stride, padding,
                                  grad_input_padding, groups, dilation)
Example #10
0
    def forward(self, x):
        # x = torch.sigmoid(x)
        # x = F.interpolate(x, (128, 128))
        key = self.conv(x)
        key = torch.sigmoid(key)
        key = F.interpolate(key, (64, 64))
        key = key.flatten()
        attention = torch.matmul(self.keys, key)
        attention = torch.softmax(attention, 0)
        attention = torch.reshape(attention, (-1, 1))

        kernel = self.values * attention
        kernel = torch.sum(kernel, 0)
        kernel = torch.reshape(kernel, (3, 3, 5, 5))

        out = torch.conv_transpose2d(x, weight=kernel, stride=2, padding=2)

        return out
Example #11
0
def main_conv(input, weight, bias=None, stride=None, padding=None, output_padding=None):
    start = time.time()
    out = conv_transpose2d(input, weight, stride=stride, padding=padding, output_padding=output_padding)
    g = rand_like(out)
    out.backward(g)
    Timer.show_time((time.time() - start), "Numpy conv2d")

    t_input, t_weight, v = to_torch([input, weight, g])
    start = time.time()
    t_out = torch.conv_transpose2d(t_input, t_weight, bias=None, stride=stride,
                                   padding=padding, output_padding=output_padding)
    t_out.backward(v)

    Timer.show_time((time.time() - start), "torch conv2d")

    check(out, t_out, grad=False, prefix="out", print_max=True)
    check(input, t_input, prefix="input grad", print_max=True)
    check(weight, t_weight, prefix="weight grad", print_max=True)
Example #12
0
def imfilter_transpose2D_SpatialDomain(input,
                                       kernel,
                                       padType="symmetric",
                                       mode="conv"):

    assert(mode in ("conv","corr")), "Valid filtering modes are"\
    +" 'conv' and 'corr'."
    assert(padType in ("periodic","symmetric","zero","valid")), "Valid padType"\
    +" values are 'periodic'|'symmetric'|'zero'|'valid'."

    assert (input.dim() < 5), "The input must be at most a 4D tensor."

    while input.dim() < 4:
        input = input.unsqueeze(0)

    while kernel.dim() < 4:
        kernel = kernel.unsqueeze(0)

    channels = input.size(1)
    assert(kernel.size(1) == 1 or kernel.size(1) == channels),"Invalid "\
    +"filtering kernel dimensions."

    if kernel.shape[1] == 1 and input.shape[1] != kernel.shape[1]:
        kernel = torch.cat([kernel] * input.shape[1], dim=1)

    if mode == "conv":
        kernel = reverse(reverse(kernel, dim=-1), dim=-2)

    if padType == "valid":
        padding = 0
    else:
        padding = getPad2RetainShape(kernel.shape[-2:])

    b, c, h, w = input.shape
    input = input.reshape(input.shape[0] * input.shape[1], input.shape[2],
                          input.shape[3])
    input = input[None]
    kernel = kernel.reshape(kernel.shape[0] * kernel.shape[1], kernel.shape[2],
                            kernel.shape[3])

    kernel = kernel[:, None]
    out = torch.conv_transpose2d(input, kernel, groups=kernel.shape[0])
    out = out[0].reshape(b, c, out.shape[2], out.shape[3])
    return pad_transpose2D(out, padding, padType)
Example #13
0
def im2patch_sinv(input, shape, patchSize, stride=1):
    r""" im2patch_sinv is the pseudo inverse of im2patch.
    
    shape : is the size of the original tensor from which the patches where 
    extracted.
    """
    assert (input.dim() == 4), "A 4D tensor is expected."
    assert (isinstance(patchSize,
                       tuple)), "patchSize is expected to be a tuple."
    assert (isinstance(patchSize,
                       tuple)), "patchSize is expected to be a tuple."

    if len(patchSize) < 2:
        patchSize *= 2
    if len(shape) < 4:
        shape = (1, ) * (4 - len(shape)) + shape
    elif len(shape) > 4:
        shape = shape[0:3]

    Pn = reduce(lambda x, y: x * y, patchSize[0:2])
    batch = shape[0]
    Nc = math.floor(input.shape[1] / Pn)
    if Nc != 1:
        input = input.view(batch * Nc, input.shape[1] / Nc, input.shape[2],
                           input.shape[3])

    h = th.eye(Pn).type(input.type())
    h = h.view(Pn, 1, patchSize[0], patchSize[1])

    out = th.conv_transpose2d(input, h, stride=stride)

    if Nc != 1:
        out = out.view(batch, Nc * out.shape[1], out.shape[2], out.shape[3])

    D = compute_patch_overlap(shape, patchSize, stride)
    D = D.type(input.type())
    out = out.div(D)

    if reduce(lambda x, y: x or y,
              [out.shape[i] < shape[i] for i in range(4)]):
        out = th.nn.functional.pad(
            out, (0, shape[3] - out.shape[3], 0, shape[2] - out.shape[2]))

    return out
Example #14
0
def conv2d_input(input_size, weight, grad_output, stride=1, padding=0, dilation=1, groups=1, bias=None):
    r"""
    Computes the gradient of conv2d with respect to the input of the convolution.
    This is same as the 2D transposed convolution operator under the hood but requires
    the shape of the gradient w.r.t. input to be specified explicitly.

    Args:
        input_size : Shape of the input gradient tensor
        weight: weight tensor (out_channels x in_channels/groups x kH x kW)
        grad_output : output gradient tensor (minibatch x out_channels x oH x oW)
        stride (int or tuple, optional): Stride of the convolution. Default: 1
        padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
        bias: optional bias tensor (out_channels). Default: None

    Examples::

        >>> input = torch.randn(1,1,3,3, requires_grad=True)
        >>> weight = torch.randn(1,1,1,2, requires_grad=True)
        >>> output = F.conv2d(input, weight)
        >>> grad_output = torch.randn(output.shape)
        >>> grad_input = torch.autograd.grad(output, input, grad_output)
        >>> F.grad.conv2d_input(input.shape, weight, grad_output)

    """
    stride = _pair(stride)
    padding = _pair(padding)
    dilation = _pair(dilation)
    kernel_size = (weight.shape[2], weight.shape[3])

    if input_size is None:
        raise ValueError("grad.conv2d_input requires specifying an input_size")

    grad_input_padding = _grad_input_padding(grad_output, input_size, stride,
                                             padding, kernel_size)

    return torch.conv_transpose2d(
        grad_output, weight, bias, stride, padding, grad_input_padding, groups,
        dilation)
Example #15
0
    def backward(ctx, grad_output):
        """Computes the gradient with respect to the input"""
        input_shape, output, kernel_size, padding, stride = ctx.saved_tensors
        assert stride[0] == stride[1], "stride must be same in all axes"
        assert padding[0] == padding[1], "padding must be same in all axes"

        in_channels = input_shape[1]
        # compute as d conv2d / d input with kernel as average filter
        kernel = torch.ones(in_channels, 1, kernel_size[0], kernel_size[1]) / (
            kernel_size[0] * kernel_size[1]
        )

        grad_input_padding = torch.nn.grad._grad_input_padding(
            grad_output, input_shape, stride, padding, kernel_size
        )

        # set groups=in_channels so input gradient is computed per channel
        if isinstance(grad_output, crypten.CrypTensor):
            return grad_output.conv_transpose2d(
                kernel,
                bias=None,
                stride=stride,
                padding=padding,
                output_padding=grad_input_padding,
                groups=in_channels,
                dilation=1,
            )

        return torch.conv_transpose2d(
            grad_output,
            kernel,
            bias=None,
            stride=stride,
            padding=padding,
            output_padding=grad_input_padding,
            groups=in_channels,
            dilation=1,
        )
Example #16
0
    def forward(self, input):
        input = (2 * input) - 1
        ## 1
        key = self.conv1(input).flatten()
        attention = torch.matmul(self.keys, key)
        attention = torch.softmax(attention, 0)
        attention = torch.reshape(attention, (-1, 1))

        kernel = self.values * attention
        kernel = torch.sum(kernel, 0)
        kernel = torch.reshape(kernel, (3, 3, 5, 5))

        out = torch.conv_transpose2d(input, weight=kernel, stride=2, padding=2)
        ## 2
        key = self.conv2(out).flatten()
        attention = torch.matmul(self.keys2, key)
        attention = torch.softmax(attention, 0)
        attention = torch.reshape(attention, (-1, 1))

        kernel = self.values2 * attention
        kernel = torch.sum(kernel, 0)
        kernel = torch.reshape(kernel, (3, 3, 5, 5))

        out = torch.conv_transpose2d(out, weight=kernel, stride=2, padding=2)

        ## 3

        key = self.conv2(out).flatten()
        attention = torch.matmul(self.keys3, key)
        attention = torch.softmax(attention, 0)
        attention = torch.reshape(attention, (-1, 1))

        kernel = self.values3 * attention
        kernel = torch.sum(kernel, 0)
        kernel = torch.reshape(kernel, (3, 3, 5, 5))

        out = torch.conv_transpose2d(out, weight=kernel, stride=2, padding=2)

        out = torch.sigmoid(out)
        out = F.interpolate(out, size=(64, 64))
        # out = torch.sigmoid(out)
        # out1 = F.interpolate(out, size=(64, 64))
        #
        # key = self.conv2(out1).flatten()
        # attention = torch.matmul(self.keys2, key)
        # attention = torch.softmax(attention, 0)
        # attention = torch.reshape(attention, (-1, 1))
        #
        # kernel = self.values2 * attention
        # kernel = torch.sum(kernel, 0)
        # kernel = torch.reshape(kernel, (3, 3, 5, 5))
        #
        # out = torch.conv_transpose2d(input, weight=kernel, stride=2, padding=2)
        # out = torch.sigmoid(out)
        # out = F.interpolate(out, size=(64, 64))
        # out2 = F.interpolate(out, size=(64, 64))
        # out = torch.sigmoid(out1 + out2)
        # st.stop()
        return out
        #
        # # st.write(x.shape)
        # x = (2 * x) - 1
        # x = self.deconv1(x)
        # # x = torch.relu(x)
        # x = self.deconv2(x)
        # x = self.deconv3(x)
        # x = torch.sigmoid(x)
        # # x = self.conv1(x)
        # # x = torch.sigmoid(x)
        # # x = self.deconv2(x)
        # # x = torch.sigmoid(x)
        # x = F.interpolate(x, size=(64, 64))
        # x = torch.relu(x)
        # x = self.deconv2(x)
        # x = F.interpolate(x, size=(60, 60))
        # x = torch.relu(x)
        # x = self.deconv3(x)
        # x = torch.sigmoid(x)
        # x = torch.sigmoid(x)
        # x = F.interpolate(x, size=(64, 64))
        # x = self.deconv2(x)
        # x = torch.sigmoid(x)
        # x = F.interpolate(x, size=(64, 64))
        # x = self.deconv3(x)
        # x = torch.sigmoid(x)
        # x = F.interpolate(x, size=(64, 64))
        # x = torch.clamp(x, 0, 255)
        # x = torch.sigmoid(x)
        # x = self.deconvs(x)
        # x = self.conv1(x)
        # x = torch.unsqueeze(x, 0)
        # st.write(x.shape)
        # x = x * self.param
        # # x = torch.clamp(x, min=0, max=1)
        #
        # x = F.interpolate(x, size=(32, 32)) #mode='bicubic')#.permute(1, 2, 0)
        # x = torch.clamp(x, min=0, max=1)
        # x = torch.squeeze(x, dim=0)
        # x = x.permute(1, 2, 0)
        # x = torch.squeeze(x, 0)
        # x = x.permute(2, 1, 0)
        # x = torch.sigmoid(x)
        # x = F.interpolate(x, size=(64, 64))
        # st.write(x.shape)
        return x
Example #17
0
def compute_patch_overlap(shape,
                          patchSize,
                          stride=1,
                          padding=0,
                          GPU=False,
                          dtype='f'):
    r""" Returns a tensor whose dimensions are equal to 'shape' and it 
    indicates how many patches extracted from the image (the patches are of 
    size patchSize and are extracted using a specified stride) each pixel of 
    the image contributes. 

    For example below is the array which indicates how many times each pixel
    at the particular location of an image of size 16 x 16 has been found in 
    any of the 49 4x4 patches that have been extracted using a stride=2.

    T = 
     1     1     2     2     2     2     2     2     2     2     2     2     2     2     1     1
     1     1     2     2     2     2     2     2     2     2     2     2     2     2     1     1
     2     2     4     4     4     4     4     4     4     4     4     4     4     4     2     2
     2     2     4     4     4     4     4     4     4     4     4     4     4     4     2     2
     2     2     4     4     4     4     4     4     4     4     4     4     4     4     2     2
     2     2     4     4     4     4     4     4     4     4     4     4     4     4     2     2
     2     2     4     4     4     4     4     4     4     4     4     4     4     4     2     2
     2     2     4     4     4     4     4     4     4     4     4     4     4     4     2     2
     2     2     4     4     4     4     4     4     4     4     4     4     4     4     2     2
     2     2     4     4     4     4     4     4     4     4     4     4     4     4     2     2
     2     2     4     4     4     4     4     4     4     4     4     4     4     4     2     2
     2     2     4     4     4     4     4     4     4     4     4     4     4     4     2     2
     2     2     4     4     4     4     4     4     4     4     4     4     4     4     2     2
     2     2     4     4     4     4     4     4     4     4     4     4     4     4     2     2
     1     1     2     2     2     2     2     2     2     2     2     2     2     2     1     1
     1     1     2     2     2     2     2     2     2     2     2     2     2     2     1     1


     Based on this table the pixel at the location (3,2) has been used in 4
     different patches while the pixel at the location (15,4) has been used in
     2 different patches."""

    assert (isinstance(shape, tuple)), "shape is expected to be a tuple."
    assert (isinstance(patchSize,
                       tuple)), "patchSize is expected to be a tuple."
    if len(shape) < 4:
        shape = (1, ) * (4 - len(shape)) + shape
    elif len(shape) > 4:
        shape = shape[0:3]

    if len(patchSize) < 2:
        patchSize = patchSize * 2

    if dtype == 'f':
        dtype = th.FloatTensor
    elif dtype == 'd':
        dtype = th.DoubleTensor
    else:
        raise Exception(
            "Supported data types are 'f' (float) and 'd' (double).")

    shape_ = (shape[0] * shape[1], 1, shape[2], shape[3])

    Pn = reduce(lambda x, y: x * y, patchSize[0:2])
    h = th.eye(Pn).type(dtype)
    h = h.view(Pn, 1, patchSize[0], patchSize[1])

    x = th.ones(shape_).type(dtype)

    if th.cuda.is_available() and GPU:
        x = x.cuda()
        h = h.cuda()

    T = th.conv2d(x, h, stride=stride, padding=padding)
    T = th.conv_transpose2d(T, h, stride=stride, padding=padding)

    return T.view(shape)
Example #18
0
 def forward(self, x):
     x = self.linear(x.flatten())
     x = x.reshape((1, 3, 22, 22))
     x = torch.conv_transpose2d(x, self.weight, stride=1, padding=1)
     return x
Example #19
0
def grad_2D_T(y):
    weight = y.new_zeros(2, 1, 3, 3)
    weight[0, 0] = torch.tensor([[0, 0, 0], [-1, 1, 0], [0, 0, 0]])
    weight[1, 0] = torch.tensor([[0, -1, 0], [0, 1, 0], [0, 0, 0]])
    out = torch.conv_transpose2d(y, weight, padding=1)
    return out[:, 0, :, :]  # Remove channel dimension
Example #20
0
def conv_input_test():
    # grad_y = torch.rand((4, 8, 3, 3), device=torch.device('cuda:0'))
    # w = torch.rand((8, 3, 1, 1), device=torch.device('cuda:0'))
    # x = torch.rand((4, 3, 6, 6), device=torch.device('cuda:0'))
    # stride = 2
    # padding = 0

    # grad_y = torch.rand((1, 1, 3, 3), device=torch.device('cuda:0'))
    # w = torch.rand((1, 1, 3, 3), device=torch.device('cuda:0'))
    # x = torch.rand((1, 1, 5, 5), device=torch.device('cuda:0'))
    # stride = 2
    # padding = 1

    # grad_y = torch.rand((1, 1, 3, 3), device=torch.device('cuda:0'))
    # w = torch.rand((1, 1, 3, 3), device=torch.device('cuda:0'))
    # x = torch.rand((1, 1, 5, 5), device=torch.device('cuda:0'))
    # stride = 1
    # padding = 0

    grad_y = torch.rand((1, 1, 2, 2), device=torch.device('cuda:0'))
    w = torch.rand((1, 1, 1, 1), device=torch.device('cuda:0'))
    x = torch.rand((1, 1, 3, 3), device=torch.device('cuda:0'))
    stride = 2
    padding = 0
    print("x:=================")
    print(x)
    print("w:=================")
    print(w)
    print("grad_y:=================")
    print(grad_y)

    # grad_y = torch.rand((4, 32, 112, 112), device=torch.device('cuda:0'))
    # w = torch.rand((32, 3, 3, 3), device=torch.device('cuda:0'))
    # x = torch.rand((4, 3, 224, 224), device=torch.device('cuda:0'))
    # stride = 2
    # padding = 1

    # grad_y = torch.rand((256, 64, 32, 32), device=torch.device('cuda:0'))
    # w = torch.rand((64, 32, 3, 3), device=torch.device('cuda:0'))
    # x = torch.rand((256, 32, 32, 32), device=torch.device('cuda:0'))
    # stride = 2
    # padding = 1

    start = time.time()
    grad_x = conv_input(grad_y, x, w, [
        stride,
    ], [
        padding,
    ])
    end = time.time()
    # print(grad_x.view(-1))
    print(end - start)

    # start = time.time()
    # grad_x2 = conv_input_no_cuda(grad_y, x, w, stride, padding)
    # end = time.time()
    # # print(grad_x2.view(-1))
    # print(end - start)

    # sub = (grad_x - grad_x2).view(-1)
    # print(torch.sum(sub), torch.var(sub))

    start = time.time()
    out = torch.conv_transpose2d(grad_y, w, None, stride, padding)
    end = time.time()
    print(end - start)
    print(out.shape)
Example #21
0
def conv_transpose2d(input, *args, **kwargs):
    return torch.conv_transpose2d(input.q, *args, **kwargs)