Ejemplo n.º 1
0
def cov_estimate(*, optimization_path: Sequence[jnp.ndarray],
                 optimization_path_grads: Sequence[jnp.ndarray], history: int):
    """Estimate covariance from an optimization path."""
    dim = optimization_path[0].shape[0]
    position_diffs = jnp.empty((dim, 0))
    gradient_diffs = jnp.empty((dim, 0))
    approximations: List[Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]] = []
    diagonal_estimate = jnp.ones(dim)
    for j in range(len(optimization_path) - 1):
        _, thin_factors, scaling_outer_product = bfgs_inverse_hessian(
            updates_of_position_differences=position_diffs,
            updates_of_gradient_differences=gradient_diffs,
        )
        position_diff = optimization_path[j + 1] - optimization_path[j]
        gradient_diff = optimization_path_grads[j] - optimization_path_grads[
            j + 1]
        b = position_diff @ gradient_diff
        gradient_diff_norm = gradient_diff**2
        new_diagonal_estimate = diagonal_estimate
        if b < 1e-12 * jnp.sum(gradient_diff_norm):
            position_diffs = jnp.column_stack(
                (position_diffs[:, -history + 1:], position_diff))
            gradient_diffs = jnp.column_stack(
                (gradient_diffs[:, -history + 1:], gradient_diff))
            a = gradient_diff @ (diagonal_estimate * gradient_diff)
            c = position_diff @ (position_diff / diagonal_estimate)
            new_diagonal_estimate = 1.0 / (a / (b * diagonal_estimate) +
                                           gradient_diff_norm / b -
                                           (a * position_diff**2) /
                                           (b * c * diagonal_estimate**2))
        approximations.append(
            (diagonal_estimate, thin_factors, scaling_outer_product))
        diagonal_estimate = new_diagonal_estimate
    return approximations
Ejemplo n.º 2
0
def lbfgs(
    *,
    log_target_density: Callable[[jnp.ndarray], jnp.ndarray],
    initial_value: jnp.ndarray,  # theta_init
    inverse_hessian_history: int = 6,  # J
    relative_tolerance: float = 1e-13,  # tau_rel
    max_iters: int = 1000,  # L
    wolfe_bounds: Tuple[float, float] = (1e-4, 0.9),
    positivity_threshold: float = 2.2e-16,
):
    """LBFGS implementation which returns the optimization path and gradients."""
    dim = initial_value.shape[0]
    grad_log_density = jax.grad(log_target_density)
    optimization_path = [initial_value]
    current_lp = log_target_density(initial_value)
    grad_optimization_path = [grad_log_density(initial_value)]
    position_diffs = jnp.empty((dim, 0))
    gradient_diffs = jnp.empty((dim, 0))
    for _ in range(max_iters):
        diagonal_estimate, thin_factors, scaling_outer_product = bfgs_inverse_hessian(
            updates_of_position_differences=position_diffs,
            updates_of_gradient_differences=gradient_diffs,
        )
        grad_lp = grad_optimization_path[-1]
        search_direction = diagonal_estimate * grad_lp + thin_factors @ (
            scaling_outer_product @ (jnp.transpose(thin_factors) @ grad_lp))
        step_size = 1.0
        while step_size > 1e-8:
            proposed = optimization_path[-1] + step_size * search_direction
            proposed_lp = log_target_density(proposed)
            if proposed_lp >= current_lp + (wolfe_bounds[0] * grad_lp) @ (
                    step_size * search_direction):
                proposed_grad = grad_log_density(proposed)
                if (proposed_grad @ search_direction <=
                        wolfe_bounds[1] * grad_lp @ search_direction):
                    break
            step_size = 0.5 * step_size
        optimization_path.append(proposed)
        grad_optimization_path.append(proposed_grad)
        if (proposed_lp -
                current_lp) / jnp.abs(current_lp) < relative_tolerance:
            return optimization_path, grad_optimization_path
        current_lp = proposed_lp

        position_diff: jnp.ndarray = optimization_path[-1] - optimization_path[
            -2]
        grad_diff = -grad_optimization_path[-1] + grad_optimization_path[-2]
        if position_diff @ grad_diff > positivity_threshold * jnp.sum(grad_diff
                                                                      **2):
            position_diffs = jnp.column_stack(
                (position_diffs[:,
                                -inverse_hessian_history + 1:], position_diff))
            gradient_diffs = jnp.column_stack(
                (gradient_diffs[:, -inverse_hessian_history + 1:], grad_diff))
    return optimization_path, grad_optimization_path
Ejemplo n.º 3
0
def main():
    smc = np.load('mc_ddpip_3d_smeared.npy')
    print(smc.shape)

    data = np.column_stack(get_vars(smc)[:3])
    print(data.shape)

    do_fit(data, 100)
Ejemplo n.º 4
0
 def _indices(key):
     if not sparse_shape:
         return jnp.empty((nse, n_sparse), dtype=int)
     flat_ind = random.choice(key,
                              sparse_size,
                              shape=(nse, ),
                              replace=not unique_indices)
     return jnp.column_stack(jnp.unravel_index(flat_ind, sparse_shape))
Ejemplo n.º 5
0
def momentum_from_cluster(clu: Cluster) -> (Momentum):
    """ Assuming photon (massless particle) """
    sinth = clu.sinth
    return Momentum.from_ndarray(
        np.column_stack([
            clu.energy * sinth * np.sin(clu.phi),
            clu.energy * sinth * np.cos(clu.phi),
            clu.energy * clu.costh,
        ]))
Ejemplo n.º 6
0
def test_cluster_from_ndarray():
    """ """
    N = 100
    energy = rjax.uniform(rng, (N,), minval=0., maxval=3.)
    costh = rjax.uniform(rng, (N,), minval=-1., maxval=1.)
    phi = rjax.uniform(rng, (N,), minval=-np.pi, maxval=np.pi)
    clu = Cluster.from_ndarray(np.column_stack([energy, costh, phi]))
    assert np.allclose(clu.energy, energy)
    assert np.allclose(clu.costh, costh)
    assert np.allclose(clu.phi, phi)
    assert clu.as_array.shape == (N, 3)
Ejemplo n.º 7
0
def sample(events: np.ndarray) -> (np.ndarray):
    """ Resolution sampler for MC events
    Args:
        - events: [E (MeV), m^2(DD) (GeV^2), m^2(Dpi) (GeV^2)]
    """
    e = events[:, 0] * 10**-3
    tdd = jnp.sqrt(events[:, 1]) - 2 * mdn
    mdpi = jnp.sqrt(events[:, 2])

    offsets = jax.random.normal(jax.random.PRNGKey(1), events.shape)

    return jnp.column_stack([
        (e + smddpi(e, tdd) * offsets[:, 0]) * 10**3,
        (tdd + stdd(tdd) * offsets[:, 1] + 2 * mdn)**2,
        (mdpi + smdstp() * offsets[:, 2])**2,
    ])
Ejemplo n.º 8
0
def test_linear_regression(in_dim=5, out_dim=2, shape=(1000, )):
    key = jr.PRNGKey(time.time_ns())
    key1, key2 = jr.split(key, 2)
    covariates = jr.normal(key1, shape + (in_dim, ))
    covariates = np.column_stack([covariates, np.ones(shape)])
    data = jr.normal(key2, shape + (out_dim, ))
    lr = regrs.GaussianLinearRegression.fit(
        dict(data=data, covariates=covariates))

    # compare to least squares fit.  note that the covariance matrix is only the
    # covariance of the residuals if we fit the intercept term
    what = np.linalg.lstsq(covariates, data)[0].T
    assert np.allclose(lr.weights, what)
    resid = data - covariates @ what.T
    assert np.allclose(lr.covariance_matrix,
                       np.cov(resid, rowvar=False, bias=True),
                       atol=1e-6)
Ejemplo n.º 9
0
def real_fourier_basis(n: int) -> Tuple[np.ndarray, np.ndarray]:
    """Construct a real unitary fourier basis of size (n).

    Args:
        n: The basis size:

    Returns:
        A tuple of the basis, and the frequencies at which to evaluate
        the spectral distribution function to get the variances for the
        Fourier-domain coefficients.

    """
    assert n > 1
    dc = np.ones((n, ))
    dc_freq = 0

    cosine_basis_vectors = []
    cosine_freqs = []
    sine_basis_vectors = []
    sine_freqs = []

    ts = np.arange(n)
    for w in range(1, 1 + (n - 1) // 2):
        x = w * (2 * np.pi / n) * ts
        cosine_basis_vectors.append(np.sqrt(2) * np.cos(x))
        cosine_freqs.append(w)
        sine_basis_vectors.append(-np.sqrt(2) * np.sin(x))
        sine_freqs.append(w)

    if n % 2 == 0:
        w = n // 2
        x = w * 2 * np.pi * ts / n
        cosine_basis_vectors.append(np.cos(x))
        cosine_freqs.append(w)

    basis = np.column_stack(
        (dc, *cosine_basis_vectors, *sine_basis_vectors[::-1]))
    freqs = np.concatenate([
        np.array([dc_freq]),
        np.array(cosine_freqs),
        np.array(sine_freqs[::-1])
    ])

    return basis / np.sqrt(n), freqs / n
Ejemplo n.º 10
0
    def __call__(self, task_features, node_features, neighbor_relations,
                 neighbor_features):
        """Maps task, node, edge and neighbors to logits.

    We use both multiplicative and additive dependency between
      representations of task, node, neighbors, neighbor_features.

    Args:
      task_features: Not used.
      node_features: Integer tensor of size `num_neighbors x num_node_features`
        with (duplicated) features of the current node.
      neighbor_relations: Integer tensor of size `num_neighbors`.
      neighbor_features: Integer tensor of size `num_neighbors x
        num_node_features` of neighbor features.

    Returns:
      A float tensor of size `num_neighbors` with logits associated with
      the probability of transitioning to that neighbor.
    """
        del task_features
        num_neighbors = len(neighbor_relations)
        # Embed all inputs
        inputs = (node_features, neighbor_relations, neighbor_features)
        embeddings = [emb(x) for emb, x in zip(self.input_embedding, inputs)]

        # Reshape
        node_embs, relation_embs, neighbor_embs = embeddings
        node_embs = node_embs.reshape(num_neighbors, -1)
        relation_embs = relation_embs.reshape(num_neighbors, -1)
        neighbor_embs = neighbor_embs.reshape(num_neighbors, -1)

        local_feature_embs = jnp.column_stack((node_embs, neighbor_embs))
        hidden_local_features = self.node_hidden(local_feature_embs)
        hidden_relation_embs = self.relation_hidden_mul(relation_embs)
        hidden_relation_bias = self.relation_hidden_sum(relation_embs)

        # FiLM-like layer to mix representations of neighborhood (node+neighbor)
        # and edge
        mixed_representation = hidden_local_features * hidden_relation_embs + hidden_relation_bias

        mixed_representation = nn.relu(mixed_representation)
        neighbor_logits = self.output_layer(mixed_representation)
        return neighbor_logits.squeeze(-1)
Ejemplo n.º 11
0
    def __call__(self, param_name, param):
        """Shuffles the weight matrix/mask for a given parameter, per-neuron.

    This is to be used with mask_map, and accepts the standard mask_map
    function parameters.

    Args:
      param_name: The parameter's name.
      param: The parameter's weight or mask matrix.

    Returns:
      A shuffled weight/mask matrix, with each neuron shuffled independently.
    """
        del param_name  # Unused.

        incoming_connections = jnp.prod(jnp.array(param.shape[:-1]))
        num_neurons = param.shape[-1]

        # Ensure each input neuron has at least one connection unmasked.
        mask = _fill_diagonal_wrap((incoming_connections, num_neurons),
                                   1,
                                   dtype=jnp.uint8)

        # Randomly shuffle which of the neurons have these connections.
        mask = jax.random.shuffle(self._get_rng(), mask, axis=0)

        # Add extra required random connections to mask to satisfy sparsity.
        mask_cols = []
        for col in range(mask.shape[-1]):
            neuron_mask = mask[:, col]
            off_diagonal_count = max(
                round((1 - self._sparsity) * incoming_connections) -
                jnp.count_nonzero(neuron_mask), 0)

            zero_indices = jnp.flatnonzero(neuron_mask == 0)
            random_entries = _random_neuron_mask(len(zero_indices),
                                                 off_diagonal_count,
                                                 self._get_rng())

            neuron_mask = neuron_mask.at[zero_indices].set(random_entries)
            mask_cols.append(neuron_mask)

        return jnp.column_stack(mask_cols).reshape(param.shape)
def table_ambisonics_order_vs_rE(max_order=20):
    """Return a dataframe with rE as a function of order."""
    order = np.arange(1, max_order + 1, dtype=np.int32)
    rE3 = np.array(list(map(shelf.max_rE_3d, order)))
    drE3 = np.append(np.nan, rE3[1:] - rE3[:-1])

    rE2 = np.array(list(map(shelf.max_rE_2d, order)))
    drE2 = np.append(np.nan, rE2[1:] - rE2[:-1])

    df = pd.DataFrame(np.column_stack((
        order,
        rE2,
        100 * drE2 / rE2,
        2 * np.arccos(rE2) * 180 / π,
        rE3,
        100 * drE3 / rE3,
        2 * np.arccos(rE3) * 180 / π,
    )),
                      columns=('order', '2D', '% change', 'asw', '3D',
                               '% change', 'asw'))
    return df
Ejemplo n.º 13
0
def do_fit(data, steps=100):
    """ """
    norm_sample_size = 10**6
    norm_space = (
        (-2.5, 15),  # energy range (MeV)
        (0, 22),  # T(DD) range (MeV)
        (2.004, 2.026),  # m(Dpi) range (GeV)
    )

    _, initial_params = NN.init(rjax.PRNGKey(0), data)
    print('Model initialized:')
    print(jax.tree_map(np.shape, initial_params))

    rng = rjax.PRNGKey(10)
    rng, key1, key2, key3 = rjax.split(rng, 4)
    keys = [key1, key2, key3]
    norm_sample = np.column_stack([
        rjax.uniform(rjax.PRNGKey(key[0]), (norm_sample_size,), minval=lo, maxval=hi)\
            for key, (lo, hi) in zip(keys, norm_space)
    ])

    def loglh_loss(model):
        """ loss function for the unbinned maximum likelihood fit """
        return -np.sum(np.log(model(data))) +\
            data.shape[0] * np.log(np.sum(model(norm_sample)))

    model = nn.Model(NN, initial_params)
    adam = optim.Adam(learning_rate=0.03)
    optimizer = adam.create(model)

    for i in range(steps):
        loss, grad = jax.value_and_grad(loglh_loss)(optimizer.target)
        optimizer = optimizer.apply_gradient(grad)
        print(f'{i}/{steps}: loss: {loss:.3f}')

    with open('nn_model_3d.dat', 'wb') as ofile:
        state = TrainState(optimizer=optimizer)
        data = flax.serialization.to_bytes(state)
        print(f'Model serialized, num bytes: {len(data)}')
        ofile.write(data)
Ejemplo n.º 14
0
    def append(self, state: Tuple) -> None:
        """Append a trace or new elements to the current trace. This is useful
        when performing repeated inference on the same dataset, or using the
        generator runtime. Sequential inference should use different traces for
        each sequence.

        Parameter
        ---------
        state
            A tuple that contains the chain state and the corresponding sampling info.
        """
        sample, sample_info = state
        concatenate = lambda cur, new: jnp.concatenate((cur, new), axis=1)
        concatenate_1d = lambda cur, new: jnp.column_stack((cur, new))

        if self.raw.samples is None:
            stacked_chain = sample
        else:
            try:
                stacked_chain = jax.tree_multimap(concatenate,
                                                  self.raw.samples, sample)
            except TypeError:
                stacked_chain = jax.tree_multimap(concatenate_1d,
                                                  self.raw.samples, sample)

        if self.raw.sampling_info is None:
            stacked_info = sample_info
        else:
            try:
                stacked_info = jax.tree_multimap(concatenate,
                                                 self.raw.sampling_info,
                                                 sample_info)
            except TypeError:
                stacked_info = jax.tree_multimap(concatenate_1d,
                                                 self.raw.sampling_info,
                                                 sample_info)

        self.raw = replace(self.raw,
                           samples=stacked_chain,
                           sampling_info=stacked_info)
Ejemplo n.º 15
0
def _sparse_bcoo_todense(state, jit: bool = False, compile: bool = False):
    shape = (2000, 2000)
    nse = 10000
    size = np.prod(shape)
    rng = np.random.RandomState(1701)
    data = rng.randn(nse)
    indices = np.unravel_index(rng.choice(size, size=nse, replace=False),
                               shape=shape)
    mat = sparse.BCOO((jnp.array(data), jnp.column_stack(indices)),
                      shape=shape)

    f = lambda mat: mat.todense()
    if jit or compile:
        f = jax.jit(f)

    if compile:
        while state:
            f.lower(mat).compile()
    else:
        f(mat).block_until_ready()
        while state:
            f(mat).block_until_ready()
Ejemplo n.º 16
0
def cartesian_to_cluster(pos: Position, mom: Momentum, R=1000) -> (Cluster):
    """ Make cluster from Position and Momentum """
    # Find intersection with calorimeter (see the geometry sketch)
    pos_mom_scalar_product = pos.x * mom.px + pos.y * mom.py + pos.z * mom.pz
    mom_total = mom.ptot
    mom_total_squared = mom_total**2

    full_computation = False
    if full_computation:
        alpha = (-pos_mom_scalar_product +
                 np.sqrt(pos_mom_scalar_product**2 + mom_total_squared *
                         (R**2 - pos.r**2))) / mom_total_squared
    else:  # takes into account that R >> particle flight length
        alpha = R / mom_total - pos_mom_scalar_product / mom_total_squared

    cluster_position = Position.from_ndarray(pos.as_array +
                                             alpha.reshape(-1, 1) *
                                             mom.as_array)

    return Cluster.from_ndarray(
        np.column_stack(
            [mom_total, cluster_position.costh, cluster_position.phi]))
Ejemplo n.º 17
0
    def _add_supernode(self, node_features, dense_submat, dense_q):
        """Adds supernode with full incoming and outgoing connectivity.

    Adds a row and column of 1s to `dense_submat`, and normalizes. Also adds a
      row to `node_features`, containing the average of the other node features.
      Adds a weight of 1 at the end of `dense_q`.

    Args:
      node_features: Shape (num_nodes, feature_dim) Matrix of node features.
      dense_submat: Shape (num_nodes, num_nodes) Adjacency matrix.
      dense_q: Shape (num_nodes,) Node weights.

    Returns:
      node_features: Shape (num_nodes + 1, feature_dim) Matrix of node features.
      dense_submat: Shape (num_nodes + 1, num_nodes + 1) Adjacency matrix.
      dense_q: Shape (num_nodes + 1,) Node weights.
    """
        dense_submat = jnp.row_stack(
            (dense_submat, jnp.ones(dense_submat.shape[1])))
        dense_submat = jnp.column_stack(
            (dense_submat, jnp.ones(dense_submat.shape[0])))
        # Normalize nonzero elements
        # The sum is bounded away from 0, so this is always differentiable
        # TODO(gnegiar): Do we want this? It means the supernode gets half the
        # outgoing weights
        dense_submat = dense_submat / dense_submat.sum(axis=-1, keepdims=True)
        # Add a weight to the supernode
        dense_q = jnp.append(dense_q, jnp.mean(dense_q))
        # We embed the supernode using a distinct value.
        # TODO(gnegiar): Should we use another embedding?
        node_features = jnp.append(node_features,
                                   jnp.full((1, node_features.shape[1]),
                                            2,
                                            dtype=int),
                                   axis=0)
        return node_features, dense_submat, dense_q
Ejemplo n.º 18
0
 def as_array(self) -> (np.ndarray):
     return np.column_stack([self.x, self.y, self.z])
Ejemplo n.º 19
0
def column_stack(arrays):
  arrays = [a.value if isinstance(a, JaxArray) else a for a in arrays]
  return JaxArray(jnp.column_stack(arrays))
Ejemplo n.º 20
0
def main_opt(N, l, i0, nn_arq, act_fun, n_epochs, lr, w_decay, rho_g):

    start_time = time.time()

    str_nn_arq = ''
    for item in nn_arq:
        str_nn_arq = str_nn_arq + '_{}'.format(item)

    f_job = 'nn_arq{}_N_{}_i0_{}_l_{}_batch'.format(str_nn_arq, N, i0, l)
    f_out = '{}/out_opt_{}.txt'.format(r_dir, f_job)
    f_w_nn = '{}/W_{}.npy'.format(r_dir, f_job)
    file_results = '{}/data_nh3_{}.npy'.format(r_dir, f_job)

    #     --------------------------------------
    #     Data
    n_atoms = 4
    batch_size = 768  #1024#768#512#256#128#64#32
    Dtr, Dval, Dt = load_data(file_results, N, l)
    Xtr, gXtr, gXctr, ytr = Dtr
    Xval, gXval, gXcval, yval = Dval
    Xt, gXt, gXct, yt = Dt
    print(gXtr.shape, gXtr.shape, gXctr.shape, ytr.shape)
    # --------------------------------
    #     BATCHES

    n_complete_batches, leftover = divmod(N, batch_size)
    n_batches = n_complete_batches + bool(leftover)

    def data_stream():
        rng = onpr.RandomState(0)
        while True:
            perm = rng.permutation(N)
            for i in range(n_batches):
                batch_idx = perm[i * batch_size:(i + 1) * batch_size]
                yield Xtr[batch_idx], gXtr[batch_idx], gXctr[batch_idx], ytr[
                    batch_idx]

    batches = data_stream()
    # --------------------------------

    f = open(f_out, 'a+')
    print('-----------------------------------', file=f)
    print('Starting time', file=f)
    print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M"), file=f)
    print('-----------------------------------', file=f)
    print(f_out, file=f)
    print('N = {}, n_atoms = {}, data_random = {}, NN_random = {}'.format(
        N, n_atoms, l, i0),
          file=f)
    print(nn_arq, file=f)
    print('lr = {}, w decay = {}'.format(lr, w_decay), file=f)
    print('Activation function = {}'.format(act_fun), file=f)
    print('N Epoch = {}'.format(n_epochs), file=f)
    print('rho G = {}'.format(rho_g), file=f)
    print('-----------------------------------', file=f)
    f.close()

    #     --------------------------------------
    #     initialize NN

    nn_arq.append(3)
    tuple_nn_arq = tuple(nn_arq)
    nn_model = NN_adiab(n_atoms, tuple_nn_arq)

    def get_init_NN_params(key):
        x = Xtr[0, :]
        x = x[None, :]  #         x = jnp.ones((1,Xtr.shape[1]))
        variables = nn_model.init(key, x)
        return variables

#     Initilialize parameters

    rng = random.PRNGKey(i0)
    rng, subkey = jax.random.split(rng)
    params = get_init_NN_params(subkey)

    f = open(f_out, 'a+')
    if os.path.isfile(f_w_nn):
        print('Reading NN parameters from prev calculation!', file=f)
        print('-----------------------', file=f)

        nn_dic = jnp.load(f_w_nn, allow_pickle=True)
        params = unfreeze(params)
        params['params'] = nn_dic.item()['params']
        params = freeze(params)
#         print(params)

    f.close()
    init_params = params

    #     --------------------------------------
    #     Phys functions

    @jit
    def nn_adiab(params, x):
        y_ad_pred = nn_model.apply(params, x)
        return y_ad_pred

    @jit
    def jac_nn_adiab(params, x):
        g_y_pred = jacrev(nn_adiab, argnums=1)(params, x[None, :])
        return jnp.reshape(g_y_pred, (2, g_y_pred.shape[-1]))

#     --------------------------------------
#    training loss functions

    @jit
    def f_loss_ad_energy(params, batch):
        X_inputs, _, _, y_true = batch
        y_pred = nn_adiab(params, X_inputs)
        diff_y = y_pred - y_true  #Ha2cm*
        return jnp.linalg.norm(diff_y, axis=0)

    @jit
    def f_loss_jac(params, batch):
        X_inputs, gX_inputs, _, y_true = batch
        gX_pred = vmap(jac_nn_adiab, (None, 0))(params, X_inputs)
        diff_g_X = gX_pred - gX_inputs
        # jnp.linalg.norm(diff_g_X,axis=0)

        diff_g_X0 = diff_g_X[:, 0, :]
        diff_g_X1 = diff_g_X[:, 1, :]
        l0 = jnp.linalg.norm(diff_g_X0)
        l1 = jnp.linalg.norm(diff_g_X1)
        return jnp.stack([l0, l1])

#     ------

    @jit
    def f_loss(params, rho_g, batch):
        rho_g = jnp.exp(rho_g)
        loss_ad_energy = f_loss_ad_energy(params, batch)
        loss_jac_energy = f_loss_jac(params, batch)
        loss = jnp.vdot(jnp.ones_like(loss_ad_energy),
                        loss_ad_energy) + jnp.vdot(rho_g, loss_jac_energy)
        return loss
#     --------------------------------------
#     Optimization  and Training

#     Perform a single training step.

    @jit
    def train_step(optimizer, rho_g, batch):  #, learning_rate_fn, model
        grad_fn = jax.value_and_grad(f_loss)
        loss, grad = grad_fn(optimizer.target, rho_g, batch)
        optimizer = optimizer.apply_gradient(grad)  #, {"learning_rate": lr}
        return optimizer, (loss, grad)

#     @jit

    def train(rho_g, nn_params):
        optimizer = optim.Adam(learning_rate=lr,
                               weight_decay=w_decay).create(nn_params)
        optimizer = jax.device_put(optimizer)

        train_loss = []
        loss0 = 1E16
        loss0_tot = 1E16
        itercount = itertools.count()
        f_params = init_params
        for epoch in range(n_epochs):
            for _ in range(n_batches):
                optimizer, loss_and_grad = train_step(optimizer, rho_g,
                                                      next(batches))
                loss, grad = loss_and_grad

#             f = open(f_out,'a+')
#             print(i,loss,file=f)
#             f.close()

            train_loss.append(loss)
#             params = optimizer.target
#             loss_tot = f_validation(params)

        nn_params = optimizer.target

        return nn_params, loss_and_grad, train_loss

    @jit
    def val_step(optimizer, nn_params):  #, learning_rate_fn, model

        rho_g_prev = optimizer.target
        nn_params, loss_and_grad_train, train_loss_iter = train(
            rho_g_prev, nn_params)
        loss_train, grad_loss_train = loss_and_grad_train

        grad_fn_val = jax.value_and_grad(f_loss, argnums=1)
        loss_val, grad_val = grad_fn_val(nn_params, optimizer.target, Dval)
        optimizer = optimizer.apply_gradient(
            grad_val)  #, {"learning_rate": lr}
        return optimizer, nn_params, (loss_val, loss_train,
                                      train_loss_iter), (grad_loss_train,
                                                         grad_val)

#     Initilialize rho_G

    rng = random.PRNGKey(0)
    rng, subkey = jax.random.split(rng)

    rho_G0 = random.uniform(subkey, shape=(2, ), minval=5E-4, maxval=0.025)
    rho_G0 = jnp.log(rho_G0)
    print('Initial lambdas', rho_G0)
    init_G = rho_G0  #

    optimizer_out = optim.Adam(learning_rate=2E-4,
                               weight_decay=0.).create(init_G)
    optimizer_out = jax.device_put(optimizer_out)

    f_params = init_params

    for i in range(50000):
        start_va_time = time.time()
        optimizer_out, f_params, loss_all, grad_all = val_step(
            optimizer_out, f_params)

        rho_g = optimizer_out.target
        loss_val, loss_train, train_loss_iter = loss_all
        grad_loss_train, grad_val = grad_all

        loss0_tot = f_loss(f_params, rho_g, Dt)

        dict_output = serialization.to_state_dict(f_params)
        jnp.save(f_w_nn, dict_output)  #unfreeze()

        f = open(f_out, 'a+')
        #         print(i,rho_g, loss0, loss0_tot, (time.time() - start_va_time),file=f)
        print(i, loss_val, loss_train, (time.time() - start_va_time), file=f)
        print(jnp.exp(rho_g), file=f)
        print(grad_val, file=f)
        #         print(train_loss_iter ,file=f)
        #         print(grad_val,file=f)
        #         print(grad_loss_train,file=f)
        f.close()


#     --------------------------------------
#     Prediction
    f = open(f_out, 'a+')
    print('Prediction of the entire data set', file=f)
    print('N = {}, n_atoms = {}, random = {}'.format(N, n_atoms, i0), file=f)
    print('NN : {}'.format(nn_arq), file=f)
    print('lr = {}, w decay = {}, rho G = {}'.format(lr, w_decay, rho_g),
          file=f)
    print('Activation function = {}'.format(act_fun), file=f)
    print('Total points  = {}'.format(yt.shape[0]), file=f)

    y_pred = nn_adiab(f_params, Xt)
    gX_pred = vmap(jac_nn_adiab, (None, 0))(f_params, Xt)

    diff_y = y_pred - yt
    rmse_Ha = jnp.linalg.norm(diff_y)
    rmse_cm = jnp.linalg.norm(Ha2cm * diff_y)
    mae_Ha = jnp.linalg.norm(diff_y, ord=1)
    mae_cm = jnp.linalg.norm(Ha2cm * diff_y, ord=1)

    print('RMSE = {} [Ha]'.format(rmse_Ha), file=f)
    print('RMSE(tr) = {} [cm-1]'.format(loss0), file=f)
    print('RMSE = {} [cm-1]'.format(rmse_cm), file=f)
    print('MAE = {} [Ha]'.format(mae_Ha), file=f)
    print('MAE = {} [cm-1]'.format(mae_cm), file=f)

    Dpred = jnp.column_stack((Xt, y_pred))
    data_dic = {
        'Dtr': Dtr,
        'Dpred': Dpred,
        'gXpred': gX_pred,
        'loss_tr': loss0,
        'error_full': rmse_cm,
        'N': N,
        'l': l,
        'i0': i0,
        'rho_g': rho_g
    }

    jnp.save(file_results, data_dic)

    print('---------------------------------', file=f)
    print('Total time =  %.6f seconds ---' % ((time.time() - start_time)),
          file=f)
    print('---------------------------------', file=f)
    f.close()
Ejemplo n.º 21
0
def cbind(x, y, backend="cpu"):
    # if len(x.shape) == 1 or len(y.shape) == 1:
    sys_platform = platform.system()
    if backend in ("gpu", "tpu") and (sys_platform in ("Linux", "Darwin")):
        return jnp.column_stack((x, y))
    return np.column_stack((x, y))
ring = Ring(10, .1)
gauss = Gaussian([0, 0], 9)
ring_target = Setup(ring, gauss)
ring_proposal = Setup(gauss, ring)

target = Ring(10, .1)
proposal = Ring(15, .1)
double_ring = Setup(target, proposal)

target = Squiggle([0, 0], [1, .1])
proposal = Gaussian([-2, 0], [1, 1])
squiggle_target = Setup(target, proposal)

means = np.array([np.exp(2j * np.pi * x) for x in np.linspace(0, 1, 6)[:-1]])
means = np.column_stack((means.real, means.imag))
target = GaussianMixture(means, .03, np.ones(5))
proposal = Gaussian([-2, 0], [.5, .5])
mix_of_gauss = Setup(target, proposal)

target = Gaussian([0, 0], [1e-4, 9])
proposal = Gaussian([0, 0], [1, 1])
thin_target = Setup(target, proposal)

d = 50
variances = np.logspace(-2, 0, num=d)
target = Gaussian(np.zeros(d), variances)
proposal = Gaussian(np.zeros(d), np.ones(d))
high_d_gaussian = Setup(target, proposal)

# rotated gaussian
def main_opt(N, l, i0, nn_arq, act_fun, n_epochs, lr, w_decay, rho_g):

    start_time = time.time()

    str_nn_arq = ''
    for item in nn_arq:
        str_nn_arq = str_nn_arq + '_{}'.format(item)

    f_job = 'nn_arq{}_N_{}_i0_{}_l_{}_batch'.format(str_nn_arq, N, i0, l)
    f_out = '{}/out_opt_{}.txt'.format(r_dir, f_job)
    f_w_nn = '{}/W_{}.npy'.format(r_dir, f_job)
    file_results = '{}/data_nh3_{}.npy'.format(r_dir, f_job)

    #     --------------------------------------
    #     Data
    n_atoms = 4
    batch_size = 768  #1024#768#512#256#128#64#32
    Dtr, Dt = load_data(file_results, N, l)
    Xtr, gXtr, gXctr, ytr = Dtr
    Xt, gXt, gXct, yt = Dt
    print(gXtr.shape, gXtr.shape, gXctr.shape, ytr.shape)
    # --------------------------------
    #     BATCHES

    n_complete_batches, leftover = divmod(N, batch_size)
    n_batches = n_complete_batches + bool(leftover)

    def data_stream():
        rng = onpr.RandomState(0)
        while True:
            perm = rng.permutation(N)
            for i in range(n_batches):
                batch_idx = perm[i * batch_size:(i + 1) * batch_size]
                yield Xtr[batch_idx], gXtr[batch_idx], gXctr[batch_idx], ytr[
                    batch_idx]

    batches = data_stream()
    # --------------------------------

    f = open(f_out, 'a+')
    print('-----------------------------------', file=f)
    print('Starting time', file=f)
    print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M"), file=f)
    print('-----------------------------------', file=f)
    print(f_out, file=f)
    print('N = {}, n_atoms = {}, data_random = {}, NN_random = {}'.format(
        N, n_atoms, l, i0),
          file=f)
    print(nn_arq, file=f)
    print('lr = {}, w decay = {}'.format(lr, w_decay), file=f)
    print('Activation function = {}'.format(act_fun), file=f)
    print('N Epoch = {}'.format(n_epochs), file=f)
    print('rho G = {}'.format(rho_g), file=f)
    print('-----------------------------------', file=f)
    f.close()

    #     --------------------------------------
    #     initialize NN

    nn_arq.append(3)
    tuple_nn_arq = tuple(nn_arq)
    nn_model = NN_adiab(n_atoms, tuple_nn_arq)

    def get_init_NN_params(key):
        x = Xtr[0, :]
        x = x[None, :]  #         x = jnp.ones((1,Xtr.shape[1]))
        variables = nn_model.init(key, x)
        return variables

#     Initilialize parameters

    rng = random.PRNGKey(i0)
    rng, subkey = jax.random.split(rng)
    params = get_init_NN_params(subkey)

    f = open(f_out, 'a+')
    if os.path.isfile(f_w_nn):
        print('Reading NN parameters from prev calculation!', file=f)
        print('-----------------------', file=f)

        nn_dic = jnp.load(f_w_nn, allow_pickle=True)
        params = unfreeze(params)
        params['params'] = nn_dic.item()['params']
        params = freeze(params)
    f.close()
    init_params = params

    #     --------------------------------------
    #     Phys functions

    @jit
    def nn_adiab(params, x):
        y_ad_pred = nn_model.apply(params, x)
        return y_ad_pred

    @jit
    def jac_nn_adiab(params, x):
        g_y_pred = jacrev(nn_adiab, argnums=1)(params, x[None, :])
        return jnp.reshape(g_y_pred, (2, g_y_pred.shape[-1]))

    '''
#     WRONG
    @jit
    def f_nac_coup_i(gH_diab,eigvect_): #for a single cartesian dimension
        temp = jnp.dot(gH_diab,eigvect_[:,0])
        return jnp.vdot(eigvect_[:,1],temp)
    @jit
    def f_nac_coup(params,x):
        eigval_, eigvect_ = f_adiab(params,x)
        gy_diab = jac_nn_diab(params,x)
        gy_diab = jnp.reshape(gy_diab.T,(12,2,2))
        g_coup = vmap(f_nac_coup_i,(0,None))(gy_diab,eigvect_)
        return g_coup
    '''

    #     --------------------------------------
    #     Validation loss functions

    @jit
    def f_validation(params):
        y_pred = nn_adiab(params, Xt)
        diff_y = y_pred - yt
        z = jnp.linalg.norm(diff_y)
        return z

    @jit
    def f_jac_validation(params):
        gX_pred = vmap(jac_nn_adiab, (None, 0))(params, Xt)
        diff_y = gX_pred - gXt
        z = jnp.linalg.norm(diff_y)
        return z

    '''
    @jit
    def f_nac_validation(params):
        g_nac_coup = vmap(f_nac_coup,(None,0))(params,Xt)
        diff_y = g_nac_coup - gXct
        z = jnp.linalg.norm(diff_y)
        return z 
    '''
    #     --------------------------------------
    #    training loss functions
    @jit
    def f_loss_ad_energy(params, batch):
        X_inputs, _, _, y_true = batch
        y_pred = nn_adiab(params, X_inputs)
        diff_y = y_pred - y_true  #Ha2cm*
        loss = jnp.linalg.norm(diff_y)
        return loss

    @jit
    def f_loss_jac(params, batch):
        X_inputs, gX_inputs, _, y_true = batch
        gX_pred = vmap(jac_nn_adiab, (None, 0))(params, X_inputs)
        diff_g_X = gX_pred - gX_inputs
        return jnp.linalg.norm(diff_g_X)

    '''    
    @jit
    def f_loss_nac(params,batch):
        X_inputs, _,gXc_inputs,y_true = batch
        g_nac_coup = vmap(f_nac_coup,(None,0))(params,x)
        diff_y = g_nac_coup - gXc_inputs
        z = jnp.linalg.norm(diff_y)
        return z 
    '''
    #     ------
    @jit
    def f_loss(params, batch):
        loss_ad_energy = f_loss_ad_energy(params, batch)
        #         loss_jac_energy = f_loss_jac(params,batch)
        loss = loss_ad_energy  #+ rho_g*loss_jac_energy
        return loss


#     --------------------------------------
#     Optimization  and Training

#     Perform a single training step.

    @jit
    def train_step(optimizer, batch):  #, learning_rate_fn, model
        grad_fn = jax.value_and_grad(f_loss)
        loss, grad = grad_fn(optimizer.target, batch)
        optimizer = optimizer.apply_gradient(grad)  #, {"learning_rate": lr}
        return optimizer, loss

    optimizer = optim.Adam(learning_rate=lr,
                           weight_decay=w_decay).create(init_params)
    optimizer = jax.device_put(optimizer)

    loss0 = 1E16
    loss0_tot = 1E16
    itercount = itertools.count()
    f_params = init_params
    for epoch in range(n_epochs):
        for _ in range(n_batches):
            optimizer, loss = train_step(optimizer, next(batches))

        params = optimizer.target
        loss_tot = f_validation(params)

        if epoch % 10 == 0:
            f = open(f_out, 'a+')
            print(epoch, loss, loss_tot, file=f)
            f.close()

        if loss < loss0:
            loss0 = loss
            f = open(f_out, 'a+')
            print(epoch, loss, loss_tot, file=f)
            f.close()

        if loss_tot < loss0_tot:
            loss0_tot = loss_tot
            f_params = params
            dict_output = serialization.to_state_dict(params)
            jnp.save(f_w_nn, dict_output)  #unfreeze()

    f = open(f_out, 'a+')
    print('---------------------------------', file=f)
    print('Training time =  %.6f seconds ---' % ((time.time() - start_time)),
          file=f)
    print('---------------------------------', file=f)
    f.close()

    #     --------------------------------------
    #     Prediction
    f = open(f_out, 'a+')
    print('Prediction of the entire data set', file=f)
    print('N = {}, n_atoms = {}, random = {}'.format(N, n_atoms, i0), file=f)
    print('NN : {}'.format(nn_arq), file=f)
    print('lr = {}, w decay = {}, rho G = {}'.format(lr, w_decay, rho_g),
          file=f)
    print('Activation function = {}'.format(act_fun), file=f)
    print('Total points  = {}'.format(yt.shape[0]), file=f)

    y_pred = nn_adiab(f_params, Xt)
    gX_pred = vmap(jac_nn_adiab, (None, 0))(f_params, Xt)

    diff_y = y_pred - yt
    rmse_Ha = jnp.linalg.norm(diff_y)
    rmse_cm = jnp.linalg.norm(Ha2cm * diff_y)
    mae_Ha = jnp.linalg.norm(diff_y, ord=1)
    mae_cm = jnp.linalg.norm(Ha2cm * diff_y, ord=1)

    print('RMSE = {} [Ha]'.format(rmse_Ha), file=f)
    print('RMSE(tr) = {} [cm-1]'.format(loss0), file=f)
    print('RMSE = {} [cm-1]'.format(rmse_cm), file=f)
    print('MAE = {} [Ha]'.format(mae_Ha), file=f)
    print('MAE = {} [cm-1]'.format(mae_cm), file=f)

    Dpred = jnp.column_stack((Xt, y_pred))
    data_dic = {
        'Dtr': Dtr,
        'Dpred': Dpred,
        'gXpred': gX_pred,
        'loss_tr': loss0,
        'error_full': rmse_cm,
        'N': N,
        'l': l,
        'i0': i0,
        'rho_g': rho_g
    }

    jnp.save(file_results, data_dic)

    print('---------------------------------', file=f)
    print('Total time =  %.6f seconds ---' % ((time.time() - start_time)),
          file=f)
    print('---------------------------------', file=f)
    f.close()
Ejemplo n.º 24
0
 def as_array(self) -> (np.ndarray):
     return np.column_stack([self.energy, self.costh, self.phi])
Ejemplo n.º 25
0
 def as_array(self) -> (np.ndarray):
     """ Helix parameters as np.ndarray """
     return np.column_stack([
         self.d0, self.phi0, self.omega, self.z0, self.tanl])