Exemple #1
0
 def update(i, g, state):
     x, avg_sq_grad, mom = state
     avg_sq_grad = avg_sq_grad * gamma + jnp.square(g) * (1. - gamma)
     mom = momentum * mom + step_size(i) * g / jnp.sqrt(avg_sq_grad + eps)
     x = x - mom
     return x, avg_sq_grad, mom
Exemple #2
0
def lennard_jones(conf, lj_params, cutoff, groups=None):
    """
    Implements a non-periodic LJ612 potential using the Lorentz−Berthelot combining
    rules, where sig_ij = (sig_i + sig_j)/2 and eps_ij = sqrt(eps_i * eps_j).

    Parameters
    ----------
    conf: shape [num_atoms, 3] np.array
        atomic coordinates

    params: shape [num_params,] np.array
        unique parameters

    box: shape [3, 3] np.array
        periodic boundary vectors, if not None

    param_idxs: shape [num_atoms, 2] np.array
        each tuple (sig, eps) is used as part of the combining rules

    scale_matrix: shape [num_atoms, num_atoms] np.array
        scale mask denoting how we should scale interaction e[i,j].
        The elements should be between [0, 1]. If e[i,j] is 1 then the interaction
        is fully included, 0 implies it is discarded.

    cutoff: float
        Whether or not we apply cutoffs to the system. Any interactions
        greater than cutoff is fully discarded.
    
    """
    box = None
    assert box is None

    sig = lj_params[:, 0]
    eps = lj_params[:, 1]

    sig_i = np.expand_dims(sig, 0)
    sig_j = np.expand_dims(sig, 1)
    sig_ij = (sig_i + sig_j)/2
    sig_ij_raw = sig_ij

    eps_i = np.expand_dims(eps, 0)
    eps_j = np.expand_dims(eps, 1)

    eps_ij = np.sqrt(eps_i * eps_j)

    eps_ij_raw = eps_ij

    ri = np.expand_dims(conf, 0)
    rj = np.expand_dims(conf, 1)
    gi = np.expand_dims(groups, axis=0)
    gj = np.expand_dims(groups, axis=1)
    gij = np.bitwise_and(gi, gj) > 0

    # print(gij)
    dij = distance(ri, rj, box, gij)

    if cutoff is not None:
        eps_ij = np.where(dij < cutoff, eps_ij, np.zeros_like(eps_ij))

    N = conf.shape[0]
    keep_mask = np.ones((N,N)) - np.eye(N)

    # (ytz): this avoids a nan in the gradient in both jax and tensorflow
    sig_ij = np.where(keep_mask, sig_ij, np.zeros_like(sig_ij))
    eps_ij = np.where(keep_mask, eps_ij, np.zeros_like(eps_ij))

    sig2 = sig_ij/dij
    sig2 *= sig2
    sig6 = sig2*sig2*sig2

    eij = 4*eps_ij*(sig6-1.0)*sig6

    # if cutoff is not None:
        # sw = switch_fn(dij, cutoff)
        # eij = eij*sw

    eij = np.where(keep_mask, eij, np.zeros_like(eij))
    return np.sum(eij/2)
Exemple #3
0
def global_norm(updates: Updates) -> Updates:
  return jnp.sqrt(
      sum([jnp.sum(jnp.square(x)) for x in jax.tree_leaves(updates)]))
Exemple #4
0
def _statistics(data, batch_size):
    data = jnp.atleast_1d(data)
    if data.ndim == 1:
        data = data.reshape((1, -1))

    if data.ndim > 2:
        raise NotImplementedError("Statistics are implemented only for ndim<=2")

    mean = _mean(data)
    variance = _var(data)

    ts = _total_size(data)

    bare_var = variance

    batch_var, n_batches = _batch_variance(data)

    l_block = max(1, data.shape[1] // batch_size)

    block_var, n_blocks = _block_variance(data, l_block)

    tau_batch = ((ts / n_batches) * batch_var / bare_var - 1) * 0.5
    tau_block = ((ts / n_blocks) * block_var / bare_var - 1) * 0.5

    batch_good = (tau_batch < 6 * data.shape[1]) * (n_batches >= batch_size)
    block_good = (tau_block < 6 * l_block) * (n_blocks >= batch_size)

    stat_dtype = nkjax.dtype_real(data.dtype)
    # if batch_good:
    #    error_of_mean = jnp.sqrt(batch_var / n_batches)
    #    tau_corr = jnp.max(0, tau_batch)
    # elif block_good:
    #    error_of_mean = jnp.sqrt(block_var / n_blocks)
    #    tau_corr = jnp.max(0, tau_block)
    # else:
    #    error_of_mean = jnp.nan
    #    tau_corr = jnp.nan
    # jax style
    def batch_good_err(args):
        batch_var, tau_batch, *_ = args
        error_of_mean = jnp.sqrt(batch_var / n_batches)
        tau_corr = jnp.clip(tau_batch, 0)
        return jnp.asarray(error_of_mean, dtype=stat_dtype), jnp.asarray(
            tau_corr, dtype=stat_dtype
        )

    def block_good_err(args):
        _, _, block_var, tau_block = args
        error_of_mean = jnp.sqrt(block_var / n_blocks)
        tau_corr = jnp.clip(tau_block, 0)
        return jnp.asarray(error_of_mean, dtype=stat_dtype), jnp.asarray(
            tau_corr, dtype=stat_dtype
        )

    def nan_err(args):
        return jnp.asarray(jnp.nan, dtype=stat_dtype), jnp.asarray(
            jnp.nan, dtype=stat_dtype
        )

    def batch_not_good(args):
        batch_var, tau_batch, block_var, tau_block, block_good = args
        return jax.lax.cond(
            block_good,
            block_good_err,
            nan_err,
            (batch_var, tau_batch, block_var, tau_block),
        )

    error_of_mean, tau_corr = jax.lax.cond(
        batch_good,
        batch_good_err,
        batch_not_good,
        (batch_var, tau_batch, block_var, tau_block, block_good),
    )

    if n_batches > 1:
        N = data.shape[-1]

        # V_loc = _np.var(data, axis=-1, ddof=0)
        # W_loc = _np.mean(V_loc)
        # W = _mean(W_loc)
        # # This approximation seems to hold well enough for larger n_samples
        W = variance

        R_hat = jnp.sqrt((N - 1) / N + batch_var / W)
    else:
        R_hat = jnp.nan

    res = Stats(mean, error_of_mean, variance, tau_corr, R_hat)

    return res
Exemple #5
0
def decay(y, t, arg1, arg2):
  return -np.sqrt(t) - y + arg1 - np.mean((y + arg2)**2)
Exemple #6
0
def _norm_tree(x):
    return jnp.sqrt(_vdot_real_tree(x, x))
    def __call__(self, x, training: bool):
        """Normalizes the input using batch statistics.
        Args:
            x: the input to be normalized.
        Returns:
            Normalized inputs (the same shape as inputs).
        """
        x = jnp.asarray(x, self.dtype)
        axis = self.axis if isinstance(self.axis, tuple) else (self.axis, )
        axis = _absolute_dims(x.ndim, axis)
        feature_shape = tuple(d if i in axis else 1
                              for i, d in enumerate(x.shape))
        reduced_feature_shape = tuple(d for i, d in enumerate(x.shape)
                                      if i in axis)
        reduction_axis = tuple(i for i in range(x.ndim) if i not in axis)

        # we detect if we're in initialization via empty variable tree.
        initializing = not self.has_variable('batch_stats', 'mean')

        ra_mean = self.variable('batch_stats', 'mean',
                                lambda s: jnp.zeros(s, jnp.float32),
                                reduced_feature_shape)
        ra_var = self.variable('batch_stats', 'var',
                               lambda s: jnp.ones(s, jnp.float32),
                               reduced_feature_shape)

        if not training:
            mean, var = ra_mean.value, ra_var.value
        else:
            mean = jnp.mean(x, axis=reduction_axis, keepdims=False)
            var = jnp.mean(lax.abs(x - mean),
                           axis=reduction_axis,
                           keepdims=False) * jnp.sqrt(jnp.pi / 2)
            if self.axis_name is not None and not initializing:
                concatenated_mean = jnp.concatenate([mean, var])
                mean, var = jnp.split(
                    lax.pmean(concatenated_mean,
                              axis_name=self.axis_name,
                              axis_index_groups=self.axis_index_groups), 2)

            if not initializing:
                ra_mean.value = self.momentum * ra_mean.value + (
                    1 - self.momentum) * mean
                ra_var.value = self.momentum * ra_var.value + (
                    1 - self.momentum) * var

        mean = jnp.asarray(mean, self.dtype)
        var = jnp.asarray(var, self.dtype)
        y = x - mean.reshape(feature_shape)
        mul = lax.reciprocal(var + self.epsilon)
        if self.use_scale:
            scale = self.param('scale', self.scale_init,
                               reduced_feature_shape).reshape(feature_shape)
            scale = jnp.asarray(scale, self.dtype)
            mul = mul * scale
        y = y * mul
        if self.use_bias:
            bias = self.param('bias', self.bias_init,
                              reduced_feature_shape).reshape(feature_shape)
            bias = jnp.asarray(bias, self.dtype)
            y = y + bias
        return jnp.asarray(y, self.dtype)
Exemple #8
0
def run_model(T, M, N, D, init_params, Xzero, N_Iter, learning_rate):

    tot = time.time()
    samples = 5
    params, graph = train(T, M, N, D, init_params, N_Iter, learning_rate,
                          Xzero)
    print("total time:", time.time() - tot, "s")

    np.random.seed(42)
    t_test, W_test = fetch_minibatch(T, M, N, D)

    loss3 = loss_function(params, t_test, W_test,
                          Xzero)  # annoying as it gets calc'd twice
    print("loss3")
    print(loss3)

    X_pred, Y_pred, Y_tilde_pred = vXYpaths2(params, t_test, W_test, Xzero)

    # Y_test = jnp.reshape(u_exact(np.reshape(t_test[0:M, :, :], [-1, 1]), jnp.reshape(X_pred[0:M, :, :], [-1, D])),
    #                     [M, -1, 1])   #fix all these uneccessary reshapes at some point
    Y_test = jnp.reshape(
        u_exact(np.reshape(t_test[0:M, :, :], [-1, 1]),
                jnp.reshape(X_pred[0:M, :, :], [-1, D])),
        [M, 1, -1])  #fix all these uneccessary reshapes at some point

    plt.figure()
    plt.plot(graph[0], graph[1])
    plt.xlabel('Iterations')
    plt.ylabel('Value')
    plt.yscale("log")
    plt.title('Evolution of the training loss')

    plt.figure()
    # plt.plot(t_test[0:1, :, 0].T, Y_pred[0:1, :, 0].T, 'b', label='Learned $u(t,X_t)$')  #<-f****d the dimensions of just Y_pred somewhere....
    # plt.plot(t_test[0:1, :, 0].T, Y_test[0:1, :, 0].T, 'r--', label='Exact $u(t,X_t)$')
    # plt.plot(t_test[0:1, -1, 0], Y_test[0:1, -1, 0], 'ko', label='$Y_T = u(T,X_T)$')

    plt.plot(t_test[0:1, :, 0].T,
             Y_pred[0:1, 0, :].T,
             'b',
             label='Learned $u(t,X_t)$')
    plt.plot(t_test[0:1, :, 0].T,
             Y_test[0:1, 0, :].T,
             'r--',
             label='Exact $u(t,X_t)$')
    plt.plot(t_test[0:1, -1, 0],
             Y_test[0:1, 0, -1],
             'ko',
             label='$Y_T = u(T,X_T)$')

    # plt.plot(t_test[1:samples, :, 0].T, Y_pred[1:samples, :, 0].T, 'b')
    plt.plot(t_test[1:samples, :, 0].T, Y_pred[1:samples, 0, :].T, 'b')

    # plt.plot(t_test[1:samples, :, 0].T, Y_test[1:samples, :, 0].T, 'r--')
    # plt.plot(t_test[1:samples, -1, 0], Y_test[1:samples, -1, 0], 'ko')
    plt.plot(t_test[1:samples, :, 0].T, Y_test[1:samples, 0, :].T, 'r--')
    plt.plot(t_test[1:samples, -1, 0], Y_test[1:samples, 0, -1], 'ko')

    plt.plot([0], Y_test[0, 0, 0], 'ks', label='$Y_0 = u(0,X_0)$')

    plt.xlabel('$t$')
    plt.ylabel('$Y_t = u(t,X_t)$')
    plt.title(
        str(D) + '-dimensional Black-Scholes-Barenblatt, ' + "FC" + "-" +
        "ReLu" + "_JRJaxvec")
    plt.legend()

    errors = jnp.sqrt((Y_test - Y_pred)**2 / Y_test**2)
    # mean_errors = jnp.mean(errors, 0)
    # std_errors = jnp.std(errors, 0)
    mean_errors = jnp.mean(errors, 0)[0, :]
    std_errors = jnp.std(errors, 0)[0, :]

    plt.figure()
    # plt.plot(t_test[0, :, 0], mean_errors, 'b', label='mean')
    # plt.plot(t_test[0, :, 0], mean_errors + 2 * std_errors, 'r--', label='mean + two standard deviations')
    plt.plot(t_test[0, :, 0], mean_errors, 'b', label='mean')
    plt.plot(t_test[0, :, 0],
             mean_errors + 2 * std_errors,
             'r--',
             label='mean + two standard deviations')

    plt.xlabel('$t$')
    plt.ylabel('relative error')
    plt.title(
        str(D) + '-dimensional-Black-Scholes-Barenblatt-' + "FC" + "-" +
        "ReLu" + "_JRJaxVec")
    plt.legend()
    plt.savefig(
        str(D) + '-dimensional-Black-Scholes-Barenblatt-' + "FC" + "-" +
        "ReLu" + "_JRJaxVec")
    cwd = os.getcwd()
    print(cwd)

    text_file = open("JRJaxVec_Output.txt", "w")
    text_file.write(f"where is this file\nhere: {cwd}")
    text_file.close()
Exemple #9
0
 def update_fn(updates, state):
     nu = _update_moment(updates, state.nu, decay, 2)
     updates = tree_multimap(lambda g, n: g / (jnp.sqrt(n + eps)), updates,
                             nu)
     return updates, ScaleByRmsState(nu=nu)
Exemple #10
0
 def _sqrt(x):
     return jnp.sqrt(x)
Exemple #11
0
def safe_sqrt(x, eps=1e-7):
    safe_x = jnp.where(x == 0, jnp.ones_like(x) * eps, x)
    return jnp.sqrt(safe_x)
def length_normalized(x, epsilon=1e-6):
    variance = jnp.mean(x**2, axis=-1, keepdims=True)
    norm_inputs = x / jnp.sqrt(variance + epsilon)
    return norm_inputs
Exemple #13
0
def dot_product_attention(scope,
                          query,
                          key,
                          value,
                          dtype=jnp.float32,
                          bias=None,
                          axis=None,
                          broadcast_dropout=True,
                          dropout_rng=None,
                          dropout_rate=0.,
                          deterministic=False,
                          precision=None):
    """Computes dot-product attention given query, key, and value.

  This is the core function for applying attention based on
  https://arxiv.org/abs/1706.03762. It calculates the attention weights given
  query and key and combines the values using the attention weights. This
  function supports multi-dimensional inputs.


  Args:
    query: queries for calculating attention with shape of `[batch_size, dim1,
      dim2, ..., dimN, num_heads, mem_channels]`.
    key: keys for calculating attention with shape of `[batch_size, dim1, dim2,
      ..., dimN, num_heads, mem_channels]`.
    value: values to be used in attention with shape of `[batch_size, dim1,
      dim2,..., dimN, num_heads, value_channels]`.
    dtype: the dtype of the computation (default: float32)
    bias: bias for the attention weights. This can be used for incorporating
      autoregressive mask, padding mask, proximity bias.
    axis: axises over which the attention is applied.
    broadcast_dropout: bool: use a broadcasted dropout along batch dims.
    dropout_rng: JAX PRNGKey: to be used for dropout
    dropout_rate: dropout rate
    deterministic: bool, deterministic or not (to apply dropout)
    precision: numerical precision of the computation see `jax.lax.Precision`
      for details.

  Returns:
    Output of shape `[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]`.
  """
    assert key.shape[:-1] == value.shape[:-1]
    assert (query.shape[0:1] == key.shape[0:1]
            and query.shape[-1] == key.shape[-1])

    if axis is None:
        axis = tuple(range(1, key.ndim - 2))
    if not isinstance(axis, Iterable):
        axis = (axis, )
    assert key.ndim == query.ndim
    assert key.ndim == value.ndim
    for ax in axis:
        if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2):
            raise ValueError('Attention axis must be between the batch '
                             'axis and the last-two axes.')
    depth = query.shape[-1]
    n = key.ndim
    # batch_dims is  <bs, <non-attention dims>, num_heads>
    batch_dims = tuple(onp.delete(range(n), axis + (n - 1, )))
    # q & k -> (bs, <non-attention dims>, num_heads, <attention dims>, channels)
    qk_perm = batch_dims + axis + (n - 1, )
    key = key.transpose(qk_perm)
    query = query.transpose(qk_perm)
    # v -> (bs, <non-attention dims>, num_heads, channels, <attention dims>)
    v_perm = batch_dims + (n - 1, ) + axis
    value = value.transpose(v_perm)

    query = query / jnp.sqrt(depth).astype(dtype)
    batch_dims_t = tuple(range(len(batch_dims)))
    attn_weights = lax.dot_general(query,
                                   key, (((n - 1, ), (n - 1, )),
                                         (batch_dims_t, batch_dims_t)),
                                   precision=precision)

    # apply attention bias: masking, droput, proximity bias, ect.
    if bias is not None:
        attn_weights = attn_weights + bias

    # normalize the attention weights
    norm_dims = tuple(range(attn_weights.ndim - len(axis), attn_weights.ndim))
    attn_weights = lax.exp(attn_weights - jax.scipy.special.logsumexp(
        attn_weights, axis=norm_dims, keepdims=True))
    attn_weights = attn_weights.astype(dtype)

    # apply dropout
    if not deterministic and dropout_rate > 0.:
        if dropout_rng is None:
            dropout_rng = scope.make_rng('dropout')
        keep_prob = jax.lax.tie_in(attn_weights, 1.0 - dropout_rate)
        if broadcast_dropout:
            # dropout is broadcast across the batch+head+non-attention dimension
            dropout_dims = attn_weights.shape[-(2 * len(axis)):]
            dropout_shape = (tuple([1] * len(batch_dims_t)) + dropout_dims)
            keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape)
        else:
            keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)
        multiplier = (keep.astype(attn_weights.dtype) /
                      jnp.asarray(keep_prob, dtype=dtype))
        attn_weights = attn_weights * multiplier

    # compute the new values given the attention weights
    wv_contracting_dims = (norm_dims, range(value.ndim - len(axis),
                                            value.ndim))
    y = lax.dot_general(attn_weights,
                        value,
                        (wv_contracting_dims, (batch_dims_t, batch_dims_t)),
                        precision=precision)

    # back to (bs, dim1, dim2, ..., dimN, num_heads, channels)
    perm_inv = _invert_perm(qk_perm)
    y = y.transpose(perm_inv)
    return y
Exemple #14
0
def l2_norm(tree):
    """Compute the l2 norm of a pytree of arrays. Useful for weight decay."""
    leaves, _ = tree_flatten(tree)
    return jnp.sqrt(sum(jnp.vdot(x, x) for x in leaves))
Exemple #15
0
 def qnorm(carry):
     k, _, q, qnorm_scaled = carry
     _, qnorm = _safe_normalize(q)
     qnorm_scaled = qnorm / jnp.sqrt(2)
     return (k, False, q, qnorm_scaled)
Exemple #16
0
 def update_fn(updates, state):
     mu = _update_moment(updates, state.mu, decay, 1)
     nu = _update_moment(updates, state.nu, decay, 2)
     updates = tree_multimap(lambda g, m, n: g / jnp.sqrt(n - m**2 + eps),
                             updates, mu, nu)
     return updates, ScaleByRStdDevState(mu=mu, nu=nu)
Exemple #17
0
 def givens_rotation(v1, v2):
     t = jnp.sqrt(v1**2 + v2**2)
     cs = v1 / t
     sn = -v2 / t
     return cs, sn
Exemple #18
0
def global_norm(items):
    return jnp.sqrt(jnp.sum([jnp.sum(x**2) for x in tree_leaves(items)]))
Exemple #19
0
    data_binned = []
    a = []
    # Gather data points an bin them
    for i in range(0, 160):
        y_bin = (y[i * 8 * 1:(i * 8 * 1) + 8 * 1])
        a = sum(y_bin[first_p:last_p])
        data_binned = np.append(data_binned, a)

    x = np.arange(0, len(data_binned), 1)
    df_Forward = pd.DataFrame({
        'binned PMT data':
        data_binned[0:80] / (b_width * n_of_scans * pps),
        'steps':
        x[0:80],
        'error':
        np.sqrt(data_binned[0:80]) / (b_width * n_of_scans * pps),
    })
    df_Backward = pd.DataFrame({
        'binned PMT data':
        data_binned[80:159] / (b_width * n_of_scans * pps),
        'steps':
        x[80:159],
        'error':
        np.sqrt(data_binned[80:159]) / (b_width * n_of_scans * pps),
    })
    ''' Now to minimize the scans'''

    if sel2 == 'Forward':
        x_step = df_Forward['steps']
        y_data = df_Forward['binned PMT data']
        y_err = df_Forward['error']
Exemple #20
0
 def mean(self):
     return np.sqrt(2 / np.pi) * self.scale
Exemple #21
0
    def _train_step(self):
        """Runs a single training step."""
        if self._replay.add_count > self.min_replay_history:
            if self.training_steps % self.update_period == 0:
                self._sample_from_replay_buffer()

                if self._replay_scheme == 'prioritized':
                    # The original prioritized experience replay uses a linear exponent
                    # schedule 0.4 -> 1.0. Comparing the schedule to a fixed exponent of
                    # 0.5 on 5 games (Asterix, Pong, Q*Bert, Seaquest, Space Invaders)
                    # suggested a fixed exponent actually performs better, except on Pong.
                    probs = self.replay_elements['sampling_probabilities']
                    # Weight the loss by the inverse priorities.
                    loss_weights = 1.0 / jnp.sqrt(probs + 1e-10)
                    loss_weights /= jnp.max(loss_weights)
                else:
                    loss_weights = jnp.ones(
                        self.replay_elements['state'].shape[0])

                self.optimizer_state, self.online_params, aux_losses = train(
                    self.network_def, self.online_params,
                    self.target_network_params, self.optimizer,
                    self.optimizer_state, self.replay_elements['state'],
                    self.replay_elements['action'],
                    self.replay_elements['next_state'],
                    self.replay_elements['reward'],
                    self.replay_elements['terminal'], loss_weights,
                    self._support, self.cumulative_gamma, self._mico_weight,
                    self._distance_fn)

                loss = aux_losses.pop('loss')
                if self._replay_scheme == 'prioritized':
                    # Rainbow and prioritized replay are parametrized by an exponent
                    # alpha, but in both cases it is set to 0.5 - for simplicity's sake we
                    # leave it as is here, using the more direct sqrt(). Taking the square
                    # root "makes sense", as we are dealing with a squared loss.  Add a
                    # small nonzero value to the loss to avoid 0 priority items. While
                    # technically this may be okay, setting all items to 0 priority will
                    # cause troubles, and also result in 1.0 / 0.0 = NaN correction terms.
                    self._replay.set_priority(self.replay_elements['indices'],
                                              jnp.sqrt(loss + 1e-10))

                if self._replay_scheme == 'prioritized':
                    probs = self.replay_elements['sampling_probabilities']
                    loss_weights = 1.0 / jnp.sqrt(probs + 1e-10)
                    loss_weights /= jnp.max(loss_weights)
                    self._replay.set_priority(self.replay_elements['indices'],
                                              jnp.sqrt(loss + 1e-10))
                    loss = loss_weights * loss
                if self.summary_writer is not None:
                    values = []
                    for k in aux_losses:
                        values.append(
                            tf.compat.v1.Summary.Value(
                                tag=f'Losses/{k}', simple_value=aux_losses[k]))
                    summary = tf.compat.v1.Summary(value=values)
                    self.summary_writer.add_summary(summary,
                                                    self.training_steps)
            if self.training_steps % self.target_update_period == 0:
                self._sync_weights()

        self.training_steps += 1
Exemple #22
0
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     normalize_term = np.log(np.sqrt(2 * np.pi) * self.scale)
     value_scaled = (value - self.loc) / self.scale
     return -0.5 * value_scaled**2 - normalize_term
Exemple #23
0
def _BW(m_,w_,Sbc):
    gamma=np.sqrt(m_*m_*(m_*m_+w_*w_))
    k = np.sqrt(2*np.sqrt(2)*m_*np.abs(w_)*gamma/np.pi/np.sqrt(m_*m_+gamma))
    l = Sbc.shape[0]
    temp = dplex.dconstruct(m_*m_ - Sbc,  -m_*w_*np.ones(l))
    return dplex.ddivide(k, temp)
Exemple #24
0
def compute_ssim(img0,
                 img1,
                 max_val,
                 filter_size=11,
                 filter_sigma=1.5,
                 k1=0.01,
                 k2=0.03,
                 return_map=False):
    """Computes SSIM from two images.

  This function was modeled after tf.image.ssim, and should produce comparable
  output.

  Args:
    img0: array. An image of size [..., width, height, num_channels].
    img1: array. An image of size [..., width, height, num_channels].
    max_val: float > 0. The maximum magnitude that `img0` or `img1` can have.
    filter_size: int >= 1. Window size.
    filter_sigma: float > 0. The bandwidth of the Gaussian used for filtering.
    k1: float > 0. One of the SSIM dampening parameters.
    k2: float > 0. One of the SSIM dampening parameters.
    return_map: Bool. If True, will cause the per-pixel SSIM "map" to returned

  Returns:
    Each image's mean SSIM, or a tensor of individual values if `return_map`.
  """
    # Construct a 1D Gaussian blur filter.
    hw = filter_size // 2
    shift = (2 * hw - filter_size + 1) / 2
    f_i = ((jnp.arange(filter_size) - hw + shift) / filter_sigma)**2
    filt = jnp.exp(-0.5 * f_i)
    filt /= jnp.sum(filt)

    # Blur in x and y (faster than the 2D convolution).
    filt_fn1 = lambda z: jsp.signal.convolve2d(z, filt[:, None], mode="valid")
    filt_fn2 = lambda z: jsp.signal.convolve2d(z, filt[None, :], mode="valid")

    # Vmap the blurs to the tensor size, and then compose them.
    num_dims = len(img0.shape)
    map_axes = tuple(list(range(num_dims - 3)) + [num_dims - 1])
    for d in map_axes:
        filt_fn1 = jax.vmap(filt_fn1, in_axes=d, out_axes=d)
        filt_fn2 = jax.vmap(filt_fn2, in_axes=d, out_axes=d)
    filt_fn = lambda z: filt_fn1(filt_fn2(z))

    mu0 = filt_fn(img0)
    mu1 = filt_fn(img1)
    mu00 = mu0 * mu0
    mu11 = mu1 * mu1
    mu01 = mu0 * mu1
    sigma00 = filt_fn(img0**2) - mu00
    sigma11 = filt_fn(img1**2) - mu11
    sigma01 = filt_fn(img0 * img1) - mu01

    # Clip the variances and covariances to valid values.
    # Variance must be non-negative:
    sigma00 = jnp.maximum(0., sigma00)
    sigma11 = jnp.maximum(0., sigma11)
    sigma01 = jnp.sign(sigma01) * jnp.minimum(jnp.sqrt(sigma00 * sigma11),
                                              jnp.abs(sigma01))

    c1 = (k1 * max_val)**2
    c2 = (k2 * max_val)**2
    numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
    denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
    ssim_map = numer / denom
    ssim = jnp.mean(ssim_map, list(range(num_dims - 3, num_dims)))
    return ssim_map if return_map else ssim
Exemple #25
0
def self_energy(conf, charges, alpha):
    return np.sum(ONE_4PI_EPS0 * np.power(charges, 2) * alpha/np.sqrt(np.pi))
Exemple #26
0
def plot_cornerplot(results, vars=None, save_name=None):
    rkey0 = random.PRNGKey(123496)
    vars = _get_vars(results, vars)
    ndims = _get_ndims(results, vars)
    figsize = min(20, max(4, int(2 * ndims)))
    fig, axs = plt.subplots(ndims, ndims, figsize=(figsize, figsize))
    if ndims == 1:
        axs = [[axs]]
    nsamples = results.num_samples
    log_p = results.log_p[:results.num_samples]
    nbins = int(jnp.sqrt(results.ESS)) + 1
    lims = {}
    dim = 0
    for key in vars:  # sorted(results.samples.keys()):
        n1 = tuple_prod(results.samples[key].shape[1:])
        for i in range(n1):
            samples1 = results.samples[key][:results.num_samples, ...].reshape(
                (nsamples, -1))[:, i]
            if jnp.std(samples1) == 0.:
                dim += 1
                continue
            weights = jnp.where(jnp.isfinite(samples1), jnp.exp(log_p), 0.)
            log_weights = jnp.where(jnp.isfinite(samples1), log_p, -jnp.inf)
            samples1 = jnp.where(jnp.isfinite(samples1), samples1, 0.)
            # kde1 = gaussian_kde(samples1, weights=weights, bw_method='silverman')
            # samples1_resampled = kde1.resample(size=int(results.ESS))
            rkey0, rkey = random.split(rkey0, 2)
            samples1_resampled = resample(rkey,
                                          samples1,
                                          log_weights,
                                          S=int(results.ESS))
            binsx = jnp.linspace(*jnp.percentile(samples1_resampled, [0, 100]),
                                 2 * nbins)
            dim2 = 0
            for key2 in vars:  # sorted(results.samples.keys()):
                n2 = tuple_prod(results.samples[key2].shape[1:])
                for i2 in range(n2):
                    ax = axs[dim][dim2]
                    if dim2 > dim:
                        dim2 += 1
                        ax.set_xticks([])
                        ax.set_xticklabels([])
                        ax.set_yticks([])
                        ax.set_yticklabels([])
                        continue
                    if n2 > 1:
                        title2 = "{}[{}]".format(key2, i2)
                    else:
                        title2 = "{}".format(key2)
                    if n1 > 1:
                        title1 = "{}[{}]".format(key, i)
                    else:
                        title1 = "{}".format(key)
                    ax.set_title('{} {}'.format(title1, title2))
                    if dim == dim2:
                        # ax.plot(binsx, kde1(binsx))
                        ax.hist(samples1_resampled,
                                bins='auto',
                                fc='None',
                                edgecolor='black',
                                density=True)
                        sample_mean = jnp.average(samples1, weights=weights)
                        sample_std = jnp.sqrt(
                            jnp.average((samples1 - sample_mean)**2,
                                        weights=weights))
                        ax.set_title(
                            "{:.2f}:{:.2f}:{:.2f}\n{:.2f}+-{:.2f}".format(
                                *jnp.percentile(samples1_resampled,
                                                [5, 50, 95]), sample_mean,
                                sample_std))
                        ax.vlines(sample_mean,
                                  *ax.get_ylim(),
                                  linestyles='solid',
                                  colors='red')
                        ax.vlines([
                            sample_mean - sample_std, sample_mean + sample_std
                        ],
                                  *ax.get_ylim(),
                                  linestyles='dotted',
                                  colors='red')
                        ax.set_xlim(binsx.min(), binsx.max())
                        lims[dim] = ax.get_xlim()
                    else:
                        samples2 = results.samples[key2][:results.num_samples,
                                                         ...].reshape(
                                                             (nsamples,
                                                              -1))[:, i2]
                        if jnp.std(samples2) == 0.:
                            dim2 += 1
                            continue
                        weights = jnp.where(jnp.isfinite(samples2),
                                            jnp.exp(log_p), 0.)
                        log_weights = jnp.where(jnp.isfinite(samples2), log_p,
                                                -jnp.inf)
                        samples2 = jnp.where(jnp.isfinite(samples2), samples2,
                                             0.)
                        # kde2 = gaussian_kde(jnp.stack([samples1, samples2], axis=0),
                        #                     weights=weights,
                        #                     bw_method='silverman')
                        # samples2_resampled = kde2.resample(size=int(results.ESS))
                        rkey0, rkey = random.split(rkey0, 2)
                        samples2_resampled = resample(rkey,
                                                      jnp.stack(
                                                          [samples1, samples2],
                                                          axis=-1),
                                                      log_weights,
                                                      S=int(results.ESS))
                        # norm = plt.Normalize(log_weights.min(), log_weights.max())
                        # color = jnp.atleast_2d(plt.cm.jet(norm(log_weights)))
                        ax.hist2d(samples2_resampled[:, 1],
                                  samples2_resampled[:, 0],
                                  bins=(nbins, nbins),
                                  density=True,
                                  cmap=plt.cm.bone_r)
                        # ax.scatter(samples2_resampled[:, 1], samples2_resampled[:, 0], marker='+', c='black', alpha=0.5)
                        # binsy = jnp.linspace(*jnp.percentile(samples2_resampled[:, 1], [0, 100]), 2 * nbins)
                        # X, Y = jnp.meshgrid(binsx, binsy, indexing='ij')
                        # ax.contour(kde2(jnp.stack([X.flatten(), Y.flatten()], axis=0)).reshape((2 * nbins, 2 * nbins)),
                        #            extent=(binsy.min(), binsy.max(),
                        #                    binsx.min(), binsx.max()),
                        #            origin='lower')
                    if dim == ndims - 1:
                        ax.set_xlabel("{}".format(title2))
                    if dim2 == 0:
                        ax.set_ylabel("{}".format(title1))

                    dim2 += 1
            dim += 1
    for dim in range(ndims):
        for dim2 in range(ndims):
            if dim == dim2:
                continue
            ax = axs[dim][dim2] if ndims > 1 else axs[0]
            if dim in lims.keys():
                ax.set_ylim(lims[dim])
            if dim2 in lims.keys():
                ax.set_xlim(lims[dim2])
    if save_name is not None:
        fig.savefig(save_name)
    plt.show()
Exemple #27
0
def point_to_coordinate_n(points_n, num_fragments=6):
    """
    Takes points from dihedral_to_point and sequentially converts them into
    coordinates of a 3D structure.
    Reconstruction is done in parallel by independently reconstructing
    num_fragments and the reconstituting the chain at the end in reverse order.
    The core reconstruction algorithm is NeRF, based on
    DOI: 10.1002/jcc.20237 by Parsons et al. 2005.
    The parallelized version is described in
    https://www.biorxiv.org/content/early/2018/08/06/385450.
    :param points: Tensor containing points as returned by `dihedral_to_point`.
    Shape [NUM_STEPS x NUM_DIHEDRALS, BATCH_SIZE, NUM_DIMENSIONS]                        
    :param num_fragments: Number of fragments in which the sequence is split
    to perform parallel computation.
    :return: Tensor containing correctly transformed atom coordinates.
    Shape [NUM_STEPS x NUM_DIHEDRALS, BATCH_SIZE, NUM_DIMENSIONS]
    """

    # Compute optimal number of fragments if needed
    total_num_angles = points_n.shape[0]  # NUM_STEPS x NUM_DIHEDRALS

    if num_fragments is None:
        num_fragments = int(math.sqrt(total_num_angles))

    # Initial three coordinates (specifically chosen to eliminate need for
    # extraneous matmul)
    Triplet = collections.namedtuple('Triplet', 'a, b, c')
    batch_size = points_n.shape[1]
    init_matrix = onp.array(
        [[-onp.sqrt(1.0 / 2.0), onp.sqrt(3.0 / 2.0), 0],
         [-onp.sqrt(2.0), 0, 0], [0, 0, 0]],
        dtype=onp.float32)

    init_coords = [
        onp.tile(row, (num_fragments * batch_size, 1)).reshape(
            num_fragments, batch_size, NUM_DIMENSIONS) for row in init_matrix
    ]
    init_coords = Triplet(
        *init_coords
    )  # NUM_DIHEDRALS x [NUM_FRAGS, BATCH_SIZE, NUM_DIMENSIONS]

    # Pad points to yield equal-sized fragments
    padding = (
        (num_fragments - (total_num_angles % num_fragments)) % num_fragments
    )  # (NUM_FRAGS x FRAG_SIZE) - (NUM_STEPS x NUM_DIHEDRALS)
    points_n = onp.pad(
        points_n, ((0, padding), (0, 0), (0, 0)),
        mode='constant')  # [NUM_FRAGS x FRAG_SIZE, BATCH_SIZE, NUM_DIMENSIONS]
    points_n = points_n.reshape(
        num_fragments, -1, batch_size,
        NUM_DIMENSIONS)  # [NUM_FRAGS, FRAG_SIZE, BATCH_SIZE, NUM_DIMENSIONS]
    points_n = onp.transpose(
        points_n,
        (1, 0, 2, 3))  # [FRAG_SIZE, NUM_FRAGS, BATCH_SIZE, NUM_DIMENSIONS]

    # Extension function used for single atom reconstruction and whole fragment
    # alignment
    def extend(prev_three_coords, point, multi_m):
        """
        Aligns an atom or an entire fragment depending on value of `multi_m`
        with the preceding three atoms.
        :param prev_three_coords: Named tuple storing the last three atom
        coordinates ("a", "b", "c") where "c" is the current end of the
        structure (i.e. closest to the atom/ fragment that will be added now).
        Shape NUM_DIHEDRALS x [NUM_FRAGS/0, BATCH_SIZE, NUM_DIMENSIONS].
        First rank depends on value of `multi_m`.
        :param point: Point describing the atom that is added to the structure.
        Shape [NUM_FRAGS/FRAG_SIZE, BATCH_SIZE, NUM_DIMENSIONS]
        First rank depends on value of `multi_m`.
        :param multi_m: If True, a single atom is added to the chain for
        multiple fragments in parallel. If False, an single fragment is added.
        Note the different parameter dimensions.
        :return: Coordinates of the atom/ fragment.
        """
        # Normalize rows: https://necromuralist.github.io/neural_networks/posts/normalizing-with-numpy/
        Xbc = (prev_three_coords.c - prev_three_coords.b)
        bc = Xbc / onp.linalg.norm(Xbc, axis=-1, keepdims=True)

        Xn = onp.cross(prev_three_coords.b - prev_three_coords.a,
                       bc,
                       axisa=-1,
                       axisb=-1,
                       axisc=-1)
        n = Xn / onp.linalg.norm(Xn, axis=-1, keepdims=True)

        if multi_m:  # multiple fragments, one atom at a time
            m = onp.transpose(onp.stack([bc, onp.cross(n, bc), n]),
                              (1, 2, 3, 0))
        else:  # single fragment, reconstructed entirely at once.
            s = point.shape + (3, )  # +
            m = onp.transpose(onp.stack([bc, onp.cross(n, bc), n]), (1, 2, 0))
            m = onp.tile(m, (s[0], 1, 1)).reshape(s)

        coord = onp.squeeze(onp.matmul(m, onp.expand_dims(point, axis=3)),
                            axis=3) + prev_three_coords.c

        return coord

    # Loop over FRAG_SIZE in NUM_FRAGS parallel fragments, sequentially
    # generating the coordinates for each fragment across all batches
    coords_list = [None] * points_n.shape[
        0]  # FRAG_SIZE x [NUM_FRAGS, BATCH_SIZE, NUM_DIMENSIONS]
    prev_three_coords = init_coords
    for i in range(points_n.shape[0]):  # Iterate over FRAG_SIZE
        coord = extend(prev_three_coords, points_n[i], True)
        coords_list[i] = coord
        prev_three_coords = Triplet(prev_three_coords.b, prev_three_coords.c,
                                    coord)

    coords_pretrans = onp.transpose(onp.stack(coords_list), (1, 0, 2, 3))

    # Loop backwards over NUM_FRAGS to align the individual fragments. For each
    # next fragment, we transform the fragments we have already iterated over
    # (coords_trans) to be aligned with the next fragment
    coords_trans = coords_pretrans[-1]

    for i in reversed(range(coords_pretrans.shape[0] - 1)):
        # Transform the fragments that we have already iterated over to be
        # aligned with the next fragment `coords_trans`
        transformed_coords = extend(
            Triplet(*[di[i] for di in prev_three_coords]), coords_trans, False)
        coords_trans = onp.concatenate(
            [coords_pretrans[i], transformed_coords], 0)
    coords = onp.pad(coords_trans[:total_num_angles - 1],
                     ((1, 0), (0, 0), (0, 0)))

    return coords
Exemple #28
0
def _iterative_classical_gram_schmidt(Q, x, max_iterations=2):
    """
  Orthogonalize x against the columns of Q. The process is repeated
  up to `max_iterations` times, or fewer if the condition
  ||r|| < (1/sqrt(2)) ||x|| is met earlier (see below for the meaning
  of r and x).

  Parameters
  ----------
  Q : array or tree of arrays
      A matrix of orthonormal columns.
  x : array or tree of arrays
      A vector. It will be replaced with a new vector q which is orthonormal
      to the columns of Q, such that x in span(col(Q), q).

  Returns
  -------
  q : array or tree of arrays
      A unit vector, orthonormal to each column of Q, such that
      x in span(col(Q), q).
  r : array
      Stores the overlaps of x with each vector in Q.
  """
    # "twice is enough"
    # http://slepc.upv.es/documentation/reports/str1.pdf

    # This assumes that Q's leaves all have the same dimension in the last
    # axis.
    r = jnp.zeros((tree_leaves(Q)[0].shape[-1]))
    q = x
    _, xnorm = _safe_normalize(x)
    xnorm_scaled = xnorm / jnp.sqrt(2)

    def body_function(carry):
        k, q, r, qnorm_scaled = carry
        h = _project_on_columns(Q, q)
        Qh = tree_map(lambda X: _dot_tree(X, h), Q)
        q = _sub(q, Qh)
        r = _add(r, h)

        def qnorm_cond(carry):
            k, not_done, _, _ = carry
            return jnp.logical_and(not_done, k < (max_iterations - 1))

        def qnorm(carry):
            k, _, q, qnorm_scaled = carry
            _, qnorm = _safe_normalize(q)
            qnorm_scaled = qnorm / jnp.sqrt(2)
            return (k, False, q, qnorm_scaled)

        init = (k, True, q, qnorm_scaled)
        _, _, q, qnorm_scaled = lax.while_loop(qnorm_cond, qnorm, init)
        return (k + 1, q, r, qnorm_scaled)

    def cond_function(carry):
        k, _, r, qnorm_scaled = carry
        _, rnorm = _safe_normalize(r)
        return jnp.logical_and(k < (max_iterations - 1), rnorm < qnorm_scaled)

    k, q, r, qnorm_scaled = body_function((0, q, r, xnorm_scaled))
    k, q, r, _ = lax.while_loop(cond_function, body_function,
                                (k, q, r, qnorm_scaled))
    return q, r
    if k % 2 == 0:
        temp = jnp.kron(temp, qnnops.PauliBasis[1])
    else:
        temp = jnp.kron(temp, qnnops.PauliBasis[2])
        
    for i in range(int(n_gamma/2) - (k//2) - 1):
        temp = jnp.kron(temp, qnnops.PauliBasis[0])
        
    gamma_matrices.append(temp)

# Number of SYK4 interaction terms
n_terms = int(factorial(n_gamma) / factorial(4) / factorial(n_gamma - 4)) 

# SYK4 random coupling
couplings = jax.random.normal(key=jax.random.PRNGKey(args.seed_SYK),
                              shape=(n_terms, ), dtype=jnp.float64) * jnp.sqrt(6 / (n_gamma ** 3))

ham_matrix = 0
for idx, (x, y, w, z) in enumerate(combinations(range(n_gamma), 4)):
    ham_matrix += (couplings[idx] / 4) * jnp.linalg.multi_dot([gamma_matrices[x], gamma_matrices[y], gamma_matrices[w], gamma_matrices[z]])

expmgr.save_array('hamiltonian_matrix.npy', ham_matrix, upload_to_wandb=False)

eigval, eigvec = jnp.linalg.eigh(ham_matrix)
eigvec = eigvec.T  # Transpose such that eigvec[i] is an eigenvector, rather than eigenftn[:, i]
ground_state = eigvec[0]
next_to_ground_state = eigvec[1]

print("The lowest eigenvalues (energy) and corresponding eigenvectors (state)")
for i in range(min(5, len(eigval))):
    print(f'| {i}-th state energy={eigval[i]:.4f}')
Exemple #30
0
 def update(i, g, state):
     x, avg_sq_grad = state
     avg_sq_grad = avg_sq_grad * gamma + jnp.square(g) * (1. - gamma)
     x = x - step_size(i) * g / jnp.sqrt(avg_sq_grad + eps)
     return x, avg_sq_grad