def shift_right(x, axis=1): """Shift the input to the right by padding and slicing on axis.""" pad_widths = [(0, 0)] * len(x.shape) pad_widths[axis] = (1, 0) padded = jnp.pad(x, pad_widths, mode='constant', constant_values=x.dtype.type(0)) return lax.dynamic_slice_in_dim(padded, 0, padded.shape[axis] - 1, axis)
def update(batch_idx, __opt_state): """Update func for gradients, includes gradient clipping.""" kl_warmup = kl_warmup_fun(epoch_idx * num_batches + batch_idx) batch_data = lax.dynamic_slice_in_dim(epoch_data, batch_idx * BATCH_SIZE, BATCH_SIZE, axis=0) batch_data = batch_data.astype(np.float32) params = get_params(__opt_state) grads = grad(loss_fn)(params, batch_data, next(batch_keys), BATCH_SIZE, ic_prior, VAR_MIN, kl_warmup, L2_REG) clipped_grads = optimizers.clip_grads(grads, MAX_GRAD_NORM) return opt_update(batch_idx, clipped_grads, __opt_state)
def _use_qr(u, m, n, params): """QDWH iteration using QR decomposition. Args: u: a matrix, with static (padded) shape M x N. m, n: the dynamic shape of the matrix, where m <= M and n <= N. params: the QDWH parameters. """ a, b, c = params M, N = u.shape y = _dynamic_concat(jnp.sqrt(c) * u, jnp.eye(N, dtype=jnp.dtype(u)), m) q, _ = lax_linalg.qr(y, full_matrices=False) # q1 = q[:m, :] q1 = _mask(lax.slice(q, (0, 0), (M, N)), (m, n)) # q2 = (q[m:, :]).T.conj() q2 = lax.dynamic_slice_in_dim(q, m, N, axis=0) q2 = _mask(q2, (n, n)).T.conj() e = b / c u = (e * u + (a - e) / jnp.sqrt(c) * jnp.einsum('ij,jk->ik', q1, q2)) return u
def svd_truncated(A, chi_max=None, cutoff=DEFAULT_CUTOFF, epsilon=DEFAULT_EPS, return_norm_change=False): """ Like `svd`, but keeps at most `chi_max` many singular values and ignores singular values below `cutoff` """ if return_norm_change: U, S, Vh, norm_change = svd_reduced(A, tolerance=cutoff, epsilon=epsilon, return_norm_change=True) if chi_max is not None: k = np.min([chi_max, len(S)]) old_S_norm = np.linalg.norm(S) U = dynamic_slice_in_dim(U, 0, k, 1) S = dynamic_slice_in_dim(S, 0, k, 0) Vh = dynamic_slice_in_dim(Vh, 0, k, 0) norm_change *= np.linalg.norm( S) / old_S_norm # FIXME is this correct? return U, S, Vh, norm_change else: U, S, Vh = svd_reduced(A, tolerance=cutoff, epsilon=epsilon, return_norm_change=False) if chi_max is not None: k = np.min([chi_max, len(S)]) U = dynamic_slice_in_dim(U, 0, k, 1) S = dynamic_slice_in_dim(S, 0, k, 0) Vh = dynamic_slice_in_dim(Vh, 0, k, 0) return U, S, Vh
def svd_reduced(A, tolerance=1e-12, epsilon=DEFAULT_EPS, return_norm_change=False): """ Like `svd`, but ignores singular-values <= `tolerance`. """ U, S, Vh = svd(A, epsilon) if tolerance > 0.: if return_norm_change: old_S_norm = np.linalg.norm(S) k = np.sum(S > tolerance) U = dynamic_slice_in_dim(U, 0, k, 1) S = dynamic_slice_in_dim(S, 0, k, 0) Vh = dynamic_slice_in_dim(Vh, 0, k, 0) norm_change = np.linalg.norm(S) / old_S_norm return U, S, Vh, norm_change else: k = np.sum(S > tolerance) # U = U[:, :k] # S = S[:k] # Vh = Vh[:k, :] U = dynamic_slice_in_dim(U, 0, k, 1) S = dynamic_slice_in_dim(S, 0, k, 0) Vh = dynamic_slice_in_dim(Vh, 0, k, 0) return U, S, Vh else: if return_norm_change: return U, S, Vh, 1. else: return U, S, Vh
def get_batch(i=0, idxs=idxs): ret_idx = lax.dynamic_slice_in_dim(idxs, i * batch_size, batch_size) return tuple( np.take(a, ret_idx, axis=0) if isinstance(a, list) else lax. index_take(a, (ret_idx, ), axes=(0, )) for a in arrays)
def unravel_list(arr): return [ jnp.reshape(lax.dynamic_slice_in_dim(arr, leaves_idx[i], m.size), m.shape).astype(m.dtype) for i, m in enumerate(leaves_metadata) ]
# Run one full epoch rng = random.split(random.fold_in(rng, epoch), 1)[0] opt_state = run_epoch(rng, opt_state, epoch) # Calculate losses # First get params we need params = get_params(opt_state) kl_warmup = kl_warmup_fun(epoch * num_batches) # Run loss_fn on all training data rng, train_keys = utils.keygen(rng, num_batches) epoch_train_loss = [] for batch_ix in range(num_batches): eval_train_dat = lax.dynamic_slice_in_dim(X_train, batch_ix * BATCH_SIZE, BATCH_SIZE, axis=0) eval_train_dat = eval_train_dat.astype(np.float32) _train_loss = loss_fn(params, eval_train_dat, next(train_keys), BATCH_SIZE, ic_prior, VAR_MIN, kl_warmup, L2_REG) epoch_train_loss.append(_train_loss) epoch_train_loss = onp.mean(epoch_train_loss) training_loss.append(epoch_train_loss) # Run loss_fn on validation data. n_valid_batches = int(X_valid.shape[0] * EPOCHS / BATCH_SIZE) rng, valid_keys = utils.keygen(rng, n_valid_batches) epoch_valid_loss = [] for batch_ix in range(n_valid_batches): eval_val_dat = lax.dynamic_slice_in_dim(X_valid,
def dynamics(t, x, u): '''moves the next standard basis vector''' idx = (position(x) + u[0]) % num_states return lax.dynamic_slice_in_dim(jnp.eye(num_states), idx, 1)[0]
def binarize_batch(rng, i, images): i = i % num_batches batch = lax.dynamic_slice_in_dim(images, i * batch_size, batch_size) return random.bernoulli(rng, batch)
def body_fun(carry, i): x = dynamic_slice_in_dim(xs, i * statedim, statedim, 0) u = dynamic_slice_in_dim(us, i * inputdim, inputdim, 0) #x = xs[i*statedim:(i+1)*statedim] #u = us[i*inputdim:(i+1)*inputdim] return carry + self.ocp.L(x, u), 0
def zero_padded_controls_window(U, t): U_pad = jnp.vstack((U, jnp.zeros(U.shape))) return lax.dynamic_slice_in_dim(U_pad, t, T, axis=0)
def fetch_batch(i, images): i = i % num_batches batch = lax.dynamic_slice_in_dim(images, i * batch_size, batch_size) return batch
def unravel_list_batched(arr): return [np.reshape(lax.dynamic_slice_in_dim(arr, leaves_idx[i], m.event_size, axis=batch_dims), m.shape).astype(m.dtype) for i, m in enumerate(leaves_metadata)]
def onerow(i): return jnp.diag(lax.dynamic_slice_in_dim(fmp, i, _W, axis=0))
def get_batch(i, idxs): ret_idx = lax.dynamic_slice_in_dim(idxs, (i + 1) * batch_size, batch_size) return tuple( lax.index_take(a, (ret_idx, ), axes=(0, )) for a in arrays)