def backward(ctx, out_grad): # https://kevinzakka.github.io/2016/09/14/batch_normalization/ in_coords_key, glob_coords_key, coords_manager = ctx.saved_vars # To prevent the memory leakage, compute the norm again inv_std, norm_feat = ctx.saved_tensors gpooling_mode = PoolingMode.GLOBAL_AVG_POOLING_KERNEL gpool_avg_forward = get_minkowski_function("GlobalPoolingForward", out_grad) broadcast_forward = get_minkowski_function("BroadcastForward", out_grad) # 1/N \sum dout mean_dout, num_nonzero = gpool_avg_forward( out_grad, gpooling_mode, in_coords_key, glob_coords_key, coords_manager._manager, ) # 1/N \sum (dout * out) mean_dout_feat, num_nonzero = gpool_avg_forward( out_grad * norm_feat, gpooling_mode, in_coords_key, glob_coords_key, coords_manager._manager, ) # out * 1/N \sum (dout * out) feat_mean_dout_feat = broadcast_forward( norm_feat, mean_dout_feat, BroadcastMode.ELEMENTWISE_MULTIPLICATION, in_coords_key, glob_coords_key, coords_manager._manager, ) unnorm_din = broadcast_forward( out_grad - feat_mean_dout_feat, -mean_dout, BroadcastMode.ELEMENTWISE_ADDITON, in_coords_key, glob_coords_key, coords_manager._manager, ) norm_din = broadcast_forward( unnorm_din, inv_std, BroadcastMode.ELEMENTWISE_MULTIPLICATION, in_coords_key, glob_coords_key, coords_manager._manager, ) return norm_din, None, None, None, None
def forward( ctx, in_feat: torch.Tensor, in_coords_key: CoordinateMapKey, glob_coords_key: CoordinateMapKey = None, coords_manager: CoordinateManager = None, gpooling_mode=PoolingMode.GLOBAL_AVG_POOLING_KERNEL, ): if glob_coords_key is None: glob_coords_key = CoordinateMapKey( in_coords_key.get_coordinate_size()) gpool_avg_forward = get_minkowski_function("GlobalPoolingForward", in_feat) broadcast_forward = get_minkowski_function("BroadcastForward", in_feat) mean, num_nonzero = gpool_avg_forward( in_feat, gpooling_mode, in_coords_key, glob_coords_key, coords_manager._manager, ) # X - \mu centered_feat = broadcast_forward( in_feat, -mean, BroadcastMode.ELEMENTWISE_ADDITON, in_coords_key, glob_coords_key, coords_manager._manager, ) # Variance = 1/N \sum (X - \mu) ** 2 variance, num_nonzero = gpool_avg_forward( centered_feat**2, gpooling_mode, in_coords_key, glob_coords_key, coords_manager._manager, ) # norm_feat = (X - \mu) / \sigma inv_std = 1 / (variance + 1e-8).sqrt() norm_feat = broadcast_forward( centered_feat, inv_std, BroadcastMode.ELEMENTWISE_MULTIPLICATION, in_coords_key, glob_coords_key, coords_manager._manager, ) ctx.saved_vars = (in_coords_key, glob_coords_key, coords_manager, gpooling_mode) # For GPU tensors, must use save_for_backward. ctx.save_for_backward(inv_std, norm_feat) return norm_feat
def backward(ctx, grad_out_feat): grad_out_feat = grad_out_feat.contiguous() bw_fn = get_minkowski_function("LocalPoolingBackward", grad_out_feat) grad_in_feat = bw_fn( ctx.input_features, grad_out_feat, ctx.num_nonzero, ctx.kernel_generator.kernel_size, ctx.kernel_generator.kernel_stride, ctx.kernel_generator.kernel_dilation, ctx.kernel_generator.region_type, ctx.kernel_generator.region_offsets, ctx.pooling_mode, ctx.in_coordinate_map_key, ctx.out_coordinate_map_key, ctx.coordinate_manager._manager, ) return ( grad_in_feat, None, None, None, None, None, )
def forward( ctx, input_features: torch.Tensor, input_features_global: torch.Tensor, operation_type: BroadcastMode, in_coords_key: CoordinateMapKey, glob_coords_key: CoordinateMapKey, coords_manager: CoordinateManager, ): assert isinstance(operation_type, BroadcastMode) ctx.saved_vars = ( input_features, input_features_global, operation_type, in_coords_key, glob_coords_key, coords_manager, ) fw_fn = get_minkowski_function("BroadcastForward", input_features) return fw_fn( input_features, input_features_global, operation_type, in_coords_key, glob_coords_key, coords_manager._manager, )
def backward(ctx, grad_out_feat: torch.Tensor): grad_out_feat = grad_out_feat.contiguous() ( kernel_generator, convolution_mode, in_coordinate_map_key, out_coordinate_map_key, coordinate_manager, ) = ctx.misc bw_fn = get_minkowski_function("ConvolutionBackward", grad_out_feat) grad_in_feat, grad_kernel = bw_fn( ctx.input_features, grad_out_feat, ctx.kernel_weights, kernel_generator.kernel_size, kernel_generator.kernel_stride, kernel_generator.kernel_dilation, kernel_generator.region_type, kernel_generator.region_offsets, convolution_mode, in_coordinate_map_key, out_coordinate_map_key, coordinate_manager._manager, ) return ( grad_in_feat, grad_kernel, None, None, None, None, None, )
def forward( ctx, input_features: torch.Tensor, pooling_mode: PoolingMode, in_coordinate_map_key: CoordinateMapKey, out_coordinate_map_key: CoordinateMapKey = None, coordinate_manager: CoordinateManager = None, ): if out_coordinate_map_key is None: out_coordinate_map_key = CoordinateMapKey( in_coordinate_map_key.get_coordinate_size()) input_features = input_features.contiguous() ctx.input_features = input_features ctx.in_coords_key = in_coordinate_map_key ctx.out_coords_key = out_coordinate_map_key ctx.coordinate_manager = coordinate_manager ctx.pooling_mode = pooling_mode fw_fn = get_minkowski_function("GlobalPoolingForward", input_features) out_feat, num_nonzero = fw_fn( input_features, pooling_mode, ctx.in_coords_key, ctx.out_coords_key, ctx.coordinate_manager._manager, ) ctx.num_nonzero = num_nonzero return out_feat
def backward(ctx, grad_out_feat: torch.Tensor): bw_fn = get_minkowski_function("PruningBackward", grad_out_feat) grad_in_feat = bw_fn( grad_out_feat, ctx.in_coords_key, ctx.out_coords_key, ctx.coords_manager._manager, ) return grad_in_feat, None, None, None, None
def forward( ctx, input_features, tensor_stride=1, stride=1, kernel_size=-1, dilation=1, region_type=-1, region_offset=None, average=False, in_coords_key=None, out_coords_key=None, coords_manager=None, ): assert isinstance(region_type, RegionType) if out_coords_key is None: out_coords_key = CoordsKey(in_coords_key.D) assert in_coords_key.D == out_coords_key.D tensor_stride, stride, kernel_size, dilation, region_type = prep_args( tensor_stride, stride, kernel_size, dilation, region_type, in_coords_key.D) if region_offset is None: region_offset = torch.IntTensor() ctx.in_feat = input_features out_feat = input_features.new() ctx.num_nonzero = input_features.new() ctx = save_ctx( ctx, tensor_stride, stride, kernel_size, dilation, region_type, in_coords_key, out_coords_key, coords_manager, ) D = in_coords_key.D fw_fn = get_minkowski_function("PoolingTransposeForward", input_features) fw_fn( ctx.in_feat, out_feat, ctx.num_nonzero, convert_to_int_list(ctx.tensor_stride, D), convert_to_int_list(ctx.stride, D), convert_to_int_list(ctx.kernel_size, D), convert_to_int_list(ctx.dilation, D), region_type, region_offset, ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey, ctx.coords_man.CPPCoordsManager, ) return out_feat
def backward(ctx, grad_out_feat): bw_fn = get_minkowski_function("GlobalPoolingBackward", grad_out_feat) grad_in_feat = bw_fn( ctx.input_features, grad_out_feat, ctx.num_nonzero, ctx.pooling_mode, ctx.in_coords_key, ctx.out_coords_key, ctx.coordinate_manager._manager, ) return grad_in_feat, None, None, None, None, None
def backward(ctx, grad_out_feat): grad_in_feat = grad_out_feat.new() D = ctx.in_coords_key.D bw_fn = get_minkowski_function("PoolingTransposeBackward", grad_out_feat) bw_fn( ctx.in_feat, grad_in_feat, grad_out_feat, ctx.num_nonzero, convert_to_int_list(ctx.tensor_stride, D), convert_to_int_list(ctx.stride, D), convert_to_int_list(ctx.kernel_size, D), convert_to_int_list(ctx.dilation, D), ctx.region_type, ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey, ctx.coords_man.CPPCoordsManager, ) return grad_in_feat, None, None, None, None, None, None, None, None, None, None
def backward( ctx, grad_out_feat=None, grad_in_map=None, grad_out_map=None, grad_weights=None ): grad_out_feat = grad_out_feat.contiguous() bw_fn = get_minkowski_function("InterpolationBackward", grad_out_feat) ( in_coordinate_map_key, coordinate_manager, ) = ctx.inputs in_map, out_map, weights = ctx.saved_tensors grad_in_feat = bw_fn( grad_out_feat, in_map, out_map, weights, in_coordinate_map_key, coordinate_manager._manager, ) return grad_in_feat, None, None, None
def forward( ctx, in_feat: torch.Tensor, mask: torch.Tensor, in_coords_key: CoordinateMapKey, out_coords_key: CoordinateMapKey = None, coords_manager: CoordinateManager = None, ): ctx.in_coords_key = in_coords_key ctx.out_coords_key = out_coords_key ctx.coords_manager = coords_manager fw_fn = get_minkowski_function("PruningForward", in_feat) return fw_fn( in_feat, mask, ctx.in_coords_key, ctx.out_coords_key, ctx.coords_manager._manager, )
def forward( ctx, input_features: torch.Tensor, kernel_weights: torch.Tensor, kernel_generator: KernelGenerator, convolution_mode: ConvolutionMode, in_coordinate_map_key: CoordinateMapKey, out_coordinate_map_key: CoordinateMapKey = None, coordinate_manager: CoordinateManager = None, ): if out_coordinate_map_key is None: out_coordinate_map_key = CoordinateMapKey( in_coordinate_map_key.get_coordinate_size()) input_features = input_features.contiguous() ctx.input_features = input_features ctx.kernel_weights = kernel_weights ctx.misc = [ kernel_generator, convolution_mode, in_coordinate_map_key, out_coordinate_map_key, coordinate_manager, ] fw_fn = get_minkowski_function("ConvolutionForward", input_features) return fw_fn( ctx.input_features, kernel_weights, kernel_generator.kernel_size, kernel_generator.kernel_stride, kernel_generator.kernel_dilation, kernel_generator.region_type, kernel_generator.region_offsets, kernel_generator.expand_coordinates, convolution_mode, in_coordinate_map_key, out_coordinate_map_key, coordinate_manager._manager, )
def forward( ctx, input_features: torch.Tensor, pooling_mode: PoolingMode, kernel_generator: KernelGenerator, in_coordinate_map_key: CoordinateMapKey, out_coordinate_map_key: CoordinateMapKey = None, coordinate_manager: CoordinateManager = None, ): if out_coordinate_map_key is None: out_coordinate_map_key = CoordinateMapKey( in_coordinate_map_key.get_coordinate_size()) input_features = input_features.contiguous() ctx.input_features = input_features ctx = save_ctx( ctx, kernel_generator, in_coordinate_map_key, out_coordinate_map_key, coordinate_manager, ) ctx.pooling_mode = pooling_mode fw_fn = get_minkowski_function("LocalPoolingTransposeForward", input_features) out_feat, num_nonzero = fw_fn( ctx.input_features, kernel_generator.kernel_size, kernel_generator.kernel_stride, kernel_generator.kernel_dilation, kernel_generator.region_type, kernel_generator.region_offsets, kernel_generator.expand_coordinates, pooling_mode, ctx.in_coordinate_map_key, ctx.out_coordinate_map_key, ctx.coordinate_manager._manager, ) ctx.num_nonzero = num_nonzero return out_feat
def forward( ctx, input_features: torch.Tensor, tfield: torch.Tensor, in_coordinate_map_key: CoordinateMapKey, coordinate_manager: CoordinateManager = None, ): input_features = input_features.contiguous() # in_map, out_map, weights = coordinate_manager.interpolation_map_weight( # in_coordinate_map_key, tfield) fw_fn = get_minkowski_function("InterpolationForward", input_features) out_feat, in_map, out_map, weights = fw_fn( input_features, tfield, in_coordinate_map_key, coordinate_manager._manager, ) ctx.save_for_backward(in_map, out_map, weights) ctx.inputs = ( in_coordinate_map_key, coordinate_manager, ) return out_feat, in_map, out_map, weights
def backward(ctx, grad_out_feat: torch.Tensor): bw_fn = get_minkowski_function("ConvolutionTransposeBackward", grad_out_feat) grad_in_feat, grad_kernel = bw_fn( ctx.input_features, grad_out_feat, ctx.kernel_weights, ctx.kernel_generator.kernel_size, ctx.kernel_generator.kernel_stride, ctx.kernel_generator.kernel_dilation, ctx.kernel_generator.region_type, ctx.kernel_generator.region_offsets, ctx.in_coordinate_map_key, ctx.out_coordinate_map_key, ctx.coordinate_manager._manager, ) return ( grad_in_feat, grad_kernel, None, None, None, None, )
def backward(ctx, grad_out_feat): if not grad_out_feat.is_contiguous(): grad_out_feat = grad_out_feat.contiguous() ( input_features, input_features_global, operation_type, in_coords_key, glob_coords_key, coords_manager, ) = ctx.saved_vars bw_fn = get_minkowski_function("BroadcastBackward", grad_out_feat) grad_in_feat, grad_in_feat_glob = bw_fn( input_features, input_features_global, grad_out_feat, operation_type, in_coords_key, glob_coords_key, coords_manager._manager, ) return grad_in_feat, grad_in_feat_glob, None, None, None, None