[0.10, 0.90]])

# observation matrix
B = jnp.array([
    [1/6, 1/6, 1/6, 1/6, 1/6, 1/6], # fair die
    [1/10, 1/10, 1/10, 1/10, 1/10, 5/10] # loaded die
])

pi = jnp.array([1, 1]) / 2

casino = HMMJax(A, B, pi)
num_hidden, num_obs = 2, 6

seed = 0
rng_key = PRNGKey(seed)
rng_key, rng_sample = split(rng_key)

n_obs_seq, max_len = 4, 5000
num_epochs = 400

observations, lens = pad_sequences(*hmm_sample_n(casino, hmm_sample_jax, n_obs_seq, max_len, rng_sample))
optimizer = optimizers.momentum(step_size=1e-3, mass=0.95)

# Mini Batch Gradient Descent
batch_size = 2
params_mbgd,losses_mbgd =  fit(observations,
                               lens,
                               num_hidden,
                               num_obs,
                               batch_size,
                               optimizer,
Example #2
0
def siren_layer_params(key, scale, m, n, dtype = jnp.float32):
	w_key, b_key = random.split(key)
	return random.uniform(w_key, (m, n), dtype, minval = -scale, maxval = scale), jnp.zeros((n, ), dtype)
Example #3
0
def init_siren_params(key, layers, c0, w0, w1, dtype = jnp.float32):
	keys = random.split(key, len(layers))
	weights = [w0] + [1.0]*(len(layers)-3) + [w1]
	return [siren_layer_params(k, w*jnp.sqrt(c0/m), m, n) for k, w, m, n in zip(keys, weights, layers[:-1], layers[1:])]
Example #4
0
def main(args):
    rng_key, _ = random.split(random.PRNGKey(3))
    design_matrix, response = make_dataset(rng_key)
    run_inference(design_matrix, response, rng_key, args.num_warmup,
                  args.num_samples, args.num_chains, args.interval_size)
 def model(key):
     k1, k2 = random.split(key)
     z = random_variable(random_normal, name='z')(k1)
     return random_variable(lambda key: random_normal(key) + z,
                            name='x')(k2)
Example #6
0
import os, sys

sys.path.append(
    os.path.dirname(os.path.dirname(os.path.dirname(
        os.path.abspath(__file__)))))

from jaxmeta.model_init import init_siren_params, init_tanh_params
from models import simple_model, normalized_model, tanh_model
from jaxmeta.loss import l1_regularization, l2_regularization
from jaxmeta.grad import jacobian_fn, hessian_fn

from data import domain, epsilon
import config

key, *subkeys = random.split(config.key, 3)
direct_params = init_siren_params(subkeys[0], config.direct_layers,
                                  config.direct_c0, config.direct_w0)
inverse_params = jnp.array([2.0])

direct_model = normalized_model(domain)
jacobian = jacobian_fn(direct_model)


@jax.jit
def inverse_model(params, x):
    return 1 + params[0] / jnp.pi * jnp.cos(2 * jnp.pi * x)


@jax.jit
def rhs(params, xt):
Example #7
0
def sample_development(key, state, recovery_probabilities):
  """Individuals who are in a transitional state either progress or recover."""
  key, subkey = random.split(key)
  is_recovered = random.bernoulli(subkey, recovery_probabilities[state])
  return key, (state + 1) * (1 - is_recovered) + RECOVERED * is_recovered
Example #8
0
 def apply_fun(params, inputs, **kwargs):
   rng = kwargs.pop('rng', None)
   rngs = random.split(rng, nlayers) if rng is not None else (None,) * nlayers
   for fun, param, rng in zip(apply_funs, params, rngs):
     inputs = fun(param, inputs, rng=rng, **kwargs)
   return inputs
Example #9
0
 def init_fun(rng, input_shape):
   rngs = random.split(rng, nlayers)
   return zip(*[init(rng, shape) for init, rng, shape
                in zip(init_funs, rngs, input_shape)])
Example #10
0
    def main(self):

        @jax.jit
        def train_step(state, batch, z_rng):
            def loss_fn(params):
                recon_x, mean, logvar = self.model().apply({'params': params}, batch, z_rng)
                bce_loss = mse(recon_x, batch).mean()
                kld_loss = kl_divergence(mean, logvar).mean()
                loss = bce_loss + self.kl_coeff*kld_loss
                return loss
            grads = jax.grad(loss_fn)(state.params)
            return state.apply_gradients(grads=grads)

        @jax.jit
        def eval(params, images, z, z_rng):
            def eval_model(vae):
                recon_images, mean, logvar = vae(images, z_rng)
                comparison = jnp.concatenate([images[:8].reshape(-1, self.image_size, self.image_size, 3),
                                            recon_images[:8].reshape(-1, self.image_size, self.image_size, 3)])
                generate_images = vae.generate(z)
                generate_images = generate_images.reshape(-1, self.image_size, self.image_size, 3)
                metrics = self.compute_metrics(recon_images, images, mean, logvar)
                return metrics, comparison, generate_images

            return nn.apply(eval_model, self.model())({'params': params})

        # Make sure tf does not allocate gpu memory.
        tf.config.experimental.set_visible_devices([], 'GPU')

        rng = random.PRNGKey(0)
        rng, key = random.split(rng)

        @tf.autograph.experimental.do_not_convert
        def decode_fn(s):
            img = tf.io.decode_jpeg(tf.io.read_file(s))
            img.set_shape([218, 178, 3])
            img = tf.cast(img, tf.float32) / 255.0
            img = tf.image.resize(img, (self.image_size, self.image_size), antialias=True)
            return tf.cast(img, dtype = jnp.dtype("float32"))

        dataset_celeba = tf.data.Dataset.list_files(self.data_dir+'/img_align_celeba/img_align_celeba/*.jpg', shuffle=False)
        train_dataset_celeba = (dataset_celeba
                                .map(decode_fn)
                                .map(tf.image.random_flip_left_right)
                                .shuffle(self.batch_size*16)
                                .batch(self.batch_size)
                                .repeat())
        test_ds = next(iter(tfds.as_numpy(train_dataset_celeba)))
        train_ds = iter(tfds.as_numpy(train_dataset_celeba))

        init_data = jnp.ones([self.batch_size, self.image_size, self.image_size, 3], jnp.float32)

        state = train_state.TrainState.create(
            apply_fn=self.model().apply,
            params=self.model().init(key, init_data, rng)['params'],
            tx=optax.adam(self.learning_rate),
        )

        rng, z_key, eval_rng = random.split(rng, 3)
        z = random.normal(z_key, (64, self.latents))

        steps_per_epoch = 50000 // self.batch_size

        for epoch in range(self.num_epochs):
            for _ in range(steps_per_epoch):
                batch = next(train_ds)
                rng, key = random.split(rng)
                state = train_step(state, batch, key)

            metrics, comparison, sample = eval(state.params, test_ds, z, eval_rng)
            save_image(comparison, f'{self.figdir}/reconstruction_{epoch}.png', nrow=8)
            save_image(sample, f'{self.figdir}/sample_{epoch}.png', nrow=8)
            print('eval epoch: {}, loss: {:.4f}, MSE: {:.4f}, KLD: {:.4f}'.format(
                epoch + 1, metrics['loss'], metrics['mse'], metrics['kld']
            ))
        checkpoints.save_checkpoint(".", state, epoch, "celeba_vae_checkpoint_")
Example #11
0
 def init_fun(rng, input_shape):
   shape = tuple(d for i, d in enumerate(input_shape) if i not in axis)
   k1, k2 = random.split(rng)
   beta, gamma = _beta_init(k1, shape), _gamma_init(k2, shape)
   return input_shape, (beta, gamma)
Example #12
0
 def split(self) -> "RNGWrapper":
     orig_key = np.array([self.int_1, self.int_2], dtype=np.uint32)
     key, subkey = random.split(orig_key)
     return RNGWrapper(int_1=subkey[0], int_2=subkey[1])
Example #13
0
def _predictive(
        rng_key,
        model,
        posterior_samples,
        batch_shape,
        return_sites=None,
        infer_discrete=False,
        parallel=True,
        model_args=(),
        model_kwargs={},
):
    masked_model = numpyro.handlers.mask(model, mask=False)
    if infer_discrete:
        # inspect the model to get some structure
        rng_key, subkey = random.split(rng_key)
        batch_ndim = len(batch_shape)
        prototype_sample = tree_map(
            lambda x: jnp.reshape(x, (-1, ) + jnp.shape(x)[batch_ndim:])[0],
            posterior_samples,
        )
        prototype_trace = trace(
            seed(substitute(masked_model, prototype_sample),
                 subkey)).get_trace(*model_args, **model_kwargs)
        first_available_dim = -_guess_max_plate_nesting(prototype_trace) - 1

    def single_prediction(val):
        rng_key, samples = val
        if infer_discrete:
            from numpyro.contrib.funsor import config_enumerate
            from numpyro.contrib.funsor.discrete import _sample_posterior

            model_trace = prototype_trace
            temperature = 1
            pred_samples = _sample_posterior(
                config_enumerate(condition(model, samples)),
                first_available_dim,
                temperature,
                rng_key,
                *model_args,
                **model_kwargs,
            )
        else:
            model_trace = trace(
                seed(substitute(masked_model, samples),
                     rng_key)).get_trace(*model_args, **model_kwargs)
            pred_samples = {
                name: site["value"]
                for name, site in model_trace.items()
            }

        if return_sites is not None:
            if return_sites == "":
                sites = {
                    k
                    for k, site in model_trace.items()
                    if site["type"] != "plate"
                }
            else:
                sites = return_sites
        else:
            sites = {
                k
                for k, site in model_trace.items()
                if (site["type"] == "sample" and k not in samples) or (
                    site["type"] == "deterministic")
            }
        return {
            name: value
            for name, value in pred_samples.items() if name in sites
        }

    num_samples = int(np.prod(batch_shape))
    if num_samples > 1:
        rng_key = random.split(rng_key, num_samples)
    rng_key = rng_key.reshape((*batch_shape, 2))
    chunk_size = num_samples if parallel else 1
    return soft_vmap(single_prediction, (rng_key, posterior_samples),
                     len(batch_shape), chunk_size)
Example #14
0
    def body_fn(state):
        i, key, _, _ = state
        key, subkey = random.split(key)

        if radius is None or prototype_params is None:
            # XXX: we don't want to apply enum to draw latent samples
            model_ = model
            if enum:
                from numpyro.contrib.funsor import enum as enum_handler

                if isinstance(model, substitute) and isinstance(
                        model.fn, enum_handler):
                    model_ = substitute(model.fn.fn, data=model.data)
                elif isinstance(model, enum_handler):
                    model_ = model.fn

            # Wrap model in a `substitute` handler to initialize from `init_loc_fn`.
            seeded_model = substitute(seed(model_, subkey),
                                      substitute_fn=init_strategy)
            model_trace = trace(seeded_model).get_trace(
                *model_args, **model_kwargs)
            constrained_values, inv_transforms = {}, {}
            for k, v in model_trace.items():
                if (v["type"] == "sample" and not v["is_observed"]
                        and not v["fn"].support.is_discrete):
                    constrained_values[k] = v["value"]
                    with helpful_support_errors(v):
                        inv_transforms[k] = biject_to(v["fn"].support)
            params = transform_fn(
                inv_transforms,
                {k: v
                 for k, v in constrained_values.items()},
                invert=True,
            )
        else:  # this branch doesn't require tracing the model
            params = {}
            for k, v in prototype_params.items():
                if k in init_values:
                    params[k] = init_values[k]
                else:
                    params[k] = random.uniform(subkey,
                                               jnp.shape(v),
                                               minval=-radius,
                                               maxval=radius)
                    key, subkey = random.split(key)

        potential_fn = partial(potential_energy,
                               model,
                               model_args,
                               model_kwargs,
                               enum=enum)
        if validate_grad:
            if forward_mode_differentiation:
                pe = potential_fn(params)
                z_grad = jacfwd(potential_fn)(params)
            else:
                pe, z_grad = value_and_grad(potential_fn)(params)
            z_grad_flat = ravel_pytree(z_grad)[0]
            is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))
        else:
            pe = potential_fn(params)
            is_valid = jnp.isfinite(pe)
            z_grad = None

        return i + 1, key, (params, pe, z_grad), is_valid
Example #15
0
# u, s, v_t = onp.linalg.svd(inputs, full_matrices=False)
# I = np.eye(v_t.shape[-1])
# I_add = npr.normal(0.0, 0.002, size=I.shape)
# noisy_I = I + I_add

init_fun, conv_net = stax.serial(
    Conv(32, (5, 5), (2, 2), padding="SAME"),
    BatchNorm(),
    Relu,
    Conv(10, (3, 3), (2, 2), padding="SAME"),
    Relu,
    Flatten,
    Dense(num_classes),
    LogSoftmax,
)
_, key = random.split(random.PRNGKey(0))


class DataTopologyAE(AbstractProblem):
    def __init__(self):
        self.HPARAMS_PATH = "hparams.json"

    @staticmethod
    def objective(params, bparam, batch) -> float:
        x, _ = batch
        x = np.reshape(x, (x.shape[0], -1))
        logits = predict_fun(params, x, bparam=bparam[0], rng=key)
        keep = random.bernoulli(key, bparam[0], x.shape)
        inputs_d = np.where(keep, x, 0)

        loss = np.mean(np.square((np.subtract(logits, inputs_d))))
Example #16
0
 def apply_fun(params, inputs, **kwargs):
   rng = kwargs.pop('rng', None)
   rngs = random.split(rng, nlayers) if rng is not None else (None,) * nlayers
   return [f(p, x, rng=r, **kwargs) for f, p, x, r in zip(apply_funs, params, inputs, rngs)]
Example #17
0
File: train.py Project: pschuh/flax
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
  """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
  tf.io.gfile.makedirs(workdir)

  vocab_path = config.vocab_path
  if vocab_path is None:
    vocab_path = os.path.join(workdir, "sentencepiece_model")
    config.vocab_path = vocab_path
  tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

  # Load Dataset
  # ---------------------------------------------------------------------------
  logging.info("Initializing dataset.")
  train_ds, eval_ds, _, encoder = input_pipeline.get_datasets(
      n_devices=jax.local_device_count(),
      config=config,
      vocab_path=vocab_path)

  train_iter = iter(train_ds)
  vocab_size = int(encoder.vocab_size())
  eos_id = temperature_sampler.EOS_ID  # Default Sentencepiece EOS token.

  def decode_tokens(toks):
    valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
    return encoder.detokenize(valid_toks).numpy().decode("utf-8")

  def encode_strings(strs, max_len):
    tokenized_batch = np.zeros((len(strs), max_len), np.int32)
    for i, s in enumerate(strs):
      toks = encoder.tokenize(s).numpy()
      # Remove EOS token in prompt.
      tokenized_batch[i, :toks.shape[0]-1] = toks[:-1]
    return tokenized_batch

  tokenized_prompts = encode_strings(
      [config.prompts], config.max_predict_length)

  logging.info("Initializing model, optimizer, and step functions.")
  # Build Model and Optimizer
  # ---------------------------------------------------------------------------
  train_config = models.TransformerConfig(
      vocab_size=vocab_size,
      output_vocab_size=vocab_size,
      logits_via_embedding=config.logits_via_embedding,
      dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32,
      emb_dim=config.emb_dim,
      num_heads=config.num_heads,
      num_layers=config.num_layers,
      qkv_dim=config.qkv_dim,
      mlp_dim=config.mlp_dim,
      max_len=max(config.max_target_length, config.max_eval_target_length),
      dropout_rate=config.dropout_rate,
      attention_dropout_rate=config.attention_dropout_rate,
      deterministic=False,
      decode=False,
      kernel_init=nn.initializers.xavier_uniform(),
      bias_init=nn.initializers.normal(stddev=1e-6))
  eval_config = train_config.replace(deterministic=True)
  predict_config = train_config.replace(deterministic=True, decode=True)

  start_step = 0
  rng = jax.random.PRNGKey(config.seed)
  rng, init_rng = jax.random.split(rng)
  rng, inference_rng = random.split(rng)
  input_shape = (config.per_device_batch_size, config.max_target_length)

  m = models.TransformerLM(eval_config)
  initial_variables = jax.jit(m.init)(init_rng,
                                      jnp.ones(input_shape, jnp.float32))

  # apply an optimizer to this tree
  optimizer_def = optim.Adam(
      config.learning_rate,
      beta1=0.9,
      beta2=0.98,
      eps=1e-9,
      weight_decay=config.weight_decay)
  optimizer = optimizer_def.create(initial_variables["params"])

  # We access model params only from optimizer below via optimizer.target.
  del initial_variables

  if config.restore_checkpoints:
    # Restore unreplicated optimizer + model state from last checkpoint.
    optimizer = checkpoints.restore_checkpoint(workdir, optimizer)
    # Grab last step.
    start_step = int(optimizer.state.step)

  writer = metric_writers.create_default_writer(
      workdir, just_logging=jax.host_id() > 0)
  if start_step == 0:
    writer.write_hparams(dict(config))

  # Replicate optimizer.
  optimizer = jax_utils.replicate(optimizer)

  learning_rate_fn = create_learning_rate_scheduler(
      base_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps)

  # compile multidevice versions of train/eval/predict step fn.
  p_train_step = jax.pmap(
      functools.partial(
          train_step,
          config=train_config,
          learning_rate_fn=learning_rate_fn),
      axis_name="batch",
      donate_argnums=(0,))  # pytype: disable=wrong-arg-types
  p_eval_step = jax.pmap(
      functools.partial(
          eval_step, config=eval_config),
      axis_name="batch")

  p_pred_step = jax.pmap(
      functools.partial(
          predict_step, config=predict_config,
          temperature=config.sampling_temperature,
          top_k=config.sampling_top_k),
      axis_name="batch",
      static_broadcasted_argnums=(3, 4))  # eos token, max_length are constant

  # Main Train Loop
  # ---------------------------------------------------------------------------

  # We init the first set of dropout PRNG keys, but update it afterwards inside
  # the main pmap"d training update for performance.
  dropout_rngs = jax.random.split(rng, jax.local_device_count())
  del rng

  logging.info("Starting training loop.")
  hooks = []
  report_progress = periodic_actions.ReportProgress(
      num_train_steps=config.num_train_steps, writer=writer)
  if jax.host_id() == 0:
    hooks += [
        report_progress,
        periodic_actions.Profile(logdir=workdir, num_profile_steps=5)
    ]
  train_metrics = []
  with metric_writers.ensure_flushes(writer):
    for step in range(start_step, config.num_train_steps):
      is_last_step = step == config.num_train_steps - 1

      # Shard data to devices and do a training step.
      with jax.profiler.StepTraceAnnotation("train", step_num=step):
        batch = common_utils.shard(jax.tree_map(np.asarray, next(train_iter)))
        optimizer, metrics = p_train_step(
            optimizer, batch, dropout_rng=dropout_rngs)
        train_metrics.append(metrics)

      # Quick indication that training is happening.
      logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step)
      for h in hooks:
        h(step)

      # Periodic metric handling.
      if step % config.eval_every_steps == 0 or is_last_step:
        with report_progress.timed("training_metrics"):
          logging.info("Gathering training metrics.")
          train_metrics = common_utils.get_metrics(train_metrics)
          lr = train_metrics.pop("learning_rate").mean()
          metrics_sums = jax.tree_map(jnp.sum, train_metrics)
          denominator = metrics_sums.pop("denominator")
          summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
          summary["learning_rate"] = lr
          summary["perplexity"] = jnp.clip(
              jnp.exp(summary["loss"]), a_max=1.0e4)
          summary = {"train_" + k: v for k, v in summary.items()}
          writer.write_scalars(step, summary)
          train_metrics = []

        with report_progress.timed("eval"):
          eval_results = evaluate(
              p_eval_step=p_eval_step,
              target=optimizer.target,
              eval_ds=eval_ds,
              num_eval_steps=config.num_eval_steps)
          # (clipped) perplexity after averaging log-perplexitie
          eval_results["perplexity"] = jnp.clip(
              jnp.exp(eval_results["loss"]), a_max=1.0e4)
          writer.write_scalars(
              step, {"eval_" + k: v for k, v in eval_results.items()})

        with report_progress.timed("generate_text"):
          exemplars = generate_prediction(
              p_pred_step=p_pred_step,
              target=optimizer.target,
              tokenized_prompts=tokenized_prompts,
              eos_id=eos_id,
              inference_rng=inference_rng,
              decode_tokens=decode_tokens,
              max_predict_length=config.max_predict_length)
          writer.write_texts(step, {"samples": exemplars})

      # Save a checkpoint on one host after every checkpoint_freq steps.
      save_checkpoint = (step % config.checkpoint_every_steps == 0 or
                         is_last_step)
      if config.save_checkpoints and save_checkpoint and jax.host_id() == 0:
        with report_progress.timed("checkpoint"):
          checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(optimizer),
                                      step)
Example #18
0
 def init_fun(rng, input_shape):
   output_shape = input_shape[:-1] + (out_dim,)
   k1, k2 = random.split(rng)
   W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim,))
   return output_shape, (W, b)
Example #19
0
def interaction_sampler(key, w):
  key, subkey = random.split(key)
  return key, random.bernoulli(subkey, w).astype(np.int32)
    # Plot some information about the training.
    plotting.plot_losses(opt_details_dict['tlosses'],
                         opt_details_dict['elosses'],
                         sampled_every=print_every)
    plt.savefig(os.path.join(figure_dir, 'losses.png'))

    # Plot a bunch of examples of eval trials run through LFADS.
    nexamples_to_save = 10
    for eidx in range(nexamples_to_save):
        bidx = onp.random.randint(eval_data.shape[0])
        psa_example = eval_data[bidx, :, :].astype(np.float32)
        # Make an entire batch of a single, example, and then
        # randomize the VAE with batchsize number of keys.
        examples = onp.repeat(np.expand_dims(psa_example, axis=0),
                              batch_size,
                              axis=0)
        skeys = random.split(key, batch_size)
        lfads_dict = lfads.batch_lfads_jit(trained_params, lfads_hps, skeys,
                                           examples, 1.0)
        # posterior sample and average
        psa_example_dict = utils.average_lfads_batch(lfads_dict)

        # The ii_scale may need to flipped or rescaled as the is an identifiability
        # of issue on the scale and sign of the inferred input.
        plotting.plot_lfads(psa_example,
                            psa_example_dict,
                            data_dict,
                            eval_data_offset + bidx,
                            ii_scale=1.0)
        plt.savefig(os.path.join(figure_dir, 'lfads_output_%d.png' % (bidx)))
Example #21
0
                                             (1, 2, 3, 4))(step_rng, thetax,
                                                           thetay, thetad,
                                                           paramsm, netm,
                                                           num_samples)
        opt_state = opt_update(it, loss_grad, opt_state)
        return opt_state, loss_val

    params = (thetax, thetay, thetad, paramsm)
    opt_state, trace = lax.scan(step, opt_init(params), jnp.arange(num_steps))
    thetax, thetay, thetad, paramsm = get_params(opt_state)
    return (thetax, thetay, thetad, paramsm), trace


# Set random number generation seeds.
rng = random.PRNGKey(args.seed)
rng, rng_netm = random.split(rng, 2)
rng, rng_thetax, rng_thetay, rng_thetad = random.split(rng, 4)
rng, rng_ms, rng_train = random.split(rng, 3)
rng, rng_xobs = random.split(rng, 2)

paramsm, netm = network_factory(rng_netm, 1, args.num_mobius * 2,
                                args.num_hidden)
thetax = random.uniform(rng_thetax, [args.num_spline])
thetay = random.uniform(rng_thetay, [args.num_spline])
thetad = random.uniform(rng_thetad, [args.num_spline - 1])

# Compute number of parameters.
count = lambda x: jnp.prod(jnp.array(x.shape))
num_paramsm = jnp.array(
    tree_util.tree_map(count,
                       tree_util.tree_flatten(paramsm)[0])).sum()
Example #22
0
 def init_fun(rng, input_shape):
     shape = tuple([1 for d in input_shape[:-1]] + [input_shape[-1]])
     k1, k2 = random.split(rng)
     bias, scale = _bias_init(k1, shape), _scale_init(k2, shape)
     return input_shape, (bias, scale)
Example #23
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # Make sure tf does not allocate gpu memory.
    tf.config.experimental.set_visible_devices([], 'GPU')

    batch_size = FLAGS.batch_size
    learning_rate = FLAGS.learning_rate
    num_train_steps = FLAGS.num_train_steps
    eval_freq = FLAGS.eval_frequency
    random_seed = FLAGS.random_seed

    if not FLAGS.dev:
        raise app.UsageError('Please provide path to dev set.')
    if not FLAGS.train:
        raise app.UsageError('Please provide path to training set.')
    if batch_size % jax.device_count() > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')
    device_batch_size = batch_size // jax.device_count()

    if jax.host_id() == 0:
        train_summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, FLAGS.experiment + '_train'))
        eval_summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, FLAGS.experiment + '_eval'))

    # create the training and development dataset
    vocabs = input_pipeline.create_vocabs(FLAGS.train)
    config = models.TransformerConfig(vocab_size=len(vocabs['forms']),
                                      output_vocab_size=len(vocabs['xpos']),
                                      max_len=FLAGS.max_length)

    attributes_input = [input_pipeline.CoNLLAttributes.FORM]
    attributes_target = [input_pipeline.CoNLLAttributes.XPOS]
    train_ds = input_pipeline.sentence_dataset_dict(FLAGS.train,
                                                    vocabs,
                                                    attributes_input,
                                                    attributes_target,
                                                    batch_size=batch_size,
                                                    bucket_size=config.max_len)
    train_iter = iter(train_ds)

    eval_ds = input_pipeline.sentence_dataset_dict(FLAGS.dev,
                                                   vocabs,
                                                   attributes_input,
                                                   attributes_target,
                                                   batch_size=batch_size,
                                                   bucket_size=config.max_len,
                                                   repeat=1)

    model = models.Transformer(config)

    rng = random.PRNGKey(random_seed)
    rng, init_rng = random.split(rng)

    # call a jitted initialization function to get the initial parameter tree
    @jax.jit
    def initialize_variables(init_rng):
        init_batch = jnp.ones((config.max_len, 1), jnp.float32)
        init_variables = model.init(init_rng, inputs=init_batch, train=False)
        return init_variables

    init_variables = initialize_variables(init_rng)

    optimizer_def = optim.Adam(learning_rate,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=1e-1)
    optimizer = optimizer_def.create(init_variables['params'])
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = create_learning_rate_scheduler(
        base_learning_rate=learning_rate)

    p_train_step = jax.pmap(functools.partial(
        train_step, model=model, learning_rate_fn=learning_rate_fn),
                            axis_name='batch')

    def eval_step(params, batch):
        """Calculate evaluation metrics on a batch."""
        inputs, targets = batch['inputs'], batch['targets']
        weights = jnp.where(targets > 0, 1.0, 0.0)
        logits = model.apply({'params': params}, inputs=inputs, train=False)
        return compute_metrics(logits, targets, weights)

    p_eval_step = jax.pmap(eval_step, axis_name='batch')

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    dropout_rngs = random.split(rng, jax.local_device_count())
    metrics_all = []
    tick = time.time()
    best_dev_score = 0
    for step, batch in zip(range(num_train_steps), train_iter):
        batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access

        optimizer, metrics, dropout_rngs = p_train_step(
            optimizer, batch, dropout_rng=dropout_rngs)
        metrics_all.append(metrics)

        if (step + 1) % eval_freq == 0:
            metrics_all = common_utils.get_metrics(metrics_all)
            lr = metrics_all.pop('learning_rate').mean()
            metrics_sums = jax.tree_map(jnp.sum, metrics_all)
            denominator = metrics_sums.pop('denominator')
            summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
            summary['learning_rate'] = lr
            logging.info('train in step: %d, loss: %.4f', step,
                         summary['loss'])
            if jax.host_id() == 0:
                tock = time.time()
                steps_per_sec = eval_freq / (tock - tick)
                tick = tock
                train_summary_writer.scalar('steps per second', steps_per_sec,
                                            step)
                for key, val in summary.items():
                    train_summary_writer.scalar(key, val, step)
                train_summary_writer.flush()

            metrics_all = [
            ]  # reset metric accumulation for next evaluation cycle.

            eval_metrics = []
            eval_iter = iter(eval_ds)

            for eval_batch in eval_iter:
                eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
                # Handle final odd-sized batch by padding instead of dropping it.
                cur_pred_batch_size = eval_batch['inputs'].shape[0]
                if cur_pred_batch_size != batch_size:
                    # pad up to batch size
                    eval_batch = jax.tree_map(
                        lambda x: pad_examples(x, batch_size), eval_batch)
                eval_batch = common_utils.shard(eval_batch)

                metrics = p_eval_step(optimizer.target, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
            eval_denominator = eval_metrics_sums.pop('denominator')
            eval_summary = jax.tree_map(
                lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
                eval_metrics_sums)

            logging.info('eval in step: %d, loss: %.4f, accuracy: %.4f', step,
                         eval_summary['loss'], eval_summary['accuracy'])

            if best_dev_score < eval_summary['accuracy']:
                best_dev_score = eval_summary['accuracy']
                # TODO: save model.
            eval_summary['best_dev_score'] = best_dev_score
            logging.info('best development model score %.4f', best_dev_score)
            if jax.host_id() == 0:
                for key, val in eval_summary.items():
                    eval_summary_writer.scalar(key, val, step)
                eval_summary_writer.flush()
Example #24
0
 def process_message(self, msg):
     if msg['type'] == 'sample' and not msg['is_observed']:
         self.rng, rng_sample = random.split(self.rng)
         msg['kwargs']['random_state'] = rng_sample
Example #25
0
def init_fn(key, n_features):
    key1, key2 = random.split(key)
    w = random.normal(key1, [n_features])
    b = random.normal(key2)
    return w, b
Example #26
0
def flax_module(
    name, nn_module, *, input_shape=None, apply_rng=None, mutable=None, **kwargs
):
    """
    Declare a :mod:`~flax` style neural network inside a
    model so that its parameters are registered for optimization via
    :func:`~numpyro.primitives.param` statements.

    Given a flax ``nn_module``, in flax to evaluate the module with
    a given set of parameters, we use: ``nn_module.apply(params, x)``.
    In a NumPyro model, the pattern will be::

        net = flax_module("net", nn_module)
        y = net(x)

    or with dropout layers::

        net = flax_module("net", nn_module, apply_rng=["dropout"])
        rng_key = numpyro.prng_key()
        y = net(x, rngs={"dropout": rng_key})

    :param str name: name of the module to be registered.
    :param flax.linen.Module nn_module: a `flax` Module which has .init and .apply methods
    :param tuple input_shape: shape of the input taken by the
        neural network.
    :param list apply_rng: A list to indicate which extra rng _kinds_ are needed for
        ``nn_module``. For example, when ``nn_module`` includes dropout layers, we
        need to set ``apply_rng=["dropout"]``. Defaults to None, which means no extra
        rng key is needed. Please see
        `Flax Linen Intro <https://flax.readthedocs.io/en/latest/notebooks/linen_intro.html#Invoking-Modules>`_
        for more information in how Flax deals with stochastic layers like dropout.
    :param list mutable: A list to indicate mutable states of ``nn_module``. For example,
        if your module has BatchNorm layer, we will need to define ``mutable=["batch_stats"]``.
        See the above `Flax Linen Intro` tutorial for more information.
    :param kwargs: optional keyword arguments to initialize flax neural network
        as an alternative to `input_shape`
    :return: a callable with bound parameters that takes an array
        as an input and returns the neural network transformed output
        array.
    """
    try:
        import flax  # noqa: F401
    except ImportError as e:
        raise ImportError(
            "Looking like you want to use flax to declare "
            "nn modules. This is an experimental feature. "
            "You need to install `flax` to be able to use this feature. "
            "It can be installed with `pip install flax`."
        ) from e
    module_key = name + "$params"
    nn_params = numpyro.param(module_key)

    if mutable:
        nn_state = numpyro_mutable(name + "$state")
        assert nn_state is None or isinstance(nn_state, dict)
        assert (nn_state is None) == (nn_params is None)

    if nn_params is None:
        # feed in dummy data to init params
        args = (jnp.ones(input_shape),) if input_shape is not None else ()
        rng_key = numpyro.prng_key()
        # split rng_key into a dict of rng_kind: rng_key
        rngs = {}
        if apply_rng:
            assert isinstance(apply_rng, list)
            for kind in apply_rng:
                rng_key, subkey = random.split(rng_key)
                rngs[kind] = subkey
        rngs["params"] = rng_key

        nn_vars = flax.core.unfreeze(nn_module.init(rngs, *args, **kwargs))
        if "params" not in nn_vars:
            raise ValueError(
                "Your nn_module does not have any parameter. Currently, it is not"
                " supported in NumPyro. Please make a github issue if you need"
                " that feature."
            )
        nn_params = nn_vars["params"]
        if mutable:
            nn_state = {k: v for k, v in nn_vars.items() if k != "params"}
            assert set(mutable) == set(nn_state)
            numpyro_mutable(name + "$state", nn_state)
        # make sure that nn_params keep the same order after unflatten
        params_flat, tree_def = tree_flatten(nn_params)
        nn_params = tree_unflatten(tree_def, params_flat)
        numpyro.param(module_key, nn_params)

    def apply_with_state(params, *args, **kwargs):
        params = {"params": params, **nn_state}
        out, new_state = nn_module.apply(params, mutable=mutable, *args, **kwargs)
        nn_state.update(**new_state)
        return out

    def apply_without_state(params, *args, **kwargs):
        return nn_module.apply({"params": params}, *args, **kwargs)

    apply_fn = apply_with_state if mutable else apply_without_state
    return partial(apply_fn, nn_params)
Example #27
0
def tanh_layer_params(key, m, n, dtype = jnp.float32):
	w_key, b_key = random.split(key)
	w_init_fn = jax.nn.initializers.glorot_normal()
	return w_init_fn(w_key, (m, n), dtype), jnp.zeros((n, ), dtype)
Example #28
0
def random_samples(rng, images, labels, num_samples):
    _, rng = split(rng, 2)
    sampled_idxs = random.choice(rng,
                                 jnp.arange(images.shape[0]), (num_samples, ),
                                 replace=False)
    return images[sampled_idxs], labels[sampled_idxs]
Example #29
0
def init_tanh_params(key, layers):
	keys = random.split(key, len(layers))
	return [tanh_layer_params(k, m, n) for (k, m, n) in zip(keys, layers[:-1], layers[1:])]
Example #30
0
# *** Define initial configuration ***
g = 10
dt = 0.015
qc = 0.06
Q = jnp.array([[qc * dt**3 / 3, qc * dt**2 / 2], [qc * dt**2 / 2, qc * dt]])

fx_vmap = jax.vmap(fx)
fz_vec = jax.vmap(lambda x: fz(x, g=g, dt=dt))

nsteps = 200
Rt = jnp.eye(1) * 0.02
x0 = jnp.array([1.5, 0.0]).astype(float)
time = jnp.arange(0, nsteps * dt, dt)

key = random.PRNGKey(3141)
key_samples, key_pf, key_noisy = random.split(key, 3)
model = ds.NLDS(lambda x: fz(x, g=g, dt=dt), fx, Q, Rt)
sample_state, sample_obs = model.sample(key, x0, nsteps)

# *** Pertubed data ***
key_noisy, key_values = random.split(key_noisy)
sample_obs_noise = sample_obs.copy()
samples_map = random.bernoulli(key_noisy, 0.5, (nsteps, ))
replacement_values = random.uniform(key_values, (samples_map.sum(), ),
                                    minval=-2,
                                    maxval=2)
sample_obs_noise = index_update(sample_obs_noise.ravel(), samples_map,
                                replacement_values)
colors = ["tab:red" if samp else "tab:blue" for samp in samples_map]

# *** Perform filtering ****