def test_top_k_dist_warper(self):
        input_ids = None
        vocab_size = 10
        batch_size = 2

        # create ramp distribution
        ramp_logits = np.broadcast_to(np.arange(vocab_size)[None, :], (batch_size, vocab_size)).copy()
        ramp_logits[1:, : vocab_size // 2] = ramp_logits[1:, : vocab_size // 2] + vocab_size

        top_k_warp = FlaxTopKLogitsWarper(3)

        scores = top_k_warp(input_ids, ramp_logits, cur_len=None)

        # check that correct tokens are filtered
        self.assertListEqual(jnp.isinf(scores[0]).tolist(), 7 * [True] + 3 * [False])
        self.assertListEqual(jnp.isinf(scores[1]).tolist(), 2 * [True] + 3 * [False] + 5 * [True])

        # check special case
        length = 5
        top_k_warp_safety_check = FlaxTopKLogitsWarper(top_k=1, filter_value=0.0, min_tokens_to_keep=3)

        ramp_logits = np.broadcast_to(np.arange(length)[None, :], (batch_size, length)).copy()
        scores = top_k_warp_safety_check(input_ids, ramp_logits, cur_len=None)

        # min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified
        self.assertListEqual((scores == 0.0).sum(axis=-1).tolist(), [2, 2])
Ejemplo n.º 2
0
def _normalize(x, y):
    z = jnp.where(jnp.isinf(x), x, x + y)
    zz = jnp.where(
        lax.abs(x) > lax.abs(y),
        x - z + y,
        y - z + x,
    )
    return z, jnp.where(jnp.isinf(z), 0, zz)
Ejemplo n.º 3
0
def adam(gradval,
         params,
         eta=0.01,
         beta1=0.9,
         beta2=0.9,
         eps=1e-7,
         c=0.01,
         R=500,
         per=100,
         disp=None,
         log=False):
    params = {k: np.array(v) for k, v in params.items()}
    gavg = {k: np.zeros_like(v) for k, v in params.items()}
    grms = {k: np.zeros_like(v) for k, v in params.items()}
    n = len(params)

    if log:
        hist = {k: [] for k in params}

    # iterate to max
    for j in range(R + 1):
        val, grad = gradval(params)

        # test for nans
        vnan = np.isnan(val) or np.isinf(val)
        gnan = np.array(
            [np.isnan(g).any() or np.isinf(g).any() for g in grad.values()])
        if vnan or gnan.any():
            print('Encountered nan/inf!')
            disp(j, val, params)
            summary_naninf(grad)
            return params

        # clip gradients
        gradc = {k: clip2(grad[k], c) for k in params}

        # early bias correction
        etat = eta * (np.sqrt(1 - beta2**(j + 1))) / (1 - beta1**(j + 1))

        for k in params:
            gavg[k] += (1 - beta1) * (gradc[k] - gavg[k])
            grms[k] += (1 - beta2) * (gradc[k]**2 - grms[k])
            params[k] += etat * gavg[k] / (np.sqrt(grms[k]) + eps)

            if log:
                hist[k].append(params[k])

        if disp is not None and j % per == 0:
            disp(j, val, params)

    if log:
        hist = {k: np.stack(v) for k, v in hist.items()}
        return hist

    return params
Ejemplo n.º 4
0
def _contains_query(vals, query):
    if isinstance(query, tuple):
        return map(partial(_contains_query, vals), query)

    if np.isnan(query):
        if np.any(np.isnan(vals)):
            raise FoundValue('NaN')
    elif np.isinf(query):
        if np.any(np.isinf(vals)):
            raise FoundValue('Found Inf')
    elif np.isscalar(query):
        if np.any(vals == query):
            raise FoundValue(str(query))
    else:
        raise ValueError('Malformed Query: {}'.format(query))
Ejemplo n.º 5
0
def infer(model, state, unroll = 1, eta = 0.001, jit = True):
    init, update, params = adam(eta)
    opt = init(state.xs)
    window = collections.deque(maxlen=windowlen(5000, unroll))
    def step(rng, opt, t):
        keys = random.split(rng, unroll)
        for i in range(0, unroll):
            xs = params(opt)
            L, dxs = jax.value_and_grad(energy, argnums=1)(model, xs)
            opt = update(t+i, dxs, opt)
        return L, opt
    if jit:
        step = jax.jit(step)
    try:
        while True:
            state.rng, key = random.split(state.rng)
            L, opt = step(key, opt, state.t)
            if np.isnan(L) or np.isinf(L):
                raise Exception("Invalid loss")
            state.xs = params(opt)
            state.t += unroll
            window.append(-L.item())
            print("\r[ %e, %.2e iterations   " % (statistics.mean(window), state.t), end = "")
    except KeyboardInterrupt:
        pass
    return state
    def test_min_length_dist_processor(self):
        vocab_size = 20
        batch_size = 4
        eos_token_id = 0

        min_dist_processor = FlaxMinLengthLogitsProcessor(
            min_length=10, eos_token_id=eos_token_id)

        # check that min length is applied at length 5
        input_ids = ids_tensor((batch_size, 20), vocab_size=20)
        cur_len = 5
        scores = self._get_uniform_logits(batch_size, vocab_size)
        scores_before_min_length = min_dist_processor(input_ids,
                                                      scores,
                                                      cur_len=cur_len)
        self.assertListEqual(
            scores_before_min_length[:, eos_token_id].tolist(),
            4 * [-float("inf")])

        # check that min length is not applied anymore at length 15
        scores = self._get_uniform_logits(batch_size, vocab_size)
        cur_len = 15
        scores_before_min_length = min_dist_processor(input_ids,
                                                      scores,
                                                      cur_len=cur_len)
        self.assertFalse(jnp.isinf(scores_before_min_length).any())
def tree_count_infs_nans(tree, psum_axis_name=None):
    leaves = jax.tree_leaves(tree)
    num_infs = sum(jnp.sum(jnp.isinf(x)) for x in leaves)
    num_nans = sum(jnp.sum(jnp.isnan(x)) for x in leaves)
    if psum_axis_name:
        num_infs, num_nans = jax.lax.psum((num_infs, num_nans),
                                          axis_name=psum_axis_name)
    return num_infs, num_nans
Ejemplo n.º 8
0
def create_model(yT, yC, num_components):
    # Cosntants
    nC = yC.shape[0]
    nT = yT.shape[0]
    zC = jnp.isinf(yC).sum().item()
    zT = jnp.isinf(yT).sum().item()
    yT_finite = yT[jnp.isinf(yT) == False]
    yC_finite = yC_finite = yC[jnp.isinf(yC) == False]
    K = num_components
    
    p = numpyro.sample('p', dist.Beta(.5, .5))
    gammaC = numpyro.sample('gammaC', dist.Beta(1, 1))
    gammaT = numpyro.sample('gammaT', dist.Beta(1, 1))

    etaC = numpyro.sample('etaC', dist.Dirichlet(jnp.ones(K) / K))
    etaT = numpyro.sample('etaT', dist.Dirichlet(jnp.ones(K) / K))
    
    with numpyro.plate('mixutre_components', K):
        nu = numpyro.sample('nu', dist.LogNormal(3.5, 0.5))
        mu = numpyro.sample('mu', dist.Normal(0, 3))
        sigma = numpyro.sample('sigma', dist.LogNormal(0, .5))
        phi = numpyro.sample('phi', dist.Normal(0, 3))


    gammaT_star = simulate_data.compute_gamma_T_star(gammaC, gammaT, p)
    etaT_star = simulate_data.compute_eta_T_star(etaC, etaT, p, gammaC, gammaT,
                                                 gammaT_star)

    with numpyro.plate('y_C', nC - zC):
        numpyro.sample('finite_obs_C',
                       Mix(nu[None, :],
                           mu[None, :],
                           sigma[None, :],
                           phi[None, :],
                           etaC[None, :]), obs=yC_finite[:, None])

    with numpyro.plate('y_T', nT - zT):
        numpyro.sample('finite_obs_T',
                       Mix(nu[None, :],
                           mu[None, :],
                           sigma[None, :],
                           phi[None, :],
                           etaT_star[None, :]), obs=yT_finite[:, None])

    numpyro.sample('N_C', dist.Binomial(nC, gammaC), obs=zC)
    numpyro.sample('N_T', dist.Binomial(nT, gammaT_star), obs=zT)
Ejemplo n.º 9
0
def run_hmc_steps(theta, eps, Lmax, key, log_posterior,
                  log_posterior_grad_theta, diagonal_mass_matrix):
    # Diagonal mass matrix: diagonal entries of M (a vector)

    inverse_diag_mass = 1. / diagonal_mass_matrix

    key, subkey = random.split(key)

    # Location-scale transform to get the right variance
    # TODO: Check!
    phi = random.normal(
        subkey, shape=(theta.shape[0], )) * np.sqrt(diagonal_mass_matrix)

    start_theta = theta
    start_phi = phi

    cur_grad = log_posterior_grad_theta(theta)

    key, subkey = random.split(key)

    L = np_classic.random.randint(1, Lmax)

    for cur_l in range(L):
        phi = phi + 0.5 * eps * cur_grad
        theta = theta + eps * inverse_diag_mass * phi
        cur_grad = log_posterior_grad_theta(theta)
        phi = phi + 0.5 * eps * cur_grad

    # Compute (log) acceptance probability
    proposed_log_post = log_posterior(theta)
    previous_log_post = log_posterior(start_theta)

    proposed_log_phi = np.sum(
        norm.logpdf(phi, scale=np.sqrt(diagonal_mass_matrix)))
    previous_log_phi = np.sum(
        norm.logpdf(start_phi, scale=np.sqrt(diagonal_mass_matrix)))

    print(f'Proposed log posterior is: {proposed_log_post}.'
          f'Previous was {previous_log_post}.')

    if (np.isinf(proposed_log_post) or np.isnan(proposed_log_post)
            or np.isneginf(proposed_log_post)):
        # Reject
        was_accepted = False
        new_theta = start_theta
        # FIXME: What number to put here?
        log_r = -10
        return was_accepted, log_r, new_theta

    log_r = (proposed_log_post + proposed_log_phi - previous_log_post -
             previous_log_phi)

    was_accepted, new_theta = acceptance_step(log_r, theta, start_theta, key)

    return was_accepted, log_r, new_theta
Ejemplo n.º 10
0
    def __call__(self, controller_state, obs):
        state, t = obs.predicted_pressure, obs.time
        errs, waveform = controller_state.errs, controller_state.waveform
        fwd_targets = controller_state.fwd_targets
        target = waveform.at(t)
        fwd_t = t + self.fwd_history_len * DEFAULT_DT
        if self.fwd_history_len > 0:
            fwd_target = jax.lax.cond(fwd_t >= self.horizon * DEFAULT_DT,
                                      lambda x: fwd_targets[-1],
                                      lambda x: waveform.at(fwd_t), None)
        if self.normalize:
            target_normalized = self.p_normalizer(target).squeeze()
            state_normalized = self.p_normalizer(state).squeeze()
            next_errs = jnp.roll(errs, shift=-1)
            next_errs = next_errs.at[-1].set(target_normalized -
                                             state_normalized)
            if self.fwd_history_len > 0:
                fwd_target_normalized = self.p_normalizer(fwd_target).squeeze()
                next_fwd_targets = jnp.roll(fwd_targets, shift=-1)
                next_fwd_targets = next_fwd_targets.at[-1].set(
                    fwd_target_normalized)
            else:
                next_fwd_targets = jnp.array([])
        else:
            next_errs = jnp.roll(errs, shift=-1)
            next_errs = next_errs.at[-1].set(target - state)
            if self.fwd_history_len > 0:
                next_fwd_targets = jnp.roll(fwd_targets, shift=-1)
                next_fwd_targets = next_fwd_targets.at[-1].set(fwd_target)
            else:
                next_fwd_targets = jnp.array([])
        controller_state = controller_state.replace(
            errs=next_errs, fwd_targets=next_fwd_targets)
        decay = self.decay(waveform, t)

        def true_func(null_arg):
            trajectory = jnp.hstack([next_errs, next_fwd_targets])
            u_in = self.model_apply({"params": self.params}, trajectory)
            return u_in.squeeze().astype(jnp.float64)

        # changed decay compare from None to float(inf) due to cond requirements
        u_in = jax.lax.cond(jnp.isinf(decay), true_func,
                            lambda x: jnp.array(decay), None)
        u_in = jax.lax.clamp(0.0, u_in.astype(jnp.float64),
                             self.clip).squeeze()
        # update controller_state
        new_dt = jnp.max(
            jnp.array([DEFAULT_DT, t - proper_time(controller_state.time)]))
        new_time = t
        new_steps = controller_state.steps + 1
        controller_state = controller_state.replace(time=new_time,
                                                    steps=new_steps,
                                                    dt=new_dt)
        return controller_state, u_in
def log_normal_with_outliers(x, mean, cov, sigma):
    """
    Computes log-Normal density with outliers removed.

    Args:
        x: RV value
        mean: mean of Gaussian
        cov: covariance of underlying, minus the obs. covariance
        sigma: stddev's of obs. error, inf encodes an outlier.

    Returns: a normal density for all points not of inf stddev obs. error.
    """
    C = cov / (sigma[:, None] * sigma[None, :]) + jnp.eye(cov.shape[0])
    L = jnp.linalg.cholesky(C)
    Ls = sigma[:, None] * L
    log_det = jnp.sum(jnp.where(jnp.isinf(sigma), 0., jnp.log(jnp.diag(Ls))))
    dx = (x - mean)
    dx = solve_triangular(L, dx / sigma, lower=True)
    maha = dx @ dx
    log_likelihood = -0.5 * jnp.sum(~jnp.isinf(sigma)) * jnp.log(2. * jnp.pi) \
                     - log_det \
                     - 0.5 * maha
    return log_likelihood
Ejemplo n.º 12
0
    def __call__(self, controller_state, obs):
        state, t = obs.predicted_pressure, obs.time
        errs, waveform = controller_state.errs, controller_state.waveform
        target = waveform.at(t)
        if self.normalize:
            target_normalized = self.p_normalizer(target).squeeze()
            state_normalized = self.p_normalizer(state).squeeze()
            next_errs = jnp.roll(errs, shift=-1)
            next_errs = next_errs.at[-1].set(target_normalized -
                                             state_normalized)
        else:
            next_errs = jnp.roll(errs, shift=-1)
            next_errs = next_errs.at[-1].set(target - state)
        controller_state = controller_state.replace(errs=next_errs)
        decay = self.decay(waveform, t)

        def true_func(null_arg):
            trajectory = jnp.expand_dims(next_errs[-self.history_len:],
                                         axis=(0, 1))
            input_val = jnp.reshape((trajectory @ self.featurizer),
                                    (1, self.history_len, 1))
            u_in = self.model_apply({"params": self.params}, input_val)
            return u_in.squeeze().astype(jnp.float32)

        # changed decay compare from None to float(inf) due to cond requirements
        u_in = jax.lax.cond(jnp.isinf(decay), true_func,
                            lambda x: jnp.array(decay), None)
        # Implementing "leaky" clamp to solve the zero gradient problem
        if self.use_leaky_clamp:
            u_in = jax.lax.cond(u_in < 0.0, lambda x: x * 0.01, lambda x: x,
                                u_in)
            u_in = jax.lax.cond(u_in > self.clip,
                                lambda x: self.clip + x * 0.01, lambda x: x,
                                u_in)
        else:
            u_in = jax.lax.clamp(0.0, u_in.astype(jnp.float32),
                                 self.clip).squeeze()
        # update controller_state
        new_dt = jnp.max(
            jnp.array([DEFAULT_DT, t - proper_time(controller_state.time)]))
        new_time = t
        new_steps = controller_state.steps + 1
        controller_state = controller_state.replace(time=new_time,
                                                    steps=new_steps,
                                                    dt=new_dt)
        return controller_state, u_in
Ejemplo n.º 13
0
    def annotated_with_grad(flat_theta, summary):

        flat_theta = jnp.array(flat_theta)

        obj, grad = with_grad(flat_theta)

        print(obj, jnp.linalg.norm(grad))

        if jnp.isnan(obj) or jnp.isinf(obj) or jnp.any(jnp.isnan(grad)):
            import ipdb

            problem = reconstruct(flat_theta, summary, jnp.reshape)

            ipdb.set_trace()

        return np.array(obj).astype(np.float64), np.array(grad).astype(
            np.float64)
Ejemplo n.º 14
0
 def loss(params, obs):
     diags = model(params, obs)
     # obs, _, _, info = env.step(diags)
     # norm_res = jax.vmap(lambda x: jnp.linalg.norm(x[1, :], jnp.inf))(obs)
     # return jnp.mean(norm_res)
     _, _, _, info = env.step(diags)
     # assert jnp.allclose(info['residual'] * info['niter'],
     #                     jax.vmap(lambda x, y: x * y)(
     #                         info['residual'], info['niter']))
     norm_res = jnp.mean(info['residual'] * info['niter'])
     # norm_res = jnp.mean(jnp.array(list(map(
     #     lambda i: info[0]['residual' + str(i)],
     #     range(args.batch_size)
     # ))))
     if jnp.isnan(norm_res):
         raise ValueError('encountered NaNs')
     if jnp.isinf(norm_res):
         raise ValueError('encountered infs')
     return norm_res
    def test_forced_bos_token_logits_processor(self):
        vocab_size = 20
        batch_size = 4
        bos_token_id = 0

        logits_processor = FlaxForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id)

        # check that all scores are -inf except the bos_token_id score
        input_ids = ids_tensor((batch_size, 1), vocab_size=20)
        cur_len = 1
        scores = self._get_uniform_logits(batch_size, vocab_size)
        scores = logits_processor(input_ids, scores, cur_len=cur_len)
        self.assertTrue(jnp.isneginf(scores[:, bos_token_id + 1 :]).all())
        self.assertListEqual(scores[:, bos_token_id].tolist(), 4 * [0])  # score for bos_token_id shold be zero

        # check that bos_token_id is not forced if current length is greater than 1
        cur_len = 3
        scores = self._get_uniform_logits(batch_size, vocab_size)
        scores = logits_processor(input_ids, scores, cur_len=cur_len)
        self.assertFalse(jnp.isinf(scores).any())
    def test_forced_eos_token_logits_processor(self):
        vocab_size = 20
        batch_size = 4
        eos_token_id = 0
        max_length = 5

        logits_processor = FlaxForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)

        # check that all scores are -inf except the eos_token_id when max_length is reached
        input_ids = ids_tensor((batch_size, 4), vocab_size=20)
        cur_len = 4
        scores = self._get_uniform_logits(batch_size, vocab_size)
        scores = logits_processor(input_ids, scores, cur_len=cur_len)
        self.assertTrue(jnp.isneginf(scores[:, eos_token_id + 1 :]).all())
        self.assertListEqual(scores[:, eos_token_id].tolist(), 4 * [0])  # score for eos_token_id should be zero

        # check that eos_token_id is not forced if max_length is not reached
        cur_len = 3
        scores = self._get_uniform_logits(batch_size, vocab_size)
        scores = logits_processor(input_ids, scores, cur_len=cur_len)
        self.assertFalse(jnp.isinf(scores).any())
Ejemplo n.º 17
0
    def metric(points, centers, K):
        # N
        cluster_id = masked_cluster_id(points, centers, K)
        # N,N
        dist = jnp.linalg.norm(points[:, None, :] - points[None, :, :], axis=-1)
        # N,N
        in_group = cluster_id[:, None] == cluster_id[None, :]
        in_group_dist, w = jnp.average(dist, weights=in_group, axis=-1, returned=True)
        in_group_dist *= w / (w - 1.)

        # max_K, N, N
        out_group = (~in_group) & (jnp.arange(max_K)[:, None, None] == cluster_id[None, None, :])
        # max_K, N
        out_group_dist = jnp.sum(dist * out_group, axis=-1) / jnp.sum(out_group, axis=-1)
        out_group_dist = jnp.where(jnp.isnan(out_group_dist), jnp.inf, out_group_dist)
        # N
        out_group_dist = jnp.min(out_group_dist, axis=0)
        out_group_dist = jnp.where(jnp.isinf(out_group_dist), jnp.max(in_group_dist), out_group_dist)
        sillohette = (out_group_dist - in_group_dist) / jnp.maximum(in_group_dist, out_group_dist)
        # condition for pos def cov
        sillohette = jnp.where(w < points.shape[1], -jnp.inf, sillohette)
        return jnp.mean(sillohette), cluster_id
Ejemplo n.º 18
0
    def log_prob(self, value: Array) -> Array:
        """Calculates the log probability of an event.

    This implementation differs slightly from the one in TFP, as it returns
    `-jnp.inf` on non-integer values instead of returning the log prob of the
    floor of the input. In addition, this implementation also returns `-jnp.inf`
    on inputs that are outside the support of the distribution (as opposed to
    `nan`, like TFP does). On other integer values, both implementations are
    identical.

    Args:
      value: An event.

    Returns:
      The log probability log P(value).
    """
        is_integer = jnp.where(value > jnp.floor(value), False, True)
        log_cdf = self.log_cdf(value)
        log_cdf_m1 = self.log_cdf(value - 1.)
        log_probs = math.log_expbig_minus_expsmall(log_cdf, log_cdf_m1)
        return jnp.where(jnp.isinf(log_cdf), -jnp.inf,
                         jnp.where(is_integer, log_probs, -jnp.inf))
Ejemplo n.º 19
0
Archivo: jax.py Proyecto: yibit/eagerpy
 def isinf(self: TensorType) -> TensorType:
     return type(self)(np.isinf(self.raw))
Ejemplo n.º 20
0
 def to_array(self, dtype=None):
     head, tail = self._tup
     if dtype is not None:
         head = head.astype(dtype)
         tail = tail.astype(dtype)
     return head + jnp.where(jnp.isinf(head), 0, tail)
Ejemplo n.º 21
0
imag = utils.copy_docstring(tf.math.imag,
                            lambda input, name=None: np.imag(input))

# in_top_k = utils.copy_docstring(
#     tf.math.in_top_k,
#     lambda targets, predictions, k, name=None: np.in_top_k)

# TODO(b/256095991): Add unit-test.
invert_permutation = utils.copy_docstring(tf.math.invert_permutation,
                                          lambda x, name=None: np.argsort)

is_finite = utils.copy_docstring(tf.math.is_finite,
                                 lambda x, name=None: np.isfinite(x))

is_inf = utils.copy_docstring(tf.math.is_inf, lambda x, name=None: np.isinf(x))

is_nan = utils.copy_docstring(tf.math.is_nan, lambda x, name=None: np.isnan(x))

is_non_decreasing = utils.copy_docstring(
    tf.math.is_non_decreasing, lambda x, name=None: np.all(x[1:] >= x[:-1]))

is_strictly_increasing = utils.copy_docstring(
    tf.math.is_strictly_increasing,
    lambda x, name=None: np.all(x[1:] > x[:-1]))

l2_normalize = utils.copy_docstring(
    tf.math.l2_normalize,
    lambda x, axis=None, epsilon=1e-12, name=None: (  # pylint: disable=g-long-lambda
        np.linalg.norm(x, ord=2, axis=axis, keepdims=True)))
Ejemplo n.º 22
0
def solve_and_smooth(gain_outliers, phase_obs, times, freqs):
    logger.info("Performing solve for tec and const from phases.")
    Nd, Na, Nf, Nt = phase_obs.shape

    logger.info("Number of nan: {}".format(jnp.sum(jnp.isnan(phase_obs))))
    logger.info("Number of inf: {}".format(jnp.sum(jnp.isinf(phase_obs))))

    # blocksize chosen to maximise Fisher information, which is 2 for tec+const, and 3 for tec+const+clock
    blocksize = 2

    remainder = Nt % blocksize
    if remainder != 0:
        if remainder < Nt:
            raise ValueError(
                f"Block size {blocksize} too big for number of timesteps {Nt}."
            )
        (gain_outliers, phase_obs) = tree_map(
            lambda x: jnp.concatenate(
                [x, jnp.repeat(x[..., -1:], remainder, axis=-1)], axis=-1),
            (gain_outliers, phase_obs))
        Nt = Nt + remainder
        times = jnp.concatenate([
            times, times[-1] +
            jnp.arange(1, 1 + remainder) * jnp.mean(jnp.diff(times))
        ])

    size_dict = dict(d=Nd, a=Na, f=Nf, b=blocksize)

    # [Nd*Na*(Nt//blocksize), blocksize, Nf]
    gain_outliers = axes_move(gain_outliers, ['d', 'a', 'f', 'tb'],
                              ['dat', 'b', 'f'],
                              size_dict=size_dict)
    phase_obs = axes_move(phase_obs, ['d', 'a', 'f', 'tb'], ['dat', 'b', 'f'],
                          size_dict=size_dict)

    T = Nd * Na * (Nt // blocksize)  # Nd * Na * (Nt // blocksize)
    keys = random.split(random.PRNGKey(int(1000 * default_timer())), T)

    # [Nd*Na*(Nt//blocksize), blocksize], [# Nd*Na*(Nt//blocksize), blocksize]
    tec_mean, tec_std, const_mean, const_std, uncert_mean = chunked_pmap(
        lambda *args: unconstrained_solve(freqs, *args), keys, phase_obs,
        gain_outliers)  # Nd*Na*(Nt//blocksize), blocksize

    const_weights = 1. / const_std**2

    def smooth(y, weights):
        y = axes_move(y, ['dat', 'b'], ['da', 'tb'], size_dict=size_dict)
        weights = axes_move(weights, ['dat', 'b'], ['da', 'tb'],
                            size_dict=size_dict)
        y = chunked_pmap(
            lambda y, weights: poly_smooth(times, y, deg=5, weights=weights),
            y, weights)
        y = axes_move(y, ['da', 'tb'], ['dat', 'b'], size_dict=size_dict)
        return y

    logger.info("Smoothing and outlier rejection of const (a weak prior).")
    # Nd,Na,Nt/blocksize, blocksize
    const_real_mean = smooth(jnp.cos(const_mean),
                             const_weights)  # Nd*Na*(Nt//blocksize), blocksize
    const_imag_mean = smooth(jnp.sin(const_mean),
                             const_weights)  # Nd*Na*(Nt//blocksize), blocksize
    const_mean_smoothed = jnp.arctan2(
        const_imag_mean, const_real_mean)  # Nd*Na*(Nt//blocksize), blocksize

    # empirically determined uncertainty point where sigma(tec - tec_true) > 6 mTECU
    which_reprocess = jnp.any(uncert_mean > 0.,
                              axis=1)  # Nd*Na*(Nt//blocksize)
    replace_map = jnp.where(which_reprocess)

    logger.info("Performing refined tec-only solve, with fixed const.")
    keys = random.split(random.PRNGKey(int(1000 * default_timer())),
                        jnp.sum(which_reprocess))
    # [Nd*Na*(Nt//blocksize), blocksize]
    (tec_mean_constrained, tec_std_constrained, const_mean_constrained, const_std_constrained) = \
        chunked_pmap(lambda *args: constrained_solve(freqs, *args),
                     keys,
                     phase_obs[which_reprocess],
                     gain_outliers[which_reprocess],
                     const_mean_smoothed[which_reprocess],
                     const_std[which_reprocess]
                     )
    tec_mean = tec_mean.at[replace_map].set(tec_mean_constrained)
    tec_std = tec_std.at[replace_map].set(tec_std_constrained)
    const_std = const_std.at[replace_map].set(const_std_constrained)
    const_mean = const_mean.at[replace_map].set(const_mean_constrained)

    (tec_mean, tec_std, const_mean,
     const_std) = tree_map(lambda x: x.reshape((Nd, Na, Nt)),
                           (tec_mean, tec_std, const_mean, const_std))

    # Nd, Na, Nt
    logger.info("Performing outlier detection on tec values.")
    tec_est, tec_outliers = detect_tec_outliers(times, tec_mean, tec_std)
    tec_std = jnp.where(tec_outliers, jnp.inf, tec_std)

    # remove remainder at the end
    if remainder != 0:
        (tec_mean, tec_std, const_mean,
         const_std) = tree_map(lambda x: x[..., :Nt - remainder],
                               (tec_mean, tec_std, const_mean, const_std))

    # compute phase mean with outlier-suppressed tec.
    phase_mean = tec_mean[..., None, :] * (
        TEC_CONV / freqs[:, None]) + const_mean[..., None, :]
    phase_uncert = jnp.sqrt((tec_std[..., None, :] *
                             (TEC_CONV / freqs[:, None]))**2 +
                            (const_std[..., None, :])**2)

    return phase_mean, phase_uncert, tec_mean, tec_std, tec_outliers, const_mean, const_std
Ejemplo n.º 23
0
def summary_naninf(d):
    for k, v in d.items():
        if (na := np.isnan(v)).any():
            print(f'{k} nan: {na.sum()}/{na.size}')
        if (nf := np.isinf(v)).any():
            print(f'{k} inf: {nf.sum()}/{nf.size}')
Ejemplo n.º 24
0
            def single_acceptance(args):
                """Draws a proposal, simulates and compresses, checks distance

                A new proposal is drawn from a truncated multivariate normal
                distribution whose mean is centred on the parameter to move and
                the covariance is set by the population. From this proposed
                parameter value a simulation is made and compressed and the
                distance from the target is calculated. If this distance is
                less than the current position then the proposal is accepted.

                Parameters
                ----------
                args : tuple
                    see loop variable in `single_iteration`

                Returns
                -------
                bool:
                    True if proposal not accepted and number of attempts to get
                    an accepted proposal not yet reached

                Todo
                ----
                Parallel sampling is currently commented out
                """
                (rng, loc, scale, summ, dis, draws, accepted,
                 acceptance_counter) = args
                rng, key = jax.random.split(rng)
                proposed, summaries = self.get_samples(
                    key,
                    None,
                    dist=tmvn(loc,
                              scale,
                              self.prior.low,
                              self.prior.high,
                              max_counter=max_samples))
                distances = np.squeeze(
                    self.distance_measure(np.expand_dims(summaries, 0), target,
                                          F))
                # if n_parallel_simulations is not None:
                #     min_distance_index = np.argmin(distances)
                #     min_distance = distances[min_distance_index]
                #     closer = np.less(min_distance, ϵ)
                #     loc = jax.lax.cond(
                #         closer,
                #         lambda _ : proposed[min_distance_index],
                #         lambda _ : loc,
                #         None)
                #     summ = jax.lax.cond(
                #         closer,
                #         lambda _ : summaries[min_distance_index],
                #         lambda _ : summ,
                #         None)
                #     dis = jax.lax.cond(
                #         closer,
                #         lambda _ : distances[min_distance_index],
                #         lambda _ : dis,
                #         None)
                #     iteration_draws = n_parallel_simulations \
                #         - np.isinf(distances).sum()
                #     draws += iteration_draws
                #     accepted = closer.sum()
                # else:
                closer = np.less(distances, np.min(dis))
                loc = jax.lax.cond(closer, lambda _: proposed, lambda _: loc,
                                   None)
                summ = jax.lax.cond(closer, lambda _: summaries,
                                    lambda _: summ, None)
                dis = jax.lax.cond(closer, lambda _: distances, lambda _: dis,
                                   None)
                iteration_draws = 1 - np.isinf(distances).sum()
                draws += iteration_draws
                accepted = closer.sum()
                return (rng, loc, scale, summ, dis, draws, accepted,
                        acceptance_counter + 1)
Ejemplo n.º 25
0
def main(data_dir, working_dir, obs_num, ncpu, plot_results):
    os.environ['XLA_FLAGS'] = f"--xla_force_host_platform_device_count={ncpu}"
    logger.info("Performing data smoothing via tec+const+clock inference.")
    dds4_h5parm = os.path.join(data_dir,
                               'L{}_DDS4_full_merged.h5'.format(obs_num))
    dds5_h5parm = os.path.join(working_dir,
                               'L{}_DDS5_full_merged.h5'.format(obs_num))
    linked_dds5_h5parm = os.path.join(
        data_dir, 'L{}_DDS5_full_merged.h5'.format(obs_num))
    logger.info("Looking for {}".format(dds4_h5parm))
    link_overwrite(dds5_h5parm, linked_dds5_h5parm)
    prepare_soltabs(dds4_h5parm, dds5_h5parm)
    gain_outliers, phase_obs, amp, times, freqs = get_data(
        solution_file=dds4_h5parm)
    phase_mean, phase_uncert, tec_mean, tec_std, tec_outliers, const_mean, const_std = \
        solve_and_smooth(gain_outliers, phase_obs, times, freqs)
    # exit(0)
    logger.info("Storing smoothed phase, amplitudes, tec, const, and clock")
    with DataPack(dds5_h5parm, readonly=False) as h:
        h.current_solset = 'sol000'
        # h.select(pol=slice(0, 1, 1), ant=slice(50, 51), dir=slice(0, None, 1), time=slice(0, 100, 1))
        h.select(pol=slice(0, 1, 1))
        h.phase = np.asarray(phase_mean)[None, ...]
        h.weights_phase = np.asarray(phase_uncert)[None, ...]
        h.amplitude = np.asarray(amp)[None, ...]
        h.tec = np.asarray(tec_mean)[None, ...]
        h.tec_outliers = np.asarray(tec_outliers)[None, ...]
        h.weights_tec = np.asarray(tec_std)[None, ...]
        h.const = np.asarray(const_mean)[None, ...]
        axes = h.axes_phase
        patch_names, _ = h.get_directions(axes['dir'])
        antenna_labels, _ = h.get_antennas(axes['ant'])

    if plot_results:

        diagnostic_data_dir = os.path.join(working_dir, 'diagnostic')
        os.makedirs(diagnostic_data_dir, exist_ok=True)

        logger.info("Plotting results.")
        data_plot_dir = os.path.join(working_dir, 'data_plots')
        os.makedirs(data_plot_dir, exist_ok=True)
        Nd, Na, Nf, Nt = phase_mean.shape
        for ia in range(Na):
            for id in range(Nd):
                fig, axs = plt.subplots(3, 1, sharex=True)
                axs[0].plot(times, tec_mean[id, ia, :], c='black', label='tec')
                ylim = axs[0].get_ylim()
                axs[0].vlines(times[tec_outliers[id, ia, :]],
                              *ylim,
                              colors='red',
                              label='outliers',
                              alpha=0.5)
                axs[0].set_ylim(*ylim)

                axs[1].plot(times,
                            const_mean[id, ia, :],
                            c='black',
                            label='const')
                axs[1].fill_between(
                    times,
                    const_mean[id, ia, :] - const_std[id, ia, :],
                    const_mean[id, ia, :] + const_std[id, ia, :],
                    color='black',
                    alpha=0.2)
                ylim = axs[1].get_ylim()
                axs[1].vlines(times[tec_outliers[id, ia, :]],
                              *ylim,
                              colors='red',
                              label='outliers',
                              alpha=0.5)
                axs[1].set_ylim(*ylim)

                axs[2].plot(times,
                            tec_std[id, ia, :],
                            c='black',
                            label='tec_std')
                ylim = axs[2].get_ylim()
                axs[2].vlines(times[tec_outliers[id, ia, :]],
                              *ylim,
                              colors='red',
                              label='outliers',
                              alpha=0.5)
                axs[2].set_ylim(*ylim)

                axs[0].legend()
                axs[1].legend()
                axs[2].legend()

                axs[0].set_ylabel("DTEC [mTECU]")
                axs[1].set_ylabel("const [rad]")
                axs[2].set_ylabel("DTEC uncert [mTECU]")
                axs[2].set_xlabel("time [s]")

                fig.savefig(
                    os.path.join(
                        data_plot_dir,
                        'solutions_ant{:02d}_dir{:02d}.png'.format(ia, id)))
                plt.close("all")

                fig, axs = plt.subplots(4, 1, sharex=True, sharey=True)
                # phase data with input outliers
                # phase posterior with tec outliers
                # dphase with no outliers
                # phase uncertainty

                axs[0].imshow(phase_obs[id, ia, :, :],
                              vmin=-jnp.pi,
                              vmax=jnp.pi,
                              cmap='twilight',
                              aspect='auto',
                              origin='lower',
                              interpolation='nearest')
                axs[0].imshow(jnp.where(gain_outliers[id, ia, :, :], 1.,
                                        jnp.nan),
                              vmin=0.,
                              vmax=1.,
                              cmap='bone',
                              aspect='auto',
                              origin='lower',
                              interpolation='nearest')
                add_colorbar_to_axes(axs[0],
                                     "twilight",
                                     vmin=-jnp.pi,
                                     vmax=jnp.pi)

                axs[1].imshow(phase_mean[id, ia, :, :],
                              vmin=-jnp.pi,
                              vmax=jnp.pi,
                              cmap='twilight',
                              aspect='auto',
                              origin='lower',
                              interpolation='nearest')
                axs[1].imshow(jnp.where(jnp.isinf(phase_uncert[id, ia, :, :]),
                                        1., jnp.nan),
                              vmin=0.,
                              vmax=1.,
                              cmap='bone',
                              aspect='auto',
                              origin='lower',
                              interpolation='nearest')
                add_colorbar_to_axes(axs[1],
                                     "twilight",
                                     vmin=-jnp.pi,
                                     vmax=jnp.pi)

                dphase = wrap(wrap(phase_mean) - phase_obs)
                vmin = -0.5
                vmax = 0.5

                axs[2].imshow(dphase[id, ia, :, :],
                              vmin=vmin,
                              vmax=vmax,
                              cmap='PuOr',
                              aspect='auto',
                              origin='lower',
                              interpolation='nearest')
                add_colorbar_to_axes(axs[2], "PuOr", vmin=vmin, vmax=vmax)

                vmin = 0.
                vmax = 0.8

                axs[3].imshow(phase_uncert[id, ia, :, :],
                              vmin=vmin,
                              vmax=vmax,
                              cmap='PuOr',
                              aspect='auto',
                              origin='lower',
                              interpolation='nearest')
                add_colorbar_to_axes(axs[3], "PuOr", vmin=vmin, vmax=vmax)

                axs[0].set_ylabel("freq [MHz]")
                axs[1].set_ylabel("freq [MHz]")
                axs[2].set_ylabel("freq [MHz]")
                axs[3].set_ylabel("freq [MHz]")
                axs[3].set_xlabel("time [s]")

                axs[0].set_title("phase data [rad]")
                axs[1].set_title("phase model [rad]")
                axs[2].set_title("phase diff. [rad]")
                axs[3].set_title("phase uncert [rad]")

                fig.savefig(
                    os.path.join(
                        data_plot_dir,
                        'data_comparison_ant{:02d}_dir{:02d}.png'.format(
                            ia, id)))
                plt.close("all")
        # exit(0)

        d = os.path.join(working_dir, 'tec_plots')
        animate_datapack(dds5_h5parm,
                         d,
                         num_processes=(ncpu * 2) // 3,
                         vmin=-60,
                         vmax=60.,
                         observable='tec',
                         phase_wrap=False,
                         plot_crosses=False,
                         plot_facet_idx=True,
                         labels_in_radec=True,
                         per_timestep_scale=True,
                         solset='sol000',
                         cmap=plt.cm.PuOr)
        # os.makedirs(d, exist_ok=True)
        # DatapackPlotter(dds5_h5parm).plot(
        #     fignames=[os.path.join(d, "fig-{:04d}.png".format(j)) for j in range(Nt)],
        #     vmin=-60,
        #     vmax=60., observable='tec', phase_wrap=False, plot_crosses=False,
        #     plot_facet_idx=True, labels_in_radec=True, per_timestep_scale=True,
        #     solset='sol000', cmap=plt.cm.PuOr)
        # make_animation(d, prefix='fig', fps=4)

        d = os.path.join(working_dir, 'const_plots')
        animate_datapack(dds5_h5parm,
                         d,
                         num_processes=(ncpu * 2) // 3,
                         vmin=-np.pi,
                         vmax=np.pi,
                         observable='const',
                         phase_wrap=False,
                         plot_crosses=False,
                         plot_facet_idx=True,
                         labels_in_radec=True,
                         per_timestep_scale=True,
                         solset='sol000',
                         cmap=plt.cm.PuOr)

        # os.makedirs(d, exist_ok=True)
        # DatapackPlotter(dds5_h5parm).plot(
        #     fignames=[os.path.join(d, "fig-{:04d}.png".format(j)) for j in range(Nt)],
        #     vmin=-np.pi,
        #     vmax=np.pi, observable='const', phase_wrap=False, plot_crosses=False,
        #     plot_facet_idx=True, labels_in_radec=True, per_timestep_scale=True,
        #     solset='sol000', cmap=plt.cm.PuOr)
        # make_animation(d, prefix='fig', fps=4)

        d = os.path.join(working_dir, 'clock_plots')
        animate_datapack(dds5_h5parm,
                         d,
                         num_processes=(ncpu * 2) // 3,
                         vmin=None,
                         vmax=None,
                         observable='clock',
                         phase_wrap=False,
                         plot_crosses=False,
                         plot_facet_idx=True,
                         labels_in_radec=True,
                         per_timestep_scale=True,
                         solset='sol000',
                         cmap=plt.cm.PuOr)

        # os.makedirs(d, exist_ok=True)
        # DatapackPlotter(dds5_h5parm).plot(
        #     fignames=[os.path.join(d, "fig-{:04d}.png".format(j)) for j in range(Nt)],
        #     vmin=None,
        #     vmax=None,
        #     observable='clock', phase_wrap=False, plot_crosses=False,
        #     plot_facet_idx=True, labels_in_radec=True, per_timestep_scale=True,
        #     solset='sol000', cmap=plt.cm.PuOr)
        # make_animation(d, prefix='fig', fps=4)

        d = os.path.join(working_dir, 'amplitude_plots')
        animate_datapack(dds5_h5parm,
                         d,
                         num_processes=(ncpu * 2) // 3,
                         log_scale=True,
                         observable='amplitude',
                         phase_wrap=False,
                         plot_crosses=False,
                         plot_facet_idx=True,
                         labels_in_radec=True,
                         per_timestep_scale=True,
                         solset='sol000',
                         cmap=plt.cm.PuOr)
Ejemplo n.º 26
0
def has_inf_or_nan(x):
    return jnp.isinf(x).any() or jnp.isnan(x).any()
Ejemplo n.º 27
0
 def isinf(self, a):
     return jnp.isinf(a)
Ejemplo n.º 28
0
def isinf(x):
  if isinstance(x, JaxArray):
    return JaxArray(jnp.isinf(x.value))
  else:
    return jnp.isinf(x)
Ejemplo n.º 29
0
def linesearch(cost_and_grad, x, d, f0, df0, g0, aold=None, dfold=None, fold=None, ls_pars=None):

    if ls_pars is None:
        ls_pars = LineSearchParameter()

    # Wolfe conditions
    def wolfe_one(ai, fi):
        return fi > f0 + ls_pars.ls_suff_decr * ai * df0

    def wolfe_two(dfi):
        return jnp.abs(dfi) <= - ls_pars.ls_curvature * df0

    state = _LineSearchState(
        done=False,
        failed=False,
        i=1,
        a_i1=0.,
        phi_i1=f0,
        dphi_i1=df0,
        nfev=0,
        ngev=0,
        a_star=0,
        phi_star=f0,
        dphi_star=df0,
        g_star=g0,
        )

    if ls_pars.ls_verbosity >= 1:
        print('\tStarting linesearch...')

    if (aold is not None) and (dfold is not None):
        alpha_0 = aold * dfold / df0
        initial_step_length = alpha_0
        initial_step_length = jnp.where(alpha_0 > ls_pars.ls_initial_step,
                                        ls_pars.ls_initial_step,
                                        alpha_0)
    elif (fold is not None) and (~jnp.isinf(fold)):
        candidate = 1.01 * 2 * jnp.abs((f0 - fold) / df0)
        candidate = jnp.where(candidate < 1e-8, 1e-8, candidate)
        initial_step_length = jnp.where(candidate > 1.2 * ls_pars.ls_initial_step,
                                        ls_pars.ls_initial_step,
                                        candidate)
        if ls_pars.ls_verbosity >= 3:
            print(f'\tcandidate: {candidate:.2e}, accepted: {initial_step_length:.2e}')
    else:
        initial_step_length = ls_pars.ls_initial_step

    while ((~state.done) & (state.i <= ls_pars.ls_maxiter) & (~state.failed)):
        # no amax in this version, we just double as in scipy.
        # unlike original algorithm we do our choice at the start of this loop
        ai = jnp.where(
            state.i == 1,
            initial_step_length,
            state.a_i1 * ls_pars.ls_optimism
            )
        
        fi, gri, dfi = cost_and_grad(ai)
        state = state._replace(
            nfev=state.nfev + 1,
            ngev=state.ngev + 1
            )
        # if dfi > 0:
        #     state._replace(
        #         failed=True,
        #     )
        #     break
        while jnp.isnan(fi):
            ai = ai / 10.
            fi, gri, dfi = cost_and_grad(ai)
            state = state._replace(
                nfev=state.nfev + 1,
                ngev=state.ngev + 1
            )

        if ls_pars.ls_verbosity >= 2:
            print("\titer: {}\n\t\talpha: {:.2e} "
                  "f(alpha): {:.5e}".format(state.i, ai, fi))

        if wolfe_one(ai, fi) or ((fi > state.phi_i1) and state.i > 1):
            if ls_pars.ls_verbosity >= 2:
                print('\t\tEntering zoom1...')
            zoom1 = _zoom(cost_and_grad, wolfe_one, wolfe_two,
                              state.a_i1, state.phi_i1, state.dphi_i1, ai, fi, dfi,
                              gri, ls_pars)
            state = state._replace(
                done=(zoom1.done or state.done),
                failed=(zoom1.failed or state.failed),
                a_star=zoom1.a_star,
                phi_star=zoom1.phi_star,
                dphi_star=zoom1.dphi_star,
                g_star=zoom1.g_star,
                nfev=state.nfev + zoom1.nfev,
                ngev=state.ngev + zoom1.ngev
                )
        elif wolfe_two(dfi):
            if ls_pars.ls_verbosity >= 2:
                print('\t\tWolfe two condition met, stopping')
            state = state._replace(
                done=(True or state.done),
                a_star=ai,
                phi_star=fi,
                dphi_star=dfi,
                g_star=gri
                )
        elif dfi >= 0:
            if ls_pars.ls_verbosity >= 2:
                print('\t\tEntering zoom2')
            zoom2 = _zoom(cost_and_grad, wolfe_one, wolfe_two,
                              ai, fi, dfi, state.a_i1, state.phi_i1, state.dphi_i1,
                              gri, ls_pars)
            state = state._replace(
                done=(zoom2.done or state.done),
                failed=(zoom2.failed or state.failed),
                a_star=zoom2.a_star,
                phi_star=zoom2.phi_star,
                dphi_star=zoom2.dphi_star,
                g_star=zoom2.g_star,
                nfev=state.nfev + zoom2.nfev,
                ngev=state.ngev + zoom2.ngev
                )

        state = state._replace(
            i=state.i + 1,
            a_i1=ai,
            phi_i1=fi,
            dphi_i1=dfi)

    status = jnp.where(
        state.failed,
        jnp.array(2),  # zoom failed
        jnp.where(
            state.i > ls_pars.ls_maxiter,
            jnp.array(1),  # maxiter reached
            jnp.array(0),  # passed (should be)
            ),
        )
    result = _LineSearchResult(
        failed=state.failed,
        nit=state.i - 1,  # because iterations started at 1
        nfev=state.nfev,
        ngev=state.ngev,
        k=state.i,
        a_k=state.a_i1 if status==1 else state.a_star,
        f_k=state.phi_star,
        g_k=state.g_star,
        status=status,
        )
    if ls_pars.ls_verbosity >= 1:
        print('\tLinesearch {}, alpha star = {:.2e}'.format(
            'failed' if state.failed else 'done', result.a_k))

    return result
Ejemplo n.º 30
0
def check():
    if np.isnan(lls[-1]) or np.isinf(lls[-1]):
        raise Exception("LL check failed.")