def _sample(self, n, limits: ZfitSpace): pdf = self.pdfs[0] # TODO: use real limits, currently not supported in binned sample sample = pdf.sample(n=n) edges = sample.space.binning.edges ndim = len(edges) edges = [znp.array(edge) for edge in edges] edges_flat = [znp.reshape(edge, [-1]) for edge in edges] lowers = [edge[:-1] for edge in edges_flat] uppers = [edge[1:] for edge in edges_flat] lowers_meshed = znp.meshgrid(*lowers, indexing="ij") uppers_meshed = znp.meshgrid(*uppers, indexing="ij") lowers_meshed_flat = [ znp.reshape(lower_mesh, [-1]) for lower_mesh in lowers_meshed ] uppers_meshed_flat = [ znp.reshape(upper_mesh, [-1]) for upper_mesh in uppers_meshed ] lower_flat = znp.stack(lowers_meshed_flat, axis=-1) upper_flat = znp.stack(uppers_meshed_flat, axis=-1) counts_flat = znp.reshape(sample.values(), (-1, )) counts_flat = tf.cast(counts_flat, znp.int32) # TODO: what if we have fractions? lower_flat_repeated = tf.repeat(lower_flat, counts_flat, axis=0) upper_flat_repeated = tf.repeat(upper_flat, counts_flat, axis=0) sample_unbinned = tf.random.uniform( (znp.sum(counts_flat), ndim), minval=lower_flat_repeated, maxval=upper_flat_repeated, dtype=self.dtype, ) return sample_unbinned
def _rel_counts(self, x, norm): pdf = self.pdfs[0] edges = [znp.array(edge) for edge in self.axes.edges] edges_flat = [znp.reshape(edge, [-1]) for edge in edges] lowers = [edge[:-1] for edge in edges_flat] uppers = [edge[1:] for edge in edges_flat] lowers_meshed = znp.meshgrid(*lowers, indexing="ij") uppers_meshed = znp.meshgrid(*uppers, indexing="ij") shape = tf.shape(lowers_meshed[0]) lowers_meshed_flat = [ znp.reshape(lower_mesh, [-1]) for lower_mesh in lowers_meshed ] uppers_meshed_flat = [ znp.reshape(upper_mesh, [-1]) for upper_mesh in uppers_meshed ] lower_flat = znp.stack(lowers_meshed_flat, axis=-1) upper_flat = znp.stack(uppers_meshed_flat, axis=-1) options = {"type": "bins"} @z.function def integrate_one(limits): l, u = tf.unstack(limits) limits_space = zfit.Space(obs=self.obs, limits=[l, u]) return pdf.integrate(limits_space, norm=False, options=options) limits = znp.stack([lower_flat, upper_flat], axis=1) values = tf.vectorized_map(integrate_one, limits) values = znp.reshape(values, shape) if norm: values /= pdf.normalization(norm) return values
def to_unbinned(self): meshed_center = znp.meshgrid(*self.axes.centers, indexing="ij") flat_centers = [ znp.reshape(center, (-1, )) for center in meshed_center ] centers = znp.stack(flat_centers, axis=-1) flat_weights = znp.reshape(self.values(), (-1, )) # TODO: flow? space = self.space.copy(binning=None) from zfit import Data return Data.from_tensor(obs=space, tensor=centers, weights=flat_weights)
def _counts(self, x, norm): pdf = self.pdfs[0] edges = [znp.array(edge) for edge in self.axes.edges] edges_flat = [znp.reshape(edge, [-1]) for edge in edges] lowers = [edge[:-1] for edge in edges_flat] uppers = [edge[1:] for edge in edges_flat] lowers_meshed = znp.meshgrid(*lowers, indexing="ij") uppers_meshed = znp.meshgrid(*uppers, indexing="ij") shape = tf.shape(lowers_meshed[0]) lowers_meshed_flat = [ znp.reshape(lower_mesh, [-1]) for lower_mesh in lowers_meshed ] uppers_meshed_flat = [ znp.reshape(upper_mesh, [-1]) for upper_mesh in uppers_meshed ] lower_flat = znp.stack(lowers_meshed_flat, axis=-1) upper_flat = znp.stack(uppers_meshed_flat, axis=-1) options = {"type": "bins"} if pdf.is_extended: @z.function def integrate_one(limits): l, u = tf.unstack(limits) limits_space = zfit.Space(obs=self.obs, limits=[l, u]) return pdf.ext_integrate(limits_space, norm=False, options=options) missing_yield = False else: @z.function def integrate_one(limits): l, u = tf.unstack(limits) limits_space = zfit.Space(obs=self.obs, limits=[l, u]) return pdf.integrate(limits_space, norm=False, options=options) missing_yield = True limits = znp.stack([lower_flat, upper_flat], axis=1) try: values = tf.vectorized_map(integrate_one, limits)[:, 0] except ValueError: values = tf.map_fn(integrate_one, limits) values = znp.reshape(values, shape) if missing_yield: values *= self.get_yield() if norm: values /= pdf.normalization(norm) return values
def spline_interpolator(alpha, alphas, densities): alphas = alphas[None, :, None] shape = tf.shape(densities[0]) densities_flat = [znp.reshape(density, [-1]) for density in densities] densities_flat = znp.stack(densities_flat, axis=0) alpha_shaped = znp.reshape(alpha, [1, -1, 1]) y_flat = tfa.image.interpolate_spline( train_points=alphas, train_values=densities_flat[None, ...], query_points=alpha_shaped, order=2, ) y_flat = y_flat[0, 0] y = tf.reshape(y_flat, shape) return y
def func_integral_hermite(limits, norm, params, model): lower, upper = limits.limit1d lower_rescaled = model._polynomials_rescale(lower) upper_rescaled = model._polynomials_rescale(upper) lower = z.convert_to_tensor(lower_rescaled) upper = z.convert_to_tensor(upper_rescaled) # the integral of hermite is a hermite_ni. We add the ni to the coeffs. coeffs = {"c_0": z.constant(0.0, dtype=model.dtype)} for name, coeff in params.items(): ip1_coeff = int(name.split("_", 1)[-1]) + 1 coeffs[f"c_{ip1_coeff}"] = coeff / z.convert_to_tensor( ip1_coeff * 2.0, dtype=model.dtype) coeffs = convert_coeffs_dict_to_list(coeffs) def indefinite_integral(limits): return hermite_shape(x=limits, coeffs=coeffs) integral = indefinite_integral(upper) - indefinite_integral(lower) integral = znp.reshape(integral, newshape=()) integral *= 0.5 * model.space.area() # rescale back to whole width return integral
def _check_init_values(self, space, values, variances): value_shape = tf.shape(values) edges_shape = znp.array([ tf.shape(znp.reshape(edge, (-1, )))[0] for edge in space.binning.edges ]) values_rank = value_shape.shape[0] if variances is not None: variances_shape = tf.shape(variances) variances_rank = variances_shape.shape[0] if values_rank != variances_rank: raise ShapeIncompatibleError( f"Values {values} and variances {variances} differ in rank: {values_rank} vs {variances_rank}" ) tf.assert_equal( variances_shape, value_shape, message=f"Variances and values do not have the same shape:" f" {variances_shape} vs {value_shape}", ) binning_rank = len(space.binning.edges) if binning_rank != values_rank: raise ShapeIncompatibleError( f"Values and binning differ in rank: {values_rank} vs {binning_rank}" ) tf.assert_equal( edges_shape - 1, value_shape, message=f"Edges (minus one) and values do not have the same shape:" f" {edges_shape} vs {value_shape}", )
def func_integral_chebyshev2(limits, norm, params, model): lower, upper = limits.limit1d lower_rescaled = model._polynomials_rescale(lower) upper_rescaled = model._polynomials_rescale(upper) lower = z.convert_to_tensor(lower_rescaled) upper = z.convert_to_tensor(upper_rescaled) # the integral of cheby2_ni is a cheby1_ni+1/(n+1). We add the (n+1) to the coeffs. The cheby1 shape makes # the sum for us. coeffs_cheby1 = {"c_0": z.constant(0.0, dtype=model.dtype)} for name, coeff in params.items(): n_plus1 = int(name.split("_", 1)[-1]) + 1 coeffs_cheby1[f"c_{n_plus1}"] = coeff / z.convert_to_tensor( n_plus1, dtype=model.dtype) coeffs_cheby1 = convert_coeffs_dict_to_list(coeffs_cheby1) def indefinite_integral(limits): return chebyshev_shape(x=limits, coeffs=coeffs_cheby1) integral = indefinite_integral(upper) - indefinite_integral(lower) integral = znp.reshape(integral, newshape=()) integral *= 0.5 * model.space.area() # rescale back to whole width return integral
def _counts_with_modifiers(self, x, norm): values = self.pdfs[0].counts(x, norm=norm) modifiers = list(self._binwise_modifiers.values()) if modifiers: sysshape_flat = tf.stack(modifiers) modifiers = znp.reshape(sysshape_flat, values.shape) values = values * modifiers return values
def sumfunc(params): values = self.pdfs[0].counts(obs) sysshape = list(params.values()) if sysshape: sysshape_flat = tf.stack(sysshape) sysshape = znp.reshape(sysshape_flat, values.shape) values = values * sysshape return znp.sum(values)
def _ext_pdf(self, x, norm): if not self._automatically_extended: raise SpecificFunctionNotImplemented pdf = self.pdfs[0] density = pdf.ext_pdf(x.space, norm=norm) density_flat = znp.reshape(density, (-1, )) centers_list = znp.meshgrid(*pdf.space.binning.centers, indexing="ij") centers_list_flat = [ znp.reshape(cent, (-1, )) for cent in centers_list ] centers = znp.stack(centers_list_flat, axis=-1) # [None, :, None] # TODO: only 1 dim now probs = tfa.image.interpolate_spline( train_points=centers[None, ...], train_values=density_flat[None, :, None], query_points=x.value()[None, ...], order=self.order, ) return probs[0, ..., 0]
def unbinned_to_binindex(data, space, flow=False): if flow: warnings.warn( "Flow currently not fully supported. Values outside the edges are all 0." ) values = [znp.reshape(data.value(ob), (-1, )) for ob in space.obs] edges = [znp.reshape(edge, (-1, )) for edge in space.binning.edges] bins = [ tfp.stats.find_bins(x=val, edges=edge) for val, edge in zip(values, edges) ] stacked_bins = znp.stack(bins, axis=-1) if flow: stacked_bins += 1 bin_is_nan = tf.math.is_nan(stacked_bins) zeros = znp.zeros_like(stacked_bins) binindices = znp.where(bin_is_nan, zeros, stacked_bins) stacked_bins = znp.asarray(binindices, dtype=znp.int32) return stacked_bins
def legendre_integral( limits: ztyping.SpaceType, norm: ztyping.SpaceType, params: list[zfit.Parameter], model: RecursivePolynomial, ): """Recursive integral of Legendre polynomials.""" lower, upper = limits.limit1d lower_rescaled = model._polynomials_rescale(lower) upper_rescaled = model._polynomials_rescale(upper) # if np.allclose((lower_rescaled, upper_rescaled), (-1, 1)): # return z.constant(2.) # lower = z.convert_to_tensor(lower_rescaled) upper = z.convert_to_tensor(upper_rescaled) integral_0 = model.params[f"c_0"] * (upper - lower) # if polynomial 0 is 1 if model.degree == 0: integral = integral_0 else: def indefinite_integral(limits): max_degree = ( model.degree + 1 ) # needed +1 for integral, max poly in term for n is n+1 polys = do_recurrence( x=limits, polys=legendre_polys, degree=max_degree, recurrence=legendre_recurrence, ) one_limit_integrals = [] for degree in range(1, max_degree): coeff = model.params[f"c_{degree}"] one_limit_integrals.append( coeff * (polys[degree + 1] - polys[degree - 1]) / (2.0 * (z.convert_to_tensor(degree)) + 1)) return z.reduce_sum(one_limit_integrals, axis=0) integral = indefinite_integral(upper) - indefinite_integral( lower) + integral_0 integral = znp.reshape(integral, newshape=()) integral *= 0.5 * model.space.area() # rescale back to whole width return integral
def test_unbinned_data2D(): n = 751 gauss, gauss_binned, obs, obs_binned = create_gauss2d_binned(n, 50) data = znp.random.uniform([-5, 50], [10, 600], size=(1000, 2)) y_binned = gauss_binned.pdf(data) y_true = gauss.pdf(data) max_error = np.max(y_true) / 10 np.testing.assert_allclose(y_true, y_binned, atol=max_error) centers = obs_binned.binning.centers X, Y = znp.meshgrid(*centers, indexing="ij") centers = znp.stack([znp.reshape(t, (-1,)) for t in (X, Y)], axis=-1) ycenter_binned = gauss_binned.pdf(centers) ycenter_true = gauss.pdf(centers) np.testing.assert_allclose(ycenter_binned, ycenter_true, atol=max_error / 10) # for the extended case y_binned_ext = gauss_binned.ext_pdf(data) y_true_ext = gauss.ext_pdf(data) max_error_ext = np.max(y_true_ext) / 10 np.testing.assert_allclose(y_true_ext, y_binned_ext, atol=max_error_ext) ycenter_binned_ext = gauss_binned.ext_pdf(centers) ycenter_true_ext = gauss.ext_pdf(centers) np.testing.assert_allclose( ycenter_binned_ext, ycenter_true_ext, atol=max_error_ext / 10 ) x_outside = znp.array([[-7.0, 55], [3.0, 13], [2, 150], [12, 30], [14, 1000]]) y_outside = gauss_binned.pdf(x_outside) assert y_outside[0] == 0 assert y_outside[1] == 0 assert y_outside[2] > 0 assert y_outside[3] == 0 assert y_outside[4] == 0 y_outside_ext = gauss_binned.ext_pdf(x_outside) assert y_outside_ext[0] == 0 assert y_outside_ext[1] == 0 assert y_outside_ext[2] > 0 assert y_outside_ext[3] == 0 assert y_outside_ext[4] == 0
def create_covariance(mu, sigma): mu = z.convert_to_tensor(mu) sigma = z.convert_to_tensor(sigma) # TODO (Mayou36): fix as above? params_tensor = z.convert_to_tensor(params) if sigma.shape.ndims > 1: covariance = sigma # TODO: square as well? elif sigma.shape.ndims == 1: covariance = tf.linalg.tensor_diag(z.pow(sigma, 2.0)) else: sigma = znp.reshape(sigma, [1]) covariance = tf.linalg.tensor_diag(z.pow(sigma, 2.0)) if (not params_tensor.shape[0] == mu.shape[0] == covariance.shape[0] == covariance.shape[1]): raise ShapeIncompatibleError( f"params_tensor, observation and uncertainty have to have the" " same length. Currently" f"param: {params_tensor.shape[0]}, mu: {mu.shape[0]}, " f"covariance (from uncertainty): {covariance.shape[0:2]}") return covariance
def func_integral_chebyshev1(limits, norm, params, model): lower, upper = limits.rect_limits lower_rescaled = model._polynomials_rescale(lower) upper_rescaled = model._polynomials_rescale(upper) lower = z.convert_to_tensor(lower_rescaled) upper = z.convert_to_tensor(upper_rescaled) integral = model.params[f"c_0"] * ( upper - lower) # if polynomial 0 is defined as T_0 = 1 if model.degree >= 1: integral += (model.params[f"c_1"] * 0.5 * (upper**2 - lower**2) ) # if polynomial 0 is defined as T_0 = 1 if model.degree >= 2: def indefinite_integral(limits): max_degree = model.degree + 1 polys = do_recurrence( x=limits, polys=chebyshev_polys, degree=max_degree, recurrence=chebyshev_recurrence, ) one_limit_integrals = [] for degree in range(2, max_degree): coeff = model.params[f"c_{degree}"] n_float = z.convert_to_tensor(degree) integral = n_float * polys[degree + 1] / (z.square( n_float) - 1) - limits * polys[degree] / (n_float - 1) one_limit_integrals.append(coeff * integral) return z.reduce_sum(one_limit_integrals, axis=0) integral += indefinite_integral(upper) - indefinite_integral(lower) integral = znp.reshape(integral, newshape=()) integral *= 0.5 * model.space.area() # rescale back to whole width integral = tf.gather(integral, indices=0, axis=-1) return integral
def func_integral_laguerre(limits, norm, params: dict, model): """The integral of the simple laguerre polynomials. Defined as :math:`\\int L_{n} = (-1) L_{n+1}^{(-1)}` with :math:`L^{(\alpha)}` the generalized Laguerre polynom. Args: limits: norm: params: model: Returns: """ lower, upper = limits.limit1d lower_rescaled = model._polynomials_rescale(lower) upper_rescaled = model._polynomials_rescale(upper) lower = z.convert_to_tensor(lower_rescaled) upper = z.convert_to_tensor(upper_rescaled) # The laguerre shape makes the sum for us. setting the 0th coeff to 0, since no -1 term exists. coeffs_laguerre_nup = { f'c_{int(n.split("_", 1)[-1]) + 1}': c for i, (n, c) in enumerate(params.items()) } # increase n -> n+1 of naming coeffs_laguerre_nup["c_0"] = tf.constant(0.0, dtype=model.dtype) coeffs_laguerre_nup = convert_coeffs_dict_to_list(coeffs_laguerre_nup) def indefinite_integral(limits): return -1 * laguerre_shape_alpha_minusone(x=limits, coeffs=coeffs_laguerre_nup) integral = indefinite_integral(upper) - indefinite_integral(lower) integral = znp.reshape(integral, newshape=()) integral *= 0.5 * model.space.area() # rescale back to whole width return integral
def _unnormalized_pdf(self, x): lower_func, upper_func = self._conv_limits["func"] nbins_func = self._conv_limits["nbins_func"] x_funcs = tf.linspace(lower_func, upper_func, tf.cast(nbins_func, tf.int32)) lower_kernel, upper_kernel = self._conv_limits["kernel"] nbins_kernel = self._conv_limits["nbins_kernel"] x_kernels = tf.linspace(lower_kernel, upper_kernel, tf.cast(nbins_kernel, tf.int32)) x_func = tf.meshgrid(*tf.unstack(x_funcs, axis=-1), indexing="ij") x_func = znp.transpose(x_func) x_func_flatish = znp.reshape(x_func, (-1, self.n_obs)) data_func = Data.from_tensor(tensor=x_func_flatish, obs=self.obs) x_kernel = tf.meshgrid(*tf.unstack(x_kernels, axis=-1), indexing="ij") x_kernel = znp.transpose(x_kernel) x_kernel_flatish = znp.reshape(x_kernel, (-1, self.n_obs)) data_kernel = Data.from_tensor(tensor=x_kernel_flatish, obs=self.obs) y_func = self.pdfs[0].pdf(data_func, norm=False) y_kernel = self.pdfs[1].pdf(data_kernel, norm=False) func_dims = [nbins_func] * self.n_obs kernel_dims = [nbins_kernel] * self.n_obs y_func = znp.reshape(y_func, func_dims) y_kernel = znp.reshape(y_kernel, kernel_dims) # flip the kernel to use the cross-correlation called `convolution function from TF # convolution = cross-correlation with flipped kernel # this introduces a shift and has to be corrected when interpolating/x_func # because the convolution has to be independent of the kernes **limits** # We assume they are symmetric when doing the FFT, so shift them back. y_kernel = tf.reverse(y_kernel, axis=range(self.n_obs)) kernel_shift = (upper_kernel + lower_kernel) / 2 x_func += kernel_shift lower_func += kernel_shift upper_func += kernel_shift # make rectangular grid y_func_rect = znp.reshape(y_func, func_dims) y_kernel_rect = znp.reshape(y_kernel, kernel_dims) # needed for multi dims? # if self.n_obs == 2: # y_kernel_rect = tf.linalg.adjoint(y_kernel_rect) # get correct shape for tf.nn.convolution y_func_rect_conv = znp.reshape(y_func_rect, (1, *func_dims, 1)) y_kernel_rect_conv = znp.reshape(y_kernel_rect, (*kernel_dims, 1, 1)) conv = tf.nn.convolution( input=y_func_rect_conv, filters=y_kernel_rect_conv, strides=1, padding="SAME", ) # needed for multidims? # if self.n_obs == 2: # conv = tf.linalg.adjoint(conv[0, ..., 0])[None, ..., None] # conv = scipy.signal.convolve( # y_func_rect, # y_kernel_rect, # mode='same' # )[None, ..., None] train_points = znp.expand_dims(x_func, axis=0) query_points = znp.expand_dims(x.value(), axis=0) if self.conv_interpolation == "spline": conv_points = znp.reshape(conv, (1, -1, 1)) prob = tfa.image.interpolate_spline( train_points=train_points, train_values=conv_points, query_points=query_points, order=self._conv_spline_order, ) prob = prob[0, ..., 0] elif self.conv_interpolation == "linear": prob = tfp.math.batch_interp_regular_nd_grid( x=query_points[0], x_ref_min=lower_func, x_ref_max=upper_func, y_ref=conv[0, ..., 0], # y_ref=tf.reverse(conv[0, ..., 0], axis=[0]), axis=-self.n_obs, ) prob = prob[0] return prob
def cut_edges_and_bins( edges: Iterable[znp.array], limits: ZfitSpace, axis=None, unscaled=None ) -> tuple[list[znp.array], tuple[znp.array, znp.array], list | None]: """Cut the *edges* according to *limits* and calculate the bins inside. The edges within limits are calculated and returned together with the corresponding bin indices. The indices mark the lowest and the highest index of the edges that are returned. Additionally, the unscaled edges are returned. If the limits are between two edges, this will be treated as the new edge. If the limits are outside the edges, all edges in this direction will be returned (but not extended to the limit). For example: [0, 0.5, 1., 1.5, 2.] and the limits (0.8, 3.) will return [0.8, 1., 1.5, 2.], ([1], [4]) .. code-block:: cut_edges_and_bins([[0., 0.5, 1., 1.5, 2.]], ([[0.8]], [[3]])) Args: edges: Iterable of tensor-like objects that describe the edges of a histogram. Every object should have rank n (where n is the length of *edges*) but only have the dimension i filled out. These are tensors that are ready to be broadcasted together. limits: The limits that will be used to confine the edges Returns: edges, (lower bins, upper bins), unscaled_edges: The edges and the bins are returned. The upper bin number corresponds to the highest bin which was still (partially) inside the limits **plus one** (so it's the index of the edge that is right outside). The unscaled edges are like *edges* but the last edge is the edge that is lying not inside anymore, so the actual edge of the last bin number returend. This can be used to determine the fraction cut away. """ if axis is not None: axis = convert_to_container(axis) if unscaled is None: unscaled = False if unscaled: cut_unscaled_edges = [] else: cut_unscaled_edges = None cut_scaled_edges = [] all_lower_bins = [] all_upper_bins = [] if isinstance(limits, ZfitSpace): lower, upper = limits.limits else: lower, upper = limits lower = znp.asarray(lower) upper = znp.asarray(upper) lower_all = lower[0] upper_all = upper[0] rank = len(edges) current_axis = 0 for i, edge in enumerate(edges): edge = znp.asarray(edge) edge = znp.reshape(edge, (-1,)) if axis is None or i in axis: lower_i = lower_all[current_axis, None] edge_minimum = edge[0] # edge_minimum = tf.gather(edge, indices=0, axis=i) lower_i = znp.maximum(lower_i, edge_minimum) upper_i = upper_all[current_axis, None] edge_maximum = edge[-1] # edge_maximum = tf.gather(edge, indices=tf.shape(edge)[i] - 1, axis=i) upper_i = znp.minimum(upper_i, edge_maximum) # we get the bins that are just one too far. Then we update this whole bin tensor with the actual edge. # The bins index is the index below the value. lower_bin_float = tfp.stats.find_bins( lower_i, edge, extend_lower_interval=True, extend_upper_interval=True ) lower_bin = tf.reshape(tf.cast(lower_bin_float, dtype=znp.int32), [-1]) # lower_bins = tf.tensor_scatter_nd_update(zero_bins, [[i]], lower_bin) # +1 below because the outer bin is searched, meaning the one that is higher than the value upper_bin_float = tfp.stats.find_bins( upper_i, edge, extend_lower_interval=True, extend_upper_interval=True ) upper_bin = tf.reshape(tf.cast(upper_bin_float, dtype=znp.int32), [-1]) + 1 size = upper_bin - lower_bin new_edge = tf.slice( edge, lower_bin, size + 1 ) # +1 because stop is exclusive new_edge = tf.tensor_scatter_nd_update( new_edge, [tf.constant([0]), size], [lower_i[0], upper_i[0]] ) if unscaled: new_edge_unscaled = tf.slice( edge, lower_bin, size + 1 ) # +1 because stop is exclusive current_axis += 1 else: lower_bin = [0] upper_bin = znp.asarray([edge.shape[0] - 1], dtype=znp.int32) new_edge = edge if unscaled: new_edge_unscaled = edge new_shape = [1] * rank new_shape[i] = -1 new_edge = znp.reshape(new_edge, new_shape) all_lower_bins.append(lower_bin) all_upper_bins.append(upper_bin) cut_scaled_edges.append(new_edge) if unscaled: new_edge_unscaled = znp.reshape(new_edge_unscaled, new_shape) cut_unscaled_edges.append(new_edge_unscaled) # partial = axis is not None and len(axis) < rank # # if partial: # scaled_edges_full = list(edges) # for edge, ax in zip(cut_scaled_edges, axis): # scaled_edges_full[ax] = edge # scaled_edges = scaled_edges_full # indices = tf.convert_to_tensor(axis)[:, None] # lower_bins = tf.scatter_nd(indices, lower_bins, shape=(ndims,)) # upper_bins = tf.tensor_scatter_nd_update(tf.convert_to_tensor(values.shape), # indices, upper_bins) # lower_bins_indices = tf.stack([lower_bins, dims], axis=-1) # upper_bins_indices = tf.stack([upper_bins, dims], axis=-1) # all_lower_bins = tf.cast(znp.sum(all_lower_bins, axis=0), dtype=znp.int32) all_lower_bins = tf.concat(all_lower_bins, axis=0) all_upper_bins = tf.concat(all_upper_bins, axis=0) return cut_scaled_edges, (all_lower_bins, all_upper_bins), cut_unscaled_edges