예제 #1
0
 def loss_fun(params,
              rng,
              data,
              batch_size=None,
              n=None,
              loss_type="nlp",
              reduce="sum"):
     """
     :param batch_size: How large a batch to subselect from the provided data
     :param n: The total size of the dataset (to multiply batch estimate by)
     """
     assert loss_type in ("nlp", "mse")
     inputs, targets = data
     n = inputs.shape[0] if n is None else n
     if batch_size is not None:
         rng, rng_batch = random.split(rng)
         i = random.permutation(rng_batch, n)[:batch_size]
         inputs, targets = inputs[i], targets[i]
     preds = apply_fun(params, rng, inputs).squeeze()
     mean_loss = (
         -norm.logpdf(targets.squeeze(), preds, params["noise"]).mean()
         if loss_type == "nlp" else np.power(targets.squeeze() -
                                             preds, 2).mean())
     if reduce == "sum":
         loss = n * mean_loss
     elif reduce == "mean":
         loss = mean_loss
     return loss
예제 #2
0
def generate_nested_circles(key,
                            n_samples,
                            inner_radius=2,
                            outer_radius=4,
                            noise=0.15):

    k1, k2, k3, k4 = random.split(key, 4)

    # Generate the circles
    inner_t = random.uniform(k1, shape=(n_samples // 2, )) * 2 * jnp.pi
    inner_circle = inner_radius * jnp.vstack(
        [jnp.cos(inner_t), jnp.sin(inner_t)])

    outer_t = random.uniform(k2, shape=(n_samples // 2, )) * 2 * jnp.pi
    outer_circle = outer_radius * jnp.vstack(
        [jnp.cos(outer_t), jnp.sin(outer_t)])

    data = jnp.vstack([inner_circle.T, outer_circle.T])

    # Keep track of the labels
    y = jnp.hstack([jnp.zeros(n_samples // 2), jnp.ones(n_samples // 2)])

    # Shuffle the data
    idx = jnp.arange(n_samples)
    idx = random.permutation(k3, idx)
    data = data[idx]
    y = y[idx]

    data += random.normal(k4, data.shape) * noise
    return data, y
예제 #3
0
    def gibbs_fn(rng_key, gibbs_sites, hmc_sites, pe):
        # get support_sizes of gibbs_sites
        support_sizes_flat, _ = ravel_pytree({k: support_sizes[k] for k in gibbs_sites})
        num_discretes = support_sizes_flat.shape[0]

        rng_key, rng_permute = random.split(rng_key)
        idxs = random.permutation(rng_key, jnp.arange(num_discretes))

        def body_fn(i, val):
            idx = idxs[i]
            support_size = support_sizes_flat[idx]
            rng_key, z, pe = val
            rng_key, z_new, pe_new, log_accept_ratio = proposal_fn(
                rng_key, z, pe, potential_fn=partial(potential_fn, z_hmc=hmc_sites),
                idx=idx, support_size=support_size)
            rng_key, rng_accept = random.split(rng_key)
            # u ~ Uniform(0, 1), u < accept_ratio => -log(u) > -log_accept_ratio
            # and -log(u) ~ exponential(1)
            z, pe = cond(random.exponential(rng_accept) > -log_accept_ratio,
                         (z_new, pe_new), identity,
                         (z, pe), identity)
            return rng_key, z, pe

        init_val = (rng_key, gibbs_sites, pe)
        _, gibbs_sites, pe = fori_loop(0, num_discretes, body_fn, init_val)
        return gibbs_sites, pe
예제 #4
0
  def test_permutation_invariance(self):

    num_nodes = 4
    num_features = 2
    rng = random.PRNGKey(0)

    # Generate random graph.
    adjacency = random.randint(rng, (num_nodes, num_nodes), 0, 2)
    node_feats = random.normal(rng, (num_nodes, num_features))
    sources, targets = jnp.where(adjacency)

    # Get permuted graph.
    perm = random.permutation(rng, jnp.arange(num_nodes))
    node_feats_perm = node_feats[perm]
    adjacency_perm = adjacency[perm]
    for j in range(len(adjacency)):
      adjacency_perm = jax.ops.index_update(
          adjacency_perm, j, adjacency_perm[j][perm])
    sources_perm, targets_perm = jnp.where(adjacency_perm)

    # Create GNN.
    _, initial_params = GNN.init(
      rng, node_x=node_feats, edge_x=None, sources=sources, targets=targets)
    model = nn.Model(GNN, initial_params)

    # Feedforward both original and permuted graph.
    logits = model(node_feats, None, sources, targets)
    logits_perm = model(node_feats_perm, None, sources_perm, targets_perm)

    self.assertAllClose(logits[perm], logits_perm, check_dtypes=False)
    def loss(params, key):
        keys = random.split(key, 5)
        indices = random.permutation(keys[0],
                                     jnp.arange(X.shape[0]))[:batch_size]
        X_batch = X[indices, :]

        wind_velocity = random.uniform(keys[1],
                                       shape=(3, ),
                                       minval=jnp.asarray([-200., -200., 0.]),
                                       maxval=jnp.asarray([200., 200., 0.
                                                           ])) / 1000.
        bottom = random.uniform(keys[2], minval=50., maxval=500.)
        width = random.uniform(keys[3], minval=40., maxval=300.)
        l = random.uniform(keys[4], minval=1., maxval=30.)
        sigma = 1.
        K = kernel(X_batch,
                   X_batch,
                   bottom,
                   width,
                   l,
                   sigma,
                   wind_velocity=wind_velocity)
        neural_kernel.set_params(params)
        neural_K = neural_kernel(X_batch,
                                 X_batch,
                                 bottom,
                                 width,
                                 l,
                                 sigma,
                                 wind_velocity=wind_velocity)

        return jnp.mean((K - neural_K)**2) / width**2
예제 #6
0
    def _make_minibatches(self, observations, batch_size, rng_key):
        '''
        Creates minibatches consists of the random permutations of the
        given observation sequences

        Parameters
        ----------
        observations : array(N, seq_len)
            Dataset

        batch_size : int
            The number of observation sequences that will be included in
            each minibatch

        rng_key : array
            Random key of shape (2,) and dtype uint32

        Returns
        -------
        * array(num_batches, batch_size, max_len)
            Minibatches
        '''
        num_train = len(observations)
        perm = permutation(rng_key, num_train)

        def create_mini_batch(batch_idx):
            return observations[batch_idx]

        num_batches = num_train // batch_size
        batch_indices = perm.reshape((num_batches, -1))
        minibatches = vmap(create_mini_batch)(batch_indices)

        return minibatches
예제 #7
0
    def fit(self, X):
        opt_init, opt_update, get_params = optimizers.adam(step_size=1e-3)
        opt_state = opt_init((self.encoder_params, self.decoder_params))

        def loss(params, inputs):
            encoder_params, decoder_params = params
            enc = self.encoder_apply(encoder_params, X)
            dec = self.decoder_apply(decoder_params, X)
            return np.square(inputs - dec).sum() + 1e-3 * np.abs(params).sum()

        @jit
        def step(i, opt_state, inputs):
            params = get_params(opt_state)
            gradient = grad(loss)(params, inputs)
            return opt_update(i, gradient, opt_state)

        print('Training autoencoder...')

        batch_size, itercount = 32, itertools.count()
        key = random.PRNGKey(0)
        for epoch in range(5):
            temp_key, key = random.split(key)
            X = random.permutation(temp_key, X)
            for batch_index in range(0, X.shape[0], batch_size):
                opt_state = step(next(itercount), opt_state,
                                 X[batch_index:batch_index + batch_size])

        self.encoder_params, self.decoder_params = get_params(opt_state)
예제 #8
0
def loss(rng: jnp.ndarray, bij_params: Sequence[jnp.ndarray],
         bij_fns: Sequence[Callable], deq_params: Sequence[jnp.ndarray],
         deq_fn: Callable, xon: jnp.ndarray) -> float:
    """Loss function composed of the evidence lower bound and score matching
    loss.

    Args:
        rng: Pseudo-random number generator seed.
        bij_params: List of arrays parameterizing the RealNVP bijectors.
        bij_fns: List of functions that compute the shift and scale of the RealNVP
            affine transformation.
        deq_params: Parameters of the mean and scale functions used in
            the log-normal dequantizer.
        deq_fn: Function that computes the mean and scale of the dequantization
            distribution.
        xon: Observations on O(n).

    Returns:
        nelbo: The negative evidence lower bound.

    """
    rng, rng_loss, rng_idx = random.split(rng, 3)
    idx = random.permutation(rng_idx, len(xon))[:100]
    xobs = xon[idx]
    if args.elbo_loss:
        nelbo = negative_elbo(rng_loss, bij_params, bij_fns, deq_params,
                              deq_fn, xobs)
        nelbo = nelbo.mean()
        return nelbo
    else:
        log_is = importance_log_density(rng_loss, bij_params, bij_fns,
                                        deq_params, deq_fn,
                                        args.num_importance, xobs)
        log_target = log_density(xobs)
        return jnp.mean(log_target - log_is)
예제 #9
0
파일: mnist_lib.py 프로젝트: ramasesh/flax
def train_epoch(optimizer, train_ds, batch_size, epoch, rng):
    """Train for a single epoch."""
    train_ds_size = len(train_ds['image'])
    steps_per_epoch = train_ds_size // batch_size

    perms = random.permutation(rng, len(train_ds['image']))
    perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
    perms = perms.reshape((steps_per_epoch, batch_size))
    batch_metrics = []
    for perm in perms:
        batch = {k: v[perm] for k, v in train_ds.items()}
        optimizer, metrics = train_step(optimizer, batch)
        batch_metrics.append(metrics)

    # compute mean of metrics across each batch in epoch.
    batch_metrics_np = jax.device_get(batch_metrics)
    epoch_metrics_np = {
        k: onp.mean([metrics[k] for metrics in batch_metrics_np])
        for k in batch_metrics_np[0]
    }

    logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                 epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100)

    return optimizer, epoch_metrics_np
예제 #10
0
def data_stream(rng, batch_size, X_train, y_train):
    num_batches, leftover = divmod(X_train.shape[0], batch_size)
    while True:
        temp, rng = random.split(rng)
        perm = random.permutation(temp, X_train.shape[0])
        for i in range(num_batches):
            batch_idx = perm[i * batch_size:(i + 1) * batch_size]
            yield X_train[batch_idx], y_train[batch_idx]
예제 #11
0
def data_stream(rng, batch_size, X, y):
    num_complete_batches, leftover = divmod(X.shape[0], batch_size)
    num_batches = num_complete_batches + bool(leftover)
    while True:
        temp, rng = random.split(rng)
        perm = random.permutation(temp, X.shape[0])
        for i in range(num_batches):
            batch_idx = perm[i * batch_size:(i + 1) * batch_size]
            yield X[batch_idx], y[batch_idx]
예제 #12
0
def _make_iaf_args(input_dim, hidden_dims):
    _, rng_perm = random.split(random.PRNGKey(0))
    perm = random.permutation(rng_perm, np.arange(input_dim))
    # we use Elu nonlinearity because the default one, Relu, masks out negative hidden values,
    # which in turn create some zero entries in the lower triangular part of Jacobian.
    arn_init, arn = AutoregressiveNN(input_dim, hidden_dims, param_dims=[1, 1],
                                     permutation=perm, nonlinearity=stax.Elu)
    _, init_params = arn_init(random.PRNGKey(0), (input_dim,))
    return partial(arn, init_params),
예제 #13
0
def data_stream(key, X, y, batch_size):
    n_data = len(X)
    while True:
        perm_key, key = split(key)
        perm = permutation(perm_key, n_data)
        num_batches, mod = divmod(n_data, batch_size)
        num_batches += 1 if mod else 0
        for i in range(num_batches):
            batch_idx = perm[i * batch_size:min((i + 1) * batch_size, n_data)]
            yield X[batch_idx], y[batch_idx]
예제 #14
0
파일: data.py 프로젝트: niloch/colin_net
    def __iter__(self) -> Iterator[Batch]:
        starts = np.arange(0, len(self.inputs), self.batch_size)
        self.key, subkey = random.split(self.key)
        starts = random.permutation(subkey, starts)

        for start in starts:
            end = start + self.batch_size
            batch_inputs = self.inputs[start:end]
            batch_targets = self.targets[start:end]
            yield Batch(batch_inputs, batch_targets)
예제 #15
0
 def data_stream(self, train_images, train_labels, num_train, num_batches,
                 batch_size):
     """Returns batches of data for training"""
     key = random.PRNGKey(0)
     while True:
         key, subkey = random.split(key)
         perm = random.permutation(subkey, num_train)
         for i in range(num_batches):
             batch_idx = perm[i * batch_size:(i + 1) * batch_size]
             yield train_images[batch_idx], train_labels[batch_idx]
예제 #16
0
    def gibbs_fn(rng_key, gibbs_sites, hmc_sites):
        # convert to unconstrained values
        z_hmc = {
            k: biject_to(prototype_trace[k]["fn"].support).inv(v)
            for k, v in hmc_sites.items()
            if k in prototype_trace and prototype_trace[k]["type"] == "sample"
        }
        use_enum = len(set(support_sizes) - set(gibbs_sites)) > 0
        wrapped_model = _wrap_model(model)
        if use_enum:
            from numpyro.contrib.funsor import config_enumerate, enum

            wrapped_model = enum(config_enumerate(wrapped_model),
                                 -max_plate_nesting - 1)

        def potential_fn(z_discrete):
            model_kwargs_ = model_kwargs.copy()
            model_kwargs_["_gibbs_sites"] = z_discrete
            return potential_energy(wrapped_model,
                                    model_args,
                                    model_kwargs_,
                                    z_hmc,
                                    enum=use_enum)

        # get support_sizes of gibbs_sites
        support_sizes_flat, _ = ravel_pytree(
            {k: support_sizes[k]
             for k in gibbs_sites})
        num_discretes = support_sizes_flat.shape[0]

        rng_key, rng_permute = random.split(rng_key)
        idxs = random.permutation(rng_key, jnp.arange(num_discretes))

        def body_fn(i, val):
            idx = idxs[i]
            support_size = support_sizes_flat[idx]
            rng_key, z, pe = val
            rng_key, z_new, pe_new, log_accept_ratio = proposal_fn(
                rng_key,
                z,
                pe,
                potential_fn=potential_fn,
                idx=idx,
                support_size=support_size)
            rng_key, rng_accept = random.split(rng_key)
            # u ~ Uniform(0, 1), u < accept_ratio => -log(u) > -log_accept_ratio
            # and -log(u) ~ exponential(1)
            z, pe = cond(
                random.exponential(rng_accept) > -log_accept_ratio,
                (z_new, pe_new), identity, (z, pe), identity)
            return rng_key, z, pe

        init_val = (rng_key, gibbs_sites, potential_fn(gibbs_sites))
        _, gibbs_sites, _ = fori_loop(0, num_discretes, body_fn, init_val)
        return gibbs_sites
예제 #17
0
def train_one_epoch(params, X_train, y_train, epoch, r):
    num_samples = X_train.shape[0]
    random_sample_idx = random.permutation(random.PRNGKey(epoch), jnp.arange(num_samples))
    for idx in tqdm(range(0, num_samples, BATCH_SIZE)):
        mini_batch_idx = random_sample_idx[idx:idx + BATCH_SIZE]
        mini_batch_x = X_train[mini_batch_idx]
        mini_batch_y = y_train[mini_batch_idx]
        params_grad = grad(mlp_loss, argnums=0)(params, mini_batch_x, mini_batch_y)
        params, r = apply_grads(params, params_grad, r, LEARNING_RATE)
        loss = mlp_loss(params, mini_batch_x, mini_batch_y)
    return loss, params, r
예제 #18
0
def sample_observations(key, f, n_obs, xmin, xmax, x_noise=0.1, y_noise=3.0):
    key_x, key_y, key_shuffle = split(key, 3)
    x_noise = normal(key_x, (n_obs,)) * x_noise
    y_noise = normal(key_y, (n_obs,)) * y_noise
    x = jnp.linspace(xmin, xmax, n_obs) + x_noise
    y = f(x) + y_noise
    X = np.c_[x, y]

    shuffled_ixs = permutation(key_shuffle, jnp.arange(n_obs))
    x, y = jnp.array(X[shuffled_ixs, :].T)
    return x, y
예제 #19
0
    def init_fun(rng, input_dim, **kwargs):
        perm = random.permutation(rng, np.arange(input_dim))
        inv_perm = np.argsort(perm)

        def direct_fun(params, inputs, **kwargs):
            return inputs[:, perm], np.zeros(inputs.shape[:1])

        def inverse_fun(params, inputs, **kwargs):
            return inputs[:, inv_perm], np.zeros(inputs.shape[:1])

        return (), direct_fun, inverse_fun
예제 #20
0
파일: random_test.py 프로젝트: samuela/jax
    def testPermutationInteger(self):
        key = random.PRNGKey(0)
        x = 100
        rand = lambda key: random.permutation(key, x)
        crand = api.jit(rand)

        perm1 = rand(key)
        perm2 = crand(key)

        self.assertAllClose(perm1, perm2)
        self.assertEqual(perm1.dtype, perm2.dtype)
        self.assertFalse(np.all(perm1 == np.arange(100)))  # seems unlikely!
        self.assertAllClose(np.sort(perm1), np.arange(100), check_dtypes=False)
예제 #21
0
    def gibbs_fn(rng_key, gibbs_sites, hmc_sites):
        z_hmc = hmc_sites
        use_enum = len(set(support_sizes) - set(gibbs_sites)) > 0
        if use_enum:
            from numpyro.contrib.funsor import config_enumerate, enum

            wrapped_model_ = enum(config_enumerate(wrapped_model),
                                  -max_plate_nesting - 1)
        else:
            wrapped_model_ = wrapped_model

        def potential_fn(z_discrete):
            model_kwargs_ = model_kwargs.copy()
            model_kwargs_["_gibbs_sites"] = z_discrete
            return potential_energy(wrapped_model_,
                                    model_args,
                                    model_kwargs_,
                                    z_hmc,
                                    enum=use_enum)

        # get support_sizes of gibbs_sites
        support_sizes_flat, _ = ravel_pytree(
            {k: support_sizes[k]
             for k in gibbs_sites})
        num_discretes = support_sizes_flat.shape[0]

        rng_key, rng_permute = random.split(rng_key)
        idxs = random.permutation(rng_key, jnp.arange(num_discretes))

        def body_fn(i, val):
            idx = idxs[i]
            support_size = support_sizes_flat[idx]
            rng_key, z, pe = val
            rng_key, z_new, pe_new, log_accept_ratio = proposal_fn(
                rng_key,
                z,
                pe,
                potential_fn=potential_fn,
                idx=idx,
                support_size=support_size)
            rng_key, rng_accept = random.split(rng_key)
            # u ~ Uniform(0, 1), u < accept_ratio => -log(u) > -log_accept_ratio
            # and -log(u) ~ exponential(1)
            z, pe = cond(
                random.exponential(rng_accept) > -log_accept_ratio,
                (z_new, pe_new), identity, (z, pe), identity)
            return rng_key, z, pe

        init_val = (rng_key, gibbs_sites, potential_fn(gibbs_sites))
        _, gibbs_sites, _ = fori_loop(0, num_discretes, body_fn, init_val)
        return gibbs_sites
예제 #22
0
파일: random_test.py 프로젝트: zizai/jax
  def testPermutationArray(self, dtype, shape):
    key = random.PRNGKey(0)
    x = jnp.arange(np.prod(shape)).reshape(shape).astype(dtype)
    rand = lambda key: random.permutation(key, x)
    crand = api.jit(rand)

    perm1 = rand(key)
    perm2 = crand(key)

    self.assertAllClose(perm1, perm2)
    self.assertFalse(np.all(perm1 == x))  # seems unlikely!
    self.assertAllClose(np.sort(perm1.ravel()), x.ravel(), check_dtypes=False)
    self.assertArraysAllClose(
      x, jnp.arange(np.prod(shape)).reshape(shape).astype(dtype))
예제 #23
0
def get_batch(sampling, key, X, minibatch_size, iteration):
  if sampling == 'batch':
    # Calculate epoch from iteration
    epoch = iteration // (X.shape[0] // minibatch_size)
    batch_index = iteration % (X.shape[0] // minibatch_size)
    batch_index_start = batch_index * minibatch_size
    # Regular batching
    if batch_index == 0:
      temp_key, key = random.split(key)
      X = random.permutation(temp_key, X)
    return X[batch_index_start:batch_index_start+minibatch_size], X
  elif sampling == 'uniform':
    # Uniform subsampling
    temp_key, key = random.split(key)
    X = random.permutation(temp_key, X)
    return X[:minibatch_size], X
  elif sampling == 'poisson':
    # Poisson subsampling
    temp_key, key = random.split(key)
    whether = random.uniform(temp_key, (X.shape[0],)) < (minibatch_size / X.shape[0])
    return X[whether], X
  else:
    raise Exception('Invalid sampling method: {}'.format(sampling))
예제 #24
0
        def epoch_step(opt_state, key):
            perm = permutation(key, len(observations))
            _observatios, _targets = observations[perm], targets[perm]
            sample_generator = self._sample_minibatches(
                (_observatios, _targets), batch_size)

            def train_step(opt_state, i):
                opt_state, loss = self.update(next(itercount), opt_state,
                                              next(sample_generator))
                return opt_state, loss

            opt_state, losses = scan(train_step, opt_state,
                                     jnp.arange(num_batches))
            return opt_state, losses.mean()
예제 #25
0
def load_mnist(key, n_train, n_test, shuffle=True):
    (X, y), (X_test, y_test) = mnist.load_data()

    n_train = n_train if n_train < len(y) else len(y)
    n_test = n_test if n_test < len(y_test) else len(y)

    train_key, test_key = split(key)
    train_indices = jnp.arange(len(y))
    perm = permutation(train_key, train_indices)[:n_train] if shuffle else train_indices[:n_train]

    train_ds = {
        "X": jnp.float32(X[perm].reshape(n_train, -1)) / 255.,
        "y": jnp.array(y[perm])
    }

    test_indices = jnp.arange(len(y_test))
    perm = permutation(test_key, test_indices)[:n_test] if shuffle else test_indices[:n_test]

    test_ds = {
        "X": jnp.float32(X_test[perm].reshape(n_test, -1)) / 255.,
        "y": jnp.array(y_test[perm])
    }

    return train_ds, test_ds
예제 #26
0
파일: random_test.py 프로젝트: yotarok/jax
  def testPermutationArray(self, dtype):
    key = random.PRNGKey(0)
    x = onp.arange(100).astype(dtype)
    rand = lambda key: random.permutation(key, x)
    crand = api.jit(rand)

    perm1 = rand(key)
    perm2 = crand(key)

    self.assertAllClose(perm1, perm2, check_dtypes=True)
    self.assertEqual(perm1.dtype, perm2.dtype)
    self.assertFalse(onp.all(perm1 == x))  # seems unlikely!
    self.assertAllClose(onp.sort(perm1), x, check_dtypes=False)
    self.assertArraysAllClose(x, onp.arange(100).astype(dtype),
                              check_dtypes=True)
예제 #27
0
 def train_test_split(data, rng=None, n_test=None):
     """
     Create a train-test split
     """
     rng = PRNGKey(42) if rng is None else rng
     n = len(data.x)
     rng, rng_perm = split(rng)
     i = permutation(rng, n)
     n_test = min(16384, int(0.1 * n)) if n_test is None else n_test
     if isinstance(n_test, float):
         n_test = int(n_test * n)
     n_train = n - n_test
     i_train, i_test = i[:n_train], i[n_train:]
     return (
         Data(data.x[i_train], data.y[i_train]),
         Data(data.x[i_test], data.y[i_test]),
     )
예제 #28
0
파일: made.py 프로젝트: jxzhangjhu/NuX
  def next_mask(self, prev_sel, size, rng):

    # Choose the degrees of the next layer
    max_connection = self.dim - 1 if self.triangular_jacobian == False else self.dim

    if self.method == "random":
      sel = random.randint(rng, shape=(size,), minval=min(jnp.min(sel), max_connection), maxval=dim)
    elif "sequential" in self.method:
      sel = jnp.arange(size)%max(1, max_connection) + min(1, max_connection)
      if self.method == "shuffled_sequential":
        sel = random.permutation(rng, sel)
    else:
      assert 0, "Invalid mask method"

    # Create the new mask
    mask = (prev_sel[:,None] <= sel).astype(jnp.int32)
    return mask, sel
def data_preprocessing():
    """ Seperates data (spin configurations) into test and training set and generates labels"""
    rng = random.PRNGKey(0)

    temperatures = jnp.linspace(1.0, 4.0, 7)
    temperatures1 = [1.0, 1.5, 3.0, 3.5, 4.0]
    temperatures2 = [2.0, 2.5]

    x_train = []
    y_train = []
    x_test = []
    y_test = []
    for T in temperatures:
        configs = jnp.load('data/spins_T%s.npy' % T)
        magnetization_density = jnp.abs(
            jnp.array([jnp.sum(config) / config.size for config in configs]))
        labels = jnp.where(magnetization_density < 0.5, 0, 1)
        if T in temperatures2:
            x_test.append(configs)
            y_test.append(labels)
        else:
            indices = random.permutation(rng, labels.size)
            y_test.append(labels[indices[:int(0.2 * labels.size)]])
            y_train.append(labels[indices[int(0.2 * labels.size):]])
            x_test.append(configs[indices[:int(0.2 * labels.size)]])
            x_train.append(configs[indices[int(0.2 * labels.size):]])

    y_test_new = jnp.array(y_test[0])
    x_test_new = jnp.array(x_test[0])
    for i in range(len(y_test) - 1):
        y_test_new = jnp.concatenate((y_test_new, y_test[i + 1]))
        x_test_new = jnp.concatenate((x_test_new, x_test[i + 1]))

    L = jnp.array(x_train).shape[2]
    x_test = jnp.array(x_test_new).reshape((-1, L, L, 1)).astype(jnp.float64)
    y_test = jnp.array(y_test_new).reshape((-1, 1))
    x_train = jnp.array(x_train).reshape((-1, L, L, 1)).astype(jnp.float64)
    y_train = jnp.array(y_train).reshape((-1, 1))

    jnp.save('data/x_test.npy', x_test)
    jnp.save('data/y_test.npy', y_test)
    jnp.save('data/x_train.npy', x_train)
    jnp.save('data/y_train.npy', y_train)

    return x_train, y_train, x_test, y_test
예제 #30
0
    def init_fun(rng, input_dim, **kwargs):

        perm = random.permutation(rng, np.arange(input_dim))
        inv_perm = np.argsort(perm)

        @ForwardFunction
        def forward_fun(params, inputs, **kwargs):
            outputs = inputs[:, perm]
            log_det = np.zeros(inputs.shape[0])
            return outputs, log_det

        @InverseFunction
        def inverse_fun(params, inputs, **kwargs):
            outputs = inputs[:, inv_perm]
            log_det = np.zeros(inputs.shape[0])
            return outputs, log_det

        return (), forward_fun, inverse_fun