Beispiel #1
0
def main(args):
    rng = PRNGKey(1234)
    rng, toy_data_rng = jax.random.split(rng, 2)
    X_train, X_test, mu_true = create_toy_data(toy_data_rng, args.num_samples,
                                               args.dimensions)

    train_init, train_fetch = subsample_batchify_data(
        (X_train, ), batch_size=args.batch_size)
    test_init, test_fetch = split_batchify_data((X_test, ),
                                                batch_size=args.batch_size)

    ## Init optimizer and training algorithms
    optimizer = optimizers.Adam(args.learning_rate)

    svi = DPSVI(model,
                guide,
                optimizer,
                ELBO(),
                dp_scale=args.sigma,
                clipping_threshold=args.clip_threshold,
                d=args.dimensions,
                num_obs_total=args.num_samples)

    rng, svi_init_rng, batchifier_rng = random.split(rng, 3)
    _, batchifier_state = train_init(rng_key=batchifier_rng)
    batch = train_fetch(0, batchifier_state)
    svi_state = svi.init(svi_init_rng, *batch)

    q = args.batch_size / args.num_samples
    eps = svi.get_epsilon(args.delta, q, num_epochs=args.num_epochs)
    print("Privacy epsilon {} (for sigma: {}, delta: {}, C: {}, q: {})".format(
        eps, args.sigma, args.clip_threshold, args.delta, q))

    @jit
    def epoch_train(svi_state, batchifier_state, num_batch):
        def body_fn(i, val):
            svi_state, loss = val
            batch = train_fetch(i, batchifier_state)
            svi_state, batch_loss = svi.update(svi_state, *batch)
            loss += batch_loss / (args.num_samples * num_batch)
            return svi_state, loss

        return lax.fori_loop(0, num_batch, body_fn, (svi_state, 0.))

    @jit
    def eval_test(svi_state, batchifier_state, num_batch):
        def body_fn(i, loss_sum):
            batch = test_fetch(i, batchifier_state)
            loss = svi.evaluate(svi_state, *batch)
            loss_sum += loss / (args.num_samples * num_batch)

            return loss_sum

        return lax.fori_loop(0, num_batch, body_fn, 0.)

## Train model

    for i in range(args.num_epochs):
        t_start = time.time()
        rng, data_fetch_rng = random.split(rng, 2)

        num_train_batches, train_batchifier_state = train_init(
            rng_key=data_fetch_rng)
        svi_state, train_loss = epoch_train(svi_state, train_batchifier_state,
                                            num_train_batches)
        train_loss.block_until_ready()
        t_end = time.time()

        if (i % (args.num_epochs // 10) == 0):
            rng, test_fetch_rng = random.split(rng, 2)
            num_test_batches, test_batchifier_state = test_init(
                rng_key=test_fetch_rng)
            test_loss = eval_test(svi_state, test_batchifier_state,
                                  num_test_batches)

            print(
                "Epoch {}: loss = {} (on training set: {}) ({:.2f} s.)".format(
                    i, test_loss, train_loss, t_end - t_start))

    params = svi.get_params(svi_state)
    mu_loc = params['mu_loc']
    mu_std = jnp.exp(params['mu_std_log'])
    print("### expected: {}".format(mu_true))
    print("### svi result\nmu_loc: {}\nerror: {}\nmu_std: {}".format(
        mu_loc, jnp.linalg.norm(mu_loc - mu_true), mu_std))
    mu_loc, mu_std = analytical_solution(X_train)
    print("### analytical solution\nmu_loc: {}\nerror: {}\nmu_std: {}".format(
        mu_loc, jnp.linalg.norm(mu_loc - mu_true), mu_std))
    mu_loc, mu_std = ml_estimate(X_train)
    print("### ml estimate\nmu_loc: {}\nerror: {}\nmu_std: {}".format(
        mu_loc, jnp.linalg.norm(mu_loc - mu_true), mu_std))
Beispiel #2
0
 def dz(self, samples, noise_scale=0.4, **args):
     '''Daily deaths with observation noise'''
     dz_mean = self.dz_mean(samples, **args)
     dz = dist.Normal(dz_mean, noise_scale * dz_mean).sample(PRNGKey(10))
     return dz
Beispiel #3
0

def on_slider_update(change):
    global lr, momentum
    if change["owner"].description == "log(lr)":
        lr = 10**change["new"]
    #  change["owner"].description = "LR : " + str(round(lr, 3))

    elif change["owner"].description == "momentum":
        momentum = change["new"]

    descend_and_update(nn_loss_xy)


X, Y, X_test = get_data(seed=5)
key = PRNGKey(0)
x_tr, x_te = split(X, key)
x_tr = x_tr[:, [1]]
x_te = x_te[:, [1]]
y_tr, y_te = split(Y, key)

n_layers = 3
n_neurons = 3

nn_init_fn, nn_apply_fn = stax.serial(
    *chain(*[(Tanh, Dense(n_neurons)) for _ in range(n_layers)]),
    Dense(1),
)

out_shape, init_params = nn_init_fn(PRNGKey(9), x_tr.shape[1:])
Beispiel #4
0
    # *** MLP configuration ***
    n_hidden = 6
    n_in, n_out = 1, 1
    n_params = (n_in + 1) * n_hidden + (n_hidden + 1) * n_out
    fwd_mlp = partial(mlp, n_hidden=n_hidden)
    # vectorised for multiple observations
    fwd_mlp_obs = jax.vmap(fwd_mlp, in_axes=[None, 0])
    # vectorised for multiple weights
    fwd_mlp_weights = jax.vmap(fwd_mlp, in_axes=[1, None])
    # vectorised for multiple observations and weights
    fwd_mlp_obs_weights = jax.vmap(fwd_mlp_obs, in_axes=[0, None])

    # *** Generating training and test data ***
    n_obs = 200
    key = PRNGKey(314)
    key_sample_obs, key_weights = split(key, 2)
    xmin, xmax = -3, 3
    sigma_y = 3.0
    x, y = sample_observations(key_sample_obs, f, n_obs, xmin, xmax, x_noise=0, y_noise=sigma_y)
    xtest = jnp.linspace(x.min(), x.max(), n_obs)

    # *** MLP Training with EKF ***

    W0 = normal(key_weights, (n_params,)) * 1 # initial random guess
    Q = jnp.eye(n_params) * 1e-4; # parameters do not change
    R = jnp.eye(1) * sigma_y**2; # observation noise is fixed
    Vinit = jnp.eye(n_params) * 100 # vague prior

    ekf = ds.ExtendedKalmanFilter(fz, fwd_mlp, Q, R)
    ekf_mu_hist, ekf_Sigma_hist = ekf.filter(W0, y[:, None], x[:, None], Vinit)
Beispiel #5
0
    def body(v):
        t, rmse, clusters = v
        jac_fn = jacrev(partial(cost_sp, features))
        hes_fn = jacfwd(jac_fn)

        new_cluster = clusters - jac_fn(clusters) / hes_fn(clusters).sum(
            (0, 1))
        rmse = ((new_cluster - clusters)**2).sum()
        return t + 1, rmse, new_cluster

    t, rmse, clusters = while_loop(cond, sparsify(body),
                                   (0, float("inf"), clusters))
    return clusters


if __name__ == '__main__':
    data_key, sparse_key = split(PRNGKey(8))
    num_datapoints = 100
    num_features = 7
    sparsity = .5
    num_clusters = 5
    max_iter = 10
    features = normal(data_key, (num_datapoints, num_features))
    features = features * bernoulli(sparse_key, sparsity,
                                    (num_datapoints, num_features))
    clusters = features[:num_clusters]
    features = sparse.BCOO.fromdense(features)
    new_cluster = kmeans(max_iter, clusters, features)
    print(new_cluster)
@jit
def accuracy(params, batch):
    logits = predict(params, batch["X"])
    # logits = log_softmax(logits)
    return jnp.mean(jnp.argmax(logits, -1) == batch["y"])


@jit
def logprior(params):
    # Spherical Gaussian prior
    leaves_of_params = tree_leaves(params)
    return sum(tree_map(lambda p: jnp.sum(jax.scipy.stats.norm.logpdf(p, scale=l2_regularizer)), leaves_of_params))


key = PRNGKey(42)
data_key, init_key, opt_key, sample_key, warmstart_key = split(key, 5)

n_train, n_test = 20000, 1000
train_ds, test_ds = load_mnist(data_key, n_train, n_test)
data = (train_ds["X"], train_ds["y"])
n_features = train_ds["X"].shape[1]
n_classes = 10

# model
init_random_params, predict = stax.serial(
    Dense(200), Relu,
    Dense(50), Relu,
    Dense(n_classes), LogSoftmax)

_, params_init_tree = init_random_params(init_key, input_shape=(-1, n_features))
Beispiel #7
0
    def init_kernel(init_params,
                    num_warmup,
                    step_size=1.0,
                    adapt_step_size=True,
                    adapt_mass_matrix=True,
                    dense_mass=False,
                    target_accept_prob=0.8,
                    trajectory_length=2*math.pi,
                    max_tree_depth=10,
                    run_warmup=True,
                    progbar=True,
                    rng=PRNGKey(0)):
        """
        Initializes the HMC sampler.

        :param init_params: Initial parameters to begin sampling. The type must
            be consistent with the input type to `potential_fn`.
        :param int num_warmup_steps: Number of warmup steps; samples generated
            during warmup are discarded.
        :param float step_size: Determines the size of a single step taken by the
            verlet integrator while computing the trajectory using Hamiltonian
            dynamics. If not specified, it will be set to 1.
        :param bool adapt_step_size: A flag to decide if we want to adapt step_size
            during warm-up phase using Dual Averaging scheme.
        :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass
            matrix during warm-up phase using Welford scheme.
        :param bool dense_mass: A flag to decide if mass matrix is dense or
            diagonal (default when ``dense_mass=False``)
        :param float target_accept_prob: Target acceptance probability for step size
            adaptation using Dual Averaging. Increasing this value will lead to a smaller
            step size, hence the sampling will be slower but more robust. Default to 0.8.
        :param float trajectory_length: Length of a MCMC trajectory for HMC. Default
            value is :math:`2\\pi`.
        :param int max_tree_depth: Max depth of the binary tree created during the doubling
            scheme of NUTS sampler. Defaults to 10.
        :param bool run_warmup: Flag to decide whether warmup is run. If ``True``,
            `init_kernel` returns an initial :data:`HMCState` that can be used to
            generate samples using MCMC. Else, returns the arguments and callable
            that does the initial adaptation.
        :param bool progbar: Whether to enable progress bar updates. Defaults to
            ``True``.
        :param bool heuristic_step_size: If ``True``, a coarse grained adjustment of
            step size is done at the beginning of each adaptation window to achieve
            `target_acceptance_prob`.
        :param jax.random.PRNGKey rng: random key to be used as the source of
            randomness.
        """
        step_size = float(step_size)
        nonlocal momentum_generator, wa_update, trajectory_len, max_treedepth
        trajectory_len = float(trajectory_length)
        max_treedepth = max_tree_depth
        z = init_params
        z_flat, unravel_fn = ravel_pytree(z)
        momentum_generator = partial(_sample_momentum, unravel_fn)

        find_reasonable_ss = partial(find_reasonable_step_size,
                                     potential_fn, kinetic_fn,
                                     momentum_generator)

        wa_init, wa_update = warmup_adapter(num_warmup,
                                            adapt_step_size=adapt_step_size,
                                            adapt_mass_matrix=adapt_mass_matrix,
                                            dense_mass=dense_mass,
                                            target_accept_prob=target_accept_prob,
                                            find_reasonable_step_size=find_reasonable_ss)

        rng_hmc, rng_wa = random.split(rng)
        wa_state = wa_init(z, rng_wa, step_size, mass_matrix_size=np.size(z_flat))
        r = momentum_generator(wa_state.mass_matrix_sqrt, rng)
        vv_state = vv_init(z, r)
        hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, 0, 0., 0.,
                             wa_state.step_size, wa_state.inverse_mass_matrix, wa_state.mass_matrix_sqrt,
                             rng_hmc)

        wa_update = jit(wa_update)
        if run_warmup:
            # JIT if progress bar updates not required
            if not progbar:
                hmc_state, _ = jit(fori_loop, static_argnums=(2,))(0, num_warmup,
                                                                   warmup_update,
                                                                   (hmc_state, wa_state))
            else:
                with tqdm.trange(num_warmup, desc='warmup') as t:
                    for i in t:
                        hmc_state, wa_state = warmup_update(i, (hmc_state, wa_state))
                        # TODO: set refresh=True when its performance issue is resolved
                        t.set_postfix_str(get_diagnostics_str(hmc_state), refresh=False)
            # Reset `i` and `mean_accept_prob` for fresh diagnostics.
            hmc_state.update(i=0, mean_accept_prob=0)
            return hmc_state
        else:
            return hmc_state, wa_state, warmup_update
Beispiel #8
0
def fit(observations,
        lens,
        num_hidden,
        num_obs,
        batch_size,
        optimizer,
        rng_key=None,
        num_epochs=1):
    '''
    Trains the HMM model with the given number of hidden states and observations via any optimizer.

    Parameters
    ----------
    observations: array(N, seq_len)
        All observation sequences

    lens : array(N, seq_len)
        Consists of the valid length of each observation sequence

    num_hidden : int
        The number of hidden state

    num_obs : int
        The number of observable events

    batch_size : int
        The number of observation sequences that will be included in each minibatch

    optimizer : jax.experimental.optimizers.Optimizer
        Optimizer that is used during training

    num_epochs : int
        The total number of iterations

    Returns
    -------
    * HMMJax
        Hidden Markov Model

    * array
      Consists of training losses
    '''
    global opt_init, opt_update, get_params

    if rng_key is None:
        rng_key = PRNGKey(0)

    rng_init, rng_iter = split(rng_key)
    params = init_random_params([num_hidden, num_obs], rng_init)
    opt_init, opt_update, get_params = optimizer
    opt_state = opt_init(params)
    itercount = itertools.count()

    def epoch_step(opt_state, key):
        def train_step(opt_state, params):
            batch, length = params
            opt_state, loss = update(next(itercount), opt_state, batch, length)
            return opt_state, loss

        batches, valid_lens = hmm_sample_minibatches(observations, lens,
                                                     batch_size, key)
        params = (batches, valid_lens)
        opt_state, losses = jax.lax.scan(train_step, opt_state, params)
        return opt_state, losses.mean()

    epochs = split(rng_iter, num_epochs)
    opt_state, losses = jax.lax.scan(epoch_step, opt_state, epochs)

    losses = losses.flatten()

    params = get_params(opt_state)
    params = HMMJax(softmax(params.trans_mat, axis=1),
                    softmax(params.obs_mat, axis=1), softmax(params.init_dist))
    return params, losses
            gmm = GMM(pi, mu, Sigma)
            gmm.fit_em(X, num_of_iters=5)
            n_success_ml += 1
        except Exception as E:
            print(str(E))
        try:
            gmm = GMM(pi, mu, Sigma)
            gmm.fit_em(X, num_of_iters=5, S=S, eta=eta)
            n_success_map += 1
        except Exception as E:
            print(str(E))
    pct_ml = n_success_ml / n_attempts
    pct_map = n_success_map / n_attempts
    return [1-pct_ml, 1-pct_map]

rng_key = PRNGKey(0)
plt.rcParams["axes.spines.right"] = False
plt.rcParams["axes.spines.top"] = False

n_comps = 3
pi = jnp.ones((n_comps, )) / n_comps
hist_ml, hist_map = [], []

test_dims = jnp.arange(10, 60, 10)
keys = split(rng_key, 10)

n_samples = 150
mu_base = jnp.array([[-1, 1], [1, -1], [3, -1]])

Sigma1_base = jnp.array([[1, -0.7], [-0.7, 1]])
Sigma2_base = jnp.array([[1, 0.7], [0.7, 1]])
Beispiel #10
0
    S2 = jnp.array([[0.3, -0.5], [-0.5, 1.3]])

    S3 = jnp.array([[0.8, 0.4], [0.4, 0.5]])

    cov_collection = jnp.array([S1, S2, S3]) / 60
    mu_collection = jnp.array([[0.3, 0.3], [0.8, 0.5], [0.3, 0.8]])

    hmm = HMM(
        trans_dist=distrax.Categorical(probs=A),
        init_dist=distrax.Categorical(probs=initial_probs),
        obs_dist=distrax.as_distribution(
            tfp.substrates.jax.distributions.MultivariateNormalFullCovariance(
                loc=mu_collection, covariance_matrix=cov_collection)))
    n_samples, seed = 50, 100
    samples_state, samples_obs = hmm_sample(hmm, n_samples, PRNGKey(seed))

    xmin, xmax = 0, 1
    ymin, ymax = 0, 1.2
    colors = ["tab:green", "tab:blue", "tab:red"]

    fig, ax = plt.subplots()
    _, color_sample = plot_2dhmm(hmm, samples_obs, samples_state, colors, ax,
                                 xmin, xmax, ymin, ymax)
    pml.savefig("hmm_lillypad_2d.pdf")

    fig, ax = plt.subplots()
    ax.step(range(n_samples),
            samples_state,
            where="post",
            c="black",
Beispiel #11
0
    def fit_sgd(self, observations, batch_size, rng_key=None, optimizer=None, num_epochs=3):
        '''
        Finds the parameters of Gaussian Mixture Model using gradient descent algorithm with the given hyperparameters.

        Parameters
        ----------
        observations : array
            The observation sequences which Bernoulli Mixture Model is trained on

        batch_size : int
            The size of the batch

        rng_key : array
            Random key of shape (2,) and dtype uint32

        optimizer : jax.experimental.optimizers.Optimizer
            Optimizer to be used

        num_epochs : int
            The number of epoch the training process takes place

        Returns
        -------
        * array
            Mean loss values found per epoch

        * array
            Mixing coefficients found per epoch

        * array
            Means of Gaussian distribution found per epoch

        * array
            Covariances of Gaussian distribution found per epoch

        * array
            Responsibilites found per epoch
        '''
        global opt_init, opt_update, get_params

        if rng_key is None:
            rng_key = PRNGKey(0)

        if optimizer is not None:
            opt_init, opt_update, get_params = optimizer

        opt_state = opt_init((softmax(self.mixing_coeffs), self.means, self.covariances))
        itercount = itertools.count()

        def epoch_step(opt_state, key):

            def train_step(opt_state, batch):
                opt_state, loss = self.update(next(itercount), opt_state, batch)
                return opt_state, loss

            batches = self._make_minibatches(observations, batch_size, key)
            opt_state, losses = scan(train_step, opt_state, batches)

            params = get_params(opt_state)
            mixing_coeffs, means, untransormed_cov = params
            cov_matrix = vmap(self._transform_to_covariance_matrix)(untransormed_cov)
            self.model = (softmax(mixing_coeffs), means, cov_matrix)
            responsibilities = self.responsibilities(observations)

            return opt_state, (losses.mean(), *params, responsibilities)

        epochs = split(rng_key, num_epochs)
        opt_state, history = scan(epoch_step, opt_state, epochs)

        params = get_params(opt_state)
        mixing_coeffs, means, untransormed_cov = params
        cov_matrix = vmap(self._transform_to_covariance_matrix)(untransormed_cov)
        self.model = (softmax(mixing_coeffs), means, cov_matrix)

        return history
Beispiel #12
0
def test_pixelcnn():
    loss, _ = PixelCNNPP(nr_filters=1, nr_resnet=1)
    images = jnp.zeros((2, 16, 16, 3), image_dtype)
    opt = optimizers.Adam()
    state = opt.init(loss.init_parameters(images, key=PRNGKey(0)))
Beispiel #13
0
def main(args):
    encoder_init, encode = encoder(args.hidden_dim, args.z_dim)
    decoder_init, decode = decoder(args.hidden_dim, 28 * 28)
    opt_init, opt_update = optimizers.adam(args.learning_rate)
    svi_init, svi_update, svi_eval = svi(model,
                                         guide,
                                         elbo,
                                         opt_init,
                                         opt_update,
                                         encode=encode,
                                         decode=decode,
                                         z_dim=args.z_dim)
    svi_update = jit(svi_update)
    rng = PRNGKey(0)
    train_init, train_fetch = load_dataset(MNIST,
                                           batch_size=args.batch_size,
                                           split='train')
    test_init, test_fetch = load_dataset(MNIST,
                                         batch_size=args.batch_size,
                                         split='test')
    num_train, train_idx = train_init()
    _, encoder_params = encoder_init((args.batch_size, 28 * 28))
    _, decoder_params = decoder_init((args.batch_size, args.z_dim))
    params = {'encoder': encoder_params, 'decoder': decoder_params}
    rng, sample_batch = binarize(rng, train_fetch(0, train_idx)[0])
    opt_state = svi_init(rng, (sample_batch, ), (sample_batch, ), params)
    rng, = random.split(rng, 1)

    @jit
    def epoch_train(opt_state, rng):
        def body_fn(i, val):
            loss_sum, opt_state, rng = val
            rng, batch = binarize(rng, train_fetch(i, train_idx)[0])
            loss, opt_state, rng = svi_update(
                i,
                opt_state,
                rng,
                (batch, ),
                (batch, ),
            )
            loss_sum += loss
            return loss_sum, opt_state, rng

        return lax.fori_loop(0, num_train, body_fn, (0., opt_state, rng))

    @jit
    def eval_test(opt_state, rng):
        def body_fun(i, val):
            loss_sum, rng = val
            rng, = random.split(rng, 1)
            rng, batch = binarize(rng, test_fetch(i, test_idx)[0])
            loss = svi_eval(opt_state, rng, (batch, ), (batch, )) / len(batch)
            loss_sum += loss
            return loss_sum, rng

        loss, _ = lax.fori_loop(0, num_test, body_fun, (0., rng))
        loss = loss / num_test
        return loss

    def reconstruct_img(epoch):
        img = test_fetch(0, test_idx)[0][0]
        plt.imsave(os.path.join(RESULTS_DIR,
                                'original_epoch={}.png'.format(epoch)),
                   img,
                   cmap='gray')
        _, test_sample = binarize(rng, img)
        params = optimizers.get_params(opt_state)
        z_mean, z_var = encode(params['encoder'], test_sample.reshape([1, -1]))
        z = dist.norm(z_mean, z_var).rvs(random_state=rng)
        img_loc = decode(params['decoder'], z).reshape([28, 28])
        plt.imsave(os.path.join(RESULTS_DIR,
                                'recons_epoch={}.png'.format(epoch)),
                   img_loc,
                   cmap='gray')

    for i in range(args.num_epochs):
        t_start = time.time()
        num_train, train_idx = train_init()
        _, opt_state, rng = epoch_train(opt_state, rng)
        rng, rng_test = random.split(rng, 2)
        num_test, test_idx = test_init()
        test_loss = eval_test(opt_state, rng_test)
        reconstruct_img(i)
        print("Epoch {}: loss = {} ({:.2f} s.)".format(i, test_loss,
                                                       time.time() - t_start))
Beispiel #14
0
def test_Conv1DTranspose_runs(channels, filter_shape, padding, strides,
                              input_shape):
    conv = Conv1D(channels, filter_shape, strides=strides, padding=padding)
    inputs = random_inputs(input_shape)
    params = conv.init_parameters(PRNGKey(0), inputs)
    conv.apply(params, inputs)
Beispiel #15
0
def main(args):
    N = args.num_samples
    k = args.num_components
    d = args.dimensions

    rng = PRNGKey(1234)
    rng, toy_data_rng = jax.random.split(rng, 2)

    X_train, X_test, latent_vals = create_toy_data(toy_data_rng, N, d)
    train_init, train_fetch = subsample_batchify_data((X_train,), batch_size=args.batch_size)
    test_init, test_fetch = split_batchify_data((X_test,), batch_size=args.batch_size)

    ## Init optimizer and training algorithms
    optimizer = optimizers.Adam(args.learning_rate)

    # note(lumip): fix the parameters in the models
    def fix_params(model_fn, k):
        def fixed_params_fn(obs, **kwargs):
            return model_fn(k, obs, **kwargs)
        return fixed_params_fn

    model_fixed = fix_params(model, k)
    guide_fixed = fix_params(guide, k)

    svi = DPSVI(
        model_fixed, guide_fixed, optimizer, ELBO(),
        dp_scale=0.01,  clipping_threshold=20., num_obs_total=args.num_samples
    )

    rng, svi_init_rng, fetch_rng = random.split(rng, 3)
    _, batchifier_state = train_init(fetch_rng)
    batch = train_fetch(0, batchifier_state)
    svi_state = svi.init(svi_init_rng, *batch)

    @jit
    def epoch_train(svi_state, data_idx, num_batch):
        def body_fn(i, val):
            svi_state, loss = val
            batch = train_fetch(i, batchifier_state)
            svi_state, batch_loss = svi.update(
                svi_state, *batch
            )
            loss += batch_loss / (args.num_samples * num_batch)
            return svi_state, loss

        return lax.fori_loop(0, num_batch, body_fn, (svi_state, 0.))

    @jit
    def eval_test(svi_state, batchifier_state, num_batch):
        def body_fn(i, loss_sum):
            batch = test_fetch(i, batchifier_state)
            loss = svi.evaluate(svi_state, *batch)
            loss_sum += loss / (args.num_samples * num_batch)
            return loss_sum

        return lax.fori_loop(0, num_batch, body_fn, 0.)

	## Train model
    for i in range(args.num_epochs):
        t_start = time.time()
        rng, data_fetch_rng = random.split(rng, 2)

        num_train_batches, train_batchifier_state = train_init(rng_key=data_fetch_rng)
        svi_state, train_loss = epoch_train(
            svi_state, train_batchifier_state, num_train_batches
        )
        train_loss.block_until_ready()
        t_end = time.time()

        if i % 100 == 0:
            rng, test_fetch_rng = random.split(rng, 2)
            num_test_batches, test_batchifier_state = test_init(rng_key=test_fetch_rng)
            test_loss = eval_test(
                svi_state, test_batchifier_state, num_test_batches
            )

            print("Epoch {}: loss = {} (on training set = {}) ({:.2f} s.)".format(
                    i, test_loss, train_loss, t_end - t_start
                ))

    params = svi.get_params(svi_state)
    print(params)
    posterior_modes = params['mus_loc']
    posterior_pis = dist.Dirichlet(jnp.exp(params['alpha_log'])).mean
    print("MAP estimate of mixture weights: {}".format(posterior_pis))
    print("MAP estimate of mixture modes  : {}".format(posterior_modes))

    acc = compute_assignment_accuracy(
        X_test, latent_vals[1], latent_vals[2], posterior_modes, posterior_pis
    )
    print("assignment accuracy: {}".format(acc))
               and the output (label) kets for given `params`
            
       """
    fidel = 0
    thetas, phis, omegas = params
    unitary = Unitary(N)(thetas, phis, omegas)
    for i in range(train_len):
        pred = jnp.dot(unitary, inputs[i])
        step_fidel = fidelity(pred, outputs[i])
        fidel += step_fidel

    return (fidel / train_len)[0][0]


# Fixed PRNGKeys to pick the same starting params
params = uniform(PRNGKey(0), (N**2, ), minval=0.0, maxval=2 * jnp.pi)
thetas = params[:N * (N - 1) // 2]
phis = params[N * (N - 1) // 2:N * (N - 1)]
omegas = params[N * (N - 1):]
params = [thetas, phis, omegas]

opt_init, opt_update, get_params = optimizers.adam(step_size=1e-1)
opt_state = opt_init(params)


def step(i, opt_state, opt_update):
    params = get_params(opt_state)
    g = grad(cost)(params, ket_input, ket_output)
    return opt_update(i, g, opt_state)

Beispiel #17
0
sequences = ["HASTA", "VISTA", "ALAVA", "LIMED", "HAST", "HAS", "HASVASTA"] * 5
holdout_sequences = [
    "HASTA",
    "VISTA",
    "ALAVA",
    "LIMED",
    "HAST",
    "HASVALTA",
] * 5
PROJECT_NAME = "evotuning_temp"

init_fun, apply_fun = mlstm64()

# The input_shape is always going to be (-1, 26),
# because that is the number of unique AA, one-hot encoded.
_, inital_params = init_fun(PRNGKey(42), input_shape=(-1, 26))

# 1. Evotuning with Optuna
n_epochs_config = {"low": 1, "high": 1}
lr_config = {"low": 1e-5, "high": 1e-3}
study, evotuned_params = evotune(
    sequences=sequences,
    model_func=apply_fun,
    params=inital_params,
    out_dom_seqs=holdout_sequences,
    n_trials=2,
    n_splits=2,
    n_epochs_config=n_epochs_config,
    learning_rate_config=lr_config,
)
Beispiel #18
0
    def init_kernel(init_params,
                    num_warmup,
                    step_size=1.0,
                    inverse_mass_matrix=None,
                    adapt_step_size=True,
                    adapt_mass_matrix=True,
                    dense_mass=False,
                    target_accept_prob=0.8,
                    trajectory_length=2 * math.pi,
                    max_tree_depth=10,
                    rng_key=PRNGKey(0)):
        """
        Initializes the HMC sampler.

        :param init_params: Initial parameters to begin sampling. The type must
            be consistent with the input type to `potential_fn`.
        :param int num_warmup: Number of warmup steps; samples generated
            during warmup are discarded.
        :param float step_size: Determines the size of a single step taken by the
            verlet integrator while computing the trajectory using Hamiltonian
            dynamics. If not specified, it will be set to 1.
        :param numpy.ndarray inverse_mass_matrix: Initial value for inverse mass matrix.
            This may be adapted during warmup if adapt_mass_matrix = True.
            If no value is specified, then it is initialized to the identity matrix.
        :param bool adapt_step_size: A flag to decide if we want to adapt step_size
            during warm-up phase using Dual Averaging scheme.
        :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass
            matrix during warm-up phase using Welford scheme.
        :param bool dense_mass: A flag to decide if mass matrix is dense or
            diagonal (default when ``dense_mass=False``)
        :param float target_accept_prob: Target acceptance probability for step size
            adaptation using Dual Averaging. Increasing this value will lead to a smaller
            step size, hence the sampling will be slower but more robust. Default to 0.8.
        :param float trajectory_length: Length of a MCMC trajectory for HMC. Default
            value is :math:`2\\pi`.
        :param int max_tree_depth: Max depth of the binary tree created during the doubling
            scheme of NUTS sampler. Defaults to 10.
        :param jax.random.PRNGKey rng_key: random key to be used as the source of
            randomness.
        """
        step_size = lax.convert_element_type(
            step_size, xla_bridge.canonicalize_dtype(np.float64))
        nonlocal momentum_generator, wa_update, trajectory_len, max_treedepth, wa_steps
        wa_steps = num_warmup
        trajectory_len = trajectory_length
        max_treedepth = max_tree_depth
        z = init_params
        z_flat, unravel_fn = ravel_pytree(z)
        momentum_generator = partial(_sample_momentum, unravel_fn)

        find_reasonable_ss = partial(find_reasonable_step_size, potential_fn,
                                     kinetic_fn, momentum_generator)

        wa_init, wa_update = warmup_adapter(
            num_warmup,
            adapt_step_size=adapt_step_size,
            adapt_mass_matrix=adapt_mass_matrix,
            dense_mass=dense_mass,
            target_accept_prob=target_accept_prob,
            find_reasonable_step_size=find_reasonable_ss)

        rng_key_hmc, rng_key_wa = random.split(rng_key)
        wa_state = wa_init(z,
                           rng_key_wa,
                           step_size,
                           inverse_mass_matrix=inverse_mass_matrix,
                           mass_matrix_size=np.size(z_flat))
        r = momentum_generator(wa_state.mass_matrix_sqrt, rng_key)
        vv_state = vv_init(z, r)
        energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r)
        hmc_state = HMCState(0, vv_state.z, vv_state.z_grad,
                             vv_state.potential_energy, energy, 0, 0., 0.,
                             False, wa_state, rng_key_hmc)
        return hmc_state
Beispiel #19
0
 def from_seed(cls: Type[T], seed: int) -> T:
     return cls(PRNGKey(seed))
Beispiel #20
0
 def parameters_from(self, reuse, *example_inputs):
     return self._init_parameters(*example_inputs, key=PRNGKey(0), reuse=reuse, reuse_only=True)
Beispiel #21
0
 def get_batches(batches=100, sequence_length=1000, key=PRNGKey(0)):
     for _ in range(batches):
         key, batch_key = random.split(key)
         yield random.normal(batch_key,
                             (1, receptive_field + sequence_length, 1))
Beispiel #22
0
 def _example_outputs(self, *inputs):
     _, outputs = self._init_and_apply_parameters_dict(*inputs, key=PRNGKey(0))
     return outputs
def fit_copula_jregression(y,x,n_perm = 10, seed = 20,n_perm_optim = None, single_bandwidth = True):
    #Set seed for scipy
    np.random.seed(seed)
    
    #Combine x,y
    z = jnp.concatenate((x,y.reshape(-1,1)), axis = 1)

    #Generate random permutations
    key = PRNGKey(seed)
    key,*subkey = split(key,n_perm +1 )
    subkey = jnp.array(subkey)
    z_perm = vmap(permutation,(0,None))(subkey,z)

    #Initialize parameter and put on correct scale to lie in [0,1]
    d = jnp.shape(z)[1]

    if single_bandwidth == True:
        rho_init = 0.9*jnp.ones(1)
    else:
        rho_init = 0.9*jnp.ones(d) 
    hyperparam_init = jnp.log(1/rho_init - 1) 


    #calculate rho_opt
    #either use all permutations or a selected number to fit bandwidth
    if n_perm_optim is None:
        z_perm_opt = z_perm
    else:
        z_perm_opt = z_perm[0:n_perm_optim]

    #Compiling
    print('Compiling...')
    start = time.time()

    #Condit
    temp = mvcr.fun_jcll_perm_sp(hyperparam_init,z_perm_opt)
    temp = mvcr.grad_jcll_perm_sp(hyperparam_init,z_perm_opt)

    temp = mvcd.update_pn_loop_perm(rho_init,z_perm)[0].block_until_ready()
    end = time.time()
    print('Compilation time: {}s'.format(round(end-start, 3)))

    print('Optimizing...')
    start = time.time()
    # Condit preq loglik
    opt = minimize(fun = mvcr.fun_jcll_perm_sp, x0= hyperparam_init,\
                     args = (z_perm_opt),jac =mvcr.grad_jcll_perm_sp,method = 'SLSQP') 

    #check optimization succeeded
    if opt.success == False:
        print('Optimization failed')

    #unscale hyperparameter
    hyperparam_opt = opt.x
    rho_opt = 1/(1+jnp.exp(hyperparam_opt))
    end = time.time()

    print('Optimization time: {}s'.format(round(end-start, 3)))
        
    print('Fitting...')
    start = time.time()
    vn_perm= mvcd.update_pn_loop_perm(rho_opt,z_perm)[0].block_until_ready()
    end = time.time()
    print('Fit time: {}s'.format(round(end-start, 3)))

    copula_jregression_obj = namedtuple('copula_jregression_obj',['vn_perm','rho_opt','preq_loglik'])
    return copula_jregression_obj(vn_perm,rho_opt,-opt.fun)
Beispiel #24
0
def main(args):
    encoder_nn = encoder(args.hidden_dim, args.z_dim)
    decoder_nn = decoder(args.hidden_dim, 28 * 28)
    adam = optim.Adam(args.learning_rate)
    svi = SVI(model,
              guide,
              adam,
              Trace_ELBO(),
              hidden_dim=args.hidden_dim,
              z_dim=args.z_dim)
    rng_key = PRNGKey(0)
    train_init, train_fetch = load_dataset(MNIST,
                                           batch_size=args.batch_size,
                                           split='train')
    test_init, test_fetch = load_dataset(MNIST,
                                         batch_size=args.batch_size,
                                         split='test')
    num_train, train_idx = train_init()
    rng_key, rng_key_binarize, rng_key_init = random.split(rng_key, 3)
    sample_batch = binarize(rng_key_binarize, train_fetch(0, train_idx)[0])
    svi_state = svi.init(rng_key_init, sample_batch)

    @jit
    def epoch_train(svi_state, rng_key, train_idx):
        def body_fn(i, val):
            loss_sum, svi_state = val
            rng_key_binarize = random.fold_in(rng_key, i)
            batch = binarize(rng_key_binarize, train_fetch(i, train_idx)[0])
            svi_state, loss = svi.update(svi_state, batch)
            loss_sum += loss
            return loss_sum, svi_state

        return lax.fori_loop(0, num_train, body_fn, (0., svi_state))

    @jit
    def eval_test(svi_state, rng_key, test_idx):
        def body_fun(i, loss_sum):
            rng_key_binarize = random.fold_in(rng_key, i)
            batch = binarize(rng_key_binarize, test_fetch(i, test_idx)[0])
            # FIXME: does this lead to a requirement for an rng_key arg in svi_eval?
            loss = svi.evaluate(svi_state, batch) / len(batch)
            loss_sum += loss
            return loss_sum

        loss = lax.fori_loop(0, num_test, body_fun, 0.)
        loss = loss / num_test
        return loss

    def reconstruct_img(epoch, rng_key):
        img = test_fetch(0, test_idx)[0][0]
        plt.imsave(os.path.join(RESULTS_DIR,
                                'original_epoch={}.png'.format(epoch)),
                   img,
                   cmap='gray')
        rng_key_binarize, rng_key_sample = random.split(rng_key)
        test_sample = binarize(rng_key_binarize, img)
        params = svi.get_params(svi_state)
        z_mean, z_var = encoder_nn[1](params['encoder$params'],
                                      test_sample.reshape([1, -1]))
        z = dist.Normal(z_mean, z_var).sample(rng_key_sample)
        img_loc = decoder_nn[1](params['decoder$params'], z).reshape([28, 28])
        plt.imsave(os.path.join(RESULTS_DIR,
                                'recons_epoch={}.png'.format(epoch)),
                   img_loc,
                   cmap='gray')

    for i in range(args.num_epochs):
        rng_key, rng_key_train, rng_key_test, rng_key_reconstruct = random.split(
            rng_key, 4)
        t_start = time.time()
        num_train, train_idx = train_init()
        _, svi_state = epoch_train(svi_state, rng_key_train, train_idx)
        rng_key, rng_key_test, rng_key_reconstruct = random.split(rng_key, 3)
        num_test, test_idx = test_init()
        test_loss = eval_test(svi_state, rng_key_test, test_idx)
        reconstruct_img(i, rng_key_reconstruct)
        print("Epoch {}: loss = {} ({:.2f} s.)".format(i, test_loss,
                                                       time.time() - t_start))
Beispiel #25
0
def main(args):
    rng = PRNGKey(123)
    rng, toy_data_rng = jax.random.split(rng)

    train_data, test_data, true_params = create_toy_data(
        toy_data_rng, args.num_samples, args.dimensions)

    train_init, train_fetch = subsample_batchify_data(
        train_data, batch_size=args.batch_size)
    test_init, test_fetch = split_batchify_data(test_data,
                                                batch_size=args.batch_size)

    ## Init optimizer and training algorithms
    optimizer = optimizers.Adam(args.learning_rate)

    svi = DPSVI(model,
                guide,
                optimizer,
                ELBO(),
                dp_scale=0.01,
                clipping_threshold=20.,
                num_obs_total=args.num_samples)

    rng, svi_init_rng, data_fetch_rng = random.split(rng, 3)
    _, batchifier_state = train_init(rng_key=data_fetch_rng)
    sample_batch = train_fetch(0, batchifier_state)

    svi_state = svi.init(svi_init_rng, *sample_batch)

    @jit
    def epoch_train(svi_state, batchifier_state, num_batch):
        def body_fn(i, val):
            svi_state, loss = val
            batch = train_fetch(i, batchifier_state)
            batch_X, batch_Y = batch

            svi_state, batch_loss = svi.update(svi_state, batch_X, batch_Y)
            loss += batch_loss / (args.num_samples * num_batch)
            return svi_state, loss

        return lax.fori_loop(0, num_batch, body_fn, (svi_state, 0.))

    @jit
    def eval_test(svi_state, batchifier_state, num_batch, rng):
        params = svi.get_params(svi_state)

        def body_fn(i, val):
            loss_sum, acc_sum = val

            batch = test_fetch(i, batchifier_state)
            batch_X, batch_Y = batch

            loss = svi.evaluate(svi_state, batch_X, batch_Y)
            loss_sum += loss / (args.num_samples * num_batch)

            acc_rng = jax.random.fold_in(rng, i)
            acc = estimate_accuracy(batch_X, batch_Y, params, acc_rng, 1)
            acc_sum += acc / num_batch

            return loss_sum, acc_sum

        return lax.fori_loop(0, num_batch, body_fn, (0., 0.))

## Train model

    for i in range(args.num_epochs):
        t_start = time.time()
        rng, data_fetch_rng = random.split(rng, 2)

        num_train_batches, train_batchifier_state = train_init(
            rng_key=data_fetch_rng)
        svi_state, train_loss = epoch_train(svi_state, train_batchifier_state,
                                            num_train_batches)
        train_loss.block_until_ready()
        t_end = time.time()

        if (i % (args.num_epochs // 10)) == 0:
            rng, test_rng, test_fetch_rng = random.split(rng, 3)
            num_test_batches, test_batchifier_state = test_init(
                rng_key=test_fetch_rng)
            test_loss, test_acc = eval_test(svi_state, test_batchifier_state,
                                            num_test_batches, test_rng)
            print(
                "Epoch {}: loss = {}, acc = {} (loss on training set: {}) ({:.2f} s.)"
                .format(i, test_loss, test_acc, train_loss, t_end - t_start))

    # parameters for logistic regression may be scaled arbitrarily. normalize
    #   w (and scale intercept accordingly) for comparison
    w_true = normalize(true_params[0])
    scale_true = jnp.linalg.norm(true_params[0])
    intercept_true = true_params[1] / scale_true

    params = svi.get_params(svi_state)
    w_post = normalize(params['w_loc'])
    scale_post = jnp.linalg.norm(params['w_loc'])
    intercept_post = params['intercept_loc'] / scale_post

    print("w_loc: {}\nexpected: {}\nerror: {}".format(
        w_post, w_true, jnp.linalg.norm(w_post - w_true)))
    print("w_std: {}".format(jnp.exp(params['w_std_log'])))
    print("")
    print("intercept_loc: {}\nexpected: {}\nerror: {}".format(
        intercept_post, intercept_true,
        jnp.abs(intercept_post - intercept_true)))
    print("intercept_std: {}".format(jnp.exp(params['intercept_std_log'])))
    print("")

    X_test, y_test = test_data
    rng, rng_acc_true, rng_acc_post = jax.random.split(rng, 3)
    # for evaluation accuracy with true parameters, we scale them to the same
    #   scale as the found posterior. (gives better results than normalized
    #   parameters (probably due to numerical instabilities))
    acc_true = estimate_accuracy_fixed_params(X_test, y_test, w_true,
                                              intercept_true, rng_acc_true, 10)
    acc_post = estimate_accuracy(X_test, y_test, params, rng_acc_post, 10)

    print(
        "avg accuracy on test set:  with true parameters: {} ; with found posterior: {}\n"
        .format(acc_true, acc_post))
Beispiel #26
0
    def init_kernel(init_params,
                    num_warmup,
                    step_size=1.0,
                    adapt_step_size=True,
                    adapt_mass_matrix=True,
                    dense_mass=False,
                    target_accept_prob=0.8,
                    trajectory_length=2 * math.pi,
                    max_tree_depth=10,
                    run_warmup=True,
                    progbar=True,
                    rng_key=PRNGKey(0)):
        """
        Initializes the HMC sampler.

        :param init_params: Initial parameters to begin sampling. The type must
            be consistent with the input type to `potential_fn`.
        :param int num_warmup: Number of warmup steps; samples generated
            during warmup are discarded.
        :param float step_size: Determines the size of a single step taken by the
            verlet integrator while computing the trajectory using Hamiltonian
            dynamics. If not specified, it will be set to 1.
        :param bool adapt_step_size: A flag to decide if we want to adapt step_size
            during warm-up phase using Dual Averaging scheme.
        :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass
            matrix during warm-up phase using Welford scheme.
        :param bool dense_mass: A flag to decide if mass matrix is dense or
            diagonal (default when ``dense_mass=False``)
        :param float target_accept_prob: Target acceptance probability for step size
            adaptation using Dual Averaging. Increasing this value will lead to a smaller
            step size, hence the sampling will be slower but more robust. Default to 0.8.
        :param float trajectory_length: Length of a MCMC trajectory for HMC. Default
            value is :math:`2\\pi`.
        :param int max_tree_depth: Max depth of the binary tree created during the doubling
            scheme of NUTS sampler. Defaults to 10.
        :param bool run_warmup: Flag to decide whether warmup is run. If ``True``,
            `init_kernel` returns an initial :data:`~numpyro.infer.mcmc.HMCState` that can be used to
            generate samples using MCMC. Else, returns the arguments and callable
            that does the initial adaptation.
        :param bool progbar: Whether to enable progress bar updates. Defaults to
            ``True``.
        :param jax.random.PRNGKey rng_key: random key to be used as the source of
            randomness.
        """
        step_size = lax.convert_element_type(
            step_size, xla_bridge.canonicalize_dtype(np.float64))
        nonlocal momentum_generator, wa_update, trajectory_len, max_treedepth, wa_steps
        wa_steps = num_warmup
        trajectory_len = trajectory_length
        max_treedepth = max_tree_depth
        z = init_params
        z_flat, unravel_fn = ravel_pytree(z)
        momentum_generator = partial(_sample_momentum, unravel_fn)

        find_reasonable_ss = partial(find_reasonable_step_size, potential_fn,
                                     kinetic_fn, momentum_generator)

        wa_init, wa_update = warmup_adapter(
            num_warmup,
            adapt_step_size=adapt_step_size,
            adapt_mass_matrix=adapt_mass_matrix,
            dense_mass=dense_mass,
            target_accept_prob=target_accept_prob,
            find_reasonable_step_size=find_reasonable_ss)

        rng_key_hmc, rng_key_wa = random.split(rng_key)
        wa_state = wa_init(z,
                           rng_key_wa,
                           step_size,
                           mass_matrix_size=np.size(z_flat))
        r = momentum_generator(wa_state.mass_matrix_sqrt, rng_key)
        vv_state = vv_init(z, r)
        energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r)
        hmc_state = HMCState(0, vv_state.z, vv_state.z_grad,
                             vv_state.potential_energy, energy, 0, 0., 0.,
                             False, wa_state, rng_key_hmc)

        # TODO: Remove; this should be the responsibility of the MCMC class.
        if run_warmup and num_warmup > 0:
            # JIT if progress bar updates not required
            if not progbar:
                hmc_state = fori_loop(0, num_warmup,
                                      lambda *args: sample_kernel(args[1]),
                                      hmc_state)
            else:
                with tqdm.trange(num_warmup, desc='warmup') as t:
                    for i in t:
                        hmc_state = jit(sample_kernel)(hmc_state)
                        t.set_postfix_str(get_diagnostics_str(hmc_state),
                                          refresh=False)
        return hmc_state
Beispiel #27
0
 def dy(self, samples, noise_scale=0.4, **args):
     '''Daily confirmed cases with observation noise'''
     dy_mean = self.dy_mean(samples, **args)
     dy = dist.Normal(dy_mean, noise_scale * dy_mean).sample(PRNGKey(11))
     return dy
in terms of the speed. It also checks whether or not the inference algorithms give the same result.
Author : Aleyna Kara (@karalleyna)
'''

import time

import jax.numpy as jnp
from jax.random import PRNGKey, split, uniform
import numpy as np

from hmm_lib_log import HMM, hmm_forwards_backwards_log, hmm_viterbi_log, hmm_sample_log

import distrax

seed = 0
rng_key = PRNGKey(seed)
rng_key, key_A, key_B = split(rng_key, 3)

# state transition matrix
n_hidden, n_obs = 100, 10
A = uniform(key_A, (n_hidden, n_hidden))
A = A / jnp.sum(A, axis=1)

# observation matrix
B = uniform(key_B, (n_hidden, n_obs))
B = B / jnp.sum(B, axis=1).reshape((-1, 1))

n_samples = 1000
init_state_dist = jnp.ones(n_hidden) / n_hidden

seed = 0
Beispiel #29
0
from jax.random import PRNGKey
import jax.numpy as np
from numpyro import sample
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

from mixture_model import NormalMixture
from NMC import NMC

# Model key
rng = PRNGKey(0)

# Mixture model


def mix_model(data):
    w = sample("w", dist.Dirichlet((1 / 3) * np.ones(3)), rng_key=rng)
    mu = sample("mu", dist.Normal(np.zeros(3), np.ones(3)), rng_key=rng)
    std = sample("std", dist.Gamma(np.ones(3), np.ones(3)), rng_key=rng)
    sample("obs", NormalMixture(w, mu, std), rng_key=rng, obs=data)


# Data for mixture model
data_test1 = sample("norm1",
                    dist.Normal(10, 1),
                    rng_key=PRNGKey(0),
                    sample_shape=(1000, ))
data_test2 = sample("norm2",
                    dist.Normal(0, 1),
                    rng_key=PRNGKey(0),
                    sample_shape=(1000, ))
Beispiel #30
0
 def __init__(self, cost, backend='torch'):
     self.normalized = cost.normalized
     self.constant_goal = cost.constant_goal
     self.c = cost
     self._key = PRNGKey(0)