Example #1
0
    def backward(ctx: Dict[str, Any], grad: MPCTensor) -> Tuple[MPCTensor]:
        """Perform the backward pass for the conv2d operation.

        Args:
            ctx (Dict[str, Any]): Context used to retrieve the information for the backward pass
            grad (MPCTensor): The gradient that came from the child nodes

        Returns:
            (input_grad, weight_grad) (Tuple[MPCTensor, MPCTensor]): The gradients passed
            to the input and kernal nodes.
        """
        x = ctx["x"]
        weight = ctx["weight"]
        stride = ctx["stride"]
        padding = ctx["padding"]
        dilation = ctx["dilation"]
        groups = ctx["groups"]
        weight_size = (weight.shape[2], weight.shape[3])
        in_channels = x.shape[1]
        out_channels = grad.shape[1]
        min_batch = x.shape[0]

        output_padding = torch.nn.grad._grad_input_padding(
            grad_output=torch.empty(grad.shape),
            input_size=x.shape,
            stride=(stride, stride),
            padding=(padding, padding),
            kernel_size=weight_size,
            dilation=(dilation, dilation),
        )

        input_grad = grad.conv_transpose2d(weight, None, stride,
                                           output_padding, dilation, groups)

        # Gradient w.r.t weights of the Conv.
        grad = grad.repeat(1, in_channels // groups, 1, 1)

        grad = grad.view(grad.shape[0] * grad.shape[1], 1, grad.shape[2],
                         grad.shape[3])

        x = x.view(1, x.shape[0] * x.shape[1], x.shape[2], x.shape[3])

        weight_grad = x.conv2d(
            weight=grad,
            bias=None,
            dilation=stride,
            padding=padding,
            stride=dilation,
            groups=in_channels * min_batch,
        )

        weight_grad = weight_grad.view(
            min_batch,
            weight_grad.shape[1] // min_batch,
            weight_grad.shape[2],
            weight_grad.shape[3],
        )

        weight_grad = (weight_grad.sum(0).view(
            in_channels // groups,
            out_channels,
            weight_grad.shape[2],
            weight_grad.shape[3],
        ).transpose(0, 1))

        weight_grad = weight_grad.narrow(2, 0, weight_size[1])
        weight_grad = weight_grad.narrow(3, 0, weight_size[0])

        return input_grad, weight_grad
Example #2
0
    def backward(ctx: Dict[str, Any], grad: MPCTensor) -> Tuple[MPCTensor]:
        """Perform the backward pass for the conv2d operation.

        Args:
            ctx (Dict[str, Any]): Context used to retrieve the information for the backward pass
            grad (MPCTensor): The gradient that came from the child nodes

        Returns:
            (input_grad, weight_grad) (Tuple[MPCTensor]): The gradients passed
            to the input and kernal nodes.
        """
        input = ctx["input"]
        weight = ctx["weight"]
        stride = ctx["stride"]
        padding = ctx["padding"]
        dilation = ctx["dilation"]
        groups = ctx["groups"]
        weight_size = (weight.shape[2], weight.shape[3])
        in_channels = input.shape[1]
        out_channels = grad.shape[1]
        min_batch = input.shape[0]

        # Gradient w.r.t input of the Conv.
        common_args = [
            tuple(input.shape),
            stride,
            padding,
            weight_size,
            dilation,
            grad.session,
        ]
        args = [[el] + common_args for el in grad.share_ptrs]

        shares = parallel_execution(
            GradConv2d.get_grad_input_padding, grad.session.parties
        )(args)
        grad_input_padding = MPCTensor(shares=shares, session=grad.session)

        output_padding_tensor = grad_input_padding.reconstruct()
        output_padding_tensor /= grad.session.nr_parties
        output_padding = tuple(output_padding_tensor.to(torch.int).tolist())

        input_grad = grad.conv_transpose2d(
            weight, None, stride, output_padding, dilation, groups
        )

        # Gradient w.r.t weights of the Conv.
        grad = grad.repeat(1, in_channels // groups, 1, 1)

        grad = grad.view(grad.shape[0] * grad.shape[1], 1, grad.shape[2], grad.shape[3])

        input = input.view(
            1, input.shape[0] * input.shape[1], input.shape[2], input.shape[3]
        )

        weight_grad = input.conv2d(
            weight=grad,
            bias=None,
            dilation=stride,
            padding=padding,
            stride=dilation,
            groups=in_channels * min_batch,
        )

        weight_grad = weight_grad.view(
            min_batch,
            weight_grad.shape[1] // min_batch,
            weight_grad.shape[2],
            weight_grad.shape[3],
        )

        weight_grad = (
            weight_grad.sum(0)
            .view(
                in_channels // groups,
                out_channels,
                weight_grad.shape[2],
                weight_grad.shape[3],
            )
            .transpose(0, 1)
        )

        weight_grad = weight_grad.narrow(2, 0, weight_size[1])
        weight_grad = weight_grad.narrow(3, 0, weight_size[0])

        return input_grad, weight_grad