def fit_nonparametric_nonlinearity(self, nbins=50, w=None): if w is None: if self.w_spl is not None: w = self.w_spl.flatten() elif self.w_mle is not None: w = self.w_mle.flatten() elif self.w_sta is not None: w = self.w_sta.flatten() else: w = jnp.array(w) X = self.X X = X.reshape(X.shape[0], -1) y = self.y output_raw = X @ uvec(w) output_spk = X[y != 0] @ uvec(w) hist_raw, bins = jnp.histogram(output_raw, bins=nbins, density=True) hist_spk, _ = jnp.histogram(output_spk, bins=bins, density=True) mask = ~(hist_raw == 0) yy0 = hist_spk[mask] / hist_raw[mask] self.nl_bins = bins[1:] self.fnl_nonparametric = interp1d(bins[1:][mask], yy0)
def get_bins_and_bincounts(samples, normalized=False): """take in samples, create a common set of bins, and compute the counts count(x in bin) for each bin and each sample x. Parameters ------------ samples : np.array of shape (n,) or shape (k, n). - If shape (n,): interpreted as a set of n scalar-valued samples. - If shape (k, n): interpreted as k sets of n scalar-valued samples. Returns -------- probabilities : bins : """ nr_samples = np.prod(samples.shape) nr_bins = np.log2(nr_samples) nr_bins = int(max(nr_bins, 5)) lims = [np.min(samples), np.max(samples)] bins = np.linspace(*lims, num=nr_bins) if samples.ndim == 2: out = np.asarray([ np.histogram(x, bins=bins, density=normalized)[0] for x in samples ]) return out, bins elif samples.ndim == 1: return np.histogram(samples, bins=bins, density=normalized)[0], bins else: raise ValueError( f"Input must have shape (n,) or shape (k,n). Instead received shape {samples.shape}" )
def contrast(image, factor): """ Equivalent of PIL Contrast. Args: image: image tensor factor: float factor Returns: Augmented image """ has_alpha = image.shape[-1] == 4 alpha = None if has_alpha: image, alpha = image[:, :, :3], image[:, :, -1:] degenerate = rgb_to_grayscale(image) # Cast before calling tf.histogram. degenerate = degenerate.astype('int32') # Compute the grayscale histogram, then compute the mean pixel value, # and create a constant image size of that value. Use that as the # blending degenerate target of the original image. hist, _ = jnp.histogram(degenerate, bins=256, range=(0, 255)) mean = jnp.sum(hist.astype('float32')) / 256.0 degenerate = jnp.ones_like(degenerate, dtype='float32') * mean degenerate = jnp.clip(degenerate, 0.0, 255.0) degenerate = grayscale_to_rgb(degenerate).astype(image.dtype) degenerate = blend(degenerate, image, factor) if has_alpha: return jnp.concatenate([degenerate, alpha], axis=-1) return degenerate
def scale_channel(img): """ Scale the data in the channel to implement equalize. Args: img: channel to scale. Returns: scaled channel """ # im = im[:, :, c].astype('int32') img = img.astype('int32') # Compute the histogram of the image channel. histo = jnp.histogram(img, bins=255, range=(0, 255))[0] last_nonzero = jnp.argmax(histo[::-1] > 0) # jnp.nonzero(histo)[0][-1] step = (jnp.sum(histo) - jnp.take(histo[::-1], last_nonzero)) // 255 # if test_agains_original: # # For the purposes of computing the step, filter out the nonzeros. # nonzero = jnp.nonzero(histo) # nonzero_histo = jnp.reshape(jnp.take(histo, nonzero), [-1]) # original_step = (jnp.sum(nonzero_histo) - nonzero_histo[-1]) // 255 # assert step == original_step # If step is zero, return the original image. Otherwise, build # lut from the full histogram and step and then index from it. return jnp.where(step == 0, img.astype('uint8'), jnp.take(build_lut(histo, step), img).astype('uint8'))
def draw_uniform(samples, bins, desired_size): """ Draw uniform set of samples """ hist, bin_edges = np.histogram(samples, bins=bins) avg_nb = int(desired_size / float(bins)) numbers = np.repeat(avg_nb, bins) for j in range(4): numbers[hist <= numbers] = hist[hist <= numbers] nb_rest = desired_size - np.sum(numbers[hist <= numbers]) # * bins avg_nb = round(nb_rest / np.sum(hist > numbers)) numbers[hist > numbers] = avg_nb result = [] count = 0 for i in range(bin_edges.size - 1): ind = samples >= bin_edges[i] ind &= samples <= bin_edges[i + 1] if ind.sum() > 0: positions = np.where(ind)[0] nb = min([numbers[i], ind.sum()]) result.append(jax.random.choice(positions, nb, replace=False)) return np.concatenate(result)
def scale_channel(img): """ Scale the data in the channel to implement equalize. Args: img: channel to scale. Returns: scaled channel """ # im = im[:, :, c].astype('int32') img = img.astype('int32') # Compute the histogram of the image channel. histo = jnp.histogram(img, bins=255, range=(0, 255))[0] last_nonzero = jnp.argmax(histo[::-1] > 0) # jnp.nonzero(histo)[0][-1] step = (jnp.sum(histo) - jnp.take(histo[::-1], last_nonzero)) // 255 # If step is zero, return the original image. Otherwise, build # lut from the full histogram and step and then index from it. return jnp.where(step == 0, img.astype('uint8'), jnp.take(build_lut(histo, step), img).astype('uint8'))
def log_tomographic_weight_function_stochastic(key, u, x1, p1, x2, p2): """ int w(x) f(x) dx = sum_i w(dx * i) f(dx * i) dx where, int w(x) dx = sum_i w(dx * i) dx = 1 Args: key: u: x1: [N, 3] p1: [N, 3] x2: [M, 3] p2: [M, 3] Returns: w(dx*i) dx / sum_i w(dx * i) dx [N, M] shaped """ n = u.size**2 N = x1.shape[0] M = x2.shape[0] t1 = random.uniform(key, shape=(n, N, 1)) t2 = random.uniform(key, shape=(n, M, 1)) # L, N, M norm_squared = vmap(squared_norm)(x1 + t1 * p1, x2 + t2 * p2) bins = jnp.concatenate([u, u[-1:] + u[-1] - u[-2]]) # N*M, U hist = vmap(lambda x: jnp.histogram(x, bins)[0])(jnp.reshape( norm_squared, (n, -1)).T) # N,M,U hist = jnp.reshape(hist, (x1.shape[0], x2.shape[0], u.size)) log_hist = jnp.log(hist) log_du = jnp.diff(bins) log_w = log_hist + log_du # N,M,U log_w = log_w - logsumexp(log_w, axis=-1, keepdims=True) log_w = jnp.where(hist == 0., -jnp.inf, log_w) return log_w
def histogram_entropy(data, nbins: int = 10): """Calculates the histogram entropy of 1D data. This function uses the histogram and then calculates the entropy. Does the miller-maddow correction Parameters ---------- data : np.ndarray, (n_samples,) the input data for the entropy base : int, default=2 the log base for the calculation. Returns ------- S : float the entropy""" # get histogram counts and bin edges counts, bin_edges = np.histogram(data, bins=nbins, density=False) # get bin centers and sizes bin_centers = np.mean(np.vstack((bin_edges[0:-1], bin_edges[1:])), axis=0) # get difference between the bins delta = bin_centers[3] - bin_centers[2] # normalize counts (density) pk = 1.0 * np.array(counts) / np.sum(counts) # calculate the entropy S = univariate_entropy(pk) # Miller Maddow Correction correction = 0.5 * (np.sum(counts > 0) - 1) / counts.sum() return S + correction + np.log2(delta)
def get_marginals(self, accepted_parameters=None, ranges=None, gridsize=None, smoothing=None): """ Creates the 1D and 2D marginal distribution list for plotting Using list of parameter values (accepted by the ABC) an approximate set of marginal distributions for plotting are created based on histogramming the points. Smoothing can be performed on the histogram to avoid undersampling artefacts. For every parameter the full distribution is summed over every other parameter to get the 1D marginals and for every combination the 2D marginals are calculated by summing over the remaining parameters. The list is made up of a list of n_params lists which contain n_columns number of objects. Parameters ---------- accepted_parameters : float(any, n_params) or None, default=None An array of all accepted parameter values. If None, the accepted parameters from the `parameters` class attribute are used ranges : list or None, default=None A list of arrays containing the bin centres for the marginal distribution obtained by histogramming for each parameter. If None the ranges are constructed from the ranges of the prior distribution. gridsize : list or None, default=None The number of grid points to evaluate the marginal distribution on for each parameter. This needs to be set if ranges is passed (and different from the gridsize set on initialisation) smoothing : float or None, default=None A Gaussian smoothing for the marginal distributions. Smoothing not done if smoothing is None Returns ------- list of lists: The 1D and 2D marginal distributions for each parameter (of pair) """ if accepted_parameters is None: accepted_parameters = self.parameters.accepted if ranges is None: ranges = [ np.hstack([range, np.array([range[1] - range[0]])]) - (range[1] - range[0]) / 2 for range in self.ranges ] if gridsize is None: gridsize = self.gridsize if smoothing is not None: def smooth(x): return gaussian_filter(x, smoothing, mode="mirror") else: def smooth(x): return x marginals = [] for row in range(self.n_params): marginals.append([]) for column in range(self.n_params): if column == row: marginals[row].append( np.array([ smooth( np.histogram(parameters[:, column], bins=ranges[column], density=True)[0]) for parameters in accepted_parameters ])) elif column < row: marginals[row].append( np.array([ smooth( np.histogramdd( parameters[:, [column, row]], bins=[ranges[column], ranges[row]], density=True)[0]) for parameters in accepted_parameters ])) return marginals
grads).item() grads_single_mean, grads_single_var = jnp.mean( grads[:, 0]).item(), jnp.var(grads[:, 0]).item() grads_norm_mean, grads_norm_var = jnp.mean(grad_norms).item(), jnp.var( grad_norms).item() logging_output = OrderedDict(grad_component_all_mean=grads_all_mean, grad_component_all_var=grads_all_var, grad_component_single_mean=grads_single_mean, grad_component_single_var=grads_single_var, grad_norm_mean=grads_norm_mean, grad_norm_var=grads_norm_var) expmgr.log(step=n_layers, logging_output=logging_output) wandb.log(dict( grad_component_all=wandb.Histogram( np_histogram=jnp.histogram(grads, bins=64, density=True)), grad_component_single=wandb.Histogram( np_histogram=jnp.histogram(grads[:, 0], bins=64, density=True)), grad_norm=wandb.Histogram( np_histogram=jnp.histogram(grad_norms, bins=64, density=True))), step=n_layers) suffix = f'Q{n_qubits}L{n_layers}R{rot_axis}BS{block_size}_g{g}h{h}' expmgr.save_array(f'params_{suffix}.npy', params) expmgr.save_array(f'grads_{suffix}.npy', grads) del params, grads gc.collect()
forward and inverse transformation Examples -------- >>> # single set of parameters >>> X_transform, params = get_params(x_samples, 10, 1000) >>> # example with multiple dimensions >>> multi_dims = jax.vmap(get_params, in_axes=(0, None, None)) >>> X_transform, params = multi_dims(X, 10, 1000) """ # get number of samples n_samples = np.shape(X)[0] # get histogram counts and bin edges counts, bin_edges = np.histogram(X, bins=nbins) # add regularization counts = np.array(counts) + alpha # get bin centers and sizes bin_centers = np.mean(np.vstack((bin_edges[0:-1], bin_edges[1:])), axis=0) bin_size = bin_edges[2] - bin_edges[1] # ================================= # PDF Estimation # ================================= # pdf support pdf_support = np.hstack( (bin_centers[0] - bin_size, bin_centers, bin_centers[-1] + bin_size)) # empirical PDF
def histogram(a, bins=10, range=None, weights=None, density=None): if isinstance(a, JaxArray): a = a.value if isinstance(weights, JaxArray): weights = weights.value hist, bin_edges = jnp.histogram(a=a, bins=bins, range=range, weights=weights, density=density) return JaxArray(hist), JaxArray(bin_edges)
def _entropy(v, uniq): uniq = jnp.concatenate([uniq, jnp.array([jnp.inf])], axis=0) hist, _ = jnp.histogram(v, bins=uniq) hist = hist / jnp.sum(hist) entropy = -jnp.sum(hist * jnp.log2(hist)) return entropy
def test_sorted_piecewise_constant_pdf_train_mode(self): """Test that piecewise-constant sampling reproduces its distribution.""" batch_size = 4 num_bins = 16 num_samples = 1000000 precision = 1e5 rng = random.PRNGKey(20202020) # Generate a series of random PDFs to sample from. data = [] for _ in range(batch_size): rng, key = random.split(rng) # Randomly initialize the distances between bins. # We're rolling our own fixed precision here to make cumsum exact. bins_delta = jnp.round(precision * jnp.exp( random.uniform( key, shape=(num_bins + 1, ), minval=-3, maxval=3))) # Set some of the bin distances to 0. rng, key = random.split(rng) bins_delta *= random.uniform(key, shape=bins_delta.shape) < 0.9 # Integrate the bins. bins = jnp.cumsum(bins_delta) / precision rng, key = random.split(rng) bins += random.normal(key) * num_bins / 2 rng, key = random.split(rng) # Randomly generate weights, allowing some to be zero. weights = jnp.maximum( 0, random.uniform(key, shape=(num_bins, ), minval=-0.5, maxval=1.)) gt_hist = weights / weights.sum() data.append((bins, weights, gt_hist)) # Tack on an "all zeros" weight matrix, which is a common cause of NaNs. weights = jnp.zeros_like(weights) gt_hist = jnp.ones_like(gt_hist) / num_bins data.append((bins, weights, gt_hist)) bins, weights, gt_hist = [jnp.stack(x) for x in zip(*data)] for randomized in [True, False]: rng, key = random.split(rng) # Draw samples from the batch of PDFs. samples = math.sorted_piecewise_constant_pdf( key, bins, weights, num_samples, randomized, ) self.assertEqual(samples.shape[-1], num_samples) # Check that samples are sorted. self.assertTrue(jnp.all(samples[..., 1:] >= samples[..., :-1])) # Verify that each set of samples resembles the target distribution. for i_samples, i_bins, i_gt_hist in zip(samples, bins, gt_hist): i_hist = jnp.float32(jnp.histogram(i_samples, i_bins)[0]) / num_samples i_gt_hist = jnp.array(i_gt_hist) # Merge any of the zero-span bins until there aren't any left. while jnp.any(i_bins[:-1] == i_bins[1:]): j = int(jnp.where(i_bins[:-1] == i_bins[1:])[0][0]) i_hist = jnp.concatenate([ i_hist[:j], jnp.array([i_hist[j] + i_hist[j + 1]]), i_hist[j + 2:] ]) i_gt_hist = jnp.concatenate([ i_gt_hist[:j], jnp.array([i_gt_hist[j] + i_gt_hist[j + 1]]), i_gt_hist[j + 2:] ]) i_bins = jnp.concatenate([i_bins[:j], i_bins[j + 1:]]) # Angle between the two histograms in degrees. angle = 180 / jnp.pi * jnp.arccos( jnp.minimum( 1., jnp.mean((i_hist * i_gt_hist) / jnp.sqrt( jnp.mean(i_hist**2) * jnp.mean(i_gt_hist**2))))) # Jensen-Shannon divergence. m = (i_hist + i_gt_hist) / 2 js_div = jnp.sum( sp.special.kl_div(i_hist, m) + sp.special.kl_div(i_gt_hist, m)) / 2 self.assertLessEqual(angle, 0.5) self.assertLessEqual(js_div, 1e-5)
def test_get_Q(): Q = get_polynomial_form() import pylab as plt from jax import jit @jit def tomo_weight_ref(gamma, x1, x2, p1, p2): return tomographic_weight_function(gamma, x1, x2, p1, p2, S=150) @jit def cumulative_tomo_weight_function_dimensionless(gamma, x1, x2, p1, p2): x12 = x1 - x2 h = jnp.linalg.norm(x12) n = x12 / h w1 = p1 / h w2 = p2 / h gamma_prime = gamma / h**2 return cumulative_tomographic_weight_dimensionless_function( gamma_prime, n, w1, w2, S=150) @jit def cumulative_tomo_weight_polynomial_dimensionless(gamma, x1, x2, p1, p2): x12 = x1 - x2 h = jnp.linalg.norm(x12) n = x12 / h w1 = p1 / h w2 = p2 / h gamma_prime = gamma / h**2 return vmap(lambda gamma_prime: cumulative_tomographic_weight_dimensionless_polynomial( Q, gamma_prime, n, w1, w2))(gamma_prime) # return jnp.exp(log_tomographic_weight_dimensionless_function(gamma_prime, n, w1, w2, S=150)) / h ** 2 for i in range(10): keys = random.split(random.PRNGKey(i), 6) x1 = jnp.concatenate( [10. * random.uniform(keys[0], shape=(2, )), jnp.zeros((1, ))], axis=-1) p1 = jnp.concatenate([ 4. * jnp.pi / 180. * random.uniform(keys[1], shape=(2, ), minval=-1, maxval=1), jnp.ones((1, )) ], axis=-1) p1 = 4 * p1 / jnp.linalg.norm(p1, axis=-1, keepdims=True) x2 = jnp.concatenate( [4. * random.uniform(keys[2], shape=(2, )), jnp.zeros((1, ))], axis=-1) p2 = jnp.concatenate([ 4. * jnp.pi / 180. * random.uniform(keys[3], shape=(2, ), minval=-1, maxval=1), jnp.ones((1, )) ], axis=-1) p2 = 4 * p2 / jnp.linalg.norm(p2, axis=-1, keepdims=True) t1 = random.uniform(keys[4], shape=(10000, )) t2 = random.uniform(keys[5], shape=(10000, )) u1 = x1 + t1[:, None] * p1 u2 = x2 + t2[:, None] * p2 gamma = jnp.linalg.norm(u1 - u2, axis=1)**2 plt.hist(gamma.flatten(), bins=100, density=True, label='histogram') hist, bins = jnp.histogram(gamma.flatten(), density=True, bins=100) gamma = 0.5 * (bins[:-1] + bins[1:]) w_ref = tomo_weight_ref(bins, x1, x2, p1, p2) plt.plot(gamma, w_ref, label='analytic ref') plt.legend() plt.show() cdf_ref = cumulative_tomo_weight_function_dimensionless( gamma, x1, x2, p1, p2) cdf_poly = cumulative_tomo_weight_polynomial_dimensionless( gamma, x1, x2, p1, p2) gamma_prime = gamma / jnp.linalg.norm(x1 - x2) plt.plot(gamma_prime, cdf_ref, label='ref') plt.plot(gamma_prime, cdf_poly, label='poly') plt.legend() plt.show()
def test_tomographic_weight_rel_err(): import pylab as plt from jax import jit for S in range(5, 30, 5): @jit def tomo_weight(gamma, x1, x2, p1, p2): return tomographic_weight_function(gamma, x1, x2, p1, p2, S=S) @jit def tomo_weight_ref(gamma, x1, x2, p1, p2): return tomographic_weight_function(gamma, x1, x2, p1, p2, S=150) rel_error = [] for i in range(400): keys = random.split(random.PRNGKey(i), 6) x1 = jnp.concatenate( [4. * random.uniform(keys[0], shape=(2, )), jnp.zeros((1, ))], axis=-1) p1 = jnp.concatenate([ 4. * jnp.pi / 180. * random.uniform(keys[1], shape=(2, ), minval=-1, maxval=1), jnp.ones((1, )) ], axis=-1) p1 = 4 * p1 / jnp.linalg.norm(p1, axis=-1, keepdims=True) x2 = jnp.concatenate( [4. * random.uniform(keys[2], shape=(2, )), jnp.zeros((1, ))], axis=-1) p2 = jnp.concatenate([ 4. * jnp.pi / 180. * random.uniform(keys[3], shape=(2, ), minval=-1, maxval=1), jnp.ones((1, )) ], axis=-1) p2 = 4 * p2 / jnp.linalg.norm(p2, axis=-1, keepdims=True) # x1 = random.normal(keys[0], shape_dict=(2,)) # p1 = random.normal(keys[1], shape_dict=(2,)) # x2 = random.normal(keys[2], shape_dict=(2,)) # p2 = random.normal(keys[3], shape_dict=(2,)) t1 = random.uniform(keys[4], shape=(10000, )) t2 = random.uniform(keys[5], shape=(10000, )) u1 = x1 + t1[:, None] * p1 u2 = x2 + t2[:, None] * p2 gamma = jnp.linalg.norm(u1 - u2, axis=1)**2 hist, bins = jnp.histogram(gamma.flatten(), density=True, bins=100) bins = jnp.linspace(bins.min(), bins.max(), 20) w = tomo_weight(bins, x1, x2, p1, p2) w_ref = tomo_weight_ref(bins, x1, x2, p1, p2) rel_error.append(jnp.max(jnp.abs(w - w_ref)) / jnp.max(w_ref)) rel_error = jnp.array(rel_error) plt.hist(rel_error, bins='auto') plt.title("{} : {:.2f}|{:.2f}|{:.2f}".format( S, *jnp.percentile(rel_error, [5, 50, 95]))) plt.show()
def safe_gaussian_kde(samples, weights): try: return gaussian_kde(samples, weights=weights, bw_method='silverman') except: hist, bin_edges = jnp.histogram(samples,weights=weights, bins='auto') return lambda x: hist[jnp.searchsorted(bin_edges, x)]
grads_all_mean, grads_all_var = jnp.mean(grads).item(), jnp.var(grads).item() grads_single_mean, grads_single_var = jnp.mean(grads[:, 0]).item(), jnp.var(grads[:, 0]).item() grads_norm_mean, grads_norm_var = jnp.mean(grad_norms).item(), jnp.var(grad_norms).item() logging_output = OrderedDict( grad_component_all_mean=grads_all_mean, grad_component_all_var=grads_all_var, grad_component_single_mean=grads_single_mean, grad_component_single_var=grads_single_var, grad_norm_mean=grads_norm_mean, grad_norm_var=grads_norm_var) expmgr.log(step=n_layers, logging_output=logging_output) wandb.log( dict( grad_component_all=wandb.Histogram(np_histogram=jnp.histogram(grads, bins=64, density=True)), grad_component_single=wandb.Histogram(np_histogram=jnp.histogram(grads[:, 0], bins=64, density=True)), grad_norm=wandb.Histogram(np_histogram=jnp.histogram(grad_norms, bins=64, density=True)) ), step=n_layers ) suffix = f'Q{n_qubits}L{n_layers}R{rot_axis}BS{block_size}_g{g}h{h}' expmgr.save_array(f'params_{suffix}.npy', params) expmgr.save_array(f'grads_{suffix}.npy', grads) del params, grads gc.collect()
def test_tomographic_weight(): import pylab as plt from jax import jit # @jit def tomo_weight(gamma, x1, x2, p1, p2): return tomographic_weight_function(gamma, x1, x2, p1, p2, S=10) @jit def tomo_weight_ref(gamma, x1, x2, p1, p2): return tomographic_weight_function(gamma, x1, x2, p1, p2, S=150) @jit def _tomo_weight_ref(gamma, x1, x2, p1, p2): return _tomographic_weight_function(gamma, x1, x2, p1, p2, S=150) @jit def tomo_weight_dimensionless_ref(gamma, x1, x2, p1, p2): x12 = x1 - x2 h = jnp.linalg.norm(x12) n = x12 / h w1 = p1 / h w2 = p2 / h gamma_prime = gamma / h**2 return jnp.exp( log_tomographic_weight_dimensionless_function( gamma_prime, n, w1, w2, S=150)) / h**2 for i in range(100): keys = random.split(random.PRNGKey(i), 6) x1 = jnp.concatenate( [10. * random.uniform(keys[0], shape=(2, )), jnp.zeros((1, ))], axis=-1) p1 = jnp.concatenate([ 4. * jnp.pi / 180. * random.uniform(keys[1], shape=(2, ), minval=-1, maxval=1), jnp.ones((1, )) ], axis=-1) p1 = 4 * p1 / jnp.linalg.norm(p1, axis=-1, keepdims=True) x2 = jnp.concatenate( [4. * random.uniform(keys[2], shape=(2, )), jnp.zeros((1, ))], axis=-1) p2 = jnp.concatenate([ 4. * jnp.pi / 180. * random.uniform(keys[3], shape=(2, ), minval=-1, maxval=1), jnp.ones((1, )) ], axis=-1) p2 = 4 * p2 / jnp.linalg.norm(p2, axis=-1, keepdims=True) t1 = random.uniform(keys[4], shape=(10000, )) t2 = random.uniform(keys[5], shape=(10000, )) u1 = x1 + t1[:, None] * p1 u2 = x2 + t2[:, None] * p2 gamma = jnp.linalg.norm(u1 - u2, axis=1)**2 plt.hist(gamma.flatten(), bins=100, density=True, label='histogram') hist, bins = jnp.histogram(gamma.flatten(), density=True, bins=100) bins = jnp.linspace(bins.min(), bins.max(), 50) gamma = 0.5 * (bins[:-1] + bins[1:]) w = tomo_weight(bins, x1, x2, p1, p2) plt.plot(gamma, w, label='analytic') w_ref = tomo_weight_ref(bins, x1, x2, p1, p2) _w_ref = _tomo_weight_ref(gamma, x1, x2, p1, p2) # w_ref = tomo_weight_dimensionless_ref(bins, x1,x2,p1,p2) plt.plot(gamma, w_ref, label='analytic ref') plt.legend() plt.savefig( '/home/albert/git/jaxns/debug_figs/pdf_fig{:03d}.png'.format(i)) plt.close('all') plt.plot(gamma, jnp.cumsum(w), label='analytic') # w_ref = tomo_weight_dimensionless_ref(bins, x1,x2,p1,p2) plt.plot(gamma, jnp.cumsum(w_ref), label='analytic ref') plt.legend() plt.savefig( '/home/albert/git/jaxns/debug_figs/cdf_fig{:03d}.png'.format(i)) plt.close('all')
def hist1bin(z, w): return jnp.histogram(z, bins=zedges, weights=w)[0]
for k in tqdm.trange(nMC): # Solve for xi tfc.basisClass.w = np.array(2. * onp.random.rand(*tfc.basisClass.w.shape) - 1.) tfc.basisClass.b = np.array(2. * onp.random.rand(*tfc.basisClass.b.shape) - 1.) xi = LS() # Calculate the error ur = real(*xTest) ue = u(xi, *xTest) err = ur - ue testErr[k] = np.max(np.abs(err)) p1 = MakePlot('Maximum Error', 'Number of Occurances') hist, binEdge = np.histogram(np.log10(testErr), bins=20) p1.ax[0].hist(testErr, bins=10**binEdge, color=(76. / 256., 0., 153. / 256.), edgecolor='black', zorder=20) p1.ax[0].set_xscale('log') p1.ax[0].xaxis.set_major_locator(plt.LogLocator(base=10, numticks=10)) p1.ax[0].locator_params(axis='both', tight=True) p1.ax[0].grid(True, which='both') [line.set_zorder(0) for line in p1.ax[0].lines] mTicks = p1.ax[0].xaxis.get_minor_ticks() p1.PartScreen(11, 8) p1.show()