Exemplo n.º 1
0
 def jacfn(U, Uold):
     y, vjp_fun = jax.vjp(fn, U, Uold)
     J = jax.vmap(jax.jit(vjp_fun))(np.eye(len(U)))
     return J
Exemplo n.º 2
0
    def step(self, s, a):
        """Apply control, damping, boundary, and collision forces.

    Args:
      s: (p, v, misc), where p and v are [n_entities,2] jnp.float32,
         and misc is child defined
      a: [n_agents, dim_a] jnp.float32

    Returns:
      A state tuple (p, v, misc)
    """
        p, v, misc = s  # [n,2], [n,2], [a_shape]
        f = jnp.zeros_like(p)  # [n,2]
        n = p.shape[0]  # number of entities

        # Calculate control forces
        f_control = jnp.pad(a, ((0, n - a.shape[0]), (0, 0)),
                            mode="constant")  # [n, dim_a]
        f += f_control

        # Calculate damping forces
        f_damping = -1.0 * self.damping * v  # [n,2]
        f = f + f_damping

        # Calculate boundary forces
        bounce = (((p + self.radius >= self.max_p) & (v >= 0.0)) |
                  ((p - self.radius <= self.min_p) & (v <= 0.0)))  # [n,2]
        v_new = (-1.0 * bounce + 1.0 * ~bounce) * v  # [n,2]
        f_boundary = self.mass * (v_new - v) / self.dt  # [n,2]
        f = f + f_boundary

        # Calculate shared quantities for later calculations
        # same: [n,n,1], True if i==j
        same = jnp.expand_dims(jnp.eye(n, dtype=jnp.bool_), axis=-1)
        # p2p: [n,n,2], p2p[i,j,:] is the vector from entity i to entity j
        p2p = p - jnp.expand_dims(p, axis=1)
        # dist: [n,n,1], p2p[i,j,0] is the distance between i and j
        dist = jnp.linalg.norm(p2p, axis=-1, keepdims=True)
        # overlap: [n,n,1], overlap[i,j,0] is the overlap between i and j
        overlap = ((jnp.expand_dims(self.radius, axis=1) +
                    jnp.expand_dims(self.radius, axis=0)) - dist)
        if self.same_position_check:
            # ontop: [n,n,1], ontop[i,j,0] = True if i is at the exact location of j
            ontop = (dist == 0.0)
            # ontop_dir: [n,n,1], (1,0) above diagonal, (-1,0) below diagonal
            ontop_dir = jnp.stack(
                [jnp.triu(jnp.ones((n, n))) * 2 - 1,
                 jnp.zeros((n, n))],
                axis=-1)
            # contact_dir: [n,n,2], contact_dir[i,j,:] is the unit vector in the
            # direction of j from i
            contact_dir = (~ontop * p2p +
                           (ontop * ontop_dir)) / (~ontop * dist + ontop * 1.0)
        else:
            # contact_dir: [n,n,2], contact_dir[i,j,:] is the unit vector in the
            # direction of j from i
            contact_dir = p2p / (dist + same)
        # collideable: [n,n,1], True if i and j are collideable
        collideable = (jnp.expand_dims(self.collideable, axis=1)
                       & jnp.expand_dims(self.collideable, axis=0))
        # overlap: [n,n,1], True if i,j overlap
        overlapping = overlap > 0

        # Calculate collision forces
        # Assume all entities collide with all entities, then mask out
        # non-collisions.
        #
        # For approaching, coliding entities, apply a forces
        # along the direction of collision that results in
        # relative velocities consistent with the coefficient of
        # restitution (c) and preservation of momentum in that
        # direction.
        # momentum: m_a*v_a + m_b*v_b = m_a*v'_a + m_b*v'_b
        # restitution: v'_b - v'_a = -c*(v_b-v_a)
        # solve for v'_a:
        #  v'_a = [m_a*v_a + m_b*v_b + m_b*c*(v_b-v_a)]/(m_a + m_b)
        #
        # v_contact_dir: [n,n] speed of i in dir of j
        v_contact_dir = jnp.sum(jnp.expand_dims(v, axis=-2) * contact_dir,
                                axis=-1)
        # v_approach: [n,n] speed that i,j are approaching each other
        v_approach = jnp.transpose(v_contact_dir) + v_contact_dir
        # momentum: [n,n] joint momentum in direction of contact (i->j)
        momentum = self.mass * v_contact_dir - jnp.transpose(
            self.mass * v_contact_dir)
        # v_result: [n,n] speed of i in dir of j after collision
        v_result = ((momentum + self.restitution * jnp.transpose(self.mass) *
                     (-v_approach)) / (self.mass + jnp.transpose(self.mass)))
        # f_collision: [n,n] force on i in dir of j to realize acceleration
        f_collision = self.mass * (v_result - v_contact_dir) / self.dt
        # f_collision: [n,n,2] force on i to realize acceleration due to
        # collision with j
        f_collision = jnp.expand_dims(f_collision, axis=-1) * contact_dir
        # collision_mask: [n,n,1]
        collision_mask = (collideable & overlapping & ~same &
                          (jnp.expand_dims(v_approach, axis=-1) > 0))
        # f_collision: [n,2], sum of collision forces on i
        f_collision = jnp.sum(f_collision * collision_mask, axis=-2)
        f = f + f_collision

        # Calculate overlapping spring forces
        # This corrects for any overlap due to discrete steps.
        # f_overlap: [n,n,2], force in the negative contact dir due to overlap
        f_overlap = -1.0 * contact_dir * overlap * self.overlap_spring_constant
        # overlapping_mask: [n,n,1], True if i,j are collideable, overlap,
        # and i != j
        overlapping_mask = collideable & overlapping & ~same
        # f_overlap: [n,2], sum of spring forces on i
        f_overlap = jnp.sum(f_overlap * overlapping_mask, axis=-2)
        f = f + f_overlap

        # apply forces
        v = v + (f / self.mass) * self.dt
        p = p + v * self.dt

        # update misc
        misc = self._update_misc((p, v, misc), a)  # pylint: disable=assignment-from-none

        return (p, v, misc)
Exemplo n.º 3
0
  def compute_preconditioners_from_statistics(self, states, hps, step):
    """Compute preconditioners for statistics."""
    statistics = []
    num_statistics_per_state = []
    original_shapes = []
    exponents = []
    max_size = 0
    prev_preconditioners = []
    for state in states:
      num_statistics = len(state.statistics)
      num_statistics_per_state.append(num_statistics)
      original_shapes_for_state = []
      if num_statistics > 0:
        for statistic in state.statistics:
          exponents.append(2 * num_statistics if hps.exponent_override ==
                           0 else hps.exponent_override)
          original_shapes_for_state.append(statistic.shape)
          max_size = max(max_size, statistic.shape[0])
        statistics.extend(state.statistics)
        prev_preconditioners.extend(state.preconditioners)
        original_shapes.extend(original_shapes_for_state)
    num_statistics = len(statistics)

    def pack(mat, max_size):
      """Pack a matrix to a max_size for inverse on TPUs with static shapes.

      Args:
        mat: Matrix for computing inverse pth root.
        max_size: Matrix size to pack to.

      Returns:
        Given M returns [[M, 0], [0, I]]
      """
      size = mat.shape[0]
      assert size <= max_size
      if size == max_size:
        return mat
      pad_size = max_size - size
      zs1 = jnp.zeros([size, pad_size], dtype=mat.dtype)
      zs2 = jnp.zeros([pad_size, size], dtype=mat.dtype)
      eye = jnp.eye(pad_size, dtype=mat.dtype)
      mat = jnp.concatenate([mat, zs1], 1)
      mat = jnp.concatenate([mat, jnp.concatenate([zs2, eye], 1)], 0)
      return mat

    if not hps.batch_axis_name:
      num_devices = jax.local_device_count()
    else:
      num_devices = lax.psum(1, hps.batch_axis_name)

    # Pad statistics and exponents to next multiple of num_devices.
    packed_statistics = [pack(stat, max_size) for stat in statistics]
    to_pad = -num_statistics % num_devices
    packed_statistics.extend([
        jnp.eye(max_size, dtype=packed_statistics[0].dtype)
        for _ in range(to_pad)
    ])
    exponents.extend([1 for _ in range(to_pad)])

    # Batch statistics and exponents so that so that leading axis is
    # num_devices.
    def _batch(statistics, exponents, num_devices):
      assert len(statistics) == len(exponents)
      n = len(statistics)
      b = int(n / num_devices)
      batched_statistics = [
          jnp.stack(statistics[idx:idx + b]) for idx in range(0, n, b)
      ]
      batched_exponents = [
          jnp.stack(exponents[idx:idx + b]) for idx in range(0, n, b)
      ]
      return jnp.stack(batched_statistics), jnp.stack(batched_exponents)

    # Unbatch values across leading axis and return a list of elements.
    def _unbatch(batched_values):
      b1, b2 = batched_values.shape[0], batched_values.shape[1]
      results = []
      for v_array in jnp.split(batched_values, b1, 0):
        for v in jnp.split(jnp.squeeze(v_array), b2, 0):
          results.append(jnp.squeeze(v))
      return results

    all_statistics, all_exponents = _batch(packed_statistics, exponents,
                                           num_devices)

    def _matrix_inverse_pth_root(xs, ps):
      mi_pth_root = lambda x, y: matrix_inverse_pth_root(  # pylint: disable=g-long-lambda
          x, y, ridge_epsilon=hps.matrix_eps)
      preconditioners, errors = jax.vmap(mi_pth_root)(xs, ps)
      return preconditioners, errors

    if not hps.batch_axis_name:
      preconditioners, errors = jax.pmap(_matrix_inverse_pth_root)(
          all_statistics, all_exponents)
      preconditioners_flat = _unbatch(preconditioners)
      errors_flat = _unbatch(errors)
    else:

      def _internal_inverse_pth_root_all():
        preconditioners = jnp.array(all_statistics)
        current_replica = lax.axis_index(hps.batch_axis_name)
        preconditioners, errors = _matrix_inverse_pth_root(
            all_statistics[current_replica], all_exponents[current_replica])
        preconditioners = jax.lax.all_gather(preconditioners,
                                             hps.batch_axis_name)
        errors = jax.lax.all_gather(errors, hps.batch_axis_name)
        preconditioners_flat = _unbatch(preconditioners)
        errors_flat = _unbatch(errors)
        return preconditioners_flat, errors_flat

      if hps.preconditioning_compute_steps == 1:
        preconditioners_flat, errors_flat = _internal_inverse_pth_root_all()
      else:
        # Passing statistics instead of preconditioners as they are similarly
        # shaped tensors, as error we are passing is the threshold these will
        # be ignored.
        preconditioners_init = packed_statistics
        errors_init = ([_INVERSE_PTH_ROOT_FAILURE_THRESHOLD] *
                       len(packed_statistics))
        init_state = [preconditioners_init, errors_init]
        perform_step = step % hps.preconditioning_compute_steps == 0
        preconditioners_flat, errors_flat = self.fast_cond(
            perform_step, _internal_inverse_pth_root_all, init_state)

    def _skip(error):
      return jnp.logical_or(
          jnp.isnan(error),
          error >= _INVERSE_PTH_ROOT_FAILURE_THRESHOLD).astype(error.dtype)

    def _select_preconditioner(error, new_p, old_p):
      return lax.cond(
          _skip(error), lambda _: old_p, lambda _: new_p, operand=None)

    new_preconditioners_flat = []
    for p, shape, prev_p, error in zip(preconditioners_flat, original_shapes,
                                       prev_preconditioners, errors_flat):
      new_preconditioners_flat.append(
          _select_preconditioner(error, p[:shape[0], :shape[1]], prev_p))

    assert len(states) == len(num_statistics_per_state)
    assert len(new_preconditioners_flat) == num_statistics

    # Add back empty preconditioners so we that we can set the optimizer state.
    preconditioners_for_states = []
    idx = 0
    for num_statistics, state in zip(num_statistics_per_state, states):
      if num_statistics == 0:
        preconditioners_for_states.append([])
      else:
        preconditioners_for_state = new_preconditioners_flat[idx:idx +
                                                             num_statistics]
        assert len(state.statistics) == len(preconditioners_for_state)
        preconditioners_for_states.append(preconditioners_for_state)
        idx += num_statistics
    new_states = []
    for state, new_preconditioners in zip(states, preconditioners_for_states):
      new_states.append(
          _ShampooDefaultParamState(state.diagonal_statistics, state.statistics,
                                    new_preconditioners,
                                    state.diagonal_momentum, state.momentum))

    return new_states
Exemplo n.º 4
0
 def metric(self, point: EuclideanPoint) -> PseudoMetric[EuclideanPoint]:
     # it's easier to provide metric instead of metric_in_chart
     return PseudoMetric(ChartPoint.of_point(point, IdChart()), jnp.eye(point.dim))
Exemplo n.º 5
0
 def kernel_to_state_space(self, R=None):
     var_p = 1.
     ell_p = self.lengthscale_periodic
     a = self.b_fmK_2igrid * ell_p**(-2. * self.igrid) * np.exp(
         -1. / ell_p**2.) * var_p
     q2 = np.sum(a, axis=0)
     # The angular frequency
     omega = 2 * np.pi / self.period
     # The model
     F_p = np.kron(np.diag(np.arange(self.order + 1)),
                   np.array([[0., -omega], [omega, 0.]]))
     L_p = np.eye(2 * (self.order + 1))
     # Qc_p = np.zeros(2 * (self.N + 1))
     Pinf_p = np.kron(np.diag(q2), np.eye(2))
     H_p = np.kron(np.ones([1, self.order + 1]), np.array([1., 0.]))
     lam = 3.0**0.5 / self.lengthscale_matern
     F_m = np.array([[0.0, 1.0], [-lam**2, -2 * lam]])
     L_m = np.array([[0], [1]])
     Qc_m = np.array(
         [[12.0 * 3.0**0.5 / self.lengthscale_matern**3.0 * self.variance]])
     H_m = np.array([[1.0, 0.0]])
     Pinf_m = np.array(
         [[self.variance, 0.0],
          [0.0, 3.0 * self.variance / self.lengthscale_matern**2.0]])
     # F = np.kron(F_p, np.eye(2)) + np.kron(np.eye(14), F_m)
     F = np.kron(F_m, np.eye(2 *
                             (self.order + 1))) + np.kron(np.eye(2), F_p)
     L = np.kron(L_m, L_p)
     Qc = np.kron(Qc_m, Pinf_p)
     H = np.kron(H_m, H_p)
     # Pinf = np.kron(Pinf_m, Pinf_p)
     Pinf = block_diag(
         np.kron(Pinf_m, q2[0] * np.eye(2)),
         np.kron(Pinf_m, q2[1] * np.eye(2)),
         np.kron(Pinf_m, q2[2] * np.eye(2)),
         np.kron(Pinf_m, q2[3] * np.eye(2)),
         np.kron(Pinf_m, q2[4] * np.eye(2)),
         np.kron(Pinf_m, q2[5] * np.eye(2)),
         np.kron(Pinf_m, q2[6] * np.eye(2)),
     )
     return F, L, Qc, H, Pinf
Exemplo n.º 6
0
 def apply_fun(params, inputs, **kwargs):
   del params, kwargs
   # Perform one-hot encoding
   return jnp.eye(depth)[inputs.astype(int)]
Exemplo n.º 7
0
def _add_diagonal_regularizer(covariance, diag_reg=0.):
    dimension = covariance.shape[0]
    reg = np.trace(covariance) / dimension
    return covariance + diag_reg * reg * np.eye(dimension)
Exemplo n.º 8
0
def test_CovOp(plot=False):
    from scipy.stats import multivariate_normal

    nsamps = 1000
    samps_unif = None
    regul_C_ref = 0.0001
    D = 1
    import pylab as pl
    if samps_unif is None:
        samps_unif = nsamps
    gk_x = GaussianKernel(0.2)

    targ = mixt(D, [
        multivariate_normal(3 * np.ones(D),
                            np.eye(D) * 0.7**2),
        multivariate_normal(7 * np.ones(D),
                            np.eye(D) * 1.5**2)
    ], [0.5, 0.5])
    out_samps = targ.rvs(nsamps).reshape([nsamps, 1]).astype(float)
    out_fvec = FiniteVec(gk_x, out_samps, np.ones(nsamps))

    #gk_x = LaplaceKernel(3)
    #gk_x = StudentKernel(0.7, 15)
    x = np.linspace(-2.5, 15, samps_unif)[:, np.newaxis].astype(float)
    ref_fvec = FiniteVec(gk_x, x, np.ones(len(x)))
    ref_elem = ref_fvec.sum()

    C_ref = CovOp(
        ref_fvec,
        regul=0.)  # CovOp_compl(out_fvec.k, out_fvec.inspace_points, regul=0.)

    inv_Gram_ref = np.linalg.inv(inner(ref_fvec))
    assert (np.allclose((inv_Gram_ref @ inv_Gram_ref) / C_ref.inv().matr,
                        1.,
                        atol=1e-3))
    #assert(np.allclose(multiply(C_ref.inv(), ref_elem).prefactors, np.sum(np.linalg.inv(inner(ref_fvec)), 0), rtol=1e-02))

    C_samps = CovOp(out_fvec, regul=regul_C_ref)
    unif_obj = multiply(
        C_samps.inv(),
        FiniteVec.construct_RKHS_Elem(out_fvec.k, out_fvec.inspace_points,
                                      out_fvec.prefactors).normalized())
    C_ref = CovOp(ref_fvec, regul=regul_C_ref)
    dens_obj = multiply(
        C_ref.inv(),
        FiniteVec.construct_RKHS_Elem(out_fvec.k, out_fvec.inspace_points,
                                      out_fvec.prefactors)).normalized()

    #dens_obj.prefactors = np.sum(dens_obj.prefactors, 1)
    #dens_obj.prefactors = dens_obj.prefactors / np.sum(dens_obj.prefactors)
    #print(np.sum(dens_obj.prefactors))
    #p = np.sum(inner(dens_obj, ref_fvec), 1)
    targp = np.exp(targ.logpdf(ref_fvec.inspace_points.squeeze())).squeeze()
    estp = np.squeeze(inner(dens_obj, ref_fvec))
    estp2 = np.squeeze(
        inner(dens_obj.unsigned_projection().normalized(), ref_fvec))
    assert (np.abs(targp.squeeze() - estp).mean() < 0.8)
    if plot:
        pl.plot(ref_fvec.inspace_points.squeeze(),
                estp / np.max(estp) * np.max(targp),
                "b--",
                label="scaled estimate")
        pl.plot(ref_fvec.inspace_points.squeeze(),
                estp2 / np.max(estp2) * np.max(targp),
                "g-.",
                label="scaled estimate (uns)")
        pl.plot(ref_fvec.inspace_points.squeeze(), targp, label="truth")

        #pl.plot(ref_fvec.inspace_points.squeeze(), np.squeeze(inner(unif_obj, ref_fvec)), label="unif")
        pl.legend(loc="best")
        pl.show()
    assert (np.std(np.squeeze(inner(unif_obj.normalized(), out_fvec))) < 0.15)
Exemplo n.º 9
0
 def test_gmres_matmul(self):
     A = CustomOperator(2 * jnp.eye(3))
     b = jnp.arange(9.0).reshape(3, 3)
     expected = b / 2
     actual, _ = jax.scipy.sparse.linalg.gmres(A, b)
     self.assertAllClose(expected, actual)
Exemplo n.º 10
0
 def forward(x):
   O = jnp.eye(x.size) - 2*[email protected]([email protected])@VT
   l = x - 2*[email protected]([email protected])@VT@x
   z = util.householder_prod(x, VT)
   z = z*jnp.exp(log_s)
   return util.householder_prod(z, U) + b
Exemplo n.º 11
0
    def build_ilqr_tracking_solver(self, ref_pnts, weight_mats):
        #figure out dimension
        self.T_ = len(ref_pnts)
        self.n_dims_ = len(ref_pnts[0])

        self.ref_array = np.copy(ref_pnts)
        self.weight_array = [mat for mat in weight_mats]
        #clone weight mats if there are not enough weight mats
        for i in range(self.T_ - len(self.weight_array)):
            self.weight_array.append(self.weight_array[-1])

        #build dynamics, second-order linear dynamical system
        self.A_ = np.eye(self.n_dims_*2)
        self.A_[0:self.n_dims_, self.n_dims_:] = np.eye(self.n_dims_) * self.dt_
        self.B_ = np.zeros((self.n_dims_*2, self.n_dims_))
        self.B_[self.n_dims_:, :] = np.eye(self.n_dims_) * self.dt_

        self.plant_dyn_ = lambda x, u, t, aux: np.dot(self.A_, x) + np.dot(self.B_, u)

        #build cost functions, quadratic ones
        def tmp_cost_func(x, u, t, aux):
            err = x[0:self.n_dims_] - self.ref_array[t]
            #autograd does not allow A.dot(B)
            cost = np.dot(np.dot(err, self.weight_array[t]), err) + np.sum(u**2) * self.R_
            if t > self.T_-1:
                #regularize velocity for the termination point
                #autograd does not allow self increment
                cost = cost + np.sum(x[self.n_dims_:]**2)  * self.R_ * self.Q_vel_ratio_
            return cost
        
        self.cost_ = tmp_cost_func
        self.ilqr_ = pylqr.PyLQR_iLQRSolver(T=self.T_-1, plant_dyn=self.plant_dyn_, cost=self.cost_, use_autograd=self.use_autograd)
        if not self.use_autograd:
            self.plant_dyn_dx_ = lambda x, u, t, aux: self.A_
            self.plant_dyn_du_ = lambda x, u, t, aux: self.B_
            
            def tmp_cost_func_dx(x, u, t, aux):
                err = x[0:self.n_dims_] - self.ref_array[t]
                grad = np.concatenate([2*err.dot(self.weight_array[t]), np.zeros(self.n_dims_)])
                if t > self.T_-1:
                    grad[self.n_dims_:] = grad[self.n_dims_:] + 2 * self.R_ * self.Q_vel_ratio_ * x[self.n_dims_, :]
                return grad

            self.cost_dx_ = tmp_cost_func_dx

            self.cost_du_ = lambda x, u, t, aux: 2 * self.R_ * u

            def tmp_cost_func_dxx(x, u, t, aux):
                hessian = np.zeros((2*self.n_dims_, 2*self.n_dims_))
                hessian[0:self.n_dims_, 0:self.n_dims_] = 2 * self.weight_array[t]

                if t > self.T_-1:
                    hessian[self.n_dims_:, self.n_dims_:] = 2 * np.eye(self.n_dims_) * self.R_ * self.Q_vel_ratio_
                return hessian

            self.cost_dxx_ = tmp_cost_func_dxx

            self.cost_duu_ = lambda x, u, t, aux: 2 * self.R_ * np.eye(self.n_dims_)
            self.cost_dux_ = lambda x, u, t, aux: np.zeros((self.n_dims_, 2*self.n_dims_))

            #build an iLQR solver based on given functions...
            self.ilqr_.plant_dyn_dx = self.plant_dyn_dx_
            self.ilqr_.plant_dyn_du = self.plant_dyn_du_
            self.ilqr_.cost_dx = self.cost_dx_
            self.ilqr_.cost_du = self.cost_du_
            self.ilqr_.cost_dxx = self.cost_dxx_
            self.ilqr_.cost_duu = self.cost_duu_
            self.ilqr_.cost_dux = self.cost_dux_

        return
Exemplo n.º 12
0
    def filter(self, x_hist, jump_size, dt):
        """
        Compute the online version of the Kalman-Filter, i.e,
        the one-step-ahead prediction for the hidden state or the
        time update step
        
        Parameters
        ----------
        x_hist: array(timesteps, observation_size)
            
        Returns
        -------
        * array(timesteps, state_size):
            Filtered means mut
        * array(timesteps, state_size, state_size)
            Filtered covariances Sigmat
        * array(timesteps, state_size)
            Filtered conditional means mut|t-1
        * array(timesteps, state_size, state_size)
            Filtered conditional covariances Sigmat|t-1
        """
        I = jnp.eye(self.state_size)
        timesteps, *_ = x_hist.shape
        mu_hist = jnp.zeros((timesteps, self.state_size))
        Sigma_hist = jnp.zeros((timesteps, self.state_size, self.state_size))
        Sigma_cond_hist = jnp.zeros(
            (timesteps, self.state_size, self.state_size))
        mu_cond_hist = jnp.zeros((timesteps, self.state_size))

        # Initial configuration
        K1 = self.Sigma0 @ self.C.T @ inv(self.C @ self.Sigma0 @ self.C.T +
                                          self.R)
        mu1 = self.mu0 + K1 @ (x_hist[0] - self.C @ self.mu0)
        Sigma1 = (I - K1 @ self.C) @ self.Sigma0

        mu_hist = index_update(mu_hist, 0, mu1)
        Sigma_hist = index_update(Sigma_hist, 0, Sigma1)
        mu_cond_hist = index_update(mu_cond_hist, 0, self.mu0)
        Sigma_cond_hist = index_update(Sigma_hist, 0, self.Sigma0)

        Sigman = Sigma1.copy()
        mun = mu1.copy()
        for n in range(1, timesteps):
            # Runge-kutta integration step
            for _ in range(jump_size):
                k1 = self.A @ mun
                k2 = self.A @ (mun + dt * k1)
                mun = mun + dt * (k1 + k2) / 2

                k1 = self.A @ Sigman @ self.A.T + self.Q
                k2 = self.A @ (Sigman + dt * k1) @ self.A.T + self.Q
                Sigman = Sigman + dt * (k1 + k2) / 2

            Sigman_cond = Sigman.copy()
            St = self.C @ Sigman_cond @ self.C.T + self.R
            Kn = Sigman_cond @ self.C.T @ inv(St)

            mu_update = mun.copy()
            x_update = self.C @ mun
            mun = mu_update + Kn @ (x_hist[n] - x_update)
            Sigman = (I - Kn @ self.C) @ Sigman_cond

            mu_hist = index_update(mu_hist, n, mun)
            Sigma_hist = index_update(Sigma_hist, n, Sigman)
            mu_cond_hist = index_update(mu_cond_hist, n, mu_update)
            Sigma_cond_hist = index_update(Sigma_cond_hist, n, Sigman_cond)

        return mu_hist, Sigma_hist, mu_cond_hist, Sigma_cond_hist
Exemplo n.º 13
0
        fpr, tpr, _ = skm.roc_curve(true_labels, pred_labels)
        roc_auc = skm.auc(fpr, tpr)

        return fpr, tpr, roc_auc

    seed = 1234
    rng_key = jax.random.PRNGKey(seed)

    num_data = 1000
    data_dim = 5

    for offset in np.linspace(0, 0.5, 3):

        mean = np.zeros(data_dim)
        cov = np.eye(data_dim)
        a_samples = jax.random.multivariate_normal(rng_key,
                                                   mean=mean,
                                                   cov=cov,
                                                   shape=(num_data, ))
        b_samples = jax.random.multivariate_normal(rng_key,
                                                   mean=mean + offset,
                                                   cov=cov,
                                                   shape=(num_data, ))

        # Trained discriminator
        fpr, tpr, roc_auc = ROC_AUC(rng_key,
                                    a_samples,
                                    b_samples,
                                    train_nsteps=1000)
        opt_fpr, opt_tpr, opt_roc_auc = ROC_of_true_discriminator(
Exemplo n.º 14
0
    def run(self, output_h5parm, ncpu, avg_direction_spacing,
            field_of_view_diameter, duration, time_resolution, start_time,
            array_name, phase_tracking):

        Nd = get_num_directions(
            avg_direction_spacing,
            field_of_view_diameter,
        )
        Nf = 2  # 8000
        Nt = int(duration / time_resolution) + 1
        min_freq = 700.
        max_freq = 2000.
        dp = create_empty_datapack(
            Nd,
            Nf,
            Nt,
            pols=None,
            field_of_view_diameter=field_of_view_diameter,
            start_time=start_time,
            time_resolution=time_resolution,
            min_freq=min_freq,
            max_freq=max_freq,
            array_file=ARRAYS[array_name],
            phase_tracking=(phase_tracking.ra.deg, phase_tracking.dec.deg),
            save_name=output_h5parm,
            clobber=True)

        with dp:
            dp.current_solset = 'sol000'
            dp.select(pol=slice(0, 1, 1))
            axes = dp.axes_tec
            patch_names, directions = dp.get_directions(axes['dir'])
            antenna_labels, antennas = dp.get_antennas(axes['ant'])
            timestamps, times = dp.get_times(axes['time'])
            ref_ant = antennas[0]
            ref_time = times[0]

        Na = len(antennas)
        Nd = len(directions)
        Nt = len(times)

        logger.info(f"Number of directions: {Nd}")
        logger.info(f"Number of antennas: {Na}")
        logger.info(f"Number of times: {Nt}")
        logger.info(f"Reference Ant: {ref_ant}")
        logger.info(f"Reference Time: {ref_time.isot}")

        # Plot Antenna Layout in East North Up frame
        ref_frame = ENU(obstime=ref_time, location=ref_ant.earth_location)

        _antennas = ac.ITRS(*antennas.cartesian.xyz,
                            obstime=ref_time).transform_to(ref_frame)
        # plt.scatter(_antennas.east, _antennas.north, marker='+')
        # plt.xlabel(f"East (m)")
        # plt.ylabel(f"North (m)")
        # plt.show()

        x0 = ac.ITRS(
            *antennas[0].cartesian.xyz,
            obstime=ref_time).transform_to(ref_frame).cartesian.xyz.to(
                au.km).value
        earth_centre_x = ac.ITRS(
            x=0 * au.m, y=0 * au.m, z=0. * au.m,
            obstime=ref_time).transform_to(ref_frame).cartesian.xyz.to(
                au.km).value
        self._kernel = TomographicKernel(x0,
                                         earth_centre_x,
                                         M32(),
                                         S_marg=20,
                                         compute_tec=False)

        k = directions.transform_to(ref_frame).cartesian.xyz.value.T

        t = times.mjd * 86400.
        t -= t[0]

        X1 = GeodesicTuple(x=[], k=[], t=[], ref_x=[])

        logger.info("Computing coordinates in frame ...")

        for i, time in tqdm(enumerate(times)):
            x = ac.ITRS(*antennas.cartesian.xyz,
                        obstime=time).transform_to(ref_frame).cartesian.xyz.to(
                            au.km).value.T
            ref_ant_x = ac.ITRS(
                *ref_ant.cartesian.xyz,
                obstime=time).transform_to(ref_frame).cartesian.xyz.to(
                    au.km).value

            X = make_coord_array(x,
                                 k,
                                 t[i:i + 1, None],
                                 ref_ant_x[None, :],
                                 flat=True)

            X1.x.append(X[:, 0:3])
            X1.k.append(X[:, 3:6])
            X1.t.append(X[:, 6:7])
            X1.ref_x.append(X[:, 7:8])

        X1 = X1._replace(
            x=jnp.concatenate(X1.x, axis=0),
            k=jnp.concatenate(X1.k, axis=0),
            t=jnp.concatenate(X1.t, axis=0),
            ref_x=jnp.concatenate(X1.ref_x, axis=0),
        )

        logger.info(f"Total number of coordinates: {X1.x.shape[0]}")

        def compute_covariance_row(X1: GeodesicTuple, X2: GeodesicTuple):
            K = self._kernel(X1,
                             X2,
                             self._bottom,
                             self._width,
                             self._fed_sigma,
                             self._fed_kernel_params,
                             wind_velocity=self._wind_vector)  # 1, N
            return K[0, :]

        covariance_row = lambda X: compute_covariance_row(
            tree_map(lambda x: x.reshape((1, -1)), X), X1)

        mean = jit(lambda X1: self._kernel.mean_function(X1,
                                                         self._bottom,
                                                         self._width,
                                                         self._fed_mu,
                                                         wind_velocity=self.
                                                         _wind_vector))(X1)

        cov = chunked_pmap(covariance_row,
                           X1,
                           batch_size=X1.x.shape[0],
                           chunksize=ncpu)

        plt.imshow(cov)
        plt.show()

        Z = random.normal(random.PRNGKey(42), (cov.shape[0], 1),
                          dtype=cov.dtype)

        t0 = default_timer()
        jitter = 1e-6
        logger.info(f"Computing Cholesky with jitter: {jitter}")
        L = jnp.linalg.cholesky(cov + jitter * jnp.eye(cov.shape[0]))
        if np.any(np.isnan(L)):
            logger.info("Numerically instable. Using SVD.")
            L = msqrt(cov)

        logger.info(f"Cholesky took {default_timer() - t0} seconds.")

        dtec = (L @ Z + mean[:, None])[:, 0].reshape((Na, Nd, Nt)).transpose(
            (1, 0, 2))

        logger.info(f"Saving result to {output_h5parm}")
        with dp:
            dp.current_solset = 'sol000'
            dp.select(pol=slice(0, 1, 1))
            dp.tec = np.asarray(dtec[None])
Exemplo n.º 15
0
def _one_hot(indices, depth):
  indices = np.asarray(indices)
  flat_indices = indices.reshape([-1])
  flat_ret = np.eye(depth)[flat_indices]
  return flat_ret.reshape(indices.shape + (depth,))
Exemplo n.º 16
0
def build_H_two_body(pairs, L, H=None, sxsx=True, sysy=True, szsz=True):
    Sx = np.array([[0., 1.], [1., 0.]])
    Sy = np.array([[0., -1j], [1j, 0.]])
    Sz = np.array([[1., 0.], [0., -1.]])

    # S = [Sx, Sy, Sz]
    # if H is None:
    #     H = np.sparse.csr_matrix((2 ** L, 2 ** L))
    # else:
    #     pass

    # for i, j, V in pairs:
    #     if i > j:
    #         i, j = j, i

    #     print("building", i, j)
    #     if sxsx:
    #         hx = scipy.sparse.kron(scipy.sparse.eye(2 ** (i - 1)), Sx)
    #         hx = scipy.sparse.kron(hx, scipy.sparse.eye(2 ** (j - i - 1)))
    #         hx = scipy.sparse.kron(hx, Sx)
    #         hx = scipy.sparse.kron(hx, scipy.sparse.eye(2 ** (L - j)))
    #         H = H + V * hx

    #     if sysy:
    #         hy = scipy.sparse.kron(scipy.sparse.eye(2 ** (i - 1)), Sy)
    #         hy = scipy.sparse.kron(hy, scipy.sparse.eye(2 ** (j - i - 1)))
    #         hy = scipy.sparse.kron(hy, Sy)
    #         hy = scipy.sparse.kron(hy, scipy.sparse.eye(2 ** (L - j)))
    #         H = H + V * hy

    #     if szsz:
    #         hz = scipy.sparse.kron(scipy.sparse.eye(2 ** (i - 1)), Sz)
    #         hz = scipy.sparse.kron(hz, scipy.sparse.eye(2 ** (j - i - 1)))
    #         hz = scipy.sparse.kron(hz, Sz)
    #         hz = scipy.sparse.kron(hz, scipy.sparse.eye(2 ** (L - j)))
    #         H = H + V * hz

    # H = scipy.sparse.csr_matrix(H)
    # return H

    # # S = [Sx, Sy, Sz]
    if H is None:
        H = np.zeros((2**L, 2**L))
    else:
        pass

    for i, j, V in pairs:
        if i > j:
            i, j = j, i

        print("building", i, j)
        if sxsx:
            hx = np.kron(np.eye(2**(i - 1)), Sx)
            hx = np.kron(hx, scipy.sparse.eye(2**(j - i - 1)))
            hx = np.kron(hx, Sx)
            hx = np.kron(hx, scipy.sparse.eye(2**(L - j)))
            H = H + V * hx

        if sysy:
            hy = np.kron(np.eye(2**(i - 1)), Sy)
            hy = np.kron(hy, np.eye(2**(j - i - 1)))
            hy = np.kron(hy, Sy)
            hy = np.kron(hy, np.eye(2**(L - j)))
            H = H + V * hy

        if szsz:
            hz = np.kron(np.eye(2**(i - 1)), Sz)
            hz = np.kron(hz, np.eye(2**(j - i - 1)))
            hz = np.kron(hz, Sz)
            hz = np.kron(hz, np.eye(2**(L - j)))
            H = H + V * hz
    return H
Exemplo n.º 17
0
def _add_diagonal_regularizer(A: np.ndarray, diag_reg: float,
                              diag_reg_absolute_scale: bool) -> np.ndarray:
    dimension = A.shape[0]
    if not diag_reg_absolute_scale:
        diag_reg *= np.trace(A) / dimension
    return A + diag_reg * np.eye(dimension)
Exemplo n.º 18
0
Arquivo: bfgs.py Projeto: wayfeng/jax
def minimize_bfgs(
    fun: Callable,
    x0: jnp.ndarray,
    maxiter: Optional[int] = None,
    norm=jnp.inf,
    gtol: float = 1e-5,
    line_search_maxiter: int = 10,
) -> _BFGSResults:
    """Minimize a function using BFGS.

  Implements the BFGS algorithm from
    Algorithm 6.1 from Wright and Nocedal, 'Numerical Optimization', 1999, pg.
    136-143.

  Args:
    fun: function of the form f(x) where x is a flat ndarray and returns a real
      scalar. The function should be composed of operations with vjp defined.
    x0: initial guess.
    maxiter: maximum number of iterations.
    norm: order of norm for convergence check. Default inf.
    gtol: terminates minimization when |grad|_norm < g_tol.
    line_search_maxiter: maximum number of linesearch iterations.

  Returns:
    Optimization result.
  """

    if maxiter is None:
        maxiter = jnp.size(x0) * 200

    d = x0.shape[0]

    initial_H = jnp.eye(d, dtype=x0.dtype)
    f_0, g_0 = jax.value_and_grad(fun)(x0)
    state = _BFGSResults(
        converged=jnp.linalg.norm(g_0, ord=norm) < gtol,
        failed=False,
        k=0,
        nfev=1,
        ngev=1,
        nhev=0,
        x_k=x0,
        f_k=f_0,
        g_k=g_0,
        H_k=initial_H,
        old_old_fval=f_0 + jnp.linalg.norm(g_0) / 2,
        status=0,
        line_search_status=0,
    )

    def cond_fun(state):
        return (jnp.logical_not(state.converged)
                & jnp.logical_not(state.failed)
                & (state.k < maxiter))

    def body_fun(state):
        p_k = -_dot(state.H_k, state.g_k)
        line_search_results = line_search(
            fun,
            state.x_k,
            p_k,
            old_fval=state.f_k,
            old_old_fval=state.old_old_fval,
            gfk=state.g_k,
            maxiter=line_search_maxiter,
        )
        state = state._replace(
            nfev=state.nfev + line_search_results.nfev,
            ngev=state.ngev + line_search_results.ngev,
            failed=line_search_results.failed,
            line_search_status=line_search_results.status,
        )
        s_k = line_search_results.a_k * p_k
        x_kp1 = state.x_k + s_k
        f_kp1 = line_search_results.f_k
        g_kp1 = line_search_results.g_k
        y_k = g_kp1 - state.g_k
        rho_k = jnp.reciprocal(_dot(y_k, s_k))

        sy_k = s_k[:, jnp.newaxis] * y_k[jnp.newaxis, :]
        w = jnp.eye(d) - rho_k * sy_k
        H_kp1 = (_einsum('ij,jk,lk', w, state.H_k, w) +
                 rho_k * s_k[:, jnp.newaxis] * s_k[jnp.newaxis, :])
        H_kp1 = jnp.where(jnp.isfinite(rho_k), H_kp1, state.H_k)
        converged = jnp.linalg.norm(g_kp1, ord=norm) < gtol

        state = state._replace(
            converged=converged,
            k=state.k + 1,
            x_k=x_kp1,
            f_k=f_kp1,
            g_k=g_kp1,
            H_k=H_kp1,
            old_old_fval=state.f_k,
        )
        return state

    state = lax.while_loop(cond_fun, body_fun, state)
    status = jnp.where(
        state.converged,
        0,  # converged
        jnp.where(
            state.k == maxiter,
            1,  # max iters reached
            jnp.where(
                state.failed,
                2 + state.line_search_status,  # ls failed (+ reason)
                -1,  # undefined
            )))
    state = state._replace(status=status)
    return state
Exemplo n.º 19
0
def one_hot(x, k):
  """Create a one-hot encoding of x of size k."""
  return jnp.eye(k)[x]
Exemplo n.º 20
0
def h(const, var):
    return const.t * K + const.lbd * RS - 0.5 * const.u * jnp.eye(hsize)
Exemplo n.º 21
0
    def __init__(self,
                 obs_interval,
                 num_steps_per_obs,
                 num_obs_per_subseq,
                 y_seq,
                 dim_z,
                 dim_x,
                 dim_v,
                 forward_func,
                 generate_x_0,
                 generate_z,
                 obs_func,
                 metric=None):
        """
        Args:
            obs_interval (float): Interobservation time interval.
            num_steps_per_obs (int): Number of discrete time steps to simulate
                between each observation time.
            num_obs_per_subseq (int): Average number of observations per
                partitioned subsequence. Must be a factor of `len(y_obs_seq)`.
            y_seq (array): Two-dimensional array containing observations at
                equally spaced time intervals, with first axis of array
                corresponding to observation time index (in order of increasing
                time) and second axis corresponding to dimension of each
                (vector-valued) observation.
            dim_z(int): Dimension of parameter vector `z`.
            dim_x (int): Dimension of state vector `x`.
            dim_v (int): Dimension of noise vector `v` consumed by
                `forward_func` to approximate time step.
            forward_func (Callable[[array, array, array, float], array]):
                Function implementing forward step of time-discretisation of
                diffusion such that `forward_func(z, x, v, δ)` for parameter
                vector `z`, current state `x` at time `t`, standard normal
                vector `v` and  small timestep `δ` and is distributed
                approximately according to `X(t + δ) | X(t) = x, Z = z`.
            generate_x_0 (Callable[[array, array], array]): Generator function
                for the initial state such that `generator_x_0(z, v_0)` for
                parameter vector `z` and standard normal vector `v_0` is
                distributed according to prior distribution on `X(0) | Z = z`.
            generate_z (Callable[[array], array]): Generator function
                for parameter vector such that `generator_z(u)` for standard
                normal vector `u` is distributed according to prior distribution
                on parameter vector `Z`.
            obs_func (Callable[[array], array]): Function mapping from state
                vector `x` at an observation time to the corresponding observed
                vector `y = obs_func(x)`.
            metric (Matrix): Metric matrix representation. Should be either an
                `mici.matrices.IdentityMatrix` or
                `mici.matrices.SymmetricBlockDiagonalMatrix` instance, with in
                the latter case the matrix having two blocks on the diagonal,
                the left most of size `dim_z x dim_z`, and the rightmost being
                positive diagonal. Defaults to `mici.matrices.IdentityMatrix`.
        """

        if metric is None or isinstance(metric, IdentityMatrix):
            metric_1 = np.eye(dim_z)
            log_det_sqrt_metric_1 = 0
        elif (isinstance(metric, SymmetricBlockDiagonalMatrix)
              and isinstance(metric.blocks[1], PositiveDiagonalMatrix)):
            metric_1 = metric.blocks[0].array
            log_det_sqrt_metric_1 = metric.blocks[0].log_abs_det_sqrt
            metric_2_diag = metric.blocks[1].diagonal
        else:
            raise NotImplementedError(
                'Only identity and block diagonal metrics with diagonal lower '
                'right block currently supported.')

        num_obs, dim_y = y_seq.shape
        δ = obs_interval / num_steps_per_obs
        dim_q = dim_z + dim_x + num_obs * dim_v * num_steps_per_obs
        if num_obs % num_obs_per_subseq != 0:
            raise NotImplementedError(
                'Only cases where num_obs_per_subseq is a factor of num_obs '
                'supported.')
        num_subseq = num_obs // num_obs_per_subseq
        obs_indices = slice(num_steps_per_obs - 1, None, num_steps_per_obs)
        num_step_per_subseq = num_obs_per_subseq * num_steps_per_obs
        y_subseqs_p0 = np.reshape(y_seq, (num_subseq, num_obs_per_subseq, -1))
        y_subseqs_p1 = split(
            y_seq, (num_obs_per_subseq // 2, num_obs - num_obs_per_subseq))
        y_subseqs_p1[1] = np.reshape(
            y_subseqs_p1[1], (num_subseq - 1, num_obs_per_subseq, dim_y))

        super().__init__(neg_log_dens=standard_normal_neg_log_dens,
                         grad_neg_log_dens=standard_normal_grad_neg_log_dens,
                         metric=metric)

        @api.jit
        def step_func(z, x, v):
            x_n = forward_func(z, x, v, δ)
            return (x_n, x_n)

        @api.jit
        def generate_x_obs_seq(q):
            u, v_0, v_seq_flat = split(q, (dim_z, dim_x))
            z = generate_z(u)
            x_0 = generate_x_0(z, v_0)
            v_seq = np.reshape(v_seq_flat, (-1, dim_v))
            _, x_seq = lax.scan(lambda x, v: step_func(z, x, v), x_0, v_seq)
            return x_seq[obs_indices]

        @api.partial(api.jit, static_argnums=(3, ))
        def partition_into_subseqs(v_seq, v_0, x_obs_seq, partition=0):
            """Partition noise increment and observation sequences.

            Partitition sequences in to either `num_subseq` equally sized
            subsequences (`partition == 0`)  or `num_subseq - 1` equally sized
            subsequences plus initial and final 'half' subsequences.
            """
            if partition == 0:
                v_subseqs = v_seq.reshape(
                    (num_subseq, num_step_per_subseq, dim_v))
                v_subseqs = (v_subseqs[0], v_subseqs[1:-1], v_subseqs[-1])
                x_obs_subseqs = x_obs_seq.reshape(
                    (num_subseq, num_obs_per_subseq, dim_x))
                w_inits = (v_0, x_obs_subseqs[:-2, -1], x_obs_subseqs[-2, -1])
                y_bars = (np.concatenate(
                    (y_subseqs_p0[0, :-1].flatten(), x_obs_subseqs[0, -1])),
                          np.concatenate((y_subseqs_p0[1:-1, :-1].reshape(
                              (num_subseq - 2, -1)), x_obs_subseqs[1:-1, -1]),
                                         -1), y_subseqs_p0[-1].flatten())
            else:
                v_subseqs = split(
                    v_seq, ((num_obs_per_subseq // 2) * num_steps_per_obs,
                            num_step_per_subseq * (num_subseq - 1)))
                v_subseqs[1] = v_subseqs[1].reshape(
                    (num_subseq - 1, num_step_per_subseq, dim_v))
                x_obs_subseqs = split(
                    x_obs_seq,
                    (num_obs_per_subseq // 2, num_obs - num_obs_per_subseq))
                x_obs_subseqs[1] = x_obs_subseqs[1].reshape(
                    (num_subseq - 1, num_obs_per_subseq, dim_x))
                w_inits = (v_0,
                           np.concatenate(
                               (x_obs_subseqs[0][-1:], x_obs_subseqs[1][:-1,
                                                                        -1]),
                               0), x_obs_subseqs[1][-1, -1])
                y_bars = (np.concatenate(
                    (y_subseqs_p1[0][:-1].flatten(), x_obs_subseqs[0][-1])),
                          np.concatenate((
                              y_subseqs_p1[1][:, :-1].reshape(
                                  (num_subseq - 1, -1)),
                              x_obs_subseqs[1][:, -1],
                          ), -1), y_subseqs_p1[2].flatten())
            return v_subseqs, w_inits, y_bars

        def generate_y_bar(z, w_0, v_seq, b):
            x_0 = generate_x_0(z, w_0) if b == 0 else w_0
            _, x_seq = lax.scan(lambda x, v: step_func(z, x, v), x_0, v_seq)
            y_seq = obs_func(x_seq[obs_indices])
            return y_seq.flatten() if b == 2 else np.concatenate(
                (y_seq[:-1].flatten(), x_seq[-1]))

        @api.partial(api.jit, static_argnums=(2, ))
        def constr(q, x_obs_seq, partition=0):
            """Calculate constraint function for current partition."""
            u, v_0, v_seq_flat = split(q, (
                dim_z,
                dim_x,
            ))
            v_seq = v_seq_flat.reshape((-1, dim_v))
            z = generate_z(u)
            (v_subseqs, w_inits,
             y_bars) = partition_into_subseqs(v_seq, v_0, x_obs_seq, partition)
            gen_funcs = (generate_y_bar,
                         api.vmap(generate_y_bar,
                                  (None, 0, 0, None)), generate_y_bar)
            return np.concatenate([
                (gen_funcs[b](z, w_inits[b], v_subseqs[b], b) -
                 y_bars[b]).flatten() for b in range(3)
            ])

        @api.jit
        def init_objective(q, x_obs_seq, reg_coeff):
            """Optimisation objective to find initial state on manifold."""
            u, v_0, v_seq_flat = split(q, (
                dim_z,
                dim_x,
            ))
            v_subseqs = v_seq_flat.reshape((num_obs, num_steps_per_obs, dim_v))
            z = generate_z(u)
            x_0 = generate_x_0(z, v_0)
            x_inits = np.concatenate((x_0[None], x_obs_seq[:-1]), 0)

            def generate_final_state(z, v_seq, x_0):
                _, x_seq = lax.scan(lambda x, v: step_func(z, x, v), x_0,
                                    v_seq)
                return x_seq[-1]

            c = api.vmap(generate_final_state, in_axes=(None, 0, 0))(
                z, v_subseqs, x_inits) - x_obs_seq
            return 0.5 * np.mean(c**2) + 0.5 * reg_coeff * np.mean(q**2), c

        @api.partial(api.jit, static_argnums=(2, ))
        def jacob_constr_blocks(q, x_obs_seq, partition=0):
            """Return non-zero blocks of constraint function Jacobian.

            Input state q can be decomposed into q = [u, v₀, v₁, v₂]
            where global latent state (parameters) are determined by u,
            initial subsequence by v₀, middle subsequences by v₁ and final
            subsequence by v₂.

            Constraint function can then be decomposed as

                c(q) = [c₀(u, v₀), c₁(u, v₁), c₂(u, v₂)]

            Constraint Jacobian ∂c(q) has block structure

                ∂c(q) = [[∂₀c₀(u, v₀), ∂₁c₀(u, v₀),     0,     ,     0      ]
                         [∂₀c₁(u, v₁),     0      , ∂₁c₁(u, v₁),     0      ]
                         [∂₀c₂(u, v₀),     0      ,     0      , ∂₁c₂(u, v₂)]]

            """
            def g_y_bar(u, v, w_0, b):
                z = generate_z(u)
                if b == 0:
                    w_0, v = split(v, (dim_x, ))
                v_seq = np.reshape(v, (-1, dim_v))
                return generate_y_bar(z, w_0, v_seq, b)

            u, v_0, v_seq_flat = split(q, (
                dim_z,
                dim_x,
            ))
            v_seq = np.reshape(v_seq_flat, (-1, dim_v))
            (v_subseqs, w_inits,
             y_bars) = partition_into_subseqs(v_seq, v_0, x_obs_seq, partition)
            v_bars = (np.concatenate([v_0, v_subseqs[0].flatten()]),
                      np.reshape(v_subseqs[1], (v_subseqs[1].shape[0], -1)),
                      v_subseqs[2].flatten())
            jac_g_y_bar = api.jacrev(g_y_bar, (0, 1))
            jacob_funcs = (jac_g_y_bar,
                           api.vmap(jac_g_y_bar,
                                    (None, 0, 0, None)), jac_g_y_bar)
            return tuple(
                zip(*[
                    jacob_funcs[b](u, v_bars[b], w_inits[b], b)
                    for b in range(3)
                ]))

        @api.jit
        def chol_gram_blocks(dc_du, dc_dv):
            """Calculate Cholesky factors of decomposition of Gram matrix. """
            if isinstance(metric, IdentityMatrix):
                D = tuple(
                    np.einsum('...ij,...kj', dc_dv[i], dc_dv[i])
                    for i in range(3))
            else:
                m_v = split(
                    metric_2_diag,
                    (dc_dv[0].shape[1], dc_dv[1].shape[0] * dc_dv[1].shape[2]))
                m_v[1] = m_v[1].reshape((dc_dv[1].shape[0], dc_dv[1].shape[2]))
                D = tuple(
                    np.einsum('...ij,...kj', dc_dv[i] /
                              m_v[i][..., None, :], dc_dv[i])
                    for i in range(3))
            chol_D = tuple(nla.cholesky(D[i]) for i in range(3))
            D_inv_dc_du = tuple(
                sla.cho_solve((chol_D[i], True), dc_du[i]) for i in range(3))
            chol_C = nla.cholesky(metric_1 + (
                dc_du[0].T @ D_inv_dc_du[0] +
                np.einsum('ijk,ijl->kl', dc_du[1], D_inv_dc_du[1]) +
                dc_du[2].T @ D_inv_dc_du[2]))
            return chol_C, chol_D

        @api.jit
        def log_det_sqrt_gram_from_chol(chol_C, chol_D):
            """Calculate log-det of Gram matrix from Cholesky factors."""
            return (sum(
                np.log(np.abs(chol_D[i].diagonal(0, -2, -1))).sum()
                for i in range(3)) + np.log(np.abs(chol_C.diagonal())).sum() -
                    log_det_sqrt_metric_1)

        @api.partial(api.jit, static_argnums=(2, ))
        def log_det_sqrt_gram(q, x_obs_seq, partition=0):
            """Calculate log-determinant of constraint Jacobian Gram matrix."""
            dc_du, dc_dv = jacob_constr_blocks(q, x_obs_seq, partition)
            chol_C, chol_D = chol_gram_blocks(dc_du, dc_dv)
            return (log_det_sqrt_gram_from_chol(chol_C, chol_D),
                    ((dc_du, dc_dv), (chol_C, chol_D)))

        @api.jit
        def lmult_by_jacob_constr(dc_du, dc_dv, vct):
            """Left-multiply vector by constraint Jacobian matrix."""
            vct_u, vct_v = split(vct, (dim_z, ))
            j0, j1, j2 = dc_dv[0].shape[1], dc_dv[1].shape[0], dc_dv[2].shape[
                1]
            return (np.vstack((dc_du[0], dc_du[1].reshape(
                (-1, dim_z)), dc_du[2])) @ vct_u + np.concatenate(
                    (dc_dv[0] @ vct_v[:j0],
                     np.einsum('ijk,ik->ij', dc_dv[1],
                               np.reshape(vct_v[j0:-j2], (j1, -1))).flatten(),
                     dc_dv[2] @ vct_v[-j2:])))

        @api.jit
        def rmult_by_jacob_constr(dc_du, dc_dv, vct):
            """Right-multiply vector by constraint Jacobian matrix."""
            vct_parts = split(
                vct,
                (dc_du[0].shape[0], dc_du[1].shape[0] * dc_du[1].shape[1]))
            vct_parts[1] = np.reshape(vct_parts[1], dc_du[1].shape[:2])
            return np.concatenate([
                vct_parts[0] @ dc_du[0] +
                np.einsum('ij,ijk->k', vct_parts[1], dc_du[1]) +
                vct_parts[2] @ dc_du[2], vct_parts[0] @ dc_dv[0],
                np.einsum('ij,ijk->ik', vct_parts[1],
                          dc_dv[1]).flatten(), vct_parts[2] @ dc_dv[2]
            ])

        @api.jit
        def lmult_by_inv_gram(dc_du, dc_dv, chol_C, chol_D, vct):
            """Left-multiply vector by inverse Gram matrix."""
            vct_parts = split(
                vct,
                (dc_du[0].shape[0], dc_du[1].shape[0] * dc_du[1].shape[1]))
            vct_parts[1] = np.reshape(vct_parts[1], dc_du[1].shape[:2])
            D_inv_vct = [
                sla.cho_solve((chol_D[i], True), vct_parts[i])
                for i in range(3)
            ]
            dc_du_T_D_inv_vct = sum(
                np.einsum('...jk,...j->k', dc_du[i], D_inv_vct[i])
                for i in range(3))
            C_inv_dc_du_T_D_inv_vct = sla.cho_solve((chol_C, True),
                                                    dc_du_T_D_inv_vct)
            return np.concatenate([
                sla.cho_solve((chol_D[i], True), vct_parts[i] -
                              dc_du[i] @ C_inv_dc_du_T_D_inv_vct).flatten()
                for i in range(3)
            ])

        @api.jit
        def normal_space_component(vct, dc_du, dc_dv, chol_C, chol_D):
            return rmult_by_jacob_constr(
                dc_du, dc_dv,
                lmult_by_inv_gram(dc_du, dc_dv, chol_C, chol_D,
                                  lmult_by_jacob_constr(dc_du, dc_dv, vct)))

        @api.partial(api.jit, static_argnums=(2, 7, 8, 9, 10))
        def quasi_newton_projection(q, x_obs_seq, partition, dc_du_prev,
                                    dc_dv_prev, chol_C_prev, chol_D_prev,
                                    convergence_tol, position_tol,
                                    divergence_tol, max_iters):

            norm = lambda x: np.max(np.abs(x))

            def body_func(val):
                q, i, _, _ = val
                c = constr(q, x_obs_seq, partition)
                error = norm(c)
                delta_q = rmult_by_jacob_constr(
                    dc_du_prev, dc_dv_prev,
                    lmult_by_inv_gram(dc_du_prev, dc_dv_prev, chol_C_prev,
                                      chol_D_prev, c))
                q -= delta_q
                i += 1
                return q, i, norm(delta_q), error

            def cond_func(val):
                q, i, norm_delta_q, error, = val
                diverged = np.logical_or(error > divergence_tol,
                                         np.isnan(error))
                converged = np.logical_and(error < convergence_tol,
                                           norm_delta_q < position_tol)
                return np.logical_not(
                    np.logical_or((i >= max_iters),
                                  np.logical_or(diverged, converged)))

            return lax.while_loop(cond_func, body_func, (q, 0, np.inf, -1.))

        self._generate_x_obs_seq = generate_x_obs_seq
        self._constr = constr
        self._jacob_constr_blocks = jacob_constr_blocks
        self._chol_gram_blocks = chol_gram_blocks
        self._log_det_sqrt_gram_from_chol = log_det_sqrt_gram_from_chol
        self._grad_log_det_sqrt_gram = api.jit(
            api.value_and_grad(log_det_sqrt_gram, has_aux=True), (2, ))
        self.value_and_grad_init_objective = api.jit(
            api.value_and_grad(init_objective, (0, ), has_aux=True))
        self._normal_space_component = normal_space_component
        self.quasi_newton_projection = quasi_newton_projection
Exemplo n.º 22
0
def inv(P):
    """
    Compute the inverse of a PSD matrix using the Cholesky factorisation
    """
    L = cho_factor(P)
    return cho_solve(L, np.eye(P.shape[0]))
Exemplo n.º 23
0
 def stationary_covariance(self):
     Pinf_mat = np.array([[self.variance]])
     Pinf = np.kron(Pinf_mat, np.eye(2))
     return Pinf
Exemplo n.º 24
0
 def dynamics(t, x, u):
   '''moves  the next standard basis vector'''
   idx = (position(x) + u[0]) % num_states
   return lax.dynamic_slice_in_dim(np.eye(num_states), idx, 1)[0]
Exemplo n.º 25
0
def gen_samples_A_L(keys_A_ii_0, keys_C_i, samples_nb, T, S_C_hat, C_hat,
                    X_T_X_inv):
    """
    Generate MC samples for the matrix form of
    :math:`A\left(L\right)y\left(t\right)` in the following structural VAR
    model using Algorithm 1 in Zha (1999):

    .. math::

        A\left(L\right)y\left(t\right)=\epsilon\left(t\right)

    Parameters
    -----------
    keys_A_ii_0 : list
        A list of PRNG keys, one for each block, used to sample
        :math:`A_{ii}\left(0\right)`.

    keys_C_i : list
        A list of PRNG keys, one for each block, used to sample :math:`C_{i}`.

    samples_nb : scalar(int)
        Number of samples to generate for each block.

    T : scalar(int)
        Length of the time series.

    S_C_hat : list
        A list of arrays containing
        :math:`S_{i}\left(\hat{\boldsymbol{C}}_{i}\right)`, one for each
        block. Each element must be of dimension :math:`m_{i}` by
        :math:`m_{i}`.

    C_hat : list
        A list of arrays containing :math:`\hat{\boldsymbol{C}}_{i}`, one for
        each block. Each element must be of dimension :math:`k_{i}` by
        :math:`m_{i}`.

    X_T_X_inv : list
        A list of arrays containing
        :math:`\left(X'X\right)^{-1}`, one for each block.
        Each element must be of dimension :math:`k_{i}` by :math:`k_{i}`.

    Returns
    -----------
    A_L : ndarray(float)
        An array containing MC samples for the matrix form of
        :math:`A\left(L\right)y\left(t\right)`.

    References
    -----------

    .. [1] Zha, Tao. 1999. Block recursion and structural vector
           autoregressions. Journal of Econometrics 90 (2):291–316.

    """

    n = len(S_C_hat)

    if n != len(keys_A_ii_0):
        raise ValueError("Number of keys for sampling A_ii_0 must match the" +
                         " length of S_C_hat.")

    if n != len(keys_C_i):
        raise ValueError("Number of keys for sampling C_i must match the" +
                         " length of S_C_hat.")

    if n != len(keys_C_i):
        raise ValueError("The length of C_hat must match the length of" +
                         " S_C_hat.")

    if n != len(keys_C_i):
        raise ValueError("The length of X_T_X_inv must match the length of" +
                         " S_C_hat.")

    A_L = []
    m_i_minus = 0

    for i in range(n):
        # Unpack parameters
        S_i_C_hat_i = S_C_hat[i]
        C_i_hat = C_hat[i]
        X_i_T_X_i_inv = X_T_X_inv[i]

        key_A_ii_0 = keys_A_ii_0[i]
        key_C_i = keys_C_i[i]

        # Check parameters' validity
        _check_valid_params(S_i_C_hat_i, C_i_hat, X_i_T_X_i_inv)

        # Step (a): draw A_ii_0
        A_ii_0, A_ii_0_T_A_ii_0 = gen_samples_A_ii_0(key_A_ii_0, samples_nb,
                                                     T, S_i_C_hat_i)

        # Step (b): draw C_i
        C_i = gen_samples_C_i(key_C_i, samples_nb, C_i_hat, A_ii_0_T_A_ii_0,
                              X_i_T_X_i_inv)

        # Modified step (c): compute A_i(L) (see design notes)
        m_i, k_i = S_i_C_hat_i.shape[0], X_i_T_X_i_inv.shape[0]

        A_L.append(A_ii_0 * (np.eye(m_i, M=k_i, k=m_i_minus) -
                             C_i.reshape((samples_nb, m_i, k_i))))

        m_i_minus += m_i

    return A_L
Exemplo n.º 26
0
 def precision_matrix(self):
     identity = np.broadcast_to(np.eye(self.scale_tril.shape[-1]),
                                self.scale_tril.shape)
     return cho_solve((self.scale_tril, True), identity)
Exemplo n.º 27
0
def matrix_inverse_pth_root(mat_g,
                            p,
                            iter_count=100,
                            error_tolerance=1e-6,
                            ridge_epsilon=1e-6):
  """Computes mat_g^(-1/p), where p is a positive integer.

  Coupled newton iterations for matrix inverse pth root.

  Args:
    mat_g: the symmetric PSD matrix whose power it to be computed
    p: exponent, for p a positive integer.
    iter_count: Maximum number of iterations.
    error_tolerance: Error indicator, useful for early termination.
    ridge_epsilon: Ridge epsilon added to make the matrix positive definite.

  Returns:
    mat_g^(-1/p)
  """
  mat_g_size = mat_g.shape[0]
  alpha = jnp.asarray(-1.0 / p, _INVERSE_PTH_ROOT_DATA_TYPE)
  identity = jnp.eye(mat_g_size, dtype=_INVERSE_PTH_ROOT_DATA_TYPE)
  _, max_ev, _ = power_iter(mat_g, mat_g.shape[0], 100)
  ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, 1e-16)

  def _unrolled_mat_pow_1(mat_m):
    """Computes mat_m^1."""
    return mat_m

  def _unrolled_mat_pow_2(mat_m):
    """Computes mat_m^2."""
    return jnp.matmul(mat_m, mat_m, precision=_INVERSE_PTH_ROOT_PRECISION)

  def _unrolled_mat_pow_4(mat_m):
    """Computes mat_m^4."""
    mat_pow_2 = _unrolled_mat_pow_2(mat_m)
    return jnp.matmul(
        mat_pow_2, mat_pow_2, precision=_INVERSE_PTH_ROOT_PRECISION)

  def _unrolled_mat_pow_8(mat_m):
    """Computes mat_m^4."""
    mat_pow_4 = _unrolled_mat_pow_4(mat_m)
    return jnp.matmul(
        mat_pow_4, mat_pow_4, precision=_INVERSE_PTH_ROOT_PRECISION)

  def mat_power(mat_m, p):
    """Computes mat_m^p, for p == 1, 2, 4 or 8.

    Args:
      mat_m: a square matrix
      p: a positive integer

    Returns:
      mat_m^p
    """
    # We unrolled the loop for performance reasons.
    exponent = jnp.round(jnp.log2(p))
    return lax.switch(
        jnp.asarray(exponent, jnp.int32), [
            _unrolled_mat_pow_1,
            _unrolled_mat_pow_2,
            _unrolled_mat_pow_4,
            _unrolled_mat_pow_8,
        ], (mat_m))

  def _iter_condition(state):
    (i, unused_mat_m, unused_mat_h, unused_old_mat_h, error,
     run_step) = state
    error_above_threshold = jnp.logical_and(
        error > error_tolerance, run_step)
    return jnp.logical_and(i < iter_count, error_above_threshold)

  def _iter_body(state):
    (i, mat_m, mat_h, unused_old_mat_h, error, unused_run_step) = state
    mat_m_i = (1 - alpha) * identity + alpha * mat_m
    new_mat_m = jnp.matmul(
        mat_power(mat_m_i, p), mat_m, precision=_INVERSE_PTH_ROOT_PRECISION)
    new_mat_h = jnp.matmul(
        mat_h, mat_m_i, precision=_INVERSE_PTH_ROOT_PRECISION)
    new_error = jnp.max(jnp.abs(new_mat_m - identity))
    # sometimes error increases after an iteration before decreasing and
    # converging. 1.2 factor is used to bound the maximal allowed increase.
    return (i + 1, new_mat_m, new_mat_h, mat_h, new_error,
            new_error < error * 1.2)

  if mat_g_size == 1:
    resultant_mat_h = (mat_g + ridge_epsilon)**alpha
    error = 0
  else:
    damped_mat_g = mat_g + ridge_epsilon * identity

    z = (1 + p) / (2 * jnp.linalg.norm(damped_mat_g))
    new_mat_m_0 = damped_mat_g * z
    new_error = jnp.max(jnp.abs(new_mat_m_0 - identity))
    new_mat_h_0 = identity * jnp.power(z, 1.0 / p)
    init_state = tuple(
        [0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, True])
    _, mat_m, mat_h, old_mat_h, error, convergence = lax.while_loop(
        _iter_condition, _iter_body, init_state)
    error = jnp.max(jnp.abs(mat_m - identity))
    is_converged = jnp.asarray(convergence, old_mat_h.dtype)
    resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h
    resultant_mat_h = jnp.asarray(resultant_mat_h, mat_g.dtype)
  return resultant_mat_h, error
Exemplo n.º 28
0
 def _d_a_part(self, t: float, param: np.ndarray) -> np.ndarray:
     p_z, p_check_z = self.projectors(t, param)
     a = self.a(t, param)
     identity = np.eye(self.nn)
     da_part = (identity - p_check_z) @ a @ (identity - p_z)
     return np.linalg.pinv(da_part)
Exemplo n.º 29
0
def kernel(X, Z, var, length, noise, jitter=1.0e-6, include_noise=True):
    deltaXsq = np.power((X[:, None] - Z) / length, 2.0)
    k = var * np.exp(-0.5 * deltaXsq)
    if include_noise:
        k += (noise + jitter) * np.eye(X.shape[0])
    return k
Exemplo n.º 30
0
def nonbonded_v3(conf, params, box, lamb, charge_rescale_mask, lj_rescale_mask,
                 scales, beta, cutoff, lambda_plane_idxs, lambda_offset_idxs):

    N = conf.shape[0]

    conf = convert_to_4d(conf, lamb, lambda_plane_idxs, lambda_offset_idxs,
                         cutoff)

    # make 4th dimension of box large enough so its roughly aperiodic
    if box is not None:
        box_4d = np.eye(4) * 1000
        box_4d = index_update(box_4d, index[:3, :3], box)
    else:
        box_4d = None

    box = box_4d

    charges = params[:, 0]
    sig = params[:, 1]
    eps = params[:, 2]

    sig_i = np.expand_dims(sig, 0)
    sig_j = np.expand_dims(sig, 1)
    sig_ij = sig_i + sig_j
    sig_ij_raw = sig_ij

    eps_i = np.expand_dims(eps, 0)
    eps_j = np.expand_dims(eps, 1)

    eps_ij = eps_i * eps_j

    ri = np.expand_dims(conf, 0)
    rj = np.expand_dims(conf, 1)
    dij = distance(ri, rj, box)

    N = conf.shape[0]
    keep_mask = np.ones((N, N)) - np.eye(N)
    keep_mask = np.where(eps_ij != 0, keep_mask, 0)

    if cutoff is not None:
        eps_ij = np.where(dij < cutoff, eps_ij, np.zeros_like(eps_ij))

    # (ytz): this avoids a nan in the gradient in both jax and tensorflow
    sig_ij = np.where(keep_mask, sig_ij, np.zeros_like(sig_ij))
    eps_ij = np.where(keep_mask, eps_ij, np.zeros_like(eps_ij))

    sig2 = sig_ij / dij
    sig2 *= sig2
    sig6 = sig2 * sig2 * sig2

    eij_lj = 4 * eps_ij * (sig6 - 1.0) * sig6
    eij_lj = np.where(keep_mask, eij_lj, np.zeros_like(eij_lj))

    qi = np.expand_dims(charges, 0)  # (1, N)
    qj = np.expand_dims(charges, 1)  # (N, 1)
    qij = np.multiply(qi, qj)

    # (ytz): trick used to avoid nans in the diagonal due to the 1/dij term.
    keep_mask = 1 - np.eye(conf.shape[0])
    qij = np.where(keep_mask, qij, np.zeros_like(qij))
    dij = np.where(keep_mask, dij, np.zeros_like(dij))

    # funny enough lim_{x->0} erfc(x)/x = 0
    eij_charge = np.where(keep_mask,
                          qij * erfc(beta * dij) / dij,
                          np.zeros_like(dij))  # zero out diagonals
    if cutoff is not None:
        eij_charge = np.where(dij > cutoff, np.zeros_like(eij_charge),
                              eij_charge)

    eij_total = (eij_lj * lj_rescale_mask + eij_charge * charge_rescale_mask)

    return np.sum(eij_total / 2)