def dynamics(self, state, action): """ pressure: (u_in, u_out, normalized pressure) histories action: (u_in, u_out) """ u_ins, u_outs, normalized_pressures = state['u_ins'], state['u_outs'], state['normalized_pressures'] u_in, u_out = action u_in /= 50.0 u_out = u_out * 2.0 - 1.0 u_ins = jnp.roll(u_ins, shift=-1) u_ins = u_ins.at[-1].set(u_in) u_outs = jnp.roll(u_outs, shift=-1) u_outs = u_outs.at[-1].set(u_out) normalized_pressure = jnp.concatenate((u_ins, u_outs, normalized_pressures)) for i in range(0, len(self.weights), 2): normalized_pressure = self.weights[i] @ normalized_pressure + self.weights[i + 1] if i <= len(self.weights) - 4: normalized_pressure = jnp.tanh(normalized_pressure) normalized_pressures = jnp.roll(normalized_pressures, shift=-1) normalized_pressures = normalized_pressures.at[-1].set(normalized_pressure.squeeze()) return {'u_ins': u_ins, 'u_outs': u_outs, 'normalized_pressures': normalized_pressures}
def update(self, cost): # 1. Get gradient estimates delta_M = cost * np.sum(self.eps, axis=0) delta_bias = cost * np.sum(self.eps_bias, axis=0) # 2 Execute updates self.M -= self.lr / self.t**0.75 * delta_M self.bias -= self.lr / self.t**0.75 * delta_bias # 3. Ensure norm is of correct size norm = np.linalg.norm(self.M) if (self.project and norm > (1 - self.delta)): self.M *= (1 - self.delta) / norm # 4. Get new epsilon for M self.eps = jax.ops.index_update( self.eps, 0, self._generate_uniform( shape=(self.H, self.m, self.n), norm=np.sqrt(1 - np.linalg.norm(self.eps[1:])**2))) self.eps = np.roll(self.eps, -1, axis=0) # 5. Get new epsilon for bias self.eps_bias = jax.ops.index_update( self.eps_bias, 0, self._generate_uniform( shape=(self.m, 1), norm=np.sqrt(1 - np.linalg.norm(self.eps_bias[1:])**2))) self.eps_bias = np.roll(self.eps_bias, -1, axis=0)
def energy(model, config): @jit def amplitude_diff(i, k): """compute amplitude ratio of logpsi and logpsi_flipped, where i and i+k have their sign flipped.""" flipped = jax.ops.index_mul(config, jax.ops.index[:, [i, (i + k) % N]], -1) logpsi_flipped = log_amplitude(model, flipped) return jnp.exp(logpsi_flipped - logpsi) amplitude_diff = vmap(amplitude_diff, in_axes=(0, None), out_axes=1) logpsi = log_amplitude(model, config) _, N, _ = config.shape idx1 = jnp.arange(end1) idx2 = jnp.arange(end2) # sz*sz term nn1 = config[:, :end1] * jnp.roll(config, -1, axis=1)[:, :end1] nn2 = config[:, :end2] * jnp.roll(config, -2, axis=1)[:, :end2] # sx*sx + sy*sy gives a contribution iff x[i]!=x[i+1] mask1 = nn1 - 1 mask2 = 1 - nn2 E0_J1 = jnp.sum(nn1, axis=1) E0_J2 = jnp.sum(nn2, axis=1) E1_J1 = jnp.sum(mask1 * amplitude_diff(idx1, 1), axis=1) E1_J2 = jnp.sum(mask2 * amplitude_diff(idx2, 2), axis=1) E_J1 = 0.25 * J1 * (E0_J1 + E1_J1) E_J2 = 0.25 * J2 * (E0_J2 + E1_J2) return E_J1 + E_J2
def _categorical_l2_project( z_p: Array, probs: Array, z_q: Array ) -> Array: """Projects a categorical distribution (z_p, p) onto a different support z_q. The projection step minimizes an L2-metric over the cumulative distribution functions (CDFs) of the source and target distributions. Let kq be len(z_q) and kp be len(z_p). This projection works for any support z_q, in particular kq need not be equal to kp. See "A Distributional Perspective on RL" by Bellemare et al. (https://arxiv.org/abs/1707.06887). Args: z_p: support of distribution p. probs: probability values. z_q: support to project distribution (z_p, probs) onto. Returns: Projection of (z_p, p) onto support z_q under Cramer distance. """ chex.assert_rank([z_p, probs, z_q], 1) chex.assert_type([z_p, probs, z_q], float) kp = z_p.shape[0] kq = z_q.shape[0] # Construct helper arrays from z_q. d_pos = jnp.roll(z_q, shift=-1) d_neg = jnp.roll(z_q, shift=1) # Clip z_p to be in new support range (vmin, vmax). z_p = jnp.clip(z_p, z_q[0], z_q[-1])[None, :] assert z_p.shape == (1, kp) # Get the distance between atom values in support. d_pos = (d_pos - z_q)[:, None] # z_q[i+1] - z_q[i] d_neg = (z_q - d_neg)[:, None] # z_q[i] - z_q[i-1] z_q = z_q[:, None] assert z_q.shape == (kq, 1) # Ensure that we do not divide by zero, in case of atoms of identical value. d_neg = jnp.where(d_neg > 0, 1. / d_neg, jnp.zeros_like(d_neg)) d_pos = jnp.where(d_pos > 0, 1. / d_pos, jnp.zeros_like(d_pos)) delta_qp = z_p - z_q # clip(z_p)[j] - z_q[i] d_sign = (delta_qp >= 0.).astype(probs.dtype) assert delta_qp.shape == (kq, kp) assert d_sign.shape == (kq, kp) # Matrix of entries sgn(a_ij) * |a_ij|, with a_ij = clip(z_p)[j] - z_q[i]. delta_hat = (d_sign * delta_qp * d_pos) - ((1. - d_sign) * delta_qp * d_neg) probs = probs[None, :] assert delta_hat.shape == (kq, kp) assert probs.shape == (1, kp) return jnp.sum(jnp.clip(1. - delta_hat, 0., 1.) * probs, axis=-1)
def _step(x, delta_i_x, noise, eps): if(type(self.phi) is list): x_ar = np.dot(x.T, self.phi[self.T]) else: x_ar = np.dot(x.T, self.phi) if(type(self.psi) is list): x_ma = np.dot(noise.T, self.psi[self.T]) else: x_ma = np.dot(noise.T, self.psi) if delta_i_x is not None: x_delta_sum = np.sum(delta_i_x) else : x_delta_sum = 0.0 x_delta_new=self.c + x_ar + x_ma + eps x_new = x_delta_new+x_delta_sum next_x = np.roll(x, self.n) next_noise = np.roll(noise, self.n) next_x = jax.ops.index_update(next_x, 0, x_delta_new) # equivalent to self.x[0] = self.x_new next_noise = jax.ops.index_update(next_noise, 0, eps) # equivalent to self.noise[0] = eps next_delta_i_x=None for i in range(d-1): if i==0: next_delta_i_x=jax.ops.index_update(delta_i_x, i, x_delta_new+delta_i_x[i]) else: next_delta_i_x=jax.ops.index_update(delta_i_x, i, next_delta_i_x[i-1]+next_delta_i_x[i]) return (next_x, next_delta_i_x, next_noise, x_new)
def test_azimuthal_equivariance(self, shift, train, downsampling_factor=1, num_filter_params=None): resolution = 8 transformer = _get_transformer() spins = (0, 1, 2) shape = (2, resolution, resolution, len(spins), 2) sphere, _ = test_utils.get_spin_spherical(transformer, shape, spins) rotated_sphere = jnp.roll(sphere, shift, axis=2) model = models.SpinSphericalBlock(num_channels=2, spins_in=spins, spins_out=spins, downsampling_factor=downsampling_factor, num_filter_params=num_filter_params, axis_name=None, transformer=transformer) params = model.init(_JAX_RANDOM_KEY, sphere, train=False) # Add negative bias so that the magnitude nonlinearity is active. params = params.unfreeze() for key, value in params['params']['batch_norm_nonlin'].items(): if 'magnitude_nonlin' in key: value['bias'] -= 0.1 output, _ = model.apply(params, sphere, train=train, mutable=['batch_stats']) rotated_output, _ = model.apply(params, rotated_sphere, train=train, mutable=['batch_stats']) shifted_output = jnp.roll(output, shift // downsampling_factor, axis=2) self.assertAllClose(rotated_output, shifted_output, atol=1e-6)
def stream_vel_init(n, rhoi, g): h = jnp.zeros(n) beta = jnp.full(n, beta_const) dx = Lx / (n * 1.0) hint = jnp.linspace(start=1, stop=n, num=n, dtype=int) hintm1 = jnp.roll(hint, 1) hintp1 = jnp.roll(hint, -1) h = h_left + (h_right - h_left) / Lx * (hint - 0.5) * dx return h, beta, dx
def stream_vel_taud(h, n, dx, rhoi, g): h_minus1 = jnp.roll(h, 1) h_plus1 = jnp.roll(h, -1) f = jnp.append( rhoi * g * h[0] * (h[1] - h[0]) / dx, rhoi * g * h[1:n - 1] * (h_plus1[1:n - 1] - h_minus1[1:n - 1]) / 2. / dx) f = jnp.append(f, rhoi * g * h[n - 1] * (h[n - 1] - h[n - 2]) / dx) fend = .5 * (rhoi * g * (h[n - 1])**2 - rhow * g * R_bed**2) * .5 return f, fend
def body_fun(i, p_val): coeff_0 = d0_mask_3d[i] coeff_1 = d1_mask_3d[i] h = (jnp.einsum( 'ij,ijk->ijk', coeff_0, jnp.einsum('ijk,k->ijk', jnp.roll(p_val, shift=1, axis=1), x)) - jnp.einsum('ij,ijk->ijk', coeff_1, jnp.roll( p_val, shift=2, axis=1))) p_val = p_val + h return p_val
def apply_model_to_azimuthally_rotated_pairs( transformer, model, resolution, spins, shift, init_args=None, apply_args=None, ): """Applies model to rotated pair of inputs and returns rotated outputs. Useful for equivariance tests where interpolations due to arbitrary rotations cause large errors. Azimuthal rotations by integer shifts can be performed exactly. The model is initialized and applied to a pair of azimuthally rotated inputs. One output is rotated into the other. If the model is azimuthally rotation equivariant, outputs must match. Args: transformer: transformer: SpinSphericalFourierTransformer instance. model: linen module to evaluate. resolution: input spherical grid is (resolution, resolution). spins: A sequence of (n_spins,) input and output spin weights. shift: Azimuthal rotation, in pixels. init_args: extra arguments for `model.init`. apply_args: extra arguments for `model.apply`. Returns: output: result of rotate(model(input)). rotated_output: result of model(rotate(input)). """ if init_args is None: init_args = {} if apply_args is None: apply_args = {} key = np.array([0, 0], dtype=np.uint32) shape = (2, resolution, resolution, len(spins), 2) sphere, _ = get_spin_spherical(transformer, shape, spins) rotated_sphere = jnp.roll(sphere, shift, axis=2) params = model.init(key, sphere, **init_args) # `apply` returns either `output` or `(output, vars)`. output = model.apply(params, sphere, **apply_args) rotated_output = model.apply(params, rotated_sphere, **apply_args) if isinstance(output, tuple): output, _ = output rotated_output, _ = rotated_output # If there was subsampling, change shift accordingly. stride = resolution // output.shape[1] output = jnp.roll(output, shift // stride, axis=2) return output, rotated_output
def __call__(self, controller_state, obs): state, t = obs.predicted_pressure, obs.time errs, waveform = controller_state.errs, controller_state.waveform fwd_targets = controller_state.fwd_targets target = waveform.at(t) fwd_t = t + self.fwd_history_len * DEFAULT_DT if self.fwd_history_len > 0: fwd_target = jax.lax.cond(fwd_t >= self.horizon * DEFAULT_DT, lambda x: fwd_targets[-1], lambda x: waveform.at(fwd_t), None) if self.normalize: target_normalized = self.p_normalizer(target).squeeze() state_normalized = self.p_normalizer(state).squeeze() next_errs = jnp.roll(errs, shift=-1) next_errs = next_errs.at[-1].set(target_normalized - state_normalized) if self.fwd_history_len > 0: fwd_target_normalized = self.p_normalizer(fwd_target).squeeze() next_fwd_targets = jnp.roll(fwd_targets, shift=-1) next_fwd_targets = next_fwd_targets.at[-1].set( fwd_target_normalized) else: next_fwd_targets = jnp.array([]) else: next_errs = jnp.roll(errs, shift=-1) next_errs = next_errs.at[-1].set(target - state) if self.fwd_history_len > 0: next_fwd_targets = jnp.roll(fwd_targets, shift=-1) next_fwd_targets = next_fwd_targets.at[-1].set(fwd_target) else: next_fwd_targets = jnp.array([]) controller_state = controller_state.replace( errs=next_errs, fwd_targets=next_fwd_targets) decay = self.decay(waveform, t) def true_func(null_arg): trajectory = jnp.hstack([next_errs, next_fwd_targets]) u_in = self.model_apply({"params": self.params}, trajectory) return u_in.squeeze().astype(jnp.float64) # changed decay compare from None to float(inf) due to cond requirements u_in = jax.lax.cond(jnp.isinf(decay), true_func, lambda x: jnp.array(decay), None) u_in = jax.lax.clamp(0.0, u_in.astype(jnp.float64), self.clip).squeeze() # update controller_state new_dt = jnp.max( jnp.array([DEFAULT_DT, t - proper_time(controller_state.time)])) new_time = t new_steps = controller_state.steps + 1 controller_state = controller_state.replace(time=new_time, steps=new_steps, dt=new_dt) return controller_state, u_in
def actor_loss( policy_params, q_params, alpha, transitions, key, ): obs = transitions.observation if config.use_gcbc: dist_params = networks.policy_network.apply(policy_params, obs) log_prob = networks.log_prob(dist_params, transitions.action) actor_loss = -1.0 * jnp.mean(log_prob) else: state = obs[:, :config.obs_dim] goal = obs[:, config.obs_dim:] if config.random_goals == 0.0: new_state = state new_goal = goal elif config.random_goals == 0.5: new_state = jnp.concatenate([state, state], axis=0) new_goal = jnp.concatenate( [goal, jnp.roll(goal, 1, axis=0)], axis=0) else: assert config.random_goals == 1.0 new_state = state new_goal = jnp.roll(goal, 1, axis=0) new_obs = jnp.concatenate([new_state, new_goal], axis=1) dist_params = networks.policy_network.apply( policy_params, new_obs) action = networks.sample(dist_params, key) log_prob = networks.log_prob(dist_params, action) q_action = networks.q_network.apply(q_params, new_obs, action) if len(q_action.shape) == 3: # twin q trick assert q_action.shape[2] == 2 q_action = jnp.min(q_action, axis=-1) actor_loss = alpha * log_prob - jnp.diag(q_action) assert 0.0 <= config.bc_coef <= 1.0 if config.bc_coef > 0: orig_action = transitions.action if config.random_goals == 0.5: orig_action = jnp.concatenate( [orig_action, orig_action], axis=0) bc_loss = -1.0 * networks.log_prob(dist_params, orig_action) actor_loss = (config.bc_coef * bc_loss + (1 - config.bc_coef) * actor_loss) return jnp.mean(actor_loss)
def H_ising_1(grid: np.array) -> np.float32: """Calculates Hamiltonian for an Ising model with first-order neighbors :param grid: grid with spins :type grid: np.array :return: value of Hamiltonian :rtype: np.float32 """ x = np.roll(grid, 1, axis=1) y = np.roll(grid, 1, axis=0) x = np.sum(np.multiply(grid, x)) # Ising y = np.sum(np.multiply(grid, y)) return -(x+y).astype(np.float32)
def equation_of_motion( self, x: Array, t: int, ) -> Array: """ ODEs to be integrated. """ x_plus_1 = jnp.roll(x,-1) x_minus_1 = jnp.roll(x,1) x_minus_2 = jnp.roll(x,2) dx = (x_plus_1 - x_minus_2) * x_minus_1 - x dx = dx + self.F return dx
def step(carry, eps): """Internal step function for ARMA""" x, noise = carry x_ar = jnp.dot(x.T, phi) x_ma = jnp.dot(noise.T, psi) y = c + x_ar + x_ma + eps next_x = jnp.roll(x, n) next_noise = jnp.roll(noise, n) next_x = jax.ops.index_update(next_x, 0, y) next_noise = jax.ops.index_update(next_noise, 0, eps) return (next_x, next_noise), y
def u_ene(x): # potential energy shape locations = jnp.reshape(x,(n_beads,3)) distances = jnp.sum((locations-jnp.roll(locations,1,axis=0))**(2.0),axis=1)**(0.5) extension_ene = spring_constant*0.5*jnp.sum((distances - equilibirum_distance)**2) curvatures = jnp.sum((locations-2*jnp.roll(locations,1,axis=0)+jnp.roll(locations,2,axis=0))**(2.0),axis=1)**(0.5) / (equilibirum_distance**2) bend_ene = bending_constant*0.5*jnp.sum(curvatures**2) twist_ene = twisting_constant * 0.5 * (4.0*ma.pi*ma.pi) * (linking_number - pywrithe.writhe_jax(locations))**2 overlap_ene = overlap_constant * (0.5*jnp.sum(jnp.tanh( (beam_steric_diameter**2-jnp.sum((locations[:,jnp.newaxis,:] - locations[jnp.newaxis,:,:])**2,axis=-1))/(overlap_distance**2) )+1.0)-n_beads) return extension_ene + bend_ene + twist_ene + overlap_ene
def flip_sequences(inputs: Array, lengths: Array) -> Array: """Flips a sequence of inputs along the time dimension. This function can be used to prepare inputs for the reverse direction of a bidirectional LSTM. It solves the issue that, when naively flipping multiple padded sequences stored in a matrix, the first elements would be padding values for those sequences that were padded. This function keeps the padding at the end, while flipping the rest of the elements. Example: ```python inputs = [[1, 0, 0], [2, 3, 0] [4, 5, 6]] lengths = [1, 2, 3] flip_sequences(inputs, lengths) = [[1, 0, 0], [3, 2, 0], [6, 5, 4]] ``` Args: inputs: An array of input IDs <int>[batch_size, seq_length]. lengths: The length of each sequence <int>[batch_size]. Returns: An ndarray with the flipped inputs. """ max_length = inputs.shape[0] return jnp.flip(jnp.roll(inputs, max_length - lengths, axis=0), axis=0)
def update(self, c_t, x_new): """ Description: Updates internal parameters and then returns the estimated optimal action (only one) Args: None Returns: Estimated optimal action """ self.T += 1 lr = self.learning_rate / np.sqrt(self.T) #get new noise w_new = x_new - np.dot(self.A, self.x) - np.dot(self.B, self.u) #update past noises self.w_past = np.roll(self.w_past, -self.n) self.w_past = jax.ops.index_update(self.w_past, -1, w_new) #set current state self.x = x_new self.M = self.M - lr * self.grad_fn(self.M, self.w_past) curr_norm = np.linalg.norm(self.M) if curr_norm > self.M_norm: self.M *= self.M_norm / curr_norm
def test_deepsets_error(): hilb = nk.hilbert.Particle(N=2, L=1.0, pbc=True) sdim = len(hilb.extent) x = jnp.hstack([jnp.ones(4), -jnp.ones(4)]).reshape(1, -1) xp = jnp.roll(x, sdim) ds = nk.models.DeepSetRelDistance( hilbert=hilb, layers_phi=3, layers_rho=3, features_phi=(10, 10), features_rho=(10, 1), ) with pytest.raises(ValueError): p = ds.init(jax.random.PRNGKey(42), x) with pytest.raises(AssertionError): ds = nk.models.DeepSetRelDistance( hilbert=hilb, layers_phi=2, layers_rho=2, features_phi=(10, 10), features_rho=(10, 2), ) p = ds.init(jax.random.PRNGKey(42), x) with pytest.raises(ValueError): ds = nk.models.DeepSetRelDistance( hilbert=nk.hilbert.Particle(N=2, L=1.0, pbc=False), layers_phi=2, layers_rho=2, features_phi=(10, 10), features_rho=(10, 2), ) p = ds.init(jax.random.PRNGKey(42), x)
def conv_linear_model_linearize_param_flip_at_one(params): """Computes the product of parameters flipping at index 1.""" wl = params[0] for i in range(1, len(params)): wi = jnp.roll(jnp.flip(params[i], 0), 1, 0) wl = circ_1d_conv(wi, wl) return wl
def ambient_flow_log_prob(params: Sequence[jnp.ndarray], fns: Sequence[Callable], y: jnp.ndarray) -> jnp.ndarray: """Compute the log-probability of ambient observations under the transformation given by composing RealNVP bijectors and a permutation bijector between them. Assumes that the base distribution is a standard multivariate normal. Args: params: List of arrays parameterizing the RealNVP bijectors. fns: List of functions that compute the shift and scale of the RealNVP affine transformation. y: Observations whose likelihood under the composition of bijectors should be computed. Returns: out: The log-probability of the observations given the parameters of the bijection composition. """ num_dims = y.shape[-1] num_masked = num_dims - 2 perm = jnp.roll(jnp.arange(num_dims), 1) fldj = 0. for i in reversed(range(args.num_realnvp)): y = permute.inverse(y, perm) fldj += permute.forward_log_det_jacobian() y = realnvp.inverse(y, num_masked, params[i], fns[i]) fldj += realnvp.forward_log_det_jacobian(y, num_masked, params[i], fns[i]) logprob = jspst.multivariate_normal.logpdf(y, jnp.zeros((num_dims, )), 1.) return logprob - fldj
def alternating_layer_ansatz(params, n_qubits, block_size, n_layers, rot_axis='Y'): # TODO(jdk): Check this function later whether we need to revise for scalability. rot_axis = rot_axis.upper() assert rot_axis in ('X', 'Y', 'Z') assert n_qubits % block_size == 0 assert len(params) == n_qubits * n_layers # Initial state state = jnp.array([0] * (2**n_qubits - 1) + [1], dtype=jnp.complex64) for d in range(n_layers): block_indices = jnp.arange(n_qubits) if d % 2: block_indices = jnp.roll(block_indices, -(block_size // 2)) block_indices = jnp.reshape(block_indices, (-1, block_size)) for block_idx in block_indices: state = block(params=params[block_idx + d * n_qubits], qubits=block_idx, state=state, n_qubit=n_qubits, rot_axis=rot_axis) return state
def trans_dist(self, value=None): assert self.word_len > 0 if value is None: value = jnp.eye(self.word_len) # transition-probability matrix value = jnp.roll(value, 1, axis=1) self._trans_dist = distrax.Categorical(probs=value)
def energy(model, config): @jit def amplitude_diff(config, i): """compute amplitude ratio of logpsi and logpsi_flipped, where i and i+1 have their sign flipped.""" flipped = jax.ops.index_mul(config, jax.ops.index[:, [i, (i + 1) % N]], -1) logpsi_flipped = log_amplitude(model, flipped) return jnp.exp(logpsi_flipped - logpsi) vmap_amplitude_diff = vmap(partial(amplitude_diff, config), out_axes=1) logpsi = log_amplitude(model, config) _, N, _ = config.shape idx = jnp.arange(end) # sz*sz term nn = config[:, :end] * jnp.roll(config, -1, axis=1)[:, :end] # sx*sx + sy*sy gives a contribution iff x[i]!=x[i+1] mask = nn - 1 E0 = jnp.sum(nn, axis=1) E1 = jnp.sum(mask * vmap_amplitude_diff(idx), axis=1) E = 0.25 * J * (E0 + E1) return E
def update(self, state: jnp.ndarray, u: jnp.ndarray) -> None: """ Description: update agent internal state. Args: state (jnp.ndarray): Returns: None """ noise = state - self.A @ self.state - self.B @ u self.noise_history = jax.ops.index_update(self.noise_history, 0, noise) self.noise_history = jnp.roll(self.noise_history, -1, axis=0) delta_M, delta_bias = self.grad(self.M, self.noise_history) lr = self.lr_scale lr *= (1 / (self.t + 1)) if self.decay else 1 self.M -= lr * delta_M self.bias -= lr * delta_bias # update state self.state = state self.t += 1
def replay_sample_to_sars_transition( sample: reverb.ReplaySample, is_sequence: bool) -> types.Transition: """Converts the replay sample to a types.Transition. NB: If is_sequence is True then the last next_observation of each sequence is rubbish. Don't train on it. Args: sample: The replay sample is_sequence: If False we expect the sample data to match the types.Transition already. Otherwise we expect a batch of sequences of steps. Returns: A types.Transition built from the sample data. The number of leading dimensions will be unchanged, so expect 2 for sequence based ([Batch, Time]) and 1 ([Batch]) otherwise. NB: If is_sequence is True then the last next_observation of each sequence is rubbish. Don't train on it. """ if not is_sequence: return types.Transition(*sample.data) # Note that the last next_observation is invalid. steps = sample.data return types.Transition( observation=steps.observation, action=steps.action, reward=steps.reward, discount=steps.discount, next_observation=jnp.roll(steps.observation, shift=-1, axis=1))
def update(self, state: jnp.ndarray, action: jnp.ndarray, cost: Real) -> None: """ Description: update agent internal state. Args: state (jnp.ndarray): current state action (jnp.ndarray): action taken cost (Real): scalar cost received Returns: None """ noise = state - self.A @ self.state - self.B @ action self.noise_history = self.noise_history.at[0].set(noise) self.noise_history = jnp.roll(self.noise_history, -1, axis=0) lr = self.lr_scale lr *= (1 / (self.t**(3 / 4) + 1)) if self.decay else 1 delta_M = self.grad(self.M, self.noise_history, cost) self.M -= lr * delta_M self.eps = self.eps.at[0].set( generate_uniform((self.H, self.d_action, self.d_state))) self.eps = np.roll(self.eps, -1, axis=0) self.M += self.delta * self.eps[-1] # update state self.state = state self.t += 1
def update_params(self, obs: jnp.ndarray, u: jnp.ndarray) -> None: """ Description: update agent internal state. Args: state (jnp.ndarray): Returns: None """ # update parameters delta_M = self.grad(self.M, self.G, self.y_nat, self.us) lr = self.lr_scale # lr *= (1/ (self.t+1)) if self.decay else 1 lr = jax.lax.cond(self.decay, lambda x: x * 1 / (self.t + 1), lambda x: 1.0, lr) self.M -= lr * delta_M # if(jnp.linalg.norm(self.M) > self.RM): # self.M *= (self.RM / jnp.linalg.norm(self.M)) self.M = jax.lax.cond( jnp.linalg.norm(self.M) > self.RM, lambda x: x * (self.RM / jnp.linalg.norm(self.M)), lambda x: x, self.M, ) # update us self.us = jnp.roll(self.us, 1, axis=0) self.us = self.us.at[0].set(u) self.t += 1
def step(self, action): u_in, u_out = action u_in = jax.lax.cond(u_in > 0.0, lambda x: x, lambda x: 0.0, u_in) self.in_history = jnp.roll(self.in_history, shift=1) self.in_history = self.in_history.at[0].set(u_in) self.out_history = jnp.roll(self.out_history, shift=1) self.out_history = self.out_history.at[0].set(u_out) self.target = self.waveform.at(self.time) reward = -jnp.abs(self.target - self.state["pressure"]) self.state = self.dynamics(self.state, action) self.time += 1 return self.observation, reward, False, {}
def test_azimuthal_invariance(self, shift): # Make a simple two-layer classifier with pooling for testing. resolutions = [8, 4] transformer = _get_transformer() spins = [[0, -1], [0, 1, 2]] channels = [2, 3] shape = [2, resolutions[0], resolutions[0], len(spins[0]), channels[0]] sphere, _ = test_utils.get_spin_spherical(transformer, shape, spins[0]) rotated_sphere = jnp.roll(sphere, shift, axis=2) model = models.SpinSphericalClassifier(num_classes=5, resolutions=resolutions, spins=spins, widths=channels, axis_name=None, input_transformer=transformer) params = model.init(_JAX_RANDOM_KEY, sphere, train=False) output, _ = model.apply(params, sphere, train=True, mutable=['batch_stats']) rotated_output, _ = model.apply(params, rotated_sphere, train=True, mutable=['batch_stats']) # The classifier should be rotation-invariant. self.assertAllClose(rotated_output, output, atol=1e-6)