def forward(ctx, in_feat, mode=GlobalPoolingMode.AUTO, in_coords_key=None, glob_coords_key=None, coords_manager=None): assert isinstance(mode, GlobalPoolingMode), \ f"Mode must be an instance of GlobalPoolingMode, {mode}" if glob_coords_key is None: glob_coords_key = CoordsKey(in_coords_key.D) gpool_forward = getattr(MEB, 'GlobalPoolingForward' + get_postfix(in_feat)) broadcast_forward = getattr(MEB, 'BroadcastForward' + get_postfix(in_feat)) add = operation_type_to_int(OperationType.ADDITION) multiply = operation_type_to_int(OperationType.MULTIPLICATION) mean = in_feat.new() num_nonzero = in_feat.new() cpp_in_coords_key = in_coords_key.CPPCoordsKey cpp_glob_coords_key = glob_coords_key.CPPCoordsKey cpp_coords_manager = coords_manager.CPPCoordsManager mean, num_nonzero = gpool_forward(in_feat, cpp_in_coords_key, cpp_glob_coords_key, cpp_coords_manager, True, mode.value) # X - \mu centered_feat = broadcast_forward(in_feat, -mean, add, cpp_in_coords_key, cpp_glob_coords_key, cpp_coords_manager) # Variance = 1/N \sum (X - \mu) ** 2 variance, num_nonzero = gpool_forward(centered_feat**2, cpp_in_coords_key, cpp_glob_coords_key, cpp_coords_manager, True, mode.value) # norm_feat = (X - \mu) / \sigma inv_std = 1 / (variance + 1e-8).sqrt() norm_feat = broadcast_forward(centered_feat, inv_std, multiply, cpp_in_coords_key, cpp_glob_coords_key, cpp_coords_manager) ctx.mode = mode ctx.in_coords_key, ctx.glob_coords_key = in_coords_key, glob_coords_key ctx.coords_manager = coords_manager # For GPU tensors, must use save_for_backward. ctx.save_for_backward(inv_std, norm_feat) return norm_feat
def forward(ctx, in_feat, batch_size=0, in_coords_key=None, glob_coords_key=None, coords_manager=None): if glob_coords_key is None: glob_coords_key = CoordsKey(in_coords_key.D) gpool_forward = getattr(MEB, 'GlobalPoolingForward' + get_postfix(in_feat)) broadcast_forward = getattr(MEB, 'BroadcastForward' + get_postfix(in_feat)) add = operation_type_to_int(OperationType.ADDITION) multiply = operation_type_to_int(OperationType.MULTIPLICATION) mean = in_feat.new() num_nonzero = in_feat.new() D = in_coords_key.D cpp_in_coords_key = in_coords_key.CPPCoordsKey cpp_glob_coords_key = glob_coords_key.CPPCoordsKey cpp_coords_manager = coords_manager.CPPCoordsManager gpool_forward(D, in_feat, mean, num_nonzero, cpp_in_coords_key, cpp_glob_coords_key, cpp_coords_manager, batch_size, True) # X - \mu centered_feat = in_feat.new() broadcast_forward(D, in_feat, -mean, centered_feat, add, cpp_in_coords_key, cpp_glob_coords_key, cpp_coords_manager) # Variance = 1/N \sum (X - \mu) ** 2 variance = in_feat.new() gpool_forward(D, centered_feat**2, variance, num_nonzero, cpp_in_coords_key, cpp_glob_coords_key, cpp_coords_manager, batch_size, True) # norm_feat = (X - \mu) / \sigma inv_std = 1 / (variance + 1e-8).sqrt() norm_feat = in_feat.new() broadcast_forward(D, centered_feat, inv_std, norm_feat, multiply, cpp_in_coords_key, cpp_glob_coords_key, cpp_coords_manager) ctx.batch_size = batch_size ctx.in_coords_key, ctx.glob_coords_key = in_coords_key, glob_coords_key ctx.coords_manager = coords_manager # For GPU tensors, must use save_for_backward. ctx.save_for_backward(inv_std, norm_feat) return norm_feat
def backward(ctx, out_grad): # https://kevinzakka.github.io/2016/09/14/batch_normalization/ batch_size = ctx.batch_size in_coords_key, glob_coords_key = ctx.in_coords_key, ctx.glob_coords_key coords_manager = ctx.coords_manager # To prevent the memory leakage, compute the norm again inv_std, norm_feat = ctx.saved_variables D = in_coords_key.D gpool_forward = getattr(MEB, 'GlobalPoolingForward' + get_postfix(out_grad)) broadcast_forward = getattr(MEB, 'BroadcastForward' + get_postfix(out_grad)) add = operation_type_to_int(OperationType.ADDITION) multiply = operation_type_to_int(OperationType.MULTIPLICATION) cpp_in_coords_key = in_coords_key.CPPCoordsKey cpp_glob_coords_key = glob_coords_key.CPPCoordsKey cpp_coords_manager = coords_manager.CPPCoordsManager # 1/N \sum dout num_nonzero = out_grad.new() mean_dout = out_grad.new() gpool_forward(D, out_grad, mean_dout, num_nonzero, cpp_in_coords_key, cpp_glob_coords_key, cpp_coords_manager, batch_size, True) # 1/N \sum (dout * out) mean_dout_feat = out_grad.new() gpool_forward(D, out_grad * norm_feat, mean_dout_feat, num_nonzero, cpp_in_coords_key, cpp_glob_coords_key, cpp_coords_manager, batch_size, True) # out * 1/N \sum (dout * out) feat_mean_dout_feat = out_grad.new() broadcast_forward(D, norm_feat, mean_dout_feat, feat_mean_dout_feat, multiply, cpp_in_coords_key, cpp_glob_coords_key, cpp_coords_manager) unnorm_din = out_grad.new() broadcast_forward(D, out_grad - feat_mean_dout_feat, -mean_dout, unnorm_din, add, cpp_in_coords_key, cpp_glob_coords_key, cpp_coords_manager) norm_din = out_grad.new() broadcast_forward(D, unnorm_din, inv_std, norm_din, multiply, cpp_in_coords_key, cpp_glob_coords_key, cpp_coords_manager) return norm_din, None, None, None, None
def backward(ctx, out_grad): # https://kevinzakka.github.io/2016/09/14/batch_normalization/ in_coords_key, glob_coords_key = ctx.in_coords_key, ctx.glob_coords_key coords_manager = ctx.coords_manager # To prevent the memory leakage, compute the norm again inv_std, norm_feat = ctx.saved_tensors gpool_forward = get_minkowski_function('GlobalPoolingForward', out_grad) broadcast_forward = get_minkowski_function('BroadcastForward', out_grad) add = operation_type_to_int(OperationType.ADDITION) multiply = operation_type_to_int(OperationType.MULTIPLICATION) cpp_in_coords_key = in_coords_key.CPPCoordsKey cpp_glob_coords_key = glob_coords_key.CPPCoordsKey cpp_coords_manager = coords_manager.CPPCoordsManager # 1/N \sum dout mean_dout, num_nonzero = gpool_forward(out_grad, cpp_in_coords_key, cpp_glob_coords_key, cpp_coords_manager, True, ctx.mode.value) # 1/N \sum (dout * out) mean_dout_feat, num_nonzero = gpool_forward(out_grad * norm_feat, cpp_in_coords_key, cpp_glob_coords_key, cpp_coords_manager, True, ctx.mode.value) # out * 1/N \sum (dout * out) feat_mean_dout_feat = broadcast_forward(norm_feat, mean_dout_feat, multiply, cpp_in_coords_key, cpp_glob_coords_key, cpp_coords_manager) unnorm_din = broadcast_forward(out_grad - feat_mean_dout_feat, -mean_dout, add, cpp_in_coords_key, cpp_glob_coords_key, cpp_coords_manager) norm_din = broadcast_forward(unnorm_din, inv_std, multiply, cpp_in_coords_key, cpp_glob_coords_key, cpp_coords_manager) return norm_din, None, None, None, None