Beispiel #1
0
def shift_right(x, axis=1):
    """Shift the input to the right by padding and slicing on axis."""
    pad_widths = [(0, 0)] * len(x.shape)
    pad_widths[axis] = (1, 0)
    padded = jnp.pad(x,
                     pad_widths,
                     mode='constant',
                     constant_values=x.dtype.type(0))
    return lax.dynamic_slice_in_dim(padded, 0, padded.shape[axis] - 1, axis)
Beispiel #2
0
        def update(batch_idx, __opt_state):
            """Update func for gradients, includes gradient clipping."""
            kl_warmup = kl_warmup_fun(epoch_idx * num_batches + batch_idx)

            batch_data = lax.dynamic_slice_in_dim(epoch_data,
                                                  batch_idx * BATCH_SIZE,
                                                  BATCH_SIZE,
                                                  axis=0)
            batch_data = batch_data.astype(np.float32)

            params = get_params(__opt_state)
            grads = grad(loss_fn)(params, batch_data, next(batch_keys),
                                  BATCH_SIZE, ic_prior, VAR_MIN, kl_warmup,
                                  L2_REG)
            clipped_grads = optimizers.clip_grads(grads, MAX_GRAD_NORM)

            return opt_update(batch_idx, clipped_grads, __opt_state)
Beispiel #3
0
def _use_qr(u, m, n, params):
    """QDWH iteration using QR decomposition.

  Args:
  u: a matrix, with static (padded) shape M x N.
  m, n: the dynamic shape of the matrix, where m <= M and n <= N.
  params: the QDWH parameters.
  """
    a, b, c = params
    M, N = u.shape

    y = _dynamic_concat(jnp.sqrt(c) * u, jnp.eye(N, dtype=jnp.dtype(u)), m)
    q, _ = lax_linalg.qr(y, full_matrices=False)
    # q1 = q[:m, :]
    q1 = _mask(lax.slice(q, (0, 0), (M, N)), (m, n))
    # q2 = (q[m:, :]).T.conj()
    q2 = lax.dynamic_slice_in_dim(q, m, N, axis=0)
    q2 = _mask(q2, (n, n)).T.conj()
    e = b / c
    u = (e * u + (a - e) / jnp.sqrt(c) * jnp.einsum('ij,jk->ik', q1, q2))
    return u
Beispiel #4
0
def svd_truncated(A,
                  chi_max=None,
                  cutoff=DEFAULT_CUTOFF,
                  epsilon=DEFAULT_EPS,
                  return_norm_change=False):
    """
    Like `svd`, but keeps at most `chi_max` many singular values and ignores singular values below `cutoff`
    """
    if return_norm_change:
        U, S, Vh, norm_change = svd_reduced(A,
                                            tolerance=cutoff,
                                            epsilon=epsilon,
                                            return_norm_change=True)

        if chi_max is not None:
            k = np.min([chi_max, len(S)])
            old_S_norm = np.linalg.norm(S)
            U = dynamic_slice_in_dim(U, 0, k, 1)
            S = dynamic_slice_in_dim(S, 0, k, 0)
            Vh = dynamic_slice_in_dim(Vh, 0, k, 0)
            norm_change *= np.linalg.norm(
                S) / old_S_norm  # FIXME is this correct?

        return U, S, Vh, norm_change

    else:
        U, S, Vh = svd_reduced(A,
                               tolerance=cutoff,
                               epsilon=epsilon,
                               return_norm_change=False)

        if chi_max is not None:
            k = np.min([chi_max, len(S)])
            U = dynamic_slice_in_dim(U, 0, k, 1)
            S = dynamic_slice_in_dim(S, 0, k, 0)
            Vh = dynamic_slice_in_dim(Vh, 0, k, 0)

        return U, S, Vh
Beispiel #5
0
def svd_reduced(A,
                tolerance=1e-12,
                epsilon=DEFAULT_EPS,
                return_norm_change=False):
    """
    Like `svd`, but ignores singular-values <= `tolerance`.
    """
    U, S, Vh = svd(A, epsilon)

    if tolerance > 0.:

        if return_norm_change:
            old_S_norm = np.linalg.norm(S)

            k = np.sum(S > tolerance)
            U = dynamic_slice_in_dim(U, 0, k, 1)
            S = dynamic_slice_in_dim(S, 0, k, 0)
            Vh = dynamic_slice_in_dim(Vh, 0, k, 0)
            norm_change = np.linalg.norm(S) / old_S_norm

            return U, S, Vh, norm_change

        else:
            k = np.sum(S > tolerance)
            # U = U[:, :k]
            # S = S[:k]
            # Vh = Vh[:k, :]
            U = dynamic_slice_in_dim(U, 0, k, 1)
            S = dynamic_slice_in_dim(S, 0, k, 0)
            Vh = dynamic_slice_in_dim(Vh, 0, k, 0)
            return U, S, Vh

    else:
        if return_norm_change:
            return U, S, Vh, 1.
        else:
            return U, S, Vh
Beispiel #6
0
 def get_batch(i=0, idxs=idxs):
     ret_idx = lax.dynamic_slice_in_dim(idxs, i * batch_size, batch_size)
     return tuple(
         np.take(a, ret_idx, axis=0) if isinstance(a, list) else lax.
         index_take(a, (ret_idx, ), axes=(0, )) for a in arrays)
Beispiel #7
0
 def unravel_list(arr):
     return [
         jnp.reshape(lax.dynamic_slice_in_dim(arr, leaves_idx[i], m.size),
                     m.shape).astype(m.dtype)
         for i, m in enumerate(leaves_metadata)
     ]
Beispiel #8
0
        # Run one full epoch
        rng = random.split(random.fold_in(rng, epoch), 1)[0]
        opt_state = run_epoch(rng, opt_state, epoch)

        # Calculate losses
        # First get params we need
        params = get_params(opt_state)
        kl_warmup = kl_warmup_fun(epoch * num_batches)

        # Run loss_fn on all training data
        rng, train_keys = utils.keygen(rng, num_batches)
        epoch_train_loss = []
        for batch_ix in range(num_batches):
            eval_train_dat = lax.dynamic_slice_in_dim(X_train,
                                                      batch_ix * BATCH_SIZE,
                                                      BATCH_SIZE,
                                                      axis=0)
            eval_train_dat = eval_train_dat.astype(np.float32)
            _train_loss = loss_fn(params, eval_train_dat, next(train_keys),
                                  BATCH_SIZE, ic_prior, VAR_MIN, kl_warmup,
                                  L2_REG)
            epoch_train_loss.append(_train_loss)
        epoch_train_loss = onp.mean(epoch_train_loss)
        training_loss.append(epoch_train_loss)

        # Run loss_fn on validation data.
        n_valid_batches = int(X_valid.shape[0] * EPOCHS / BATCH_SIZE)
        rng, valid_keys = utils.keygen(rng, n_valid_batches)
        epoch_valid_loss = []
        for batch_ix in range(n_valid_batches):
            eval_val_dat = lax.dynamic_slice_in_dim(X_valid,
Beispiel #9
0
 def dynamics(t, x, u):
     '''moves  the next standard basis vector'''
     idx = (position(x) + u[0]) % num_states
     return lax.dynamic_slice_in_dim(jnp.eye(num_states), idx, 1)[0]
Beispiel #10
0
 def binarize_batch(rng, i, images):
   i = i % num_batches
   batch = lax.dynamic_slice_in_dim(images, i * batch_size, batch_size)
   return random.bernoulli(rng, batch)
Beispiel #11
0
 def body_fun(carry, i):
     x = dynamic_slice_in_dim(xs, i * statedim, statedim, 0)
     u = dynamic_slice_in_dim(us, i * inputdim, inputdim, 0)
     #x = xs[i*statedim:(i+1)*statedim]
     #u = us[i*inputdim:(i+1)*inputdim]
     return carry + self.ocp.L(x, u), 0
Beispiel #12
0
 def zero_padded_controls_window(U, t):
     U_pad = jnp.vstack((U, jnp.zeros(U.shape)))
     return lax.dynamic_slice_in_dim(U_pad, t, T, axis=0)
Beispiel #13
0
 def fetch_batch(i, images):
     i = i % num_batches
     batch = lax.dynamic_slice_in_dim(images, i * batch_size, batch_size)
     return batch
Beispiel #14
0
 def unravel_list_batched(arr):
     return [np.reshape(lax.dynamic_slice_in_dim(arr, leaves_idx[i], m.event_size, axis=batch_dims),
                        m.shape).astype(m.dtype)
             for i, m in enumerate(leaves_metadata)]
Beispiel #15
0
 def onerow(i):
     return jnp.diag(lax.dynamic_slice_in_dim(fmp, i, _W, axis=0))
Beispiel #16
0
 def get_batch(i, idxs):
     ret_idx = lax.dynamic_slice_in_dim(idxs, (i + 1) * batch_size,
                                        batch_size)
     return tuple(
         lax.index_take(a, (ret_idx, ), axes=(0, )) for a in arrays)