Beispiel #1
0
    def fit_nonparametric_nonlinearity(self, nbins=50, w=None):

        if w is None:
            if self.w_spl is not None:
                w = self.w_spl.flatten()
            elif self.w_mle is not None:
                w = self.w_mle.flatten()
            elif self.w_sta is not None:
                w = self.w_sta.flatten()
        else:
            w = jnp.array(w)

        X = self.X
        X = X.reshape(X.shape[0], -1)
        y = self.y

        output_raw = X @ uvec(w)
        output_spk = X[y != 0] @ uvec(w)

        hist_raw, bins = jnp.histogram(output_raw, bins=nbins, density=True)
        hist_spk, _ = jnp.histogram(output_spk, bins=bins, density=True)

        mask = ~(hist_raw == 0)

        yy0 = hist_spk[mask] / hist_raw[mask]

        self.nl_bins = bins[1:]
        self.fnl_nonparametric = interp1d(bins[1:][mask], yy0)
Beispiel #2
0
def get_bins_and_bincounts(samples, normalized=False):
    """take in samples, create a common set of bins, and compute the counts count(x in bin)
    for each bin and each sample x.
    Parameters
    ------------
    samples : np.array of shape (n,) or shape (k, n).
    - If shape (n,): interpreted as a set of n scalar-valued samples.
    - If shape (k, n): interpreted as k sets of n scalar-valued samples.

    Returns
    --------
    probabilities :
    bins :
    """
    nr_samples = np.prod(samples.shape)
    nr_bins = np.log2(nr_samples)
    nr_bins = int(max(nr_bins, 5))

    lims = [np.min(samples), np.max(samples)]
    bins = np.linspace(*lims, num=nr_bins)

    if samples.ndim == 2:
        out = np.asarray([
            np.histogram(x, bins=bins, density=normalized)[0] for x in samples
        ])
        return out, bins
    elif samples.ndim == 1:
        return np.histogram(samples, bins=bins, density=normalized)[0], bins
    else:
        raise ValueError(
            f"Input must have shape (n,) or shape (k,n). Instead received shape {samples.shape}"
        )
Beispiel #3
0
def contrast(image, factor):
    """
    Equivalent of PIL Contrast.
    Args:
        image: image tensor
        factor: float factor

    Returns:
        Augmented image
    """
    has_alpha = image.shape[-1] == 4
    alpha = None

    if has_alpha:
        image, alpha = image[:, :, :3], image[:, :, -1:]

    degenerate = rgb_to_grayscale(image)
    # Cast before calling tf.histogram.
    degenerate = degenerate.astype('int32')

    # Compute the grayscale histogram, then compute the mean pixel value,
    # and create a constant image size of that value.  Use that as the
    # blending degenerate target of the original image.
    hist, _ = jnp.histogram(degenerate, bins=256, range=(0, 255))
    mean = jnp.sum(hist.astype('float32')) / 256.0
    degenerate = jnp.ones_like(degenerate, dtype='float32') * mean
    degenerate = jnp.clip(degenerate, 0.0, 255.0)
    degenerate = grayscale_to_rgb(degenerate).astype(image.dtype)
    degenerate = blend(degenerate, image, factor)

    if has_alpha:
        return jnp.concatenate([degenerate, alpha], axis=-1)
    return degenerate
Beispiel #4
0
    def scale_channel(img):
        """
        Scale the data in the channel to implement equalize.
        Args:
            img: channel to scale.

        Returns:
            scaled channel
        """
        # im = im[:, :, c].astype('int32')
        img = img.astype('int32')
        # Compute the histogram of the image channel.
        histo = jnp.histogram(img, bins=255, range=(0, 255))[0]

        last_nonzero = jnp.argmax(histo[::-1] > 0)  # jnp.nonzero(histo)[0][-1]
        step = (jnp.sum(histo) - jnp.take(histo[::-1], last_nonzero)) // 255

        # if test_agains_original:
        #     # For the purposes of computing the step, filter out the nonzeros.
        #     nonzero = jnp.nonzero(histo)
        #     nonzero_histo = jnp.reshape(jnp.take(histo, nonzero), [-1])
        #     original_step = (jnp.sum(nonzero_histo) - nonzero_histo[-1]) // 255
        #     assert step == original_step

        # If step is zero, return the original image.  Otherwise, build
        # lut from the full histogram and step and then index from it.
        return jnp.where(step == 0,
                         img.astype('uint8'),
                         jnp.take(build_lut(histo, step), img).astype('uint8'))
def draw_uniform(samples, bins, desired_size):
    """
    Draw uniform set of samples


    """
    hist, bin_edges = np.histogram(samples, bins=bins)
    avg_nb = int(desired_size / float(bins))
    numbers = np.repeat(avg_nb, bins)
    for j in range(4):
        numbers[hist <= numbers] = hist[hist <= numbers]
        nb_rest = desired_size - np.sum(numbers[hist <= numbers])  # * bins
        avg_nb = round(nb_rest / np.sum(hist > numbers))
        numbers[hist > numbers] = avg_nb

    result = []
    count = 0
    for i in range(bin_edges.size - 1):
        ind = samples >= bin_edges[i]
        ind &= samples <= bin_edges[i + 1]
        if ind.sum() > 0:
            positions = np.where(ind)[0]
            nb = min([numbers[i], ind.sum()])
            result.append(jax.random.choice(positions, nb, replace=False))

    return np.concatenate(result)
Beispiel #6
0
    def scale_channel(img):
        """
        Scale the data in the channel to implement equalize.
        Args:
            img: channel to scale.

        Returns:
            scaled channel
        """
        # im = im[:, :, c].astype('int32')
        img = img.astype('int32')
        # Compute the histogram of the image channel.
        histo = jnp.histogram(img, bins=255, range=(0, 255))[0]

        last_nonzero = jnp.argmax(histo[::-1] > 0)  # jnp.nonzero(histo)[0][-1]
        step = (jnp.sum(histo) - jnp.take(histo[::-1], last_nonzero)) // 255

        # If step is zero, return the original image.  Otherwise, build
        # lut from the full histogram and step and then index from it.
        return jnp.where(step == 0,
                         img.astype('uint8'),
                         jnp.take(build_lut(histo, step), img).astype('uint8'))
def log_tomographic_weight_function_stochastic(key, u, x1, p1, x2, p2):
    """
    int w(x) f(x) dx = sum_i w(dx * i) f(dx * i) dx
    where,
    int w(x) dx = sum_i w(dx * i) dx = 1
    Args:
        key:
        u:
        x1: [N, 3]
        p1: [N, 3]
        x2: [M, 3]
        p2: [M, 3]

    Returns:
        w(dx*i) dx / sum_i w(dx * i) dx
        [N, M] shaped

    """
    n = u.size**2
    N = x1.shape[0]
    M = x2.shape[0]
    t1 = random.uniform(key, shape=(n, N, 1))
    t2 = random.uniform(key, shape=(n, M, 1))
    # L, N, M
    norm_squared = vmap(squared_norm)(x1 + t1 * p1, x2 + t2 * p2)
    bins = jnp.concatenate([u, u[-1:] + u[-1] - u[-2]])
    # N*M, U
    hist = vmap(lambda x: jnp.histogram(x, bins)[0])(jnp.reshape(
        norm_squared, (n, -1)).T)
    # N,M,U
    hist = jnp.reshape(hist, (x1.shape[0], x2.shape[0], u.size))
    log_hist = jnp.log(hist)
    log_du = jnp.diff(bins)
    log_w = log_hist + log_du
    # N,M,U
    log_w = log_w - logsumexp(log_w, axis=-1, keepdims=True)
    log_w = jnp.where(hist == 0., -jnp.inf, log_w)
    return log_w
Beispiel #8
0
def histogram_entropy(data, nbins: int = 10):
    """Calculates the histogram entropy of 1D data.
    This function uses the histogram and then calculates
    the entropy. Does the miller-maddow correction
    
    Parameters
    ----------
    data : np.ndarray, (n_samples,)
        the input data for the entropy
    
    base : int, default=2
        the log base for the calculation.
    
    Returns
    -------
    S : float
        the entropy"""

    # get histogram counts and bin edges
    counts, bin_edges = np.histogram(data, bins=nbins, density=False)

    # get bin centers and sizes
    bin_centers = np.mean(np.vstack((bin_edges[0:-1], bin_edges[1:])), axis=0)

    # get difference between the bins
    delta = bin_centers[3] - bin_centers[2]

    # normalize counts (density)
    pk = 1.0 * np.array(counts) / np.sum(counts)

    # calculate the entropy
    S = univariate_entropy(pk)

    # Miller Maddow Correction
    correction = 0.5 * (np.sum(counts > 0) - 1) / counts.sum()

    return S + correction + np.log2(delta)
    def get_marginals(self,
                      accepted_parameters=None,
                      ranges=None,
                      gridsize=None,
                      smoothing=None):
        """ Creates the 1D and 2D marginal distribution list for plotting

        Using list of parameter values (accepted by the ABC) an approximate set
        of marginal distributions for plotting are created based on
        histogramming the points. Smoothing can be performed on the histogram
        to avoid undersampling artefacts.

        For every parameter the full distribution is summed over every other
        parameter to get the 1D marginals and for every combination the 2D
        marginals are calculated by summing over the remaining parameters. The
        list is made up of a list of n_params lists which contain n_columns
        number of objects.

        Parameters
        ----------
        accepted_parameters : float(any, n_params) or None, default=None
            An array of all accepted parameter values. If None, the accepted
            parameters from the `parameters` class attribute are used
        ranges : list or None, default=None
            A list of arrays containing the bin centres for the marginal
            distribution obtained by histogramming for each parameter. If None
            the ranges are constructed from the ranges of the prior
            distribution.
        gridsize : list or None, default=None
            The number of grid points to evaluate the marginal distribution on
            for each parameter. This needs to be set if ranges is passed (and
            different from the gridsize set on initialisation)
        smoothing : float or None, default=None
            A Gaussian smoothing for the marginal distributions. Smoothing not
            done if smoothing is None

        Returns
        -------
        list of lists:
            The 1D and 2D marginal distributions for each parameter (of pair)
        """
        if accepted_parameters is None:
            accepted_parameters = self.parameters.accepted
        if ranges is None:
            ranges = [
                np.hstack([range, np.array([range[1] - range[0]])]) -
                (range[1] - range[0]) / 2 for range in self.ranges
            ]
        if gridsize is None:
            gridsize = self.gridsize
        if smoothing is not None:

            def smooth(x):
                return gaussian_filter(x, smoothing, mode="mirror")
        else:

            def smooth(x):
                return x

        marginals = []
        for row in range(self.n_params):
            marginals.append([])
            for column in range(self.n_params):
                if column == row:
                    marginals[row].append(
                        np.array([
                            smooth(
                                np.histogram(parameters[:, column],
                                             bins=ranges[column],
                                             density=True)[0])
                            for parameters in accepted_parameters
                        ]))
                elif column < row:
                    marginals[row].append(
                        np.array([
                            smooth(
                                np.histogramdd(
                                    parameters[:, [column, row]],
                                    bins=[ranges[column], ranges[row]],
                                    density=True)[0])
                            for parameters in accepted_parameters
                        ]))
        return marginals
Beispiel #10
0
        grads).item()
    grads_single_mean, grads_single_var = jnp.mean(
        grads[:, 0]).item(), jnp.var(grads[:, 0]).item()
    grads_norm_mean, grads_norm_var = jnp.mean(grad_norms).item(), jnp.var(
        grad_norms).item()

    logging_output = OrderedDict(grad_component_all_mean=grads_all_mean,
                                 grad_component_all_var=grads_all_var,
                                 grad_component_single_mean=grads_single_mean,
                                 grad_component_single_var=grads_single_var,
                                 grad_norm_mean=grads_norm_mean,
                                 grad_norm_var=grads_norm_var)

    expmgr.log(step=n_layers, logging_output=logging_output)

    wandb.log(dict(
        grad_component_all=wandb.Histogram(
            np_histogram=jnp.histogram(grads, bins=64, density=True)),
        grad_component_single=wandb.Histogram(
            np_histogram=jnp.histogram(grads[:, 0], bins=64, density=True)),
        grad_norm=wandb.Histogram(
            np_histogram=jnp.histogram(grad_norms, bins=64, density=True))),
              step=n_layers)

    suffix = f'Q{n_qubits}L{n_layers}R{rot_axis}BS{block_size}_g{g}h{h}'
    expmgr.save_array(f'params_{suffix}.npy', params)
    expmgr.save_array(f'grads_{suffix}.npy', grads)

    del params, grads
    gc.collect()
Beispiel #11
0
        forward and inverse transformation
    
    Examples
    --------
    >>> # single set of parameters
    >>> X_transform, params = get_params(x_samples, 10, 1000)
    
    >>> # example with multiple dimensions
    >>> multi_dims = jax.vmap(get_params, in_axes=(0, None, None))
    >>> X_transform, params = multi_dims(X, 10, 1000)
    """
    # get number of samples
    n_samples = np.shape(X)[0]

    # get histogram counts and bin edges
    counts, bin_edges = np.histogram(X, bins=nbins)

    # add regularization
    counts = np.array(counts) + alpha

    # get bin centers and sizes
    bin_centers = np.mean(np.vstack((bin_edges[0:-1], bin_edges[1:])), axis=0)
    bin_size = bin_edges[2] - bin_edges[1]

    # =================================
    # PDF Estimation
    # =================================
    # pdf support
    pdf_support = np.hstack(
        (bin_centers[0] - bin_size, bin_centers, bin_centers[-1] + bin_size))
    # empirical PDF
Beispiel #12
0
def histogram(a, bins=10, range=None, weights=None, density=None):
  if isinstance(a, JaxArray): a = a.value
  if isinstance(weights, JaxArray): weights = weights.value
  hist, bin_edges = jnp.histogram(a=a, bins=bins, range=range, weights=weights, density=density)
  return JaxArray(hist), JaxArray(bin_edges)
Beispiel #13
0
def _entropy(v, uniq):
  uniq = jnp.concatenate([uniq, jnp.array([jnp.inf])], axis=0)
  hist, _ = jnp.histogram(v, bins=uniq)
  hist = hist / jnp.sum(hist)
  entropy = -jnp.sum(hist * jnp.log2(hist))
  return entropy
Beispiel #14
0
    def test_sorted_piecewise_constant_pdf_train_mode(self):
        """Test that piecewise-constant sampling reproduces its distribution."""
        batch_size = 4
        num_bins = 16
        num_samples = 1000000
        precision = 1e5
        rng = random.PRNGKey(20202020)

        # Generate a series of random PDFs to sample from.
        data = []
        for _ in range(batch_size):
            rng, key = random.split(rng)
            # Randomly initialize the distances between bins.
            # We're rolling our own fixed precision here to make cumsum exact.
            bins_delta = jnp.round(precision * jnp.exp(
                random.uniform(
                    key, shape=(num_bins + 1, ), minval=-3, maxval=3)))

            # Set some of the bin distances to 0.
            rng, key = random.split(rng)
            bins_delta *= random.uniform(key, shape=bins_delta.shape) < 0.9

            # Integrate the bins.
            bins = jnp.cumsum(bins_delta) / precision
            rng, key = random.split(rng)
            bins += random.normal(key) * num_bins / 2
            rng, key = random.split(rng)

            # Randomly generate weights, allowing some to be zero.
            weights = jnp.maximum(
                0,
                random.uniform(key, shape=(num_bins, ), minval=-0.5,
                               maxval=1.))
            gt_hist = weights / weights.sum()
            data.append((bins, weights, gt_hist))

        # Tack on an "all zeros" weight matrix, which is a common cause of NaNs.
        weights = jnp.zeros_like(weights)
        gt_hist = jnp.ones_like(gt_hist) / num_bins
        data.append((bins, weights, gt_hist))

        bins, weights, gt_hist = [jnp.stack(x) for x in zip(*data)]

        for randomized in [True, False]:
            rng, key = random.split(rng)
            # Draw samples from the batch of PDFs.
            samples = math.sorted_piecewise_constant_pdf(
                key,
                bins,
                weights,
                num_samples,
                randomized,
            )
            self.assertEqual(samples.shape[-1], num_samples)

            # Check that samples are sorted.
            self.assertTrue(jnp.all(samples[..., 1:] >= samples[..., :-1]))

            # Verify that each set of samples resembles the target distribution.
            for i_samples, i_bins, i_gt_hist in zip(samples, bins, gt_hist):
                i_hist = jnp.float32(jnp.histogram(i_samples,
                                                   i_bins)[0]) / num_samples
                i_gt_hist = jnp.array(i_gt_hist)

                # Merge any of the zero-span bins until there aren't any left.
                while jnp.any(i_bins[:-1] == i_bins[1:]):
                    j = int(jnp.where(i_bins[:-1] == i_bins[1:])[0][0])
                    i_hist = jnp.concatenate([
                        i_hist[:j],
                        jnp.array([i_hist[j] + i_hist[j + 1]]), i_hist[j + 2:]
                    ])
                    i_gt_hist = jnp.concatenate([
                        i_gt_hist[:j],
                        jnp.array([i_gt_hist[j] + i_gt_hist[j + 1]]),
                        i_gt_hist[j + 2:]
                    ])
                    i_bins = jnp.concatenate([i_bins[:j], i_bins[j + 1:]])

                # Angle between the two histograms in degrees.
                angle = 180 / jnp.pi * jnp.arccos(
                    jnp.minimum(
                        1.,
                        jnp.mean((i_hist * i_gt_hist) / jnp.sqrt(
                            jnp.mean(i_hist**2) * jnp.mean(i_gt_hist**2)))))
                # Jensen-Shannon divergence.
                m = (i_hist + i_gt_hist) / 2
                js_div = jnp.sum(
                    sp.special.kl_div(i_hist, m) +
                    sp.special.kl_div(i_gt_hist, m)) / 2
                self.assertLessEqual(angle, 0.5)
                self.assertLessEqual(js_div, 1e-5)
Beispiel #15
0
def test_get_Q():
    Q = get_polynomial_form()
    import pylab as plt
    from jax import jit

    @jit
    def tomo_weight_ref(gamma, x1, x2, p1, p2):
        return tomographic_weight_function(gamma, x1, x2, p1, p2, S=150)

    @jit
    def cumulative_tomo_weight_function_dimensionless(gamma, x1, x2, p1, p2):
        x12 = x1 - x2
        h = jnp.linalg.norm(x12)
        n = x12 / h
        w1 = p1 / h
        w2 = p2 / h
        gamma_prime = gamma / h**2
        return cumulative_tomographic_weight_dimensionless_function(
            gamma_prime, n, w1, w2, S=150)

    @jit
    def cumulative_tomo_weight_polynomial_dimensionless(gamma, x1, x2, p1, p2):
        x12 = x1 - x2
        h = jnp.linalg.norm(x12)
        n = x12 / h
        w1 = p1 / h
        w2 = p2 / h
        gamma_prime = gamma / h**2
        return vmap(lambda gamma_prime:
                    cumulative_tomographic_weight_dimensionless_polynomial(
                        Q, gamma_prime, n, w1, w2))(gamma_prime)
        # return jnp.exp(log_tomographic_weight_dimensionless_function(gamma_prime, n, w1, w2, S=150)) / h ** 2

    for i in range(10):
        keys = random.split(random.PRNGKey(i), 6)
        x1 = jnp.concatenate(
            [10. * random.uniform(keys[0], shape=(2, )),
             jnp.zeros((1, ))],
            axis=-1)
        p1 = jnp.concatenate([
            4. * jnp.pi / 180. *
            random.uniform(keys[1], shape=(2, ), minval=-1, maxval=1),
            jnp.ones((1, ))
        ],
                             axis=-1)
        p1 = 4 * p1 / jnp.linalg.norm(p1, axis=-1, keepdims=True)

        x2 = jnp.concatenate(
            [4. * random.uniform(keys[2], shape=(2, )),
             jnp.zeros((1, ))],
            axis=-1)
        p2 = jnp.concatenate([
            4. * jnp.pi / 180. *
            random.uniform(keys[3], shape=(2, ), minval=-1, maxval=1),
            jnp.ones((1, ))
        ],
                             axis=-1)
        p2 = 4 * p2 / jnp.linalg.norm(p2, axis=-1, keepdims=True)

        t1 = random.uniform(keys[4], shape=(10000, ))
        t2 = random.uniform(keys[5], shape=(10000, ))
        u1 = x1 + t1[:, None] * p1
        u2 = x2 + t2[:, None] * p2
        gamma = jnp.linalg.norm(u1 - u2, axis=1)**2
        plt.hist(gamma.flatten(), bins=100, density=True, label='histogram')
        hist, bins = jnp.histogram(gamma.flatten(), density=True, bins=100)
        gamma = 0.5 * (bins[:-1] + bins[1:])
        w_ref = tomo_weight_ref(bins, x1, x2, p1, p2)
        plt.plot(gamma, w_ref, label='analytic ref')
        plt.legend()
        plt.show()
        cdf_ref = cumulative_tomo_weight_function_dimensionless(
            gamma, x1, x2, p1, p2)
        cdf_poly = cumulative_tomo_weight_polynomial_dimensionless(
            gamma, x1, x2, p1, p2)
        gamma_prime = gamma / jnp.linalg.norm(x1 - x2)
        plt.plot(gamma_prime, cdf_ref, label='ref')
        plt.plot(gamma_prime, cdf_poly, label='poly')
        plt.legend()
        plt.show()
Beispiel #16
0
def test_tomographic_weight_rel_err():
    import pylab as plt
    from jax import jit

    for S in range(5, 30, 5):

        @jit
        def tomo_weight(gamma, x1, x2, p1, p2):
            return tomographic_weight_function(gamma, x1, x2, p1, p2, S=S)

        @jit
        def tomo_weight_ref(gamma, x1, x2, p1, p2):
            return tomographic_weight_function(gamma, x1, x2, p1, p2, S=150)

        rel_error = []
        for i in range(400):
            keys = random.split(random.PRNGKey(i), 6)
            x1 = jnp.concatenate(
                [4. * random.uniform(keys[0], shape=(2, )),
                 jnp.zeros((1, ))],
                axis=-1)
            p1 = jnp.concatenate([
                4. * jnp.pi / 180. *
                random.uniform(keys[1], shape=(2, ), minval=-1, maxval=1),
                jnp.ones((1, ))
            ],
                                 axis=-1)
            p1 = 4 * p1 / jnp.linalg.norm(p1, axis=-1, keepdims=True)

            x2 = jnp.concatenate(
                [4. * random.uniform(keys[2], shape=(2, )),
                 jnp.zeros((1, ))],
                axis=-1)
            p2 = jnp.concatenate([
                4. * jnp.pi / 180. *
                random.uniform(keys[3], shape=(2, ), minval=-1, maxval=1),
                jnp.ones((1, ))
            ],
                                 axis=-1)
            p2 = 4 * p2 / jnp.linalg.norm(p2, axis=-1, keepdims=True)

            # x1 = random.normal(keys[0], shape_dict=(2,))
            # p1 = random.normal(keys[1], shape_dict=(2,))
            # x2 = random.normal(keys[2], shape_dict=(2,))
            # p2 = random.normal(keys[3], shape_dict=(2,))

            t1 = random.uniform(keys[4], shape=(10000, ))
            t2 = random.uniform(keys[5], shape=(10000, ))
            u1 = x1 + t1[:, None] * p1
            u2 = x2 + t2[:, None] * p2
            gamma = jnp.linalg.norm(u1 - u2, axis=1)**2
            hist, bins = jnp.histogram(gamma.flatten(), density=True, bins=100)
            bins = jnp.linspace(bins.min(), bins.max(), 20)
            w = tomo_weight(bins, x1, x2, p1, p2)
            w_ref = tomo_weight_ref(bins, x1, x2, p1, p2)
            rel_error.append(jnp.max(jnp.abs(w - w_ref)) / jnp.max(w_ref))
        rel_error = jnp.array(rel_error)
        plt.hist(rel_error, bins='auto')
        plt.title("{} : {:.2f}|{:.2f}|{:.2f}".format(
            S, *jnp.percentile(rel_error, [5, 50, 95])))
        plt.show()
Beispiel #17
0
def safe_gaussian_kde(samples, weights):
    try:
        return gaussian_kde(samples, weights=weights, bw_method='silverman')
    except:
        hist, bin_edges = jnp.histogram(samples,weights=weights, bins='auto')
        return lambda x: hist[jnp.searchsorted(bin_edges, x)]
Beispiel #18
0
    
    grads_all_mean, grads_all_var = jnp.mean(grads).item(), jnp.var(grads).item()
    grads_single_mean, grads_single_var = jnp.mean(grads[:, 0]).item(), jnp.var(grads[:, 0]).item()
    grads_norm_mean, grads_norm_var = jnp.mean(grad_norms).item(), jnp.var(grad_norms).item()
    
    logging_output = OrderedDict(
        grad_component_all_mean=grads_all_mean,
        grad_component_all_var=grads_all_var,
        grad_component_single_mean=grads_single_mean,
        grad_component_single_var=grads_single_var,
        grad_norm_mean=grads_norm_mean,
        grad_norm_var=grads_norm_var)
    
    expmgr.log(step=n_layers, logging_output=logging_output)
    
    wandb.log(
        dict(
            grad_component_all=wandb.Histogram(np_histogram=jnp.histogram(grads, bins=64, density=True)),
            grad_component_single=wandb.Histogram(np_histogram=jnp.histogram(grads[:, 0], bins=64, density=True)),
            grad_norm=wandb.Histogram(np_histogram=jnp.histogram(grad_norms, bins=64, density=True))
        ),
        step=n_layers
    )
    
    suffix = f'Q{n_qubits}L{n_layers}R{rot_axis}BS{block_size}_g{g}h{h}'
    expmgr.save_array(f'params_{suffix}.npy', params)
    expmgr.save_array(f'grads_{suffix}.npy', grads)

    del params, grads
    gc.collect()
Beispiel #19
0
def test_tomographic_weight():
    import pylab as plt
    from jax import jit

    # @jit
    def tomo_weight(gamma, x1, x2, p1, p2):
        return tomographic_weight_function(gamma, x1, x2, p1, p2, S=10)

    @jit
    def tomo_weight_ref(gamma, x1, x2, p1, p2):
        return tomographic_weight_function(gamma, x1, x2, p1, p2, S=150)

    @jit
    def _tomo_weight_ref(gamma, x1, x2, p1, p2):
        return _tomographic_weight_function(gamma, x1, x2, p1, p2, S=150)

    @jit
    def tomo_weight_dimensionless_ref(gamma, x1, x2, p1, p2):
        x12 = x1 - x2
        h = jnp.linalg.norm(x12)
        n = x12 / h
        w1 = p1 / h
        w2 = p2 / h
        gamma_prime = gamma / h**2
        return jnp.exp(
            log_tomographic_weight_dimensionless_function(
                gamma_prime, n, w1, w2, S=150)) / h**2

    for i in range(100):
        keys = random.split(random.PRNGKey(i), 6)
        x1 = jnp.concatenate(
            [10. * random.uniform(keys[0], shape=(2, )),
             jnp.zeros((1, ))],
            axis=-1)
        p1 = jnp.concatenate([
            4. * jnp.pi / 180. *
            random.uniform(keys[1], shape=(2, ), minval=-1, maxval=1),
            jnp.ones((1, ))
        ],
                             axis=-1)
        p1 = 4 * p1 / jnp.linalg.norm(p1, axis=-1, keepdims=True)

        x2 = jnp.concatenate(
            [4. * random.uniform(keys[2], shape=(2, )),
             jnp.zeros((1, ))],
            axis=-1)
        p2 = jnp.concatenate([
            4. * jnp.pi / 180. *
            random.uniform(keys[3], shape=(2, ), minval=-1, maxval=1),
            jnp.ones((1, ))
        ],
                             axis=-1)
        p2 = 4 * p2 / jnp.linalg.norm(p2, axis=-1, keepdims=True)

        t1 = random.uniform(keys[4], shape=(10000, ))
        t2 = random.uniform(keys[5], shape=(10000, ))
        u1 = x1 + t1[:, None] * p1
        u2 = x2 + t2[:, None] * p2
        gamma = jnp.linalg.norm(u1 - u2, axis=1)**2
        plt.hist(gamma.flatten(), bins=100, density=True, label='histogram')
        hist, bins = jnp.histogram(gamma.flatten(), density=True, bins=100)
        bins = jnp.linspace(bins.min(), bins.max(), 50)
        gamma = 0.5 * (bins[:-1] + bins[1:])
        w = tomo_weight(bins, x1, x2, p1, p2)
        plt.plot(gamma, w, label='analytic')
        w_ref = tomo_weight_ref(bins, x1, x2, p1, p2)
        _w_ref = _tomo_weight_ref(gamma, x1, x2, p1, p2)
        # w_ref = tomo_weight_dimensionless_ref(bins, x1,x2,p1,p2)
        plt.plot(gamma, w_ref, label='analytic ref')
        plt.legend()
        plt.savefig(
            '/home/albert/git/jaxns/debug_figs/pdf_fig{:03d}.png'.format(i))
        plt.close('all')

        plt.plot(gamma, jnp.cumsum(w), label='analytic')
        # w_ref = tomo_weight_dimensionless_ref(bins, x1,x2,p1,p2)
        plt.plot(gamma, jnp.cumsum(w_ref), label='analytic ref')
        plt.legend()
        plt.savefig(
            '/home/albert/git/jaxns/debug_figs/cdf_fig{:03d}.png'.format(i))
        plt.close('all')
Beispiel #20
0
 def hist1bin(z, w):
     return jnp.histogram(z, bins=zedges, weights=w)[0]
Beispiel #21
0
for k in tqdm.trange(nMC):

    # Solve for xi
    tfc.basisClass.w = np.array(2. * onp.random.rand(*tfc.basisClass.w.shape) -
                                1.)
    tfc.basisClass.b = np.array(2. * onp.random.rand(*tfc.basisClass.b.shape) -
                                1.)
    xi = LS()

    # Calculate the error
    ur = real(*xTest)
    ue = u(xi, *xTest)
    err = ur - ue
    testErr[k] = np.max(np.abs(err))

p1 = MakePlot('Maximum Error', 'Number of Occurances')
hist, binEdge = np.histogram(np.log10(testErr), bins=20)
p1.ax[0].hist(testErr,
              bins=10**binEdge,
              color=(76. / 256., 0., 153. / 256.),
              edgecolor='black',
              zorder=20)
p1.ax[0].set_xscale('log')
p1.ax[0].xaxis.set_major_locator(plt.LogLocator(base=10, numticks=10))
p1.ax[0].locator_params(axis='both', tight=True)
p1.ax[0].grid(True, which='both')
[line.set_zorder(0) for line in p1.ax[0].lines]
mTicks = p1.ax[0].xaxis.get_minor_ticks()
p1.PartScreen(11, 8)
p1.show()