示例#1
0
    def dynamics(self, state, action):
        """
        pressure: (u_in, u_out, normalized pressure) histories
        action: (u_in, u_out)
        """

        u_ins, u_outs, normalized_pressures = state['u_ins'], state['u_outs'], state['normalized_pressures']
        u_in, u_out = action

        u_in /= 50.0
        u_out = u_out * 2.0 - 1.0

        u_ins = jnp.roll(u_ins, shift=-1)
        u_ins = u_ins.at[-1].set(u_in)
        u_outs = jnp.roll(u_outs, shift=-1)
        u_outs = u_outs.at[-1].set(u_out)

        normalized_pressure = jnp.concatenate((u_ins, u_outs, normalized_pressures))
        for i in range(0, len(self.weights), 2):
            normalized_pressure = self.weights[i] @ normalized_pressure + self.weights[i + 1]
            if i <= len(self.weights) - 4:
                normalized_pressure = jnp.tanh(normalized_pressure)

        normalized_pressures = jnp.roll(normalized_pressures, shift=-1)
        normalized_pressures = normalized_pressures.at[-1].set(normalized_pressure.squeeze())

        return {'u_ins': u_ins, 'u_outs': u_outs, 'normalized_pressures': normalized_pressures}
示例#2
0
    def update(self, cost):
        # 1. Get gradient estimates
        delta_M = cost * np.sum(self.eps, axis=0)
        delta_bias = cost * np.sum(self.eps_bias, axis=0)

        # 2 Execute updates
        self.M -= self.lr / self.t**0.75 * delta_M
        self.bias -= self.lr / self.t**0.75 * delta_bias

        # 3. Ensure norm is of correct size
        norm = np.linalg.norm(self.M)
        if (self.project and norm > (1 - self.delta)):
            self.M *= (1 - self.delta) / norm

        # 4. Get new epsilon for M
        self.eps = jax.ops.index_update(
            self.eps, 0,
            self._generate_uniform(
                shape=(self.H, self.m, self.n),
                norm=np.sqrt(1 - np.linalg.norm(self.eps[1:])**2)))
        self.eps = np.roll(self.eps, -1, axis=0)

        # 5. Get new epsilon for bias
        self.eps_bias = jax.ops.index_update(
            self.eps_bias, 0,
            self._generate_uniform(
                shape=(self.m, 1),
                norm=np.sqrt(1 - np.linalg.norm(self.eps_bias[1:])**2)))
        self.eps_bias = np.roll(self.eps_bias, -1, axis=0)
示例#3
0
    def energy(model, config):
        @jit
        def amplitude_diff(i, k):
            """compute amplitude ratio of logpsi and logpsi_flipped, where i and i+k
            have their sign flipped."""
            flipped = jax.ops.index_mul(config,
                                        jax.ops.index[:, [i, (i + k) % N]], -1)
            logpsi_flipped = log_amplitude(model, flipped)
            return jnp.exp(logpsi_flipped - logpsi)

        amplitude_diff = vmap(amplitude_diff, in_axes=(0, None), out_axes=1)

        logpsi = log_amplitude(model, config)

        _, N, _ = config.shape

        idx1 = jnp.arange(end1)
        idx2 = jnp.arange(end2)

        # sz*sz term
        nn1 = config[:, :end1] * jnp.roll(config, -1, axis=1)[:, :end1]
        nn2 = config[:, :end2] * jnp.roll(config, -2, axis=1)[:, :end2]
        # sx*sx + sy*sy gives a contribution iff x[i]!=x[i+1]
        mask1 = nn1 - 1
        mask2 = 1 - nn2

        E0_J1 = jnp.sum(nn1, axis=1)
        E0_J2 = jnp.sum(nn2, axis=1)
        E1_J1 = jnp.sum(mask1 * amplitude_diff(idx1, 1), axis=1)
        E1_J2 = jnp.sum(mask2 * amplitude_diff(idx2, 2), axis=1)

        E_J1 = 0.25 * J1 * (E0_J1 + E1_J1)
        E_J2 = 0.25 * J2 * (E0_J2 + E1_J2)
        return E_J1 + E_J2
示例#4
0
def _categorical_l2_project(
    z_p: Array,
    probs: Array,
    z_q: Array
) -> Array:
  """Projects a categorical distribution (z_p, p) onto a different support z_q.

  The projection step minimizes an L2-metric over the cumulative distribution
  functions (CDFs) of the source and target distributions.

  Let kq be len(z_q) and kp be len(z_p). This projection works for any
  support z_q, in particular kq need not be equal to kp.

  See "A Distributional Perspective on RL" by Bellemare et al.
  (https://arxiv.org/abs/1707.06887).

  Args:
    z_p: support of distribution p.
    probs: probability values.
    z_q: support to project distribution (z_p, probs) onto.

  Returns:
    Projection of (z_p, p) onto support z_q under Cramer distance.
  """
  chex.assert_rank([z_p, probs, z_q], 1)
  chex.assert_type([z_p, probs, z_q], float)

  kp = z_p.shape[0]
  kq = z_q.shape[0]

  # Construct helper arrays from z_q.
  d_pos = jnp.roll(z_q, shift=-1)
  d_neg = jnp.roll(z_q, shift=1)

  # Clip z_p to be in new support range (vmin, vmax).
  z_p = jnp.clip(z_p, z_q[0], z_q[-1])[None, :]
  assert z_p.shape == (1, kp)

  # Get the distance between atom values in support.
  d_pos = (d_pos - z_q)[:, None]  # z_q[i+1] - z_q[i]
  d_neg = (z_q - d_neg)[:, None]  # z_q[i] - z_q[i-1]
  z_q = z_q[:, None]
  assert z_q.shape == (kq, 1)

  # Ensure that we do not divide by zero, in case of atoms of identical value.
  d_neg = jnp.where(d_neg > 0, 1. / d_neg, jnp.zeros_like(d_neg))
  d_pos = jnp.where(d_pos > 0, 1. / d_pos, jnp.zeros_like(d_pos))

  delta_qp = z_p - z_q  # clip(z_p)[j] - z_q[i]
  d_sign = (delta_qp >= 0.).astype(probs.dtype)
  assert delta_qp.shape == (kq, kp)
  assert d_sign.shape == (kq, kp)

  # Matrix of entries sgn(a_ij) * |a_ij|, with a_ij = clip(z_p)[j] - z_q[i].
  delta_hat = (d_sign * delta_qp * d_pos) - ((1. - d_sign) * delta_qp * d_neg)
  probs = probs[None, :]
  assert delta_hat.shape == (kq, kp)
  assert probs.shape == (1, kp)

  return jnp.sum(jnp.clip(1. - delta_hat, 0., 1.) * probs, axis=-1)
        def _step(x, delta_i_x, noise, eps):

            if(type(self.phi) is list):
                x_ar = np.dot(x.T, self.phi[self.T])
            else:
                x_ar = np.dot(x.T, self.phi)

            if(type(self.psi) is list):
                x_ma = np.dot(noise.T, self.psi[self.T])
            else:
                x_ma = np.dot(noise.T, self.psi)
            if delta_i_x is not None:
                x_delta_sum = np.sum(delta_i_x)
            else :
                x_delta_sum = 0.0
            x_delta_new=self.c + x_ar + x_ma + eps
            x_new = x_delta_new+x_delta_sum

            next_x = np.roll(x, self.n) 
            next_noise = np.roll(noise, self.n)
            
            next_x = jax.ops.index_update(next_x, 0, x_delta_new) # equivalent to self.x[0] = self.x_new
            next_noise = jax.ops.index_update(next_noise, 0, eps) # equivalent to self.noise[0] = eps  
            next_delta_i_x=None
            for i in range(d-1):
                if i==0:
                    next_delta_i_x=jax.ops.index_update(delta_i_x, i, x_delta_new+delta_i_x[i]) 
                else:
                    next_delta_i_x=jax.ops.index_update(delta_i_x, i, next_delta_i_x[i-1]+next_delta_i_x[i]) 
            
            return (next_x, next_delta_i_x, next_noise, x_new)
  def test_azimuthal_equivariance(self, shift, train,
                                  downsampling_factor=1,
                                  num_filter_params=None):
    resolution = 8
    transformer = _get_transformer()
    spins = (0, 1, 2)
    shape = (2, resolution, resolution, len(spins), 2)

    sphere, _ = test_utils.get_spin_spherical(transformer, shape, spins)
    rotated_sphere = jnp.roll(sphere, shift, axis=2)

    model = models.SpinSphericalBlock(num_channels=2,
                                      spins_in=spins,
                                      spins_out=spins,
                                      downsampling_factor=downsampling_factor,
                                      num_filter_params=num_filter_params,
                                      axis_name=None,
                                      transformer=transformer)
    params = model.init(_JAX_RANDOM_KEY, sphere, train=False)

    # Add negative bias so that the magnitude nonlinearity is active.
    params = params.unfreeze()
    for key, value in params['params']['batch_norm_nonlin'].items():
      if 'magnitude_nonlin' in key:
        value['bias'] -= 0.1

    output, _ = model.apply(params, sphere, train=train,
                            mutable=['batch_stats'])
    rotated_output, _ = model.apply(params, rotated_sphere, train=train,
                                    mutable=['batch_stats'])
    shifted_output = jnp.roll(output, shift // downsampling_factor, axis=2)

    self.assertAllClose(rotated_output, shifted_output, atol=1e-6)
def stream_vel_init(n, rhoi, g):
    h = jnp.zeros(n)
    beta = jnp.full(n, beta_const)
    dx = Lx / (n * 1.0)
    hint = jnp.linspace(start=1, stop=n, num=n, dtype=int)
    hintm1 = jnp.roll(hint, 1)
    hintp1 = jnp.roll(hint, -1)
    h = h_left + (h_right - h_left) / Lx * (hint - 0.5) * dx
    return h, beta, dx
def stream_vel_taud(h, n, dx, rhoi, g):
    h_minus1 = jnp.roll(h, 1)
    h_plus1 = jnp.roll(h, -1)
    f = jnp.append(
        rhoi * g * h[0] * (h[1] - h[0]) / dx, rhoi * g * h[1:n - 1] *
        (h_plus1[1:n - 1] - h_minus1[1:n - 1]) / 2. / dx)
    f = jnp.append(f, rhoi * g * h[n - 1] * (h[n - 1] - h[n - 2]) / dx)
    fend = .5 * (rhoi * g * (h[n - 1])**2 - rhow * g * R_bed**2) * .5
    return f, fend
示例#9
0
文件: special.py 项目: jbampton/jax
 def body_fun(i, p_val):
     coeff_0 = d0_mask_3d[i]
     coeff_1 = d1_mask_3d[i]
     h = (jnp.einsum(
         'ij,ijk->ijk', coeff_0,
         jnp.einsum('ijk,k->ijk', jnp.roll(p_val, shift=1, axis=1), x)) -
          jnp.einsum('ij,ijk->ijk', coeff_1, jnp.roll(
              p_val, shift=2, axis=1)))
     p_val = p_val + h
     return p_val
示例#10
0
def apply_model_to_azimuthally_rotated_pairs(
    transformer,
    model,
    resolution,
    spins,
    shift,
    init_args=None,
    apply_args=None,
):
    """Applies model to rotated pair of inputs and returns rotated outputs.

  Useful for equivariance tests where interpolations due to arbitrary rotations
  cause large errors. Azimuthal rotations by integer shifts can be performed
  exactly.

  The model is initialized and applied to a pair of azimuthally rotated
  inputs. One output is rotated into the other. If the model is azimuthally
  rotation equivariant, outputs must match.

  Args:
    transformer: transformer: SpinSphericalFourierTransformer instance.
    model: linen module to evaluate.
    resolution: input spherical grid is (resolution, resolution).
    spins: A sequence of (n_spins,) input and output spin weights.
    shift: Azimuthal rotation, in pixels.
    init_args: extra arguments for `model.init`.
    apply_args: extra arguments for `model.apply`.

  Returns:
    output: result of rotate(model(input)).
    rotated_output: result of model(rotate(input)).
  """
    if init_args is None:
        init_args = {}
    if apply_args is None:
        apply_args = {}

    key = np.array([0, 0], dtype=np.uint32)
    shape = (2, resolution, resolution, len(spins), 2)
    sphere, _ = get_spin_spherical(transformer, shape, spins)
    rotated_sphere = jnp.roll(sphere, shift, axis=2)

    params = model.init(key, sphere, **init_args)
    # `apply` returns either `output` or `(output, vars)`.
    output = model.apply(params, sphere, **apply_args)
    rotated_output = model.apply(params, rotated_sphere, **apply_args)
    if isinstance(output, tuple):
        output, _ = output
        rotated_output, _ = rotated_output

    # If there was subsampling, change shift accordingly.
    stride = resolution // output.shape[1]
    output = jnp.roll(output, shift // stride, axis=2)

    return output, rotated_output
示例#11
0
    def __call__(self, controller_state, obs):
        state, t = obs.predicted_pressure, obs.time
        errs, waveform = controller_state.errs, controller_state.waveform
        fwd_targets = controller_state.fwd_targets
        target = waveform.at(t)
        fwd_t = t + self.fwd_history_len * DEFAULT_DT
        if self.fwd_history_len > 0:
            fwd_target = jax.lax.cond(fwd_t >= self.horizon * DEFAULT_DT,
                                      lambda x: fwd_targets[-1],
                                      lambda x: waveform.at(fwd_t), None)
        if self.normalize:
            target_normalized = self.p_normalizer(target).squeeze()
            state_normalized = self.p_normalizer(state).squeeze()
            next_errs = jnp.roll(errs, shift=-1)
            next_errs = next_errs.at[-1].set(target_normalized -
                                             state_normalized)
            if self.fwd_history_len > 0:
                fwd_target_normalized = self.p_normalizer(fwd_target).squeeze()
                next_fwd_targets = jnp.roll(fwd_targets, shift=-1)
                next_fwd_targets = next_fwd_targets.at[-1].set(
                    fwd_target_normalized)
            else:
                next_fwd_targets = jnp.array([])
        else:
            next_errs = jnp.roll(errs, shift=-1)
            next_errs = next_errs.at[-1].set(target - state)
            if self.fwd_history_len > 0:
                next_fwd_targets = jnp.roll(fwd_targets, shift=-1)
                next_fwd_targets = next_fwd_targets.at[-1].set(fwd_target)
            else:
                next_fwd_targets = jnp.array([])
        controller_state = controller_state.replace(
            errs=next_errs, fwd_targets=next_fwd_targets)
        decay = self.decay(waveform, t)

        def true_func(null_arg):
            trajectory = jnp.hstack([next_errs, next_fwd_targets])
            u_in = self.model_apply({"params": self.params}, trajectory)
            return u_in.squeeze().astype(jnp.float64)

        # changed decay compare from None to float(inf) due to cond requirements
        u_in = jax.lax.cond(jnp.isinf(decay), true_func,
                            lambda x: jnp.array(decay), None)
        u_in = jax.lax.clamp(0.0, u_in.astype(jnp.float64),
                             self.clip).squeeze()
        # update controller_state
        new_dt = jnp.max(
            jnp.array([DEFAULT_DT, t - proper_time(controller_state.time)]))
        new_time = t
        new_steps = controller_state.steps + 1
        controller_state = controller_state.replace(time=new_time,
                                                    steps=new_steps,
                                                    dt=new_dt)
        return controller_state, u_in
示例#12
0
        def actor_loss(
            policy_params,
            q_params,
            alpha,
            transitions,
            key,
        ):
            obs = transitions.observation
            if config.use_gcbc:
                dist_params = networks.policy_network.apply(policy_params, obs)
                log_prob = networks.log_prob(dist_params, transitions.action)
                actor_loss = -1.0 * jnp.mean(log_prob)
            else:
                state = obs[:, :config.obs_dim]
                goal = obs[:, config.obs_dim:]

                if config.random_goals == 0.0:
                    new_state = state
                    new_goal = goal
                elif config.random_goals == 0.5:
                    new_state = jnp.concatenate([state, state], axis=0)
                    new_goal = jnp.concatenate(
                        [goal, jnp.roll(goal, 1, axis=0)], axis=0)
                else:
                    assert config.random_goals == 1.0
                    new_state = state
                    new_goal = jnp.roll(goal, 1, axis=0)

                new_obs = jnp.concatenate([new_state, new_goal], axis=1)
                dist_params = networks.policy_network.apply(
                    policy_params, new_obs)
                action = networks.sample(dist_params, key)
                log_prob = networks.log_prob(dist_params, action)
                q_action = networks.q_network.apply(q_params, new_obs, action)
                if len(q_action.shape) == 3:  # twin q trick
                    assert q_action.shape[2] == 2
                    q_action = jnp.min(q_action, axis=-1)
                actor_loss = alpha * log_prob - jnp.diag(q_action)

                assert 0.0 <= config.bc_coef <= 1.0
                if config.bc_coef > 0:
                    orig_action = transitions.action
                    if config.random_goals == 0.5:
                        orig_action = jnp.concatenate(
                            [orig_action, orig_action], axis=0)

                    bc_loss = -1.0 * networks.log_prob(dist_params,
                                                       orig_action)
                    actor_loss = (config.bc_coef * bc_loss +
                                  (1 - config.bc_coef) * actor_loss)

            return jnp.mean(actor_loss)
示例#13
0
def H_ising_1(grid: np.array) -> np.float32:
    """Calculates Hamiltonian for an Ising model with first-order neighbors

    :param grid: grid with spins
    :type grid: np.array
    :return: value of Hamiltonian
    :rtype: np.float32
    """
    x = np.roll(grid, 1, axis=1)
    y = np.roll(grid, 1, axis=0)
    x = np.sum(np.multiply(grid, x))  # Ising
    y = np.sum(np.multiply(grid, y))
    return -(x+y).astype(np.float32)
 def equation_of_motion(
     self, 
     x: Array, 
     t: int,
 ) -> Array:
   """
   ODEs to be integrated.
   """
   x_plus_1  = jnp.roll(x,-1)
   x_minus_1 = jnp.roll(x,1)
   x_minus_2 = jnp.roll(x,2)
   dx = (x_plus_1 - x_minus_2) * x_minus_1 - x
   dx = dx + self.F
   return dx
示例#15
0
文件: arma.py 项目: NeoTim/timecast
    def step(carry, eps):
        """Internal step function for ARMA"""
        x, noise = carry
        x_ar = jnp.dot(x.T, phi)

        x_ma = jnp.dot(noise.T, psi)

        y = c + x_ar + x_ma + eps
        next_x = jnp.roll(x, n)
        next_noise = jnp.roll(noise, n)

        next_x = jax.ops.index_update(next_x, 0, y)
        next_noise = jax.ops.index_update(next_noise, 0, eps)

        return (next_x, next_noise), y
示例#16
0
def u_ene(x): # potential energy shape
     locations = jnp.reshape(x,(n_beads,3))
     distances = jnp.sum((locations-jnp.roll(locations,1,axis=0))**(2.0),axis=1)**(0.5)
     extension_ene = spring_constant*0.5*jnp.sum((distances - equilibirum_distance)**2)

     curvatures = jnp.sum((locations-2*jnp.roll(locations,1,axis=0)+jnp.roll(locations,2,axis=0))**(2.0),axis=1)**(0.5) / (equilibirum_distance**2)
     bend_ene = bending_constant*0.5*jnp.sum(curvatures**2)

     twist_ene = twisting_constant * 0.5 * (4.0*ma.pi*ma.pi) * (linking_number - pywrithe.writhe_jax(locations))**2

     overlap_ene = overlap_constant * (0.5*jnp.sum(jnp.tanh(
                                                (beam_steric_diameter**2-jnp.sum((locations[:,jnp.newaxis,:] - locations[jnp.newaxis,:,:])**2,axis=-1))/(overlap_distance**2)
                                                )+1.0)-n_beads)

     return extension_ene + bend_ene + twist_ene + overlap_ene
示例#17
0
def flip_sequences(inputs: Array, lengths: Array) -> Array:
    """Flips a sequence of inputs along the time dimension.

  This function can be used to prepare inputs for the reverse direction of a
  bidirectional LSTM. It solves the issue that, when naively flipping multiple
  padded sequences stored in a matrix, the first elements would be padding
  values for those sequences that were padded. This function keeps the padding
  at the end, while flipping the rest of the elements.

  Example:
  ```python
  inputs = [[1, 0, 0],
            [2, 3, 0]
            [4, 5, 6]]
  lengths = [1, 2, 3]
  flip_sequences(inputs, lengths) = [[1, 0, 0],
                                     [3, 2, 0],
                                     [6, 5, 4]]
  ```

  Args:
    inputs: An array of input IDs <int>[batch_size, seq_length].
    lengths: The length of each sequence <int>[batch_size].

  Returns:
    An ndarray with the flipped inputs.
  """
    max_length = inputs.shape[0]
    return jnp.flip(jnp.roll(inputs, max_length - lengths, axis=0), axis=0)
示例#18
0
    def update(self, c_t, x_new):
        """
        Description: Updates internal parameters and then returns the estimated optimal action (only one)
        Args:
            None
        Returns:
            Estimated optimal action
        """

        self.T += 1
        lr = self.learning_rate / np.sqrt(self.T)

        #get new noise
        w_new = x_new - np.dot(self.A, self.x) - np.dot(self.B, self.u)

        #update past noises
        self.w_past = np.roll(self.w_past, -self.n)
        self.w_past = jax.ops.index_update(self.w_past, -1, w_new)

        #set current state
        self.x = x_new

        self.M = self.M - lr * self.grad_fn(self.M, self.w_past)
        curr_norm = np.linalg.norm(self.M)
        if curr_norm > self.M_norm:
            self.M *= self.M_norm / curr_norm
示例#19
0
def test_deepsets_error():
    hilb = nk.hilbert.Particle(N=2, L=1.0, pbc=True)
    sdim = len(hilb.extent)

    x = jnp.hstack([jnp.ones(4), -jnp.ones(4)]).reshape(1, -1)
    xp = jnp.roll(x, sdim)
    ds = nk.models.DeepSetRelDistance(
        hilbert=hilb,
        layers_phi=3,
        layers_rho=3,
        features_phi=(10, 10),
        features_rho=(10, 1),
    )
    with pytest.raises(ValueError):
        p = ds.init(jax.random.PRNGKey(42), x)

    with pytest.raises(AssertionError):
        ds = nk.models.DeepSetRelDistance(
            hilbert=hilb,
            layers_phi=2,
            layers_rho=2,
            features_phi=(10, 10),
            features_rho=(10, 2),
        )
        p = ds.init(jax.random.PRNGKey(42), x)

    with pytest.raises(ValueError):
        ds = nk.models.DeepSetRelDistance(
            hilbert=nk.hilbert.Particle(N=2, L=1.0, pbc=False),
            layers_phi=2,
            layers_rho=2,
            features_phi=(10, 10),
            features_rho=(10, 2),
        )
        p = ds.init(jax.random.PRNGKey(42), x)
示例#20
0
def conv_linear_model_linearize_param_flip_at_one(params):
  """Computes the product of parameters flipping at index 1."""
  wl = params[0]
  for i in range(1, len(params)):
    wi = jnp.roll(jnp.flip(params[i], 0), 1, 0)
    wl = circ_1d_conv(wi, wl)
  return wl
def ambient_flow_log_prob(params: Sequence[jnp.ndarray],
                          fns: Sequence[Callable],
                          y: jnp.ndarray) -> jnp.ndarray:
    """Compute the log-probability of ambient observations under the transformation
    given by composing RealNVP bijectors and a permutation bijector between
    them. Assumes that the base distribution is a standard multivariate normal.

    Args:
        params: List of arrays parameterizing the RealNVP bijectors.
        fns: List of functions that compute the shift and scale of the RealNVP
            affine transformation.
        y: Observations whose likelihood under the composition of bijectors
            should be computed.

    Returns:
        out: The log-probability of the observations given the parameters of the
            bijection composition.

    """
    num_dims = y.shape[-1]
    num_masked = num_dims - 2
    perm = jnp.roll(jnp.arange(num_dims), 1)
    fldj = 0.
    for i in reversed(range(args.num_realnvp)):
        y = permute.inverse(y, perm)
        fldj += permute.forward_log_det_jacobian()
        y = realnvp.inverse(y, num_masked, params[i], fns[i])
        fldj += realnvp.forward_log_det_jacobian(y, num_masked, params[i],
                                                 fns[i])
    logprob = jspst.multivariate_normal.logpdf(y, jnp.zeros((num_dims, )), 1.)
    return logprob - fldj
示例#22
0
def alternating_layer_ansatz(params,
                             n_qubits,
                             block_size,
                             n_layers,
                             rot_axis='Y'):
    # TODO(jdk): Check this function later whether we need to revise for scalability.
    rot_axis = rot_axis.upper()
    assert rot_axis in ('X', 'Y', 'Z')
    assert n_qubits % block_size == 0
    assert len(params) == n_qubits * n_layers

    # Initial state
    state = jnp.array([0] * (2**n_qubits - 1) + [1], dtype=jnp.complex64)

    for d in range(n_layers):
        block_indices = jnp.arange(n_qubits)
        if d % 2:
            block_indices = jnp.roll(block_indices, -(block_size // 2))
        block_indices = jnp.reshape(block_indices, (-1, block_size))
        for block_idx in block_indices:
            state = block(params=params[block_idx + d * n_qubits],
                          qubits=block_idx,
                          state=state,
                          n_qubit=n_qubits,
                          rot_axis=rot_axis)

    return state
示例#23
0
    def trans_dist(self, value=None):
        assert self.word_len > 0

        if value is None:
            value = jnp.eye(self.word_len)  # transition-probability matrix
            value = jnp.roll(value, 1, axis=1)
        self._trans_dist = distrax.Categorical(probs=value)
示例#24
0
    def energy(model, config):
        @jit
        def amplitude_diff(config, i):
            """compute amplitude ratio of logpsi and logpsi_flipped, where i and i+1
            have their sign flipped."""
            flipped = jax.ops.index_mul(config,
                                        jax.ops.index[:, [i, (i + 1) % N]], -1)
            logpsi_flipped = log_amplitude(model, flipped)
            return jnp.exp(logpsi_flipped - logpsi)

        vmap_amplitude_diff = vmap(partial(amplitude_diff, config), out_axes=1)

        logpsi = log_amplitude(model, config)

        _, N, _ = config.shape

        idx = jnp.arange(end)
        # sz*sz term
        nn = config[:, :end] * jnp.roll(config, -1, axis=1)[:, :end]
        # sx*sx + sy*sy gives a contribution iff x[i]!=x[i+1]
        mask = nn - 1

        E0 = jnp.sum(nn, axis=1)
        E1 = jnp.sum(mask * vmap_amplitude_diff(idx), axis=1)

        E = 0.25 * J * (E0 + E1)

        return E
示例#25
0
    def update(self, state: jnp.ndarray, u: jnp.ndarray) -> None:
        """
        Description: update agent internal state.

        Args:
            state (jnp.ndarray):

        Returns:
            None
        """
        noise = state - self.A @ self.state - self.B @ u
        self.noise_history = jax.ops.index_update(self.noise_history, 0, noise)
        self.noise_history = jnp.roll(self.noise_history, -1, axis=0)

        delta_M, delta_bias = self.grad(self.M, self.noise_history)

        lr = self.lr_scale
        lr *= (1 / (self.t + 1)) if self.decay else 1
        self.M -= lr * delta_M
        self.bias -= lr * delta_bias

        # update state
        self.state = state

        self.t += 1
示例#26
0
def replay_sample_to_sars_transition(
    sample: reverb.ReplaySample,
    is_sequence: bool) -> types.Transition:
  """Converts the replay sample to a types.Transition.

  NB: If is_sequence is True then the last next_observation of each sequence is
  rubbish. Don't train on it.

  Args:
    sample: The replay sample
    is_sequence: If False we expect the sample data to match the
      types.Transition already. Otherwise we expect a batch of sequences of
      steps.

  Returns:
    A types.Transition built from the sample data. The number of leading
    dimensions will be unchanged, so expect 2 for sequence based ([Batch, Time])
    and 1 ([Batch]) otherwise.
    NB: If is_sequence is True then the last next_observation of each sequence
    is rubbish. Don't train on it.
  """
  if not is_sequence:
    return types.Transition(*sample.data)
  # Note that the last next_observation is invalid.
  steps = sample.data
  return types.Transition(
      observation=steps.observation,
      action=steps.action,
      reward=steps.reward,
      discount=steps.discount,
      next_observation=jnp.roll(steps.observation, shift=-1, axis=1))
示例#27
0
    def update(self, state: jnp.ndarray, action: jnp.ndarray,
               cost: Real) -> None:
        """
        Description: update agent internal state.

        Args:
            state (jnp.ndarray): current state
            action (jnp.ndarray): action taken
            cost (Real): scalar cost received

        Returns:
            None
        """
        noise = state - self.A @ self.state - self.B @ action
        self.noise_history = self.noise_history.at[0].set(noise)
        self.noise_history = jnp.roll(self.noise_history, -1, axis=0)

        lr = self.lr_scale
        lr *= (1 / (self.t**(3 / 4) + 1)) if self.decay else 1

        delta_M = self.grad(self.M, self.noise_history, cost)
        self.M -= lr * delta_M

        self.eps = self.eps.at[0].set(
            generate_uniform((self.H, self.d_action, self.d_state)))
        self.eps = np.roll(self.eps, -1, axis=0)

        self.M += self.delta * self.eps[-1]

        # update state
        self.state = state

        self.t += 1
示例#28
0
    def update_params(self, obs: jnp.ndarray, u: jnp.ndarray) -> None:
        """
        Description: update agent internal state.

        Args:
            state (jnp.ndarray):

        Returns:
            None
        """

        # update parameters
        delta_M = self.grad(self.M, self.G, self.y_nat, self.us)
        lr = self.lr_scale
        # lr *= (1/ (self.t+1)) if self.decay else 1
        lr = jax.lax.cond(self.decay, lambda x: x * 1 / (self.t + 1), lambda x: 1.0, lr)

        self.M -= lr * delta_M
        # if(jnp.linalg.norm(self.M) > self.RM):
        #     self.M *= (self.RM / jnp.linalg.norm(self.M))

        self.M = jax.lax.cond(
            jnp.linalg.norm(self.M) > self.RM,
            lambda x: x * (self.RM / jnp.linalg.norm(self.M)),
            lambda x: x,
            self.M,
        )

        # update us
        self.us = jnp.roll(self.us, 1, axis=0)
        self.us = self.us.at[0].set(u)

        self.t += 1
示例#29
0
    def step(self, action):
        u_in, u_out = action
        u_in = jax.lax.cond(u_in > 0.0, lambda x: x, lambda x: 0.0, u_in)

        self.in_history = jnp.roll(self.in_history, shift=1)
        self.in_history = self.in_history.at[0].set(u_in)
        self.out_history = jnp.roll(self.out_history, shift=1)
        self.out_history = self.out_history.at[0].set(u_out)

        self.target = self.waveform.at(self.time)
        reward = -jnp.abs(self.target - self.state["pressure"])

        self.state = self.dynamics(self.state, action)
        self.time += 1

        return self.observation, reward, False, {}
  def test_azimuthal_invariance(self, shift):
    # Make a simple two-layer classifier with pooling for testing.
    resolutions = [8, 4]
    transformer = _get_transformer()
    spins = [[0, -1], [0, 1, 2]]
    channels = [2, 3]
    shape = [2, resolutions[0], resolutions[0], len(spins[0]), channels[0]]
    sphere, _ = test_utils.get_spin_spherical(transformer, shape, spins[0])
    rotated_sphere = jnp.roll(sphere, shift, axis=2)

    model = models.SpinSphericalClassifier(num_classes=5,
                                           resolutions=resolutions,
                                           spins=spins,
                                           widths=channels,
                                           axis_name=None,
                                           input_transformer=transformer)
    params = model.init(_JAX_RANDOM_KEY, sphere, train=False)

    output, _ = model.apply(params, sphere, train=True,
                            mutable=['batch_stats'])
    rotated_output, _ = model.apply(params, rotated_sphere, train=True,
                                    mutable=['batch_stats'])

    # The classifier should be rotation-invariant.
    self.assertAllClose(rotated_output, output, atol=1e-6)