def backward(ctx: Dict[str, Any], grad: MPCTensor) -> Tuple[MPCTensor]: """Perform the backward pass for the conv2d operation. Args: ctx (Dict[str, Any]): Context used to retrieve the information for the backward pass grad (MPCTensor): The gradient that came from the child nodes Returns: (input_grad, weight_grad) (Tuple[MPCTensor, MPCTensor]): The gradients passed to the input and kernal nodes. """ x = ctx["x"] weight = ctx["weight"] stride = ctx["stride"] padding = ctx["padding"] dilation = ctx["dilation"] groups = ctx["groups"] weight_size = (weight.shape[2], weight.shape[3]) in_channels = x.shape[1] out_channels = grad.shape[1] min_batch = x.shape[0] output_padding = torch.nn.grad._grad_input_padding( grad_output=torch.empty(grad.shape), input_size=x.shape, stride=(stride, stride), padding=(padding, padding), kernel_size=weight_size, dilation=(dilation, dilation), ) input_grad = grad.conv_transpose2d(weight, None, stride, output_padding, dilation, groups) # Gradient w.r.t weights of the Conv. grad = grad.repeat(1, in_channels // groups, 1, 1) grad = grad.view(grad.shape[0] * grad.shape[1], 1, grad.shape[2], grad.shape[3]) x = x.view(1, x.shape[0] * x.shape[1], x.shape[2], x.shape[3]) weight_grad = x.conv2d( weight=grad, bias=None, dilation=stride, padding=padding, stride=dilation, groups=in_channels * min_batch, ) weight_grad = weight_grad.view( min_batch, weight_grad.shape[1] // min_batch, weight_grad.shape[2], weight_grad.shape[3], ) weight_grad = (weight_grad.sum(0).view( in_channels // groups, out_channels, weight_grad.shape[2], weight_grad.shape[3], ).transpose(0, 1)) weight_grad = weight_grad.narrow(2, 0, weight_size[1]) weight_grad = weight_grad.narrow(3, 0, weight_size[0]) return input_grad, weight_grad
def backward(ctx: Dict[str, Any], grad: MPCTensor) -> Tuple[MPCTensor]: """Perform the backward pass for the conv2d operation. Args: ctx (Dict[str, Any]): Context used to retrieve the information for the backward pass grad (MPCTensor): The gradient that came from the child nodes Returns: (input_grad, weight_grad) (Tuple[MPCTensor]): The gradients passed to the input and kernal nodes. """ input = ctx["input"] weight = ctx["weight"] stride = ctx["stride"] padding = ctx["padding"] dilation = ctx["dilation"] groups = ctx["groups"] weight_size = (weight.shape[2], weight.shape[3]) in_channels = input.shape[1] out_channels = grad.shape[1] min_batch = input.shape[0] # Gradient w.r.t input of the Conv. common_args = [ tuple(input.shape), stride, padding, weight_size, dilation, grad.session, ] args = [[el] + common_args for el in grad.share_ptrs] shares = parallel_execution( GradConv2d.get_grad_input_padding, grad.session.parties )(args) grad_input_padding = MPCTensor(shares=shares, session=grad.session) output_padding_tensor = grad_input_padding.reconstruct() output_padding_tensor /= grad.session.nr_parties output_padding = tuple(output_padding_tensor.to(torch.int).tolist()) input_grad = grad.conv_transpose2d( weight, None, stride, output_padding, dilation, groups ) # Gradient w.r.t weights of the Conv. grad = grad.repeat(1, in_channels // groups, 1, 1) grad = grad.view(grad.shape[0] * grad.shape[1], 1, grad.shape[2], grad.shape[3]) input = input.view( 1, input.shape[0] * input.shape[1], input.shape[2], input.shape[3] ) weight_grad = input.conv2d( weight=grad, bias=None, dilation=stride, padding=padding, stride=dilation, groups=in_channels * min_batch, ) weight_grad = weight_grad.view( min_batch, weight_grad.shape[1] // min_batch, weight_grad.shape[2], weight_grad.shape[3], ) weight_grad = ( weight_grad.sum(0) .view( in_channels // groups, out_channels, weight_grad.shape[2], weight_grad.shape[3], ) .transpose(0, 1) ) weight_grad = weight_grad.narrow(2, 0, weight_size[1]) weight_grad = weight_grad.narrow(3, 0, weight_size[0]) return input_grad, weight_grad