def forward(ctx, in_feat, mode=GlobalPoolingMode.AUTO, in_coords_key=None, glob_coords_key=None, coords_manager=None): assert isinstance(mode, GlobalPoolingMode), \ f"Mode must be an instance of GlobalPoolingMode, {mode}" if glob_coords_key is None: glob_coords_key = CoordsKey(in_coords_key.D) gpool_forward = get_minkowski_function('GlobalPoolingForward', in_feat) broadcast_forward = get_minkowski_function('BroadcastForward', in_feat) add = operation_type_to_int(OperationType.ADDITION) multiply = operation_type_to_int(OperationType.MULTIPLICATION) mean = in_feat.new() num_nonzero = in_feat.new() cpp_in_coords_key = in_coords_key.CPPCoordsKey cpp_glob_coords_key = glob_coords_key.CPPCoordsKey cpp_coords_manager = coords_manager.CPPCoordsManager mean, num_nonzero = gpool_forward(in_feat, cpp_in_coords_key, cpp_glob_coords_key, cpp_coords_manager, True, mode.value) # X - \mu centered_feat = broadcast_forward(in_feat, -mean, add, cpp_in_coords_key, cpp_glob_coords_key, cpp_coords_manager) # Variance = 1/N \sum (X - \mu) ** 2 variance, num_nonzero = gpool_forward(centered_feat**2, cpp_in_coords_key, cpp_glob_coords_key, cpp_coords_manager, True, mode.value) # norm_feat = (X - \mu) / \sigma inv_std = 1 / (variance + 1e-8).sqrt() norm_feat = broadcast_forward(centered_feat, inv_std, multiply, cpp_in_coords_key, cpp_glob_coords_key, cpp_coords_manager) ctx.mode = mode ctx.in_coords_key, ctx.glob_coords_key = in_coords_key, glob_coords_key ctx.coords_manager = coords_manager # For GPU tensors, must use save_for_backward. ctx.save_for_backward(inv_std, norm_feat) return norm_feat
def backward(ctx, grad_out_feat): bw_fn = get_minkowski_function('GlobalPoolingBackward', grad_out_feat) grad_in_feat = bw_fn(ctx.in_feat, grad_out_feat, ctx.num_nonzero, ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey, ctx.coords_manager.CPPCoordsManager, ctx.average) return grad_in_feat, None, None, None, None, None
def forward(ctx, input_features, average=True, mode=GlobalPoolingMode.AUTO, in_coords_key=None, out_coords_key=None, coords_manager=None): if out_coords_key is None: out_coords_key = CoordsKey(in_coords_key.D) assert isinstance(mode, GlobalPoolingMode), \ f"Mode must be an instance of GlobalPoolingMode, {mode}" ctx.in_coords_key = in_coords_key ctx.out_coords_key = out_coords_key ctx.in_feat = input_features ctx.average = average ctx.coords_manager = coords_manager ctx.mode = mode.value fw_fn = get_minkowski_function('GlobalPoolingForward', input_features) out_feat, num_nonzero = fw_fn(ctx.in_feat, ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey, ctx.coords_manager.CPPCoordsManager, ctx.average, ctx.mode) ctx.num_nonzero = num_nonzero return out_feat
def backward(ctx, out_grad): # https://kevinzakka.github.io/2016/09/14/batch_normalization/ in_coords_key, glob_coords_key = ctx.in_coords_key, ctx.glob_coords_key coords_manager = ctx.coords_manager # To prevent the memory leakage, compute the norm again inv_std, norm_feat = ctx.saved_tensors gpool_forward = get_minkowski_function('GlobalPoolingForward', out_grad) broadcast_forward = get_minkowski_function('BroadcastForward', out_grad) add = operation_type_to_int(OperationType.ADDITION) multiply = operation_type_to_int(OperationType.MULTIPLICATION) cpp_in_coords_key = in_coords_key.CPPCoordsKey cpp_glob_coords_key = glob_coords_key.CPPCoordsKey cpp_coords_manager = coords_manager.CPPCoordsManager # 1/N \sum dout mean_dout, num_nonzero = gpool_forward(out_grad, cpp_in_coords_key, cpp_glob_coords_key, cpp_coords_manager, True, ctx.mode.value) # 1/N \sum (dout * out) mean_dout_feat, num_nonzero = gpool_forward(out_grad * norm_feat, cpp_in_coords_key, cpp_glob_coords_key, cpp_coords_manager, True, ctx.mode.value) # out * 1/N \sum (dout * out) feat_mean_dout_feat = broadcast_forward(norm_feat, mean_dout_feat, multiply, cpp_in_coords_key, cpp_glob_coords_key, cpp_coords_manager) unnorm_din = broadcast_forward(out_grad - feat_mean_dout_feat, -mean_dout, add, cpp_in_coords_key, cpp_glob_coords_key, cpp_coords_manager) norm_din = broadcast_forward(unnorm_din, inv_std, multiply, cpp_in_coords_key, cpp_glob_coords_key, cpp_coords_manager) return norm_din, None, None, None, None
def backward(ctx, grad_out_feat): grad_in_feat = grad_out_feat.new() bw_fn = get_minkowski_function('GlobalMaxPoolingBackward', grad_out_feat) bw_fn(ctx.in_feat, grad_in_feat, grad_out_feat, ctx.max_index, ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey, ctx.coords_manager.CPPCoordsManager) return grad_in_feat, None, None, None, None, None
def backward(ctx, grad_out_feat): if not grad_out_feat.is_contiguous(): grad_out_feat = grad_out_feat.contiguous() grad_in_feat = grad_out_feat.new() bw_fn = get_minkowski_function('PruningBackward', grad_out_feat) bw_fn(grad_in_feat, grad_out_feat, ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey, ctx.coords_manager.CPPCoordsManager) return grad_in_feat, None, None, None, None, None
def backward(ctx, grad_out_feat): if not grad_out_feat.is_contiguous(): grad_out_feat = grad_out_feat.contiguous() bw_fn = get_minkowski_function('UnionBackward', grad_out_feat) grad_in_feats = bw_fn(grad_out_feat, [key.CPPCoordsKey for key in ctx.in_coords_keys], ctx.out_coords_key.CPPCoordsKey, ctx.coords_manager.CPPCoordsManager) return (None, None, None, *grad_in_feats)
def forward(ctx, input_features, kernel, tensor_stride=1, stride=1, kernel_size=-1, dilation=1, region_type=0, region_offset=None, generate_new_coords=False, in_coords_key=None, out_coords_key=None, coords_manager=None): """ region_type=0 HyperCube """ # Prep arguments # Kernel shape (n_spatial_kernels, in_nfeat, out_nfeat) assert input_features.shape[1] == kernel.shape[1], \ "The input shape " + str(list(input_features.shape)) + \ " does not match the kernel shape " + str(list(kernel.shape)) if out_coords_key is None: out_coords_key = CoordsKey(in_coords_key.D) assert in_coords_key.D == out_coords_key.D assert input_features.type() == kernel.type(), \ f"Type mismatch input: {input_features.type()} != kernel: {kernel.type()}" if not input_features.is_contiguous(): input_features = input_features.contiguous() tensor_stride, stride, kernel_size, dilation, region_type = prep_args( tensor_stride, stride, kernel_size, dilation, region_type, in_coords_key.D) if region_offset is None: region_offset = torch.IntTensor() ctx.in_feat = input_features ctx.kernel = kernel ctx = save_ctx(ctx, tensor_stride, stride, kernel_size, dilation, region_type, in_coords_key, out_coords_key, coords_manager) D = in_coords_key.D out_feat = input_features.new() fw_fn = get_minkowski_function('ConvolutionTransposeForward', input_features) fw_fn(ctx.in_feat, out_feat, kernel, convert_to_int_list(ctx.tensor_stride, D), convert_to_int_list(ctx.stride, D), convert_to_int_list(ctx.kernel_size, D), convert_to_int_list(ctx.dilation, D), region_type, region_offset, ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey, ctx.coords_man.CPPCoordsManager, generate_new_coords) return out_feat
def backward(ctx, grad_out_feat): if not grad_out_feat.is_contiguous(): grad_out_feat = grad_out_feat.contiguous() grad_in_feat = grad_out_feat.new() grad_in_feat_glob = grad_out_feat.new() bw_fn = get_minkowski_function('BroadcastBackward', grad_out_feat) bw_fn(ctx.in_feat, grad_in_feat, ctx.in_feat_glob, grad_in_feat_glob, grad_out_feat, ctx.op, ctx.in_coords_key.CPPCoordsKey, ctx.glob_coords_key.CPPCoordsKey, ctx.coords_manager.CPPCoordsManager) return grad_in_feat, grad_in_feat_glob, None, None, None, None
def backward(ctx, grad_out_feat): grad_in_feat = grad_out_feat.new() D = ctx.in_coords_key.D bw_fn = get_minkowski_function('PoolingTransposeBackward', grad_out_feat) bw_fn(ctx.in_feat, grad_in_feat, grad_out_feat, ctx.num_nonzero, convert_to_int_list(ctx.tensor_stride, D), convert_to_int_list(ctx.stride, D), convert_to_int_list(ctx.kernel_size, D), convert_to_int_list(ctx.dilation, D), ctx.region_type, ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey, ctx.coords_man.CPPCoordsManager) return grad_in_feat, None, None, None, None, None, None, None, None, None, None
def backward(ctx, grad_out_feat): if not grad_out_feat.is_contiguous(): grad_out_feat = grad_out_feat.contiguous() grad_in_feat = grad_out_feat.new() D = ctx.in_coords_key.D bw_fn = get_minkowski_function('MaxPoolingBackward', grad_out_feat) bw_fn(ctx.in_feat, grad_in_feat, grad_out_feat, ctx.max_index, convert_to_int_list(ctx.tensor_stride, D), convert_to_int_list(ctx.stride, D), convert_to_int_list(ctx.kernel_size, D), convert_to_int_list(ctx.dilation, D), ctx.region_type, ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey, ctx.coords_man.CPPCoordsManager) return grad_in_feat, None, None, None, None, None, None, None, None, None
def forward(ctx, in_coords_keys, out_coords_key, coords_manager, *in_feats): assert isinstance(in_feats, list) or isinstance(in_feats, tuple), \ "Input must be a list or a set of Tensors" assert len(in_feats) > 1, \ "input must be a set with at least 2 Tensors" in_feats = [in_feat.contiguous() for in_feat in in_feats] ctx.in_coords_keys = in_coords_keys ctx.out_coords_key = out_coords_key ctx.coords_manager = coords_manager fw_fn = get_minkowski_function('UnionForward', in_feats[0]) return fw_fn(in_feats, [key.CPPCoordsKey for key in ctx.in_coords_keys], ctx.out_coords_key.CPPCoordsKey, ctx.coords_manager.CPPCoordsManager)
def forward(ctx, input_features, tensor_stride=1, stride=1, kernel_size=-1, dilation=1, region_type=0, region_offset=None, in_coords_key=None, out_coords_key=None, coords_manager=None): assert isinstance(region_type, RegionType) if out_coords_key is None: out_coords_key = CoordsKey(in_coords_key.D) assert in_coords_key.D == out_coords_key.D if not input_features.is_contiguous(): input_features = input_features.contiguous() tensor_stride, stride, kernel_size, dilation, region_type = prep_args( tensor_stride, stride, kernel_size, dilation, region_type, in_coords_key.D) if region_offset is None: region_offset = torch.IntTensor() ctx.in_feat = input_features ctx = save_ctx(ctx, tensor_stride, stride, kernel_size, dilation, region_type, in_coords_key, out_coords_key, coords_manager) D = in_coords_key.D out_feat = input_features.new() max_index = input_features.new().int() ctx.max_index = max_index fw_fn = get_minkowski_function('MaxPoolingForward', input_features) fw_fn(input_features, out_feat, max_index, convert_to_int_list(ctx.tensor_stride, D), convert_to_int_list(ctx.stride, D), convert_to_int_list(ctx.kernel_size, D), convert_to_int_list(ctx.dilation, D), region_type, region_offset, ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey, ctx.coords_man.CPPCoordsManager) return out_feat
def forward(ctx, in_feat, mask, in_coords_key, out_coords_key, coords_manager): assert in_feat.size(0) == mask.size(0) assert isinstance(mask, torch.BoolTensor), "Mask must be a cpu bool tensor." if not in_feat.is_contiguous(): in_feat = in_feat.contiguous() if not mask.is_contiguous(): mask = mask.contiguous() ctx.in_coords_key = in_coords_key ctx.out_coords_key = out_coords_key ctx.coords_manager = coords_manager out_feat = in_feat.new() fw_fn = get_minkowski_function('PruningForward', in_feat) fw_fn(in_feat, out_feat, mask, ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey, ctx.coords_manager.CPPCoordsManager) return out_feat
def forward(ctx, input_features, tensor_stride=1, stride=1, kernel_size=-1, dilation=1, region_type=-1, region_offset=None, average=False, in_coords_key=None, out_coords_key=None, coords_manager=None): assert isinstance(region_type, RegionType) if out_coords_key is None: out_coords_key = CoordsKey(in_coords_key.D) assert in_coords_key.D == out_coords_key.D tensor_stride, stride, kernel_size, dilation, region_type = prep_args( tensor_stride, stride, kernel_size, dilation, region_type, in_coords_key.D) if region_offset is None: region_offset = torch.IntTensor() ctx.in_feat = input_features out_feat = input_features.new() ctx.num_nonzero = input_features.new() ctx = save_ctx(ctx, tensor_stride, stride, kernel_size, dilation, region_type, in_coords_key, out_coords_key, coords_manager) D = in_coords_key.D fw_fn = get_minkowski_function('PoolingTransposeForward', input_features) fw_fn(ctx.in_feat, out_feat, ctx.num_nonzero, convert_to_int_list(ctx.tensor_stride, D), convert_to_int_list(ctx.stride, D), convert_to_int_list(ctx.kernel_size, D), convert_to_int_list(ctx.dilation, D), region_type, region_offset, ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey, ctx.coords_man.CPPCoordsManager) return out_feat
def forward(ctx, input_features, input_features_global, operation_type, in_coords_key, glob_coords_key, coords_manager): assert input_features.shape[1] == input_features_global.shape[1] assert input_features.type() == input_features_global.type() assert isinstance(operation_type, OperationType) if not input_features.is_contiguous(): input_features = input_features.contiguous() if not input_features_global.is_contiguous(): input_features_global = input_features_global.contiguous() ctx.op = operation_type_to_int(operation_type) ctx.in_feat = input_features ctx.in_feat_glob = input_features_global ctx.in_coords_key = in_coords_key ctx.glob_coords_key = glob_coords_key ctx.coords_manager = coords_manager fw_fn = get_minkowski_function('BroadcastForward', input_features) out_feat = fw_fn(ctx.in_feat, ctx.in_feat_glob, ctx.op, ctx.in_coords_key.CPPCoordsKey, ctx.glob_coords_key.CPPCoordsKey, ctx.coords_manager.CPPCoordsManager) return out_feat
def forward(ctx, input_features, in_coords_key=None, out_coords_key=None, coords_manager=None): if out_coords_key is None: out_coords_key = CoordsKey(in_coords_key.D) ctx.in_coords_key = in_coords_key ctx.out_coords_key = out_coords_key ctx.in_feat = input_features out_feat = input_features.new() max_index = input_features.new().int() ctx.max_index = max_index ctx.coords_manager = coords_manager fw_fn = get_minkowski_function('GlobalMaxPoolingForward', input_features) fw_fn(ctx.in_feat, out_feat, ctx.max_index, ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey, ctx.coords_manager.CPPCoordsManager) return out_feat