Exemple #1
0
 def concatenate(self, key_arrs, axis):
     axis = axis % len(self.shape)
     arrs = [self._keys, *[k._keys for k in key_arrs]]
     return PRNGKeyArray(self.impl, jnp.stack(arrs, axis))
Exemple #2
0
def pad_trajectories(trajectories, boundary=20):
    """Pad trajectories to a bucket length that is a multiple of boundary.

  Args:
    trajectories: list[(observation, actions, rewards)], where each observation
      is shaped (t+1,) + OBS and actions & rewards are shaped (t,), with the
      length of the list being B (batch size).
    boundary: int, bucket length, the actions and rewards are padded to integer
      multiples of boundary.

  Returns:
    tuple: (padding lengths, reward_mask, padded_observations, padded_actions,
        padded_rewards) where padded_observations is shaped (B, T+1) + OBS and
        padded_actions, padded_rewards & reward_mask are shaped (B, T).
        Where T is max(t) rounded up to an integer multiple of boundary.
        padded_length is how much padding we've added and
        reward_mask is 1s for actual rewards and 0s for the padding.
  """

    # Let's compute max(t) over all trajectories.
    t_max = max(r.shape[0] for (_, _, r, _) in trajectories)

    # t_max is rounded to the next multiple of `boundary`
    boundary = int(boundary)
    bucket_length = boundary * int(np.ceil(float(t_max) / boundary))

    # So all obs will be padded to t_max + 1 and actions and rewards to t_max.
    padded_observations = []
    padded_actions = []
    padded_rewards = []
    padded_infos = collections.defaultdict(list)
    padded_lengths = []
    reward_masks = []

    for (o, a, r, i) in trajectories:
        # Determine the amount to pad, this holds true for obs, actions and rewards.
        num_to_pad = bucket_length + 1 - o.shape[0]
        padded_lengths.append(num_to_pad)
        if num_to_pad == 0:
            padded_observations.append(o)
            padded_actions.append(a)
            padded_rewards.append(r)
            reward_masks.append(onp.ones_like(r, dtype=np.int32))
            if i:
                for k, v in i.items():
                    padded_infos[k].append(v)
            continue

        # First pad observations.
        padding_config = tuple([(0, num_to_pad, 0)] + [(0, 0, 0)] *
                               (o.ndim - 1))

        padding_value = get_padding_value(o.dtype)
        action_padding_value = get_padding_value(a.dtype)
        reward_padding_value = get_padding_value(r.dtype)

        padded_obs = lax.pad(o, padding_value, padding_config)
        padded_observations.append(padded_obs)

        # Now pad actions and rewards.
        assert a.ndim == 1 and r.ndim == 1
        padding_config = ((0, num_to_pad, 0), )

        padded_action = lax.pad(a, action_padding_value, padding_config)
        padded_actions.append(padded_action)
        padded_reward = lax.pad(r, reward_padding_value, padding_config)
        padded_rewards.append(padded_reward)

        # Also create the mask to use later.
        reward_mask = onp.ones_like(r, dtype=np.int32)
        reward_masks.append(lax.pad(reward_mask, 0, padding_config))

        if i:
            for k, v in i.items():
                # Create a padding configuration for this value.
                padding_config = [(0, num_to_pad, 0)
                                  ] + [(0, 0, 0)] * (v.ndim - 1)
                padded_infos[k].append(lax.pad(v, 0.0, tuple(padding_config)))

    # Now stack these padded_infos if they exist.
    stacked_padded_infos = None
    if padded_infos:
        stacked_padded_infos = {
            k: np.stack(v)
            for k, v in padded_infos.items()
        }

    return padded_lengths, np.stack(reward_masks), np.stack(
        padded_observations), np.stack(padded_actions), np.stack(
            padded_rewards), stacked_padded_infos
Exemple #3
0
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     return self._dirichlet.log_prob(np.stack([value, 1. - value], -1))
Exemple #4
0
def cartesian_prod(x, y):
    return jnp.stack([jnp.tile(x, len(y)), jnp.repeat(y, len(x))]).T
Exemple #5
0
 def test_fn(x):
     return np.stack([x[..., 0], x[..., 0] * x[..., 1], x[..., 0] * x[..., 1] * x[..., 2]], -1)
 def stack_plus_minus(x):
     return np.stack([x, -x])
    def __init__(self,
                 target_data,
                 prior,
                 simulator,
                 compressor,
                 gridsize=100,
                 F=None,
                 distance_measure=None):
        """Constructor method

        Parameters
        ----------
        target_data: float(n_targets, input_data)
            Data (or batch of data) to infer parameter values at
        prior: fn
            A prior distribution which can be evaluated and sampled from
            (should also contain a ``low`` and a ``high`` attribute with
            appropriate ranges)
        simulator: fn
            A function which takes in a batch of parameter values (and random
            seeds) and returns a simulation made at each parameter value
        compressor: fn
            A function which takes a batch of simulations and returns their
            compressed summaries for each simulation (can be identity function
            for no compression)
        gridsize : int or list, default=100
            The number of grid points to evaluate the marginal distribution on
            for every parameter (int) or each parameter (list)
        F : list or float({n_targets}, n_params, n_params) or None
            The Fisher information matrix to rescale parameter directions for
            measuring the distance. If None, this is set to the identity matrix
        distance_measure : fn or None
            A distance measuring function between a batch of summaries and a
            single set of summaries for a target. If None, ``F_distance``
            (Euclidean scaled by the Fisher information) is used. This is only
        """
        super().__init__(prior=prior, gridsize=gridsize)
        self.simulator = simulator
        self.compressor = compressor
        if len(target_data.shape) == 1:
            target_data = np.expand_dims(target_data, 0)
        self.target_summaries = self.compressor(target_data)
        self.n_summaries = self.target_summaries.shape[-1]
        self.n_targets = self.target_summaries.shape[0]
        if F is not None:
            if self.n_summaries != self.n_params:
                raise ValueError(
                    "If using the Fisher information to scale the distance " +
                    "then the compressor must return parameter estimates. " +
                    "The compressor returns summaries with shape " +
                    f"{(self.n_summaries,)}, but should return estamites " +
                    f"with shape {(self.n_params)} if the prior is correct.")
            if isinstance(F, list):
                self.F = np.stack([
                    _check_input(f, (self.n_params, self.n_params), "F")
                    for f in F
                ], 0)
            else:
                if F.shape == (self.n_params, self.n_params):
                    self.F = np.expand_dims(F, 0)
                    if self.n_targets > 1:
                        self.F = np.repeat(self.F, self.n_targets)
                else:
                    self.F = _check_input(
                        F, (self.n_targets, self.n_params, self.n_params), "F")
            self.distance_measure = self.F_distance

        if distance_measure is not None:
            self.distance_measure = distance_measure
        elif F is None:
            self.distance_measure = self.euclidean_distance
            self.F = np.zeros(self.n_targets)

        self.parameters = container()
        self.parameters.n_params = self.n_params
        self.summaries = container()
        self.summaries.n_summaries = self.n_summaries
        self.distances = container()
        self.distances.n_targets = self.n_targets
Exemple #8
0
def dtoq_miles(data):
    return np.stack([np.cos(data*np.pi/2.0), np.sin(data*np.pi/2.0)], axis=2)
Exemple #9
0
 def dz_dt(z, t):
     return jnp.stack([z[0], z[1]])
Exemple #10
0
def test_fun4(p, t, w1, w2):
    X = jnp.stack((p, t), axis=1)
    z1 = jnp.matmul(X, w1)
    z2 = jnp.matmul(z1, w2)
    return z2.squeeze()
Exemple #11
0
    def warmup(
        self,
        rng_key: jax.numpy.ndarray,
        initial_state: HMCState,
        kernel_factory: Callable,
        num_chains,
        num_warmup_steps: int = 1000,
        accelerate=False,
        initial_step_size: float = 0.1,
    ) -> Tuple[HMCState, HMCParameters, Optional[StanWarmupState]]:
        """I don't like having a ton of warmup logic in here."""

        if not self.needs_warmup:
            parameters = HMCParameters(
                jnp.ones(initial_state.position.shape[0], dtype=jnp.int32) *
                self.parameters.num_integration_steps,
                jnp.ones(initial_state.position.shape[0]) *
                self.parameters.step_size,
                jnp.array([
                    self.parameters.inverse_mass_matrix
                    for _ in range(initial_state.position.shape[0])
                ]),
            )
            return initial_state, parameters, None

        hmc_factory = jax.partial(kernel_factory,
                                  self.parameters.num_integration_steps)
        init, update, final = stan_hmc_warmup(hmc_factory,
                                              self.is_mass_matrix_diagonal)

        rng_keys = jax.random.split(rng_key, num_chains)
        chain_state = initial_state
        warmup_state = jax.vmap(init,
                                in_axes=(0, 0, None))(rng_keys, chain_state,
                                                      initial_step_size)

        schedule = jnp.array(stan_warmup_schedule(num_warmup_steps))

        if accelerate:

            print(
                f"sampler: warmup {num_chains:,} chains for {num_warmup_steps:,} iterations.",
                end=" ",
            )
            start = datetime.now()

            @jax.jit
            def update_chain(carry, interval):
                rng_key, chain_state, warmup_state = carry
                stage, is_middle_window_end = interval

                _, rng_key = jax.random.split(rng_key)
                keys = jax.random.split(rng_key, num_chains)
                chain_state, warmup_state, chain_info = jax.vmap(
                    update,
                    in_axes=(0, None, None, 0, 0))(keys, stage,
                                                   is_middle_window_end,
                                                   chain_state, warmup_state)

                return (
                    (rng_key, chain_state, warmup_state),
                    (chain_state, warmup_state, chain_info),
                )

            last_state, warmup_chain = jax.lax.scan(
                update_chain, (rng_key, chain_state, warmup_state), schedule)
            _, last_chain_state, last_warmup_state = last_state

            print(
                f"Done in {(datetime.now()-start).total_seconds():.1f} seconds."
            )

        else:

            @jax.jit
            def update_fn(rng_key, interval, chain_state, warmup_state):
                rng_keys = jax.random.split(rng_key, num_chains)
                stage, is_middle_window_end = interval
                chain_state, warmup_state, chain_info = jax.vmap(
                    update,
                    in_axes=(0, None, None, 0, 0))(rng_keys, stage,
                                                   is_middle_window_end,
                                                   chain_state, warmup_state)
                return chain_state, warmup_state, chain_info

            chain = []
            with tqdm(schedule, unit="samples") as progress:
                progress.set_description(
                    f"Warming up {num_chains} chains for {num_warmup_steps} steps",
                    refresh=False,
                )
                for interval in progress:
                    _, rng_key = jax.random.split(rng_key)
                    chain_state, warmup_state, chain_info = update_fn(
                        rng_key, interval, chain_state, warmup_state)
                    chain.append((chain_state, warmup_state, chain_info))

            last_chain_state, last_warmup_state, _ = chain[-1]

            # The sampling process, the composition between scan and for loop
            # is identical for the warmup and the sampling.  Should we
            # generalize this to only call a single `scan` function?
            stack = lambda y, *ys: jnp.stack((y, *ys))
            warmup_chain = jax.tree_multimap(stack, *chain)

        step_size, inverse_mass_matrix = jax.vmap(
            final, in_axes=(0, ))(last_warmup_state)
        num_integration_steps = self.parameters.num_integration_steps

        parameters = HMCParameters(
            jnp.ones(initial_state.position.shape[0], dtype=jnp.int32) *
            num_integration_steps,
            step_size,
            inverse_mass_matrix,
        )

        return last_chain_state, parameters, warmup_chain
Exemple #12
0
 def attention_query_fn(n):
     return tree.tree_map(lambda n_: jnp.stack([n_, n_, n_], axis=2), n)
Exemple #13
0
 def update_edge_fn(e, unused_sn, unused_rn, unused_g):
     return tree.tree_map(lambda e_: jnp.stack([e_, e_, e_]), e)
Exemple #14
0
def main(args):
    print("Start vanilla HMC...")
    nuts_kernel = NUTS(dual_moon_model)
    mcmc = MCMC(
        nuts_kernel,
        args.num_warmup,
        args.num_samples,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(random.PRNGKey(0))
    mcmc.print_summary()
    vanilla_samples = mcmc.get_samples()['x'].copy()

    guide = AutoBNAFNormal(
        dual_moon_model,
        hidden_factors=[args.hidden_factor, args.hidden_factor])
    svi = SVI(dual_moon_model, guide, optim.Adam(0.003), AutoContinuousELBO())
    svi_state = svi.init(random.PRNGKey(1))

    print("Start training guide...")
    last_state, losses = lax.scan(lambda state, i: svi.update(state),
                                  svi_state, np.zeros(args.num_iters))
    params = svi.get_params(last_state)
    print("Finish training guide. Extract samples...")
    guide_samples = guide.sample_posterior(
        random.PRNGKey(0), params,
        sample_shape=(args.num_samples, ))['x'].copy()

    transform = guide.get_transform(params)
    _, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2),
                                                     dual_moon_model)
    transformed_potential_fn = partial(transformed_potential_energy,
                                       potential_fn, transform)
    transformed_constrain_fn = lambda x: constrain_fn(transform(x)
                                                      )  # noqa: E731

    print("\nStart NeuTra HMC...")
    nuts_kernel = NUTS(potential_fn=transformed_potential_fn)
    mcmc = MCMC(
        nuts_kernel,
        args.num_warmup,
        args.num_samples,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    init_params = np.zeros(guide.latent_size)
    mcmc.run(random.PRNGKey(3), init_params=init_params)
    mcmc.print_summary()
    zs = mcmc.get_samples()
    print("Transform samples into unwarped space...")
    samples = vmap(transformed_constrain_fn)(zs)
    print_summary(tree_map(lambda x: x[None, ...], samples))
    samples = samples['x'].copy()

    # make plots

    # guide samples (for plotting)
    guide_base_samples = dist.Normal(np.zeros(2),
                                     1.).sample(random.PRNGKey(4), (1000, ))
    guide_trans_samples = vmap(transformed_constrain_fn)(
        guide_base_samples)['x']

    x1 = np.linspace(-3, 3, 100)
    x2 = np.linspace(-3, 3, 100)
    X1, X2 = np.meshgrid(x1, x2)
    P = np.exp(DualMoonDistribution().log_prob(np.stack([X1, X2], axis=-1)))

    fig = plt.figure(figsize=(12, 8), constrained_layout=True)
    gs = GridSpec(2, 3, figure=fig)
    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[1, 0])
    ax3 = fig.add_subplot(gs[0, 1])
    ax4 = fig.add_subplot(gs[1, 1])
    ax5 = fig.add_subplot(gs[0, 2])
    ax6 = fig.add_subplot(gs[1, 2])

    ax1.plot(losses[1000:])
    ax1.set_title('Autoguide training loss\n(after 1000 steps)')

    ax2.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], n_levels=30, ax=ax2)
    ax2.set(xlim=[-3, 3],
            ylim=[-3, 3],
            xlabel='x0',
            ylabel='x1',
            title='Posterior using\nAutoBNAFNormal guide')

    sns.scatterplot(guide_base_samples[:, 0],
                    guide_base_samples[:, 1],
                    ax=ax3,
                    hue=guide_trans_samples[:, 0] < 0.)
    ax3.set(
        xlim=[-3, 3],
        ylim=[-3, 3],
        xlabel='x0',
        ylabel='x1',
        title='AutoBNAFNormal base samples\n(True=left moon; False=right moon)'
    )

    ax4.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(vanilla_samples[:, 0],
                vanilla_samples[:, 1],
                n_levels=30,
                ax=ax4)
    ax4.plot(vanilla_samples[-50:, 0],
             vanilla_samples[-50:, 1],
             'bo-',
             alpha=0.5)
    ax4.set(xlim=[-3, 3],
            ylim=[-3, 3],
            xlabel='x0',
            ylabel='x1',
            title='Posterior using\nvanilla HMC sampler')

    sns.scatterplot(zs[:, 0],
                    zs[:, 1],
                    ax=ax5,
                    hue=samples[:, 0] < 0.,
                    s=30,
                    alpha=0.5,
                    edgecolor="none")
    ax5.set(xlim=[-5, 5],
            ylim=[-5, 5],
            xlabel='x0',
            ylabel='x1',
            title='Samples from the\nwarped posterior - p(z)')

    ax6.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(samples[:, 0], samples[:, 1], n_levels=30, ax=ax6)
    ax6.plot(samples[-50:, 0], samples[-50:, 1], 'bo-', alpha=0.2)
    ax6.set(xlim=[-3, 3],
            ylim=[-3, 3],
            xlabel='x0',
            ylabel='x1',
            title='Posterior using\nNeuTra HMC sampler')

    plt.savefig("neutra.pdf")
Exemple #15
0
def main(seed, device, dynamic_gamma, dynamic_preference, mc_type):

    import jax.numpy as jnp
    from numpyro.infer import MCMC, NUTS
    from jax import random, nn, devices, device_put, vmap, jit

    # import utility functions for model inversion and belief estimation
    from utils import estimate_beliefs, mixture_model

    # import data loader
    from stats import load_data

    outcomes_data, responses_data, mask_data, ns, _, _ = load_data()

    print(seed, device, dynamic_gamma, dynamic_preference)

    model = lambda *args: mixture_model(*args, dynamic_gamma=dynamic_gamma, dynamic_preference=dynamic_preference)

    m_data = jnp.array(mask_data)
    r_data = jnp.array(responses_data).astype(jnp.int32)
    o_data = jnp.array(outcomes_data).astype(jnp.int32)

    rng_key = random.PRNGKey(seed)
    cutoff_up = 1000
    cutoff_down = 400

    priors = []
    params = []

    if mc_type == 'nu_max':
        M_rng = list(range(1, 11))  # model comparison for regular condition
    else:
        M_rng = [1,] + list(range(11, 20))  # model comparison for irregular condition

    for M in M_rng:
        if M <= 10:
            seq, _ = estimate_beliefs(o_data, r_data, device, mask=m_data, nu_max=M)
        else:
            seq, _ = estimate_beliefs(o_data, r_data, device, mask=m_data, nu_max=10, nu_min=M-10)
        
        priors.append(seq['beliefs'][0][cutoff_down:cutoff_up])
        params.append(seq['beliefs'][1][cutoff_down:cutoff_up])
    
    device = devices(device)[0]

    # init preferences
    c0 = jnp.sum(nn.one_hot(outcomes_data[:cutoff_down], 4) * jnp.expand_dims(mask_data[:cutoff_down], -1), 0)

    if dynamic_gamma:
        num_warmup = 1000
        num_samples = 1000
        num_chains = 1
    else:
        num_warmup = 200
        num_samples = 200
        num_chains = 5

    def inference(belief_sequences, obs, mask, rng_key):
        nuts_kernel = NUTS(model, dense_mass=True)
        mcmc = MCMC(
            nuts_kernel, 
            num_warmup=num_warmup, 
            num_samples=num_samples, 
            num_chains=num_chains,
            chain_method="vectorized",
            progress_bar=False
        )

        mcmc.run(
            rng_key, 
            belief_sequences, 
            obs, 
            mask, 
            extra_fields=('potential_energy',)
        )

        samples = mcmc.get_samples()
        potential_energy = mcmc.get_extra_fields()['potential_energy'].mean()
        # mcmc.print_summary()

        return samples, potential_energy

    seqs = device_put(
                    (
                        jnp.stack(priors, 0), 
                        jnp.stack(params, 0), 
                        o_data[cutoff_down:cutoff_up], 
                        c0
                    ), 
                device)

    y = device_put(r_data[cutoff_down:cutoff_up], device)
    mask = device_put(m_data[cutoff_down:cutoff_up].astype(bool), device)

    n = mask.shape[-1]
    rng_keys = random.split(rng_key, n)

    samples, potential_energy = jit(vmap(inference, in_axes=((2, 2, 1, 0), 1, 1, 0)))(seqs, y, mask, rng_keys)

    print('potential_energy', potential_energy)     

    jnp.savez('fit_data/fit_sample_mixture_gamma{}_pref{}_{}.npz'.format(int(dynamic_gamma), int(dynamic_preference), mc_type), samples=samples)
Exemple #16
0
        pkl.dump({'xx': xx, 'yy': yy, 'L_v_grid': L_v_grid}, f)
    # ------------------------------------------------------------------

iterations = []
lam_trajectory = []
eps_trajectory = []

theta = jnp.array([args.initial_theta])
key = jax.random.PRNGKey(args.seed)
w = jax.random.normal(key, (x_train.shape[1], ))
inner_optim_params = reset_inner_optim_params(w)

# Initialize PES stuff
if args.estimate in ['pes', 'pes-analytic']:
    perturbation_accums = jnp.zeros((args.N, len(theta)))
    ws = jnp.stack([w] * args.N)

# Outer optimization
# =======================================================================
if args.outer_optimizer == 'adam':
    outer_optim_params = {
        'lr': args.outer_lr,
        'b1': args.outer_b1,
        'b2': args.outer_b2,
        'eps': args.outer_eps,
        'm': jnp.zeros(len(theta)),
        'v': jnp.zeros(len(theta)),
    }

    @jax.jit
    def outer_optimizer_step(params, grads, optim_params, t):
Exemple #17
0
 def log_prob(self, value):
     return self._dirichlet.log_prob(np.stack([value, 1. - value], -1))
def rotate_every_two(tensor):
    rotate_half_tensor = jnp.stack(
        (-tensor[:, :, :, 1::2], tensor[:, :, :, ::2]), axis=-1)
    rotate_half_tensor = rotate_half_tensor.reshape(
        rotate_half_tensor.shape[:-2] + (-1, ))
    return rotate_half_tensor
Exemple #19
0
def apply_mutations(samples, mutation_types, pos_masks, mutations,
                    use_assignment_mutations=False):
  """Apply the mutations specified by mutation types to the batch of strings.

  Args:
    samples: Batch of strings [batch x str_length]
    mutation_types: IDs of mutation types to be applied to each string
      [Batch x num_mutations]
    pos_masks: One-hot encoding [Batch x num_mutations x str_length]
      of the positions to be mutate in each string.
      "num_mutations" positions will be mutated per string.
    mutations: A list of possible mutation functions.
      Functions should follow the format: fn(x, domain, pos_mask),
    use_assignment_mutations: bool. Whether mutations are defined as
      "Set position X to character C". If use_assignment_mutations=True,
      then vectorize procedure of applying mutations to the string.
      The index of mutation type should be equal to the index of the character.
      Gives considerable speed-up to this function.

  Returns:
    perturbed_samples: Strings perturbed according to the mutation list.
  """
  batch_size = samples.shape[0]
  assert len(mutation_types) == batch_size
  assert len(pos_masks) == batch_size

  str_length = samples.shape[1]
  assert pos_masks.shape[-1] == str_length

  # Check that number of mutations is consistent in mutation_types and positions
  assert mutation_types.shape[1] == pos_masks.shape[1]

  num_mutations = mutation_types.shape[1]

  # List of batched samples with 0,1,2,... mutations
  # First element of the list contains original samples
  # Last element has samples with all mutations applied to the string
  perturbed_samples_with_i_mutations = [samples]
  for i in range(num_mutations):

    perturbed_samples = []
    samples_to_perturb = perturbed_samples_with_i_mutations[-1]

    if use_assignment_mutations:
      perturbed_samples = samples_to_perturb.copy()
      mask = pos_masks[:, i].astype(int)
      # Assumes mutations are defined as "Set position to the character C"
      perturbed_samples[np.array(mask) == 1] = mutation_types[:, i]
    else:
      for j in range(batch_size):
        sample = samples_to_perturb[j].copy()

        pos = pos_masks[j, i]
        mut_id = mutation_types[j, i]

        mutation = mutations[int(mut_id)]
        perturbed_samples.append(mutation(sample, pos))
      perturbed_samples = np.stack(perturbed_samples)

    assert perturbed_samples.shape == samples.shape
    perturbed_samples_with_i_mutations.append(perturbed_samples)

  states = jnp.stack(perturbed_samples_with_i_mutations, 0)
  assert states.shape == (num_mutations + 1,) + samples.shape
  return states
Exemple #20
0
def joint_mi_criterion_mg(particle_weights, particles, cur_group,
                          cur_positives, previous_groups_prob_particles_states,
                          previous_groups_cumcond_entropy, sensitivity,
                          specificity, backtracking):
    """Compares the benefit of adding one group to previously selected ones.

  Groups are formed iteratively by considering all possible individuals
  that can be considered to add (or remove if backtracking).

  If the sensitivity and/or specificity parameters are group size dependent,
  we take that into account in our optimization.
  Here all groups have the same size, hence they all share the same
  specificity / sensitivity setting. We just replace the vector by its value at
  the appropriate coordinate.

  The size of the group considered here will be the size of cur_group + 1 if
  going forward / -1 if backtracking.

  Args:
   particle_weights: weights of particles
   particles: particles summarizing belief about infection status
   cur_group: group currently considered to add to former groups.
   cur_positives: stores which particles would test positive w.r.t cur_group
   previous_groups_prob_particles_states: particles x test outcome probabilities
   previous_groups_cumcond_entropy: previous conditional entropies
   sensitivity: value (vector) of sensitivity(-ies depending on group size).
   specificity: value (vector) of specificity(-ies depending on group size).
   backtracking: (bool), True if removing rather than adding individuals.

  Returns:
    cur_group : group updated with best choice
    cur_positives : bool vector keeping trace of whether particles
                             would test or not positive
    new_objective : MI reached with this new group
    prob_particles_states : if cur_group were to be selected, this matrix
      would keep track of probability of seeing one of 2^j possible test
      outcomes across all particles.
    new_cond_entropy : if cur_group were to be selected, this constant would be
      added to store the conditional entropies of all tests carried out thusfar
  """
    group_size = np.atleast_1d(np.sum(cur_group) + 1 - 2 * backtracking)
    sensitivity = utils.select_from_sizes(sensitivity, group_size)
    specificity = utils.select_from_sizes(specificity, group_size)
    if backtracking:
        # if backtracking, we recompute the truth table for all proposed groups,
        # namely run the np.dot below
        # TODO(cuturi)? If we switch to integer arithmetic we may be able to
        # save on this iteration by keeping track of how many positives there
        # are, and not just on whether there is or not one positive.
        candidate_groups = np.logical_not(
            add_ones_to_line(np.logical_not(cur_group)))
        positive_in_groups = np.dot(candidate_groups,
                                    np.transpose(particles)) > 0
    else:
        # in forward mode, candidate groups are recovered by adding
        # a 1 instead of zeros. Therefore, we can use previous vector of positive
        # in groups to simply compute all positive in groups for candidates
        indices_of_false_in_cur_group, = np.where(np.logical_not(cur_group))
        positive_in_groups = np.logical_or(
            cur_positives[:, np.newaxis],
            particles[:, indices_of_false_in_cur_group])
        # recover a candidates x n_particles matrix
        positive_in_groups = np.transpose(positive_in_groups)

    entropy_spec = metrics.binary_entropy(specificity)
    gamma = metrics.binary_entropy(sensitivity) - entropy_spec
    cond_entropy = previous_groups_cumcond_entropy + entropy_spec + gamma * np.sum(
        particle_weights[np.newaxis, :] * positive_in_groups, axis=1)
    rho = specificity + sensitivity - 1

    # positive_in_groups defines probability of two possible outcomes for the test
    # of each new candidate group.
    probabilities_new_test = np.stack(
        (specificity - rho * positive_in_groups,
         1 - specificity + rho * positive_in_groups),
        axis=-1)
    # we now incorporate previous probability of all previous groups added so far
    # and expand x 2 the state space of possible test results.
    new_plus_previous_groups_prob_particles_states = np.concatenate(
        (probabilities_new_test[:, :, 0][:, :, np.newaxis] *
         previous_groups_prob_particles_states[np.newaxis, :, :],
         probabilities_new_test[:, :, 1][:, :, np.newaxis] *
         previous_groups_prob_particles_states[np.newaxis, :, :]),
        axis=2)

    # average over particles to recover probability of all 2^j possible
    # test results
    new_plus_previous_groups_prob_states = np.sum(
        particle_weights[np.newaxis, :, np.newaxis] *
        new_plus_previous_groups_prob_particles_states,
        axis=1)

    whole_entropy = metrics.entropy(new_plus_previous_groups_prob_states,
                                    axis=1)

    # exhaustive way to compute cond entropy, useful to check
    # computations.
    # cond_entropy_old = np.sum(
    #     particle_weights[np.newaxis, :] *
    #     entropy(new_plus_previous_groups_prob_particles_states, axis=2),
    #     axis=1)
    objectives = whole_entropy - cond_entropy

    # greedy selection of largest/smallest value
    index = np.argmax(objectives)

    if backtracking:
        # return most promising group by recovering it from the matrix directly
        logging.info('backtracking, candidate_groups size: %i',
                     candidate_groups.shape)
        cur_group = candidate_groups[index, :]

    else:
        # return most promising group by adding a 1
        cur_group = jax.ops.index_update(cur_group,
                                         indices_of_false_in_cur_group[index],
                                         True)

    # refresh the status of vector positives
    cur_positives = positive_in_groups[index, :]
    new_objective = objectives[index]
    prob_particles_states = new_plus_previous_groups_prob_particles_states[
        index, :, :]
    new_cond_entropy = cond_entropy[index]
    return (cur_group, cur_positives, new_objective, prob_particles_states,
            new_cond_entropy)
Exemple #21
0
    def step(self, s, a):
        """Apply control, damping, boundary, and collision forces.

    Args:
      s: (p, v, misc), where p and v are [n_entities,2] jnp.float32,
         and misc is child defined
      a: [n_agents, dim_a] jnp.float32

    Returns:
      A state tuple (p, v, misc)
    """
        p, v, misc = s  # [n,2], [n,2], [a_shape]
        f = jnp.zeros_like(p)  # [n,2]
        n = p.shape[0]  # number of entities

        # Calculate control forces
        f_control = jnp.pad(a, ((0, n - a.shape[0]), (0, 0)),
                            mode="constant")  # [n, dim_a]
        f += f_control

        # Calculate damping forces
        f_damping = -1.0 * self.damping * v  # [n,2]
        f = f + f_damping

        # Calculate boundary forces
        bounce = (((p + self.radius >= self.max_p) & (v >= 0.0)) |
                  ((p - self.radius <= self.min_p) & (v <= 0.0)))  # [n,2]
        v_new = (-1.0 * bounce + 1.0 * ~bounce) * v  # [n,2]
        f_boundary = self.mass * (v_new - v) / self.dt  # [n,2]
        f = f + f_boundary

        # Calculate shared quantities for later calculations
        # same: [n,n,1], True if i==j
        same = jnp.expand_dims(jnp.eye(n, dtype=jnp.bool_), axis=-1)
        # p2p: [n,n,2], p2p[i,j,:] is the vector from entity i to entity j
        p2p = p - jnp.expand_dims(p, axis=1)
        # dist: [n,n,1], p2p[i,j,0] is the distance between i and j
        dist = jnp.linalg.norm(p2p, axis=-1, keepdims=True)
        # overlap: [n,n,1], overlap[i,j,0] is the overlap between i and j
        overlap = ((jnp.expand_dims(self.radius, axis=1) +
                    jnp.expand_dims(self.radius, axis=0)) - dist)
        if self.same_position_check:
            # ontop: [n,n,1], ontop[i,j,0] = True if i is at the exact location of j
            ontop = (dist == 0.0)
            # ontop_dir: [n,n,1], (1,0) above diagonal, (-1,0) below diagonal
            ontop_dir = jnp.stack(
                [jnp.triu(jnp.ones((n, n))) * 2 - 1,
                 jnp.zeros((n, n))],
                axis=-1)
            # contact_dir: [n,n,2], contact_dir[i,j,:] is the unit vector in the
            # direction of j from i
            contact_dir = (~ontop * p2p +
                           (ontop * ontop_dir)) / (~ontop * dist + ontop * 1.0)
        else:
            # contact_dir: [n,n,2], contact_dir[i,j,:] is the unit vector in the
            # direction of j from i
            contact_dir = p2p / (dist + same)
        # collideable: [n,n,1], True if i and j are collideable
        collideable = (jnp.expand_dims(self.collideable, axis=1)
                       & jnp.expand_dims(self.collideable, axis=0))
        # overlap: [n,n,1], True if i,j overlap
        overlapping = overlap > 0

        # Calculate collision forces
        # Assume all entities collide with all entities, then mask out
        # non-collisions.
        #
        # For approaching, coliding entities, apply a forces
        # along the direction of collision that results in
        # relative velocities consistent with the coefficient of
        # restitution (c) and preservation of momentum in that
        # direction.
        # momentum: m_a*v_a + m_b*v_b = m_a*v'_a + m_b*v'_b
        # restitution: v'_b - v'_a = -c*(v_b-v_a)
        # solve for v'_a:
        #  v'_a = [m_a*v_a + m_b*v_b + m_b*c*(v_b-v_a)]/(m_a + m_b)
        #
        # v_contact_dir: [n,n] speed of i in dir of j
        v_contact_dir = jnp.sum(jnp.expand_dims(v, axis=-2) * contact_dir,
                                axis=-1)
        # v_approach: [n,n] speed that i,j are approaching each other
        v_approach = jnp.transpose(v_contact_dir) + v_contact_dir
        # momentum: [n,n] joint momentum in direction of contact (i->j)
        momentum = self.mass * v_contact_dir - jnp.transpose(
            self.mass * v_contact_dir)
        # v_result: [n,n] speed of i in dir of j after collision
        v_result = ((momentum + self.restitution * jnp.transpose(self.mass) *
                     (-v_approach)) / (self.mass + jnp.transpose(self.mass)))
        # f_collision: [n,n] force on i in dir of j to realize acceleration
        f_collision = self.mass * (v_result - v_contact_dir) / self.dt
        # f_collision: [n,n,2] force on i to realize acceleration due to
        # collision with j
        f_collision = jnp.expand_dims(f_collision, axis=-1) * contact_dir
        # collision_mask: [n,n,1]
        collision_mask = (collideable & overlapping & ~same &
                          (jnp.expand_dims(v_approach, axis=-1) > 0))
        # f_collision: [n,2], sum of collision forces on i
        f_collision = jnp.sum(f_collision * collision_mask, axis=-2)
        f = f + f_collision

        # Calculate overlapping spring forces
        # This corrects for any overlap due to discrete steps.
        # f_overlap: [n,n,2], force in the negative contact dir due to overlap
        f_overlap = -1.0 * contact_dir * overlap * self.overlap_spring_constant
        # overlapping_mask: [n,n,1], True if i,j are collideable, overlap,
        # and i != j
        overlapping_mask = collideable & overlapping & ~same
        # f_overlap: [n,2], sum of spring forces on i
        f_overlap = jnp.sum(f_overlap * overlapping_mask, axis=-2)
        f = f + f_overlap

        # apply forces
        v = v + (f / self.mass) * self.dt
        p = p + v * self.dt

        # update misc
        misc = self._update_misc((p, v, misc), a)  # pylint: disable=assignment-from-none

        return (p, v, misc)
Exemple #22
0
    def __init__(self,
                 m,
                 n,
                 dim,
                 alpha=None,
                 sigma=None,
                 n_epoch=2,
                 sched='linear',
                 device=jax.devices()[0],
                 precompute=True,
                 periodic=False,
                 metric=None,
                 centroids=None):

        # topology of the som
        super(SOM, self).__init__()
        self.device = device
        self.step = 0
        self.m = m
        self.n = n
        self.grid_size = m * n
        self.dim = dim
        self.periodic = periodic
        if metric is None:
            self.metric = jax_cdist
        else:
            self.metric = metric

        # optimization parameters
        self.sched = sched
        self.n_epoch = n_epoch
        if alpha is not None:
            self.alpha = float(alpha)
        else:
            self.alpha = alpha
        if sigma is None:
            self.sigma = np.sqrt(self.m * self.n) / 2.0
        else:
            self.sigma = float(sigma)

        if centroids is None:
            self.centroids = jax.device_put(np.abs(np.random.randn(m * n,
                                                                   dim)),
                                            device=device)
        else:
            self.centroids = centroids

        locs = np.stack(
            [np.array([i, j]) for i in range(self.m) for j in range(self.n)])
        self.locations = jax.device_put(locs, device=device)
        self.precompute = precompute
        if self.precompute:
            # Fast computation is only right for the periodic topology
            if self.periodic:
                self.distance_mat = self.compute_all()
            else:
                self.distance_mat = jnp.stack([
                    self.get_bmu_distance_squares(loc)
                    for loc in self.locations
                ])
            self.distance_mat = jax.device_put(self.distance_mat,
                                               device=device)
        self.umat = None

        # Clustering parameters
        self.cluster_att = None
        self.clusters_user = None
Exemple #23
0
def get_deterministic_policies(n_states, n_actions):
    simplicies = list([np.eye(n_actions)[i] for i in range(n_actions)])
    pis = list(itertools.product(*[simplicies for _ in range(n_states)]))
    return [np.stack(p) for p in pis]
Exemple #24
0
def kohn_sham(
    locations,
    nuclear_charges,
    num_electrons,
    num_iterations,
    grids,
    xc_energy_density_fn,
    interaction_fn,
    initial_density=None,
    alpha=0.5,
    alpha_decay=0.9,
    enforce_reflection_symmetry=False,
    num_mixing_iterations=2,
    density_mse_converge_tolerance=-1.):
  """Runs Kohn-Sham to solve ground state of external potential.

  Args:
    locations: Float numpy array with shape (num_nuclei,), the locations of
        atoms.
    nuclear_charges: Float numpy array with shape (num_nuclei,), the nuclear
        charges.
    num_electrons: Integer, the number of electrons in the system. The first
        num_electrons states are occupid.
    num_iterations: Integer, the number of Kohn-Sham iterations.
    grids: Float numpy array with shape (num_grids,).
    xc_energy_density_fn: function takes density (num_grids,) and returns
        the energy density (num_grids,).
    interaction_fn: function takes displacements and returns
        float numpy array with the same shape of displacements.
    initial_density: Float numpy array with shape (num_grids,), initial guess
        of the density for Kohn-Sham calculation. Default None, the initial
        density is non-interacting solution from the external_potential.
    alpha: Float between 0 and 1, density linear mixing factor, the fraction
        of the output of the k-th Kohn-Sham iteration.
        If 0, the input density to the k-th Kohn-Sham iteration is fed into
        the (k+1)-th iteration. The output of the k-th Kohn-Sham iteration is
        completely ignored.
        If 1, the output density from the k-th Kohn-Sham iteration is fed into
        the (k+1)-th iteration, equivalent to no density mixing.
    alpha_decay: Float between 0 and 1, the decay factor of alpha. The mixing
        factor after k-th iteration is alpha * alpha_decay ** k.
    enforce_reflection_symmetry: Boolean, whether to enforce reflection
        symmetry. If True, the density are symmetric respecting to the center.
    num_mixing_iterations: Integer, the number of density differences in the
        previous iterations to mix the density.
    density_mse_converge_tolerance: Float, the stopping criteria. When the MSE
        density difference between two iterations is smaller than this value,
        the Kohn Sham iterations finish. The outputs of the rest of the steps
        are padded by the output of the converged step. Set this value to
        negative to disable early stopping.

  Returns:
    KohnShamState, the states of all the Kohn-Sham iteration steps.
  """
  external_potential = utils.get_atomic_chain_potential(
      grids=grids,
      locations=locations,
      nuclear_charges=nuclear_charges,
      interaction_fn=interaction_fn)
  if initial_density is None:
    # Use the non-interacting solution from the external_potential as initial
    # guess.
    initial_density, _, _ = solve_noninteracting_system(
        external_potential=external_potential,
        num_electrons=num_electrons,
        grids=grids)
  # Create initial state.
  state = KohnShamState(
      density=initial_density,
      total_energy=jnp.inf,
      locations=locations,
      nuclear_charges=nuclear_charges,
      external_potential=external_potential,
      grids=grids,
      num_electrons=num_electrons)
  states = []
  differences = None
  converged = False
  for _ in range(num_iterations):
    if converged:
      states.append(state)
      continue

    old_state = state
    state = kohn_sham_iteration(
        state=old_state,
        num_electrons=num_electrons,
        xc_energy_density_fn=xc_energy_density_fn,
        interaction_fn=interaction_fn,
        enforce_reflection_symmetry=enforce_reflection_symmetry)
    density_difference = state.density - old_state.density
    if differences is None:
      differences = jnp.array([density_difference])
    else:
      differences = jnp.vstack([differences, density_difference])
    if jnp.mean(
        jnp.square(density_difference)) < density_mse_converge_tolerance:
      converged = True
    state = state._replace(converged=converged)
    # Density mixing.
    state = state._replace(
        density=old_state.density
        + alpha * jnp.mean(differences[-num_mixing_iterations:], axis=0))
    states.append(state)
    alpha *= alpha_decay

  return tree_util.tree_multimap(lambda *x: jnp.stack(x), *states)
Exemple #25
0
 def get_stacked_net(x):
     y = get_net(x)
     return jnp.stack([y, 2.0 * y])
Exemple #26
0
 def stack(x, y):
   return jnp.stack([x, y], 0)
Exemple #27
0
def rewards_to_go(rewards, mask, gamma=0.99):
    r"""Computes rewards to go.

  Reward to go is defined as follows, the discounted reward that we have to
  yet collect, going forward from this point, i.e.:

  r2g_t = \sum_{l=0}^{\infty} (\gamma^{l} * reward_{t+l})

  Args:
    rewards: np.ndarray of shape (B, T) of rewards.
    mask: np.ndarray of shape (B, T) of mask for the rewards.
    gamma: float, discount factor.

  Returns:
    rewards to go, np.ndarray of shape (B, T).
  """
    B, T = rewards.shape  # pylint: disable=invalid-name,unused-variable

    masked_rewards = rewards * mask  # (B, T)

    # The lax.scan version of this is slow, but we still show it here for
    # completeness.
    #   rewards_rev = np.flip(masked_rewards, axis=1)  # (B, T) flipped on time.
    #   rrt = np.transpose(rewards_rev)  # (T, B) transpose to scan over time.
    #
    #   def discounting_add(carry, reward):
    #     x = reward + (gamma * carry)
    #     return x, x
    #
    #   _, ys = lax.scan(discounting_add,
    #                    np.zeros_like(rrt[0], dtype=np.float32),
    #                    rrt.astype(np.float32))
    #
    #   # ys is (T, B) and T is in reverse order.
    #   return np.flip(np.transpose(ys), axis=1)

    # We use the following recurrence relation, derived from the equation above:
    #
    # r2g[t+1] = (r2g[t] - r[t]) / gamma
    #
    # This means we'll need to calculate r2g[0] first and then r2g[1] and so on ..
    #
    # **However** this leads to overflows for long sequences: r2g[t] - r[t] > 0
    # and gamma < 1.0, so the division keeps increasing.
    #
    # So we just run the recurrence in reverse, i.e.
    #
    # r2g[t] = r[t] + (gamma*r2g[t+1])
    #
    # This is much better, but might have lost updates since the (small) rewards
    # at earlier time-steps may get added to a (very?) large sum.

    # Compute r2g_{T-1} at the start and then compute backwards in time.
    r2gs = [masked_rewards[:, -1]]

    # Go from T-2 down to 0.
    for t in reversed(range(T - 1)):
        r2gs.append(masked_rewards[:, t] + (gamma * r2gs[-1]))

    # The list should have length T.
    assert T == len(r2gs)

    # First we stack them in the correct way to make it (B, T), but these are
    # still from newest (T-1) to oldest (0), so then we flip it on time axis.
    return np.flip(np.stack(r2gs, axis=1), axis=1)
  def validation_fn(model):
    """Iterates over the full validation set and computes metrics."""
    valid_iterator = valid_iterator_factory()
    accumulator = None
    for batch in valid_iterator:
      new_values = flax.jax_utils.unreplicate(
          batch_helper(model, batch.example, batch.mask))
      if accumulator is None:
        accumulator = new_values
      else:
        accumulator = jax.tree_multimap(operator.add, accumulator, new_values)

    (
        loss_sum,
        example_count,
        batch_metrics_non_nan,
        batch_metrics_non_nan_counts,
        (
            count_t_target_t_pred,
            count_t_target_f_pred,
            count_f_target_t_pred,
            count_f_target_f_pred,
        ),
    ) = accumulator

    metrics = {}
    metrics["loss"] = float(loss_sum / example_count)
    for k in batch_metrics_non_nan:
      metrics[k] = float(batch_metrics_non_nan[k] /
                         batch_metrics_non_nan_counts[k])

    precision_at_thresholds = jnp.nan_to_num(
        count_t_target_t_pred / (count_t_target_t_pred + count_f_target_t_pred))
    recall_at_thresholds = jnp.nan_to_num(
        count_t_target_t_pred / (count_t_target_t_pred + count_t_target_f_pred))
    f1_at_thresholds = jnp.nan_to_num(
        2 * (precision_at_thresholds * recall_at_thresholds) /
        (precision_at_thresholds + recall_at_thresholds))

    best_threshold_index = jnp.argmax(f1_at_thresholds)
    logging.info("F1 score across thresholds: %s",
                 jnp.stack([candidate_thresholds, f1_at_thresholds]))
    threshold = candidate_thresholds[best_threshold_index]
    precision = precision_at_thresholds[best_threshold_index]
    recall = recall_at_thresholds[best_threshold_index]
    f1 = f1_at_thresholds[best_threshold_index]

    metrics["best_threshold"] = float(threshold)
    metrics["flipped_precision"] = float(1 - precision)
    metrics["flipped_recall"] = float(1 - recall)
    metrics["flipped_f1"] = float(1 - f1)

    if full_evaluation:
      # Add (possibly non-scalar) detailed metrics
      metrics["example_count"] = example_count
      metrics["threshold_curves"] = {
          "thresholds": candidate_thresholds,
          "count_t_target_t_pred": count_t_target_t_pred,
          "count_t_target_f_pred": count_t_target_f_pred,
          "count_f_target_t_pred": count_f_target_t_pred,
          "count_f_target_f_pred": count_f_target_f_pred,
          "precision_at_thresholds": precision_at_thresholds,
          "recall_at_thresholds": recall_at_thresholds,
          "f1_at_thresholds": f1_at_thresholds,
      }

    return metrics["flipped_f1"], metrics
Exemple #29
0
    npz_data = np.load(filename)
    out = {
        "data_grid_search":npz_data['train_data'] / 255.,
        "data_test":npz_data['test_data'] / 255.,
    }
    return out

datasets = {}
if load_div2k:
    datasets['div2k'] = load_dataset('data_div2k.npz', '1TtwlEDArhOMoH18aUyjIMSZ3WODFmUab')
if load_text:
    datasets['text'] = load_dataset('data_2d_text.npz', '1V-RQJcMuk9GD4JCUn70o7nwQE0hEzHoT')

x1 = np.linspace(0, 1, RES//2+1)[:-1]
x_train = np.stack(np.meshgrid(x1,x1), axis=-1)

x1_t = np.linspace(0, 1, RES+1)[:-1]
x_test = np.stack(np.meshgrid(x1_t,x1_t), axis=-1)

def plot_dataset(dataset):
    plt.imshow(dataset['data_test'][0,:,:,:])
    plt.show()

if visualize:
    for dataset in datasets:
        print(f'Dataset {dataset}')
        plot_dataset(datasets[dataset])

def make_network(num_layers, num_channels):
  layers = []
Exemple #30
0
def fuse_fn(*args):
    return np.stack(args)