Ejemplo n.º 1
0
 def model():
     x = numpyro.sample("x", dist.Normal(0, 1))
     numpyro.deterministic("x2", x * 2)
     with numpyro.plate("N", 10, subsample_size=5):
         numpyro.sample("obs", dist.Normal(x, 1), obs=jnp.ones(5))
Ejemplo n.º 2
0
def main():
    rng = random.PRNGKey(0)
    x_dim = 2
    T = 20.0

    policy_init, policy = stax.serial(
        Dense(64),
        Tanh,
        Dense(x_dim),
    )

    x0 = jnp.ones(x_dim)

    A, B, Q, R, N = fixed_env(x_dim)
    print("System dynamics:")
    print(f"  A = {A}")
    print(f"  B = {B}")
    print(f"  Q = {Q}")
    print(f"  R = {R}")
    print(f"  N = {N}")
    print()

    dynamics_fn = lambda x, u: A @ x + B @ u
    cost_fn = lambda x, u: x.T @ Q @ x + u.T @ R @ u + 2 * x.T @ N @ u

    ### Evaluate LQR solution to get a sense of optimal cost.
    K, _, _ = control.lqr(A, B, Q, R, N)
    K = jnp.array(K)
    opt_policy_cost_fn = policy_cost_and_grad(dynamics_fn,
                                              cost_fn,
                                              lambda KK, x: -KK @ x,
                                              example_x=x0)
    opt_loss, _opt_K_grad = opt_policy_cost_fn(K, x0, T)

    # This is true for longer time horizons, but not true for shorter time
    # horizons due to the LQR solution being an infinite-time solution.
    # assert jnp.allclose(opt_K_grad, 0)

    ### Training loop.
    rng_init_params, rng = random.split(rng)
    _, init_policy_params = policy_init(rng_init_params, (x_dim, ))
    opt = make_optimizer(optimizers.adam(1e-3))(init_policy_params)
    loss_and_grad = policy_cost_and_grad(dynamics_fn,
                                         cost_fn,
                                         policy,
                                         example_x=x0)

    loss_per_iter = []
    elapsed_per_iter = []
    for iteration in range(10000):
        t0 = time.time()
        loss, g = loss_and_grad(opt.value, x0, T)
        opt = opt.update(g)
        elapsed = time.time() - t0

        loss_per_iter.append(loss)
        elapsed_per_iter.append(elapsed)

        print(f"Iteration {iteration}")
        print(f"    excess loss = {loss - opt_loss}")
        print(f"    elapsed = {elapsed}")

    blt.remember({
        "loss_per_iter": loss_per_iter,
        "elapsed_per_iter": elapsed_per_iter,
        "opt_loss": opt_loss
    })

    _, ax1 = plt.subplots()
    ax1.set_xlabel("Iteration")
    ax1.set_ylabel("Cost", color="tab:blue")
    ax1.set_yscale("log")
    ax1.tick_params(axis="y", labelcolor="tab:blue")
    ax1.plot(loss_per_iter, color="tab:blue", label="Total rollout cost")
    plt.axhline(opt_loss, linestyle="--", color="gray")
    ax1.legend(loc="upper left")
    plt.title("Combined fwd-bwd BVP problem")
    blt.show()
Ejemplo n.º 3
0
 def rescale(outputs, inputs):
     one = np.ones(inputs.shape[1:3], dtype=inputs.dtype)
     window_sizes = lax.reduce_window(one, 0., lax.add, dims, strides,
                                      padding)
     return outputs / window_sizes
Ejemplo n.º 4
0
def main(seed=0, dataset='mnist'):
    rng = set_reproducibility(seed)

    config = get_config(dataset)
    l2_penalty = config['l2_penalty']
    num_updates = config['num_updates']
    num_rounds = config['num_rounds']
    alpha = config['alpha']
    perfect = config['perfect']
    learning_rates = config['learning_rates']

    strong = l2_penalty
    smooth = 4 - l2_penalty
    diameter = 2
    lipshitz = 1 + l2_penalty
    learning_rates += [2 / (strong + smooth)]

    X_train, y_train, X_test, y_test = get_dataset(dataset)

    # Add intercept
    if dataset != 'adult':
        X_train_ones = np.ones((X_train.shape[0], 1))
        X_train = np.concatenate((X_train_ones, X_train), 1)

        X_test_ones = np.ones((X_test.shape[0], 1))
        X_test = np.concatenate((X_test_ones, X_test), 1)

    num_train, dim = X_train.shape
    num_test = X_test.shape[0]

    epsilons = [0.25, 0.5, 0.75, 1]
    delta = 1. / (num_train**1.)
    delta_alpha = 2 * alpha
    sigmas = [
        compute_sigma(epsilon, delta, delta_alpha) for epsilon in epsilons
    ]

    unpublished_accuracies_across_rounds = []
    published_accuracies_across_rounds = []
    retrain_accuracies_across_rounds = []

    for round_i in range(num_rounds):
        print('Starting round {}...'.format(round_i))

        temp, rng = random.split(rng)
        W_init = init_W(temp, dim)

        print('Training on full dataset...')
        W, init_iterations = train(W_init, X_train, y_train, learning_rates,
                                   l2_penalty, alpha)

        print('Initialization iterations: {}'.format(init_iterations))

        # Delete first row `num_updates` times in sequence.
        updates = [lambda X, y: delete_index(X, y) for i in range(num_updates)]

        unpublished_accuracies = []
        published_accuracies = defaultdict(list)
        retrain_accuracies = []
        max_iterations = None

        # For each update...
        for i, update in enumerate(updates):
            print('Processing update {}...'.format(i))

            # Apply update
            X_train, y_train = update(X_train, y_train)

            # Finetune on remaining points
            W, iterations = train(W, X_train, y_train, learning_rates,
                                  l2_penalty, alpha)
            unpublished_accuracy = accuracy(W, X_test, y_test)
            unpublished_accuracies.append(unpublished_accuracy)
            print('Accuracy:              {:.4f}'.format(unpublished_accuracy))

            if not max_iterations or iterations > max_iterations:
                max_iterations = iterations
            print('Max Iterations:            {}'.format(max_iterations))

            # Record performance of published model for varying epsilons
            for epsilon, sigma in zip(epsilons, sigmas):
                temp, rng = random.split(rng)
                W_published = publish(temp, W, sigma)
                published_accuracy = accuracy(W_published, X_test, y_test)
                published_accuracies[epsilon].append(published_accuracy)
                print('Accuracy (published, ε = {:.2f}):  {:.4f}'.format(
                    epsilon, published_accuracy))

            # Record performance of retraining from initial point
            W_retrain, _ = train(W_init, X_train, y_train, learning_rates,
                                 l2_penalty, alpha, max_iterations)
            retrain_accuracy = accuracy(W_retrain, X_test, y_test)
            retrain_accuracies.append(retrain_accuracy)
            print('Accuracy (retrain):    {:.4f}'.format(retrain_accuracy))

            print('-' * 20)

        unpublished_accuracies_across_rounds.append(unpublished_accuracies)
        published_accuracies_across_rounds.append(published_accuracies)
        retrain_accuracies_across_rounds.append(retrain_accuracies)
Ejemplo n.º 5
0
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
    tf.io.gfile.makedirs(workdir)

    vocab_path = config.vocab_path
    if vocab_path is None:
        vocab_path = os.path.join(workdir, "sentencepiece_model")
    tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

    # Load Dataset
    # ---------------------------------------------------------------------------
    logging.info("Initializing dataset.")
    train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
        n_devices=jax.local_device_count(),
        config=config,
        vocab_path=vocab_path)

    train_iter = iter(train_ds)
    vocab_size = int(encoder.vocab_size())
    eos_id = decode.EOS_ID  # Default Sentencepiece EOS token.

    def decode_tokens(toks):
        valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
        return encoder.detokenize(valid_toks).numpy().decode("utf-8")

    if config.num_predict_steps > 0:
        predict_ds = predict_ds.take(config.num_predict_steps)

    logging.info("Initializing model, optimizer, and step functions.")

    # Build Model and Optimizer
    # ---------------------------------------------------------------------------
    train_config = models.TransformerConfig(
        vocab_size=vocab_size,
        output_vocab_size=vocab_size,
        share_embeddings=config.share_embeddings,
        logits_via_embedding=config.logits_via_embedding,
        dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32,
        emb_dim=config.emb_dim,
        num_heads=config.num_heads,
        num_layers=config.num_layers,
        qkv_dim=config.qkv_dim,
        mlp_dim=config.mlp_dim,
        max_len=max(config.max_target_length, config.max_eval_target_length),
        dropout_rate=config.dropout_rate,
        attention_dropout_rate=config.attention_dropout_rate,
        deterministic=False,
        decode=False,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6))
    eval_config = train_config.replace(deterministic=True)
    predict_config = train_config.replace(deterministic=True, decode=True)

    start_step = 0
    rng = jax.random.PRNGKey(config.seed)
    rng, init_rng = jax.random.split(rng)
    input_shape = (config.per_device_batch_size, config.max_target_length)
    target_shape = (config.per_device_batch_size, config.max_target_length)

    m = models.Transformer(eval_config)
    initial_variables = jax.jit(m.init)(init_rng,
                                        jnp.ones(input_shape, jnp.float32),
                                        jnp.ones(target_shape, jnp.float32))

    # apply an optimizer to this tree
    optimizer_def = optim.Adam(config.learning_rate,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=config.weight_decay)
    optimizer = optimizer_def.create(initial_variables["params"])

    # We access model params only from optimizer below via optimizer.target.
    del initial_variables

    if config.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(workdir, optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)

    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.host_id() > 0)
    if start_step == 1:
        writer.write_hparams(dict(config))

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = create_learning_rate_scheduler(
        base_learning_rate=config.learning_rate,
        warmup_steps=config.warmup_steps)

    # compile multidevice versions of train/eval/predict step and cache init fn.
    p_train_step = jax.pmap(functools.partial(
        train_step,
        config=train_config,
        learning_rate_fn=learning_rate_fn,
        label_smoothing=config.label_smoothing),
                            axis_name="batch",
                            donate_argnums=(0, ))  # pytype: disable=wrong-arg-types
    p_eval_step = jax.pmap(functools.partial(
        eval_step, config=eval_config, label_smoothing=config.label_smoothing),
                           axis_name="batch")
    p_init_cache = jax.pmap(functools.partial(
        initialize_cache,
        max_decode_len=config.max_predict_length,
        config=predict_config),
                            axis_name="batch")
    p_pred_step = jax.pmap(
        functools.partial(predict_step,
                          config=predict_config,
                          beam_size=config.beam_size),
        axis_name="batch",
        static_broadcasted_argnums=(3,
                                    4))  # eos token, max_length are constant

    # Main Train Loop
    # ---------------------------------------------------------------------------

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap"d training update for performance.
    dropout_rngs = jax.random.split(rng, jax.local_device_count())
    del rng

    logging.info("Starting training loop.")
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_train_steps, writer=writer)
    if jax.host_id() == 0:
        hooks += [
            report_progress,
            periodic_actions.Profile(num_profile_steps=5)
        ]
    train_metrics = []
    with metric_writers.ensure_flushes(writer):
        for step in range(start_step, config.num_train_steps):
            is_last_step = step == config.num_train_steps - 1

            # Shard data to devices and do a training step.
            with jax.profiler.StepTraceContext("train", step_num=step):
                batch = common_utils.shard(
                    jax.tree_map(np.asarray, next(train_iter)))
                optimizer, metrics = p_train_step(optimizer,
                                                  batch,
                                                  dropout_rng=dropout_rngs)
                train_metrics.append(metrics)

            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, "Finished training step %d.", 5,
                                step)
            for h in hooks:
                h(step)

            # Periodic metric handling.
            if step % config.eval_every_steps == 0 or is_last_step:
                with report_progress.timed("training_metrics"):
                    logging.info("Gathering training metrics.")
                    train_metrics = common_utils.get_metrics(train_metrics)
                    lr = train_metrics.pop("learning_rate").mean()
                    metrics_sums = jax.tree_map(jnp.sum, train_metrics)
                    denominator = metrics_sums.pop("denominator")
                    summary = jax.tree_map(lambda x: x / denominator,
                                           metrics_sums)  # pylint: disable=cell-var-from-loop
                    summary["learning_rate"] = lr
                    summary = {"train_" + k: v for k, v in summary.items()}
                    writer.write_scalars(step, summary)
                    train_metrics = []

                with report_progress.timed("eval"):
                    eval_results = evaluate(
                        p_eval_step=p_eval_step,
                        target=optimizer.target,
                        eval_ds=eval_ds,
                        num_eval_steps=config.num_eval_steps)
                    writer.write_scalars(
                        step,
                        {"eval_" + k: v
                         for k, v in eval_results.items()})

                with report_progress.timed("translate_and_bleu"):
                    exemplars, bleu_score = translate_and_calculate_bleu(
                        p_pred_step=p_pred_step,
                        p_init_cache=p_init_cache,
                        target=optimizer.target,
                        predict_ds=predict_ds,
                        decode_tokens=decode_tokens,
                        max_predict_length=config.max_predict_length)
                    writer.write_scalars(step, {"bleu": bleu_score})
                    writer.write_texts(step, {"samples": exemplars})

            # Save a checkpoint on one host after every checkpoint_freq steps.
            save_checkpoint = step % config.checkpoint_every_steps or is_last_step
            if config.save_checkpoints and save_checkpoint and jax.host_id():
                with report_progress.timed("checkpoint"):
                    checkpoints.save_checkpoint(
                        workdir, jax_utils.unreplicate(optimizer), step)
    def decode(
        self,
        decoder_input_ids,
        encoder_outputs,
        decoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_position_ids: Optional[jnp.ndarray] = None,
        past_key_values: dict = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    ):
        r"""
        Returns:

        Example:

        ```python
        >>> from transformers import ViTFeatureExtractor, FlaxVisionEncoderDecoderModel
        >>> import jax.numpy as jnp
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")

        >>> # initialize a vit-gpt2 from pretrained ViT and GPT2 models. Note that the cross-attention layers will be randomly initialized
        >>> model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
        ...     "google/vit-base-patch16-224-in21k", "gpt2"
        ... )

        >>> pixel_values = feature_extractor(images=image, return_tensors="np").pixel_values
        >>> encoder_outputs = model.encode(pixel_values)

        >>> decoder_start_token_id = model.config.decoder.bos_token_id
        >>> decoder_input_ids = jnp.ones((pixel_values.shape[0], 1), dtype="i4") * decoder_start_token_id

        >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
        >>> logits = outputs.logits
        ```"""
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (output_hidden_states
                                if output_hidden_states is not None else
                                self.config.output_hidden_states)
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        encoder_hidden_states = encoder_outputs[0]

        batch_size, sequence_length = encoder_hidden_states.shape[:2]
        encoder_attention_mask = jnp.ones((batch_size, sequence_length))

        batch_size, sequence_length = decoder_input_ids.shape
        if decoder_attention_mask is None:
            decoder_attention_mask = jnp.ones((batch_size, sequence_length))

        if decoder_position_ids is None:
            if past_key_values is not None:
                raise ValueError(
                    "Make sure to provide `decoder_position_ids` when passing `past_key_values`."
                )

            decoder_position_ids = jnp.broadcast_to(
                jnp.arange(sequence_length)[None, :],
                (batch_size, sequence_length))

        # Handle any PRNG if needed
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        inputs = {"params": params or self.params}

        # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
        # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
        # it can be changed by FlaxBartAttention module
        if past_key_values:
            inputs["cache"] = past_key_values
            mutable = ["cache"]
        else:
            mutable = False

        def _decoder_forward(module, decoder_input_ids, decoder_attention_mask,
                             decoder_position_ids, encoder_hidden_states,
                             **kwargs):

            projection_module = module._get_projection_module()
            decoder_module = module._get_decoder_module()

            # optionally project encoder_hidden_states
            if projection_module is not None:
                encoder_hidden_states = projection_module(
                    encoder_hidden_states)

            return decoder_module(
                decoder_input_ids,
                decoder_attention_mask,
                decoder_position_ids,
                encoder_hidden_states,
                **kwargs,
            )

        outputs = self.module.apply(
            inputs,
            decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
            decoder_attention_mask=jnp.array(decoder_attention_mask,
                                             dtype="i4"),
            decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=jnp.array(encoder_attention_mask,
                                             dtype="i4"),
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=not train,
            rngs=rngs,
            mutable=mutable,
            method=_decoder_forward,
        )

        # add updated cache to model output
        if past_key_values is not None and return_dict:
            outputs, past = outputs
            outputs["past_key_values"] = unfreeze(past["cache"])
            return outputs
        elif past_key_values is not None and not return_dict:
            outputs, past = outputs
            outputs = outputs[:1] + (unfreeze(past["cache"]), ) + outputs[1:]

        return outputs
Ejemplo n.º 7
0
 def model(labels):
     coefs = sample('coefs', dist.Normal(np.zeros(dim), np.ones(dim)))
     logits = np.sum(coefs * data, axis=-1)
     return sample('obs', dist.Bernoulli(logits=logits), obs=labels)
Ejemplo n.º 8
0
def main(unused_argv):
    # EX 1
    print_ex(jax2tex(lambda a, b: a + b, 1, 2))

    # EX 2
    print_ex(jax2tex(lambda a, b: a + b / a, 1, 2))

    # EX 3
    f = lambda a, b: a + b / a
    print_ex(jax2tex(grad(f), 1., 2.))

    # EX 4
    def fn(a, b, c):
        return a + a * (b + c) / a

    print_ex(jax2tex(fn, np.array([[1, 2], [2, 4], [3, 7]]), 2, 3))

    # EX 5
    # pylint: disable=function-redefined
    # pylint: disable=invalid-name
    def fn(a, b, c):
        return a + a * (b + c)

    print_ex(jax2tex(grad(fn), 4., 2., 3.))

    # EX 6
    def fn(a, b):
        return a * (a - b) / (a + b) + b

    print_ex(jax2tex(grad(fn), 1., 1.))

    # EX 7
    print_ex(jax2tex(lambda W, x: W @ x, np.ones((3, 3)), np.ones((3, ))))

    # EX 8
    print_ex(jax2tex(lambda W, x: W @ x, np.ones((3, 2)), np.ones((2, 3))))

    # EX 9
    def fn(W, x):
        return (W + W) @ (x * x)

    print_ex(jax2tex(fn, np.ones((3, 2)), np.ones((2, 3))))

    # EX 10
    def fn(W, x):
        return (W + W) @ (x * x)

    print_ex(jax2tex(grad(fn), np.ones((2, )), np.ones((2, ))))

    # EX 11
    def fn(W, x):
        z = tex_var(W @ x, 'z')
        return z * z

    print_ex(jax2tex(fn, np.ones((
        4,
        2,
    )), np.ones((2, ))))

    # EX 12
    def fn(W, x):
        z1 = tex_var(W @ x, 'z^1')
        z2 = tex_var(W @ z1, 'z^2')
        return z2 @ z2

    print_ex(jax2tex(grad(fn), np.ones((
        2,
        2,
    )), np.ones((2, ))))

    # EX 13
    def fn(W, x):
        z1 = tex_var(W @ x, 'z^1')
        z2 = tex_var(W @ z1, 'z^2')
        return np.sqrt(z2 @ z2)

    print_ex(jax2tex(fn, np.ones((
        2,
        2,
    )), np.ones((2, ))))

    # EX 14
    def fn(x):
        return lax.broadcast_in_dim(x, (2, 3), (1, ))

    print_ex(jax2tex(fn, np.ones((3, ))))

    # EX 15
    def fn(c, x, y):
        return np.where(c, x, y)

    print_ex(jax2tex(fn, np.ones((3, ), bool), np.ones((3, )), np.ones((3, ))))

    # EX 16
    def fn(c, x, y):
        return np.where(c, x, y)

    print_ex(jax2tex(fn, True, np.ones((3, )), np.ones((3, ))))

    # EX 17
    def fn(x):
        return np.transpose(x)

    print_ex(jax2tex(fn, np.ones((3, 2))))

    # EX 18
    def E(dr):
        idr = (tex_var(1, '\\sigma') / dr)
        idr6 = idr**6
        idr12 = idr**12
        return 4 * tex_var(1, '\\epsilon') * (idr12 - idr6)

    print_ex(jax2tex(E, np.ones((3, 3))))

    # Stax Examples
    def TexVar(layer, name, param_names=(), explicit_depends=False):
        init_fn, apply_fn = layer

        def tex_apply_fn(params, xs, **kwargs):
            if param_names:
                assert len(param_names) == len(params)
                params = tuple(
                    tex_var(p, name, True)
                    for p, name in zip(params, param_names))
            return tex_var(apply_fn(params, xs, **kwargs),
                           name,
                           depends_on=xs if explicit_depends else ())

        return init_fn, tex_apply_fn

    init_fn, apply_fn = stax.serial(
        TexVar(stax.Dense(256), 'z^1', ('W^1', 'b^1')),
        TexVar(stax.Relu, 'y^1'), TexVar(stax.Dense(3), 'z^2', ('W^2', 'b^2')))

    # EX 19
    def f(params, x):
        return apply_fn(params, tex_var(x, 'x', True))

    _, params = init_fn(random.PRNGKey(0), (-1, 5))
    print_ex(jax2tex(f, params, np.ones((3, 5))))

    # pylint: disable=too-many-function-args
    def L(params, x, y_hat):
        y_hat = tex_var(y_hat, '\\hat y', True)
        return tex_var(-np.sum(y_hat * jax.nn.log_softmax(f(params, x))), 'L')

    # EX 20
    print_ex(jax2tex(L, params, np.ones((3, 5)), np.ones((3, 3))))
    # EX 21
    print_ex(jax2tex(grad(L), params, np.ones((3, 5)), np.ones((3, 3))))

    # EX 22
    init_fn, apply_fn = stax.serial(
        TexVar(stax.Dense(256), 'z^1', ('W^1', 'b^1'), True),
        TexVar(stax.Relu, 'y^1'), TexVar(stax.Dense(3), 'z^2', ('W^2', 'b^2')))

    def f(params, x):
        return apply_fn(params, tex_var(x, 'x', True))

    _, params = init_fn(random.PRNGKey(0), (-1, 5))
    print_ex(jax2tex(f, params, np.ones((3, 5))))

    # EX 23
    def nngp(params, x1, x2):
        x1 = tex_var(x1, 'x^1', True)
        x2 = tex_var(x2, 'x^2', True)
        return tex_var(
            apply_fn(params, x1) @ apply_fn(params, x2).T, '\\mathcal K')

    _, params = init_fn(random.PRNGKey(0), (-1, 5))
    print_ex(jax2tex(nngp, params, np.ones((3, 5)), np.ones((3, 5))))

    # Forward Mode vs Reverse Mode
    f = lambda a, b: a + tex_var(b / a, 'z')
    # EX 24
    print_ex(jax2tex(f, 1., 1.))
    # EX 25
    print_ex(jax2tex(grad(f), 1., 1.))
    # EX 26
    # pylint: disable=g-long-lambda
    print_ex(
        jax2tex(lambda a, b: jvp(lambda a: f(a, b), (a, ), (1., ))[1], 1., 1.))

    # EX 27
    def f(x, y):
        def g(r):
            return tex_var(r**2, 'z', depends_on=r)

        return g(x) + g(y)

    print_ex(jax2tex(f, 1., 1.))

    # EX 28
    def f(x_and_y):
        x, y = x_and_y
        return x * y

    print_ex(jax2tex(f, (1., 1.)))

    # EX 29
    def f(x_and_y):
        x, y = x_and_y
        return tex_var(x, 'x') * tex_var(y, 'y')

    print_ex(jax2tex(f, (1., 1.)))

    # EX 30
    def f(x_and_y):
        x, y = x_and_y
        return tex_var(x, 'x', True) * tex_var(y, 'y', True)

    print_ex(jax2tex(f, (1., 1.)))

    def f(x):
        return np.sin(x)

    # EX 31
    print_ex(jax2tex(grad(bind_names(f)), 1.))
    # EX 32
    print_ex(jax2tex(grad(f), 1.))
Ejemplo n.º 9
0
 def test_jit_several_together(self):
     arg = jnp.arange(50, dtype=jnp.int32).reshape((10, 5))
     with hcb.outfeed_receiver(receiver_name=self._testMethodName):
         api.jit(lambda x, y: hcb.id_print((x, y, x * 2.)))(
             arg, jnp.ones(100, dtype=jnp.int32))
Ejemplo n.º 10
0
def ones(shape, dtype_str='float32', dev_str=None):
    dtype = _jnp.__dict__[dtype_str]
    return to_dev(_jnp.ones(shape, dtype), dev_str)
Ejemplo n.º 11
0
                                dropout=(None, not self.broadcast_dropout))

        attn = Attn(attn_module=self.attn_module,
                    qkv_features=qkv_features // self.num_heads,
                    out_features=out_features)

        # evaluate multi-headed-attention.
        y = attn(inputs_q, inputs_kv, bias)
        return y.mean(axis=-2)


# run it.

if __name__ == '__main__':

    inputs = jnp.ones((8, 97, 256))
    rngs = {'params': random.PRNGKey(0), 'dropout': random.PRNGKey(1)}
    model = MultiHeadDotProductAttention(
        broadcast_dropout=False,
        qkv_features=256,
        out_features=256,
        attn_module=functools.partial(SoftmaxAttnWDropout, rate=0.1),
        num_heads=8,
        batch_axes=(0, ),
    )

    y, params = model.init_with_output(rngs, inputs, inputs)

    print('input shape: ', inputs.shape)
    print('parameter shapes:')
    pprint(jax.tree_map(jnp.shape, unfreeze(params)))
Ejemplo n.º 12
0
def pmds_MAP2(
        p_dists,
        n_samples,
        n_components=2,
        batch_size=0,
        random_state=42,
        lr=1e-3,
        epochs=20,
        debug_D_squareform=None,
        fixed_points=[],
        init_mu=None,
        hard_fix=False,
        method="LV",
        sigma_local=1e-3,
        alpha=None,  # contribution of log prior
        sigma_fix=None,  # sigma^2_{fix}
):
    # can multiply the log prior with factor N
    alpha = alpha or n_samples - 1

    # local variance for each point
    sigma_local = jnp.ones((2, 1)) * sigma_local

    # init latent vars and params
    mu, mu0, sigma0 = init_params(
        n_samples,
        n_components,
        random_state=random_state,
        init_mu=init_mu,
        fixed_points=fixed_points,
        sigma_fix=sigma_fix,
    )

    # optimizer
    opt_init, opt_update, get_params = adam(step_size=lr)
    opt_state = opt_init([mu])

    @jax.jit
    def update(epoch, opt_state, dists, i0, i1, mu0, sigma0, sigma_local,
               alpha):
        params = get_params(opt_state)
        loss, grads = loss_and_grads_MAP(params, dists, i0, i1, mu0, sigma0,
                                         sigma_local, alpha)
        opt_state = opt_update(epoch, grads, opt_state)
        return loss, opt_state

    losses = []
    for epoch in range(epochs):
        # shuffle the observed pairs in each epoch
        batch = random.sample(p_dists, k=len(p_dists))
        # unpatch pairwise distances and indices of points in each pair
        dists, pair_indices = list(zip(*batch))
        dists = jnp.array(dists).reshape(-1, 1)
        i0, i1 = map(jnp.array, zip(*pair_indices))

        loss, opt_state = update(epoch, opt_state, dists, i0, i1, mu0, sigma0,
                                 sigma_local, alpha)
        losses.append(float(loss))
        print(f"{epoch}, {loss:,.4f}")

    mu = get_params(opt_state)[0]
    return mu, [losses, [],
                []]  # old code: loss, loss_log_llh and loss_log_prior
Ejemplo n.º 13
0
 def sufficient_statistics(data):
     return (np.ones(data.shape[:-1]), data,
             np.einsum('...i,...j->...ij', data, data))
Ejemplo n.º 14
0
 def model(data):
     f = numpyro.sample("beta",
                        dist.Beta(jnp.ones(2), jnp.ones(2)).to_event())
     with numpyro.plate("N", N):
         numpyro.sample("obs", dist.Bernoulli(f).to_event(1), obs=data)
Ejemplo n.º 15
0
    def __init__(self):
        super().__init__()
        # Data properties
        self.n_train: int = 1
        self.n_test: int = 1000
        self.observed_dims: int = 5
        self.t_steps: int = 300
        self.t_steps_test: int = 300
        self.delta_t: float = 0.01
        self.known_drift_diffusion: bool = True

        scale = 500
        np.random.seed(42)

        x_train = np.zeros(shape=(self.t_steps * scale, 2))
        x_test = np.zeros(shape=(self.t_steps_test * scale, 2))
        y_train = np.zeros(shape=(self.t_steps * scale, 5))
        y_test = np.zeros(shape=(self.t_steps_test * scale, 5))

        drift_train = np.zeros(shape=(self.t_steps * scale, 5))
        drift_test = np.zeros(shape=(self.t_steps_test * scale, 5))

        diffusion_train = np.zeros(shape=(self.t_steps * scale, 5))
        diffusion_test = np.zeros(shape=(self.t_steps_test * scale, 5))

        x_train[0, 1] = 1
        x_test[0, 1] = 1

        y_train[0] = self.drift_y(x_train[0])
        y_test[0] = self.drift_y(x_test[0])

        for i in range(0, self.t_steps * scale - 1):
            noise = np.random.normal(size=(2, ))
            noise_y = np.random.normal(size=(5, ))

            drift_train[i] = self.drift_y(x_train[i])
            drift_test[i] = self.drift_y(x_test[i])
            diffusion_train[i] = self.diffusion_y(x_train[i])
            diffusion_test[i] = self.diffusion_y(x_test[i])

            x_train[i+1] = \
                x_train[i] + self.drift(x_train[i]) * self.delta_t/scale + \
                np.sqrt(self.delta_t/scale) * self.diffusion(x_train[i]) * noise
            y_train[i+1] = \
                y_train[i] + self.drift_y(x_train[i]) * self.delta_t / scale + \
                np.sqrt(self.delta_t/scale) * self.diffusion_y(x_train[i]) * noise_y
            x_test[i+1] = \
                x_test[i] + self.drift(x_test[i]) * self.delta_t/scale + \
                np.sqrt(self.delta_t/scale) * self.diffusion(x_test[i]) * noise
            y_test[i+1] = \
                y_test[i] + self.drift_y(x_test[i]) * self.delta_t / scale + \
                np.sqrt(self.delta_t/scale) * self.diffusion_y(x_test[i]) * noise_y

        y_train = np.tile(y_train[None, ::scale], (self.n_train, 1, 1))
        y_test = np.tile(y_test[None, ::scale], (self.n_test, 1, 1))

        drift_train = np.tile(drift_train[None, ::scale], (self.n_train, 1, 1))
        drift_test = np.tile(drift_test[None, ::scale], (self.n_test, 1, 1))

        diffusion_train = np.tile(diffusion_train[None, ::scale],
                                  (self.n_train, 1, 1))
        diffusion_test = np.tile(diffusion_test[None, ::scale],
                                 (self.n_test, 1, 1))

        self.ys_train = jnp.array(y_train)
        self.ys_test = jnp.array(y_test)

        self.drifts_train = jnp.array(drift_train)
        self.drifts_test = jnp.array(drift_test)

        self.diffusions_train = jnp.array(diffusion_train)
        self.diffusions_test = jnp.array(diffusion_test)

        self.masks_train = jnp.ones(
            (self.n_train, self.t_steps, self.observed_dims), dtype=jnp.int32)
        self.masks_test = jnp.ones(
            (self.n_test, self.t_steps_test, self.observed_dims),
            dtype=jnp.int32)
Ejemplo n.º 16
0
 def ones(self: TensorType, shape: ShapeOrScalar) -> TensorType:
     return type(self)(np.ones(shape, dtype=self.raw.dtype))
Ejemplo n.º 17
0
 def __init__(self):
     super().__init__()
     self.w = jnp.ones([])
Ejemplo n.º 18
0
 def get_arr(scale):
     return scale + np.ones((2, 2))
Ejemplo n.º 19
0
 def _predict(params, x):
     x = x[np.newaxis, :]
     X = np.hstack((np.ones((x.shape[0], 1)), x.reshape(x.shape[0], -1))).reshape(
         x.shape[0], -1, 1
     )
     return np.tensordot(X, params["W"], axes=(1, 0))
Ejemplo n.º 20
0
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
    tf.io.gfile.makedirs(workdir)

    # Number of local devices for this host.
    n_devices = jax.local_device_count()

    if config.batch_size % n_devices:
        raise ValueError(
            "Batch size must be divisible by the number of devices")

    vocab_path = config.vocab_path
    if vocab_path is None:
        vocab_path = os.path.join(workdir, "sentencepiece_model")
    tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

    # Load Dataset
    # ---------------------------------------------------------------------------
    logging.info("Initializing dataset.")
    train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
        n_devices=n_devices,
        dataset_name=config.dataset_name,
        eval_dataset_name=config.eval_dataset_name,
        shard_idx=jax.host_id(),
        shard_count=jax.host_count(),
        vocab_path=vocab_path,
        target_vocab_size=config.vocab_size,
        batch_size=config.batch_size,
        max_corpus_chars=config.max_corpus_chars,
        max_length=config.max_target_length,
        max_eval_length=config.max_eval_target_length)
    train_iter = iter(train_ds)
    vocab_size = int(encoder.vocab_size())
    eos_id = decode.EOS_ID  # Default Sentencepiece EOS token.

    def decode_tokens(toks):
        valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
        return encoder.detokenize(valid_toks).numpy().decode("utf-8")

    if config.num_predict_steps > 0:
        predict_ds = predict_ds.take(config.num_predict_steps)

    logging.info("Initializing model, optimizer, and step functions.")

    # Build Model and Optimizer
    # ---------------------------------------------------------------------------
    train_config = models.TransformerConfig(
        vocab_size=vocab_size,
        output_vocab_size=vocab_size,
        share_embeddings=config.share_embeddings,
        logits_via_embedding=config.logits_via_embedding,
        dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32,
        emb_dim=config.emb_dim,
        num_heads=config.num_heads,
        num_layers=config.num_layers,
        qkv_dim=config.qkv_dim,
        mlp_dim=config.mlp_dim,
        max_len=max(config.max_target_length, config.max_eval_target_length),
        dropout_rate=config.dropout_rate,
        attention_dropout_rate=config.attention_dropout_rate,
        deterministic=False,
        decode=False,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6))
    eval_config = train_config.replace(deterministic=True)
    predict_config = train_config.replace(deterministic=True, decode=True)

    start_step = 0
    rng = random.PRNGKey(config.seed)
    rng, init_rng = random.split(rng)
    input_shape = (config.batch_size, config.max_target_length)
    target_shape = (config.batch_size, config.max_target_length)

    m = models.Transformer(eval_config)
    initial_variables = jax.jit(m.init)(init_rng,
                                        jnp.ones(input_shape, jnp.float32),
                                        jnp.ones(target_shape, jnp.float32))

    # apply an optimizer to this tree
    optimizer_def = optim.Adam(config.learning_rate,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=config.weight_decay)
    optimizer = optimizer_def.create(initial_variables["params"])

    # We access model params only from optimizer below via optimizer.target.
    del initial_variables

    if config.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(workdir, optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)

    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.host_id() > 0)
    if start_step == 1:
        writer.write_hparams(dict(config))

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = create_learning_rate_scheduler(
        base_learning_rate=config.learning_rate,
        warmup_steps=config.warmup_steps)

    # compile multidevice versions of train/eval/predict step and cache init fn.
    p_train_step = jax.pmap(functools.partial(
        train_step,
        config=train_config,
        learning_rate_fn=learning_rate_fn,
        label_smoothing=config.label_smoothing),
                            axis_name="batch",
                            donate_argnums=(0, ))  # pytype: disable=wrong-arg-types
    p_eval_step = jax.pmap(functools.partial(
        eval_step, config=eval_config, label_smoothing=config.label_smoothing),
                           axis_name="batch")
    p_init_cache = jax.pmap(functools.partial(
        initialize_cache,
        max_decode_len=config.max_predict_length,
        config=predict_config),
                            axis_name="batch")
    p_pred_step = jax.pmap(
        functools.partial(predict_step,
                          config=predict_config,
                          beam_size=config.beam_size),
        axis_name="batch",
        static_broadcasted_argnums=(3,
                                    4))  # eos token, max_length are constant

    # Main Train Loop
    # ---------------------------------------------------------------------------

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap"d training update for performance.
    dropout_rngs = random.split(rng, n_devices)

    logging.info("Starting training loop.")
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_train_steps, writer=writer)
    if jax.host_id() == 0:
        hooks += [
            report_progress,
            periodic_actions.Profile(num_profile_steps=5)
        ]
    metrics_all = []
    with metric_writers.ensure_flushes(writer):
        for step, batch in zip(range(start_step, config.num_train_steps),
                               train_iter):
            # Shard data to devices and do a training step.
            batch = common_utils.shard(
                jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access
            optimizer, metrics, dropout_rngs = p_train_step(
                optimizer, batch, dropout_rng=dropout_rngs)
            metrics_all.append(metrics)

            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, "Finished training step %d.", 5,
                                step)
            for h in hooks:
                h(step)

            # Save a checkpoint on one host after every checkpoint_freq steps.
            if (config.save_checkpoints and step % config.checkpoint_freq == 0
                    and step > 0 and jax.host_id() == 0):
                checkpoints.save_checkpoint(workdir,
                                            jax_utils.unreplicate(optimizer),
                                            step)

            # Periodic metric handling.
            if step % config.eval_frequency != 0 and step > 0:
                continue

            # Training Metrics
            logging.info("Gathering training metrics.")
            metrics_all = common_utils.get_metrics(metrics_all)
            lr = metrics_all.pop("learning_rate").mean()
            metrics_sums = jax.tree_map(jnp.sum, metrics_all)
            denominator = metrics_sums.pop("denominator")
            summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
            summary["learning_rate"] = lr
            summary = {"train_" + k: v for k, v in summary.items()}
            writer.write_scalars(step, summary)
            metrics_all = []

            # Eval Metrics
            logging.info("Gathering evaluation metrics.")
            eval_metrics = []
            eval_iter = iter(eval_ds)
            for _, eval_batch in zip(range(config.num_eval_steps), eval_iter):
                eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
                eval_batch = common_utils.shard(eval_batch)
                metrics = p_eval_step(optimizer.target, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
            eval_denominator = eval_metrics_sums.pop("denominator")
            eval_summary = jax.tree_map(
                lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
                eval_metrics_sums)
            eval_summary = {"eval_" + k: v for k, v in eval_summary.items()}
            writer.write_scalars(step, eval_summary)

            # Translation and BLEU Score.
            logging.info("Translating evaluation dataset.")
            t_inference_start = time.time()
            sources, references, predictions = [], [], []
            for pred_batch in predict_ds:
                pred_batch = jax.tree_map(lambda x: x._numpy(), pred_batch)  # pylint: disable=protected-access
                # Handle final odd-sized batch by padding instead of dropping it.
                cur_pred_batch_size = pred_batch["inputs"].shape[0]
                if cur_pred_batch_size % n_devices:
                    padded_size = int(
                        np.ceil(cur_pred_batch_size / n_devices) * n_devices)
                    pred_batch = jax.tree_map(
                        lambda x: pad_examples(x, padded_size),  # pylint: disable=cell-var-from-loop
                        pred_batch)
                pred_batch = common_utils.shard(pred_batch)
                cache = p_init_cache(pred_batch["inputs"])
                predicted = p_pred_step(pred_batch["inputs"], optimizer.target,
                                        cache, eos_id,
                                        config.max_predict_length)
                predicted = tohost(predicted)
                inputs = tohost(pred_batch["inputs"])
                targets = tohost(pred_batch["targets"])
                # Iterate through non-padding examples of batch.
                for i, s in enumerate(predicted[:cur_pred_batch_size]):
                    sources.append(decode_tokens(inputs[i]))
                    references.append(decode_tokens(targets[i]))
                    predictions.append(decode_tokens(s))
            logging.info(
                "Translation: %d predictions %d references %d sources.",
                len(predictions), len(references), len(sources))
            logging.info("Translation time: %.4f s step %d.",
                         time.time() - t_inference_start, step)

            # Calculate BLEU score for translated eval corpus against reference.
            bleu_matches = bleu.bleu_partial(references, predictions)
            all_bleu_matches = per_host_sum_pmap(bleu_matches)
            bleu_score = bleu.complete_bleu(*all_bleu_matches)
            # Save translation samples for tensorboard.
            exemplars = ""
            for n in np.random.choice(np.arange(len(predictions)), 8):
                exemplars += f"{sources[n]}\n\n{references[n]}\n\n{predictions[n]}\n\n"
            writer.write_scalars(step, {"bleu": bleu_score})
            writer.write_texts(step, {"samples": exemplars})
Ejemplo n.º 21
0
def ones_with_scale(key, shape, scale, dtype=jnp.float32):
    return jnp.ones(shape, dtype) * scale
Ejemplo n.º 22
0
def test_is_complex():
    assert is_complex(jnp.ones(1, dtype=jnp.complex_))
Ejemplo n.º 23
0
 def __init__(self, nin, nout):
     self.v1 = objax.TrainVar(jn.ones((nin, nout)))
     self.var_list = [
         objax.TrainVar(jn.zeros(nin)),
         objax.TrainVar(jn.zeros(nout))
     ]
Ejemplo n.º 24
0
def test_resample():
    x = random.normal(key=random.PRNGKey(0), shape=(50,))
    logits = -jnp.ones(50)
    samples = {'x':x}
    assert jnp.all(resample(random.PRNGKey(0), samples, logits)['x'] == resample(random.PRNGKey(0), x, logits))
Ejemplo n.º 25
0
        if use_bias:
            bias = jnp.empty([out_features])
        else:
            bias = None
        return cls(in_features=in_features,
                   out_features=out_features,
                   use_bias=use_bias,
                   weight=weight,
                   bias=bias)

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None)


class Dense(Module):
    def __init__(self, in_features: int, out_features: int) -> None:
        super().__init__()
        self.weight = Parameter((out_features, in_features),
                                init_method=XavierUniform())

    def forward(self, x: Tensor) -> Tensor:
        return jnp.dot(x, self.weight.data.T)


if __name__ == "__main__":
    dense = Dense(5, 10).init()
    x = jnp.ones((2, 5))
    y, new_buffers = dense.pure_forward(dense.params, dense.buffers, x)
    print(y)
Ejemplo n.º 26
0
    def __init__(self):
        super().__init__()
        # Data properties
        self.n_train: int = 1
        self.n_test: int = 1000
        self.observed_dims: int = 1
        self.t_steps: int = 800
        self.t_steps_test: int = 800
        self.delta_t: float = 0.01
        self.known_drift_diffusion: bool = True
        self.y_lims: Optional[Sequence[Sequence[float]]] = [[-2, 2]]

        np.random.seed(0)
        scale = 500

        x_train = np.zeros(shape=(self.t_steps * scale, 1))
        x_test = np.zeros(shape=(self.t_steps_test * scale, 1))

        drift_train = np.zeros(shape=(self.t_steps * scale, 1))
        drift_test = np.zeros(shape=(self.t_steps_test * scale, 1))

        diffusion_train = np.zeros(shape=(self.t_steps * scale, 1))
        diffusion_test = np.zeros(shape=(self.t_steps_test * scale, 1))

        for i in range(0, self.t_steps * scale - 1):
            drift_train[i] = self.drift(x_train[i])
            drift_test[i] = self.drift(x_test[i])
            diffusion_train[i] = self.diffusion(x_train[i])
            diffusion_test[i] = self.diffusion(x_test[i])

            noise = np.random.normal(size=(1, ))
            x_train[i+1] = \
                x_train[i] + drift_train[i] * self.delta_t/scale + \
                np.sqrt(self.delta_t/scale) * diffusion_train[i] * noise
            x_test[i+1] = \
                x_test[i] + drift_test[i] * self.delta_t/scale + \
                np.sqrt(self.delta_t/scale) * diffusion_test[i] * noise

        x_train = np.tile(x_train[None, ::scale], (self.n_train, 1, 1))
        x_test = np.tile(x_test[None, ::scale], (self.n_test, 1, 1))

        drift_train = np.tile(drift_train[None, ::scale], (self.n_train, 1, 1))
        drift_test = np.tile(drift_test[None, ::scale], (self.n_test, 1, 1))

        diffusion_train = np.tile(diffusion_train[None, ::scale],
                                  (self.n_train, 1, 1))
        diffusion_test = np.tile(diffusion_test[None, ::scale],
                                 (self.n_test, 1, 1))

        self.ys_train = jnp.array(x_train)
        self.ys_test = jnp.array(x_test)

        self.drifts_train = jnp.array(drift_train)
        self.drifts_test = jnp.array(drift_test)

        self.diffusions_train = jnp.array(diffusion_train)
        self.diffusions_test = jnp.array(diffusion_test)

        self.masks_train = jnp.ones(
            (self.n_train, self.t_steps, self.observed_dims), dtype=jnp.int32)
        self.masks_test = jnp.ones(
            (self.n_test, self.t_steps_test, self.observed_dims),
            dtype=jnp.int32)
Ejemplo n.º 27
0
 def dummy_output(fx_struct):
     return np.ones(fx_struct.shape, fx_struct.dtype)
Ejemplo n.º 28
0
    def __init__(self):
        super().__init__()
        # Data properties
        self.n_train: int = 1
        self.n_test: int = 1000
        self.observed_dims: int = 2
        self.t_steps: int = 500
        self.t_steps_test: int = 500
        self.delta_t: float = 0.01
        self.known_drift_diffusion: bool = True

        np.random.seed(42)

        scale = 500
        x_train = np.zeros(shape=(self.t_steps_test * scale, 2))
        x_test = np.zeros(shape=(self.t_steps_test * scale, 2))

        drift_train = np.zeros(shape=(self.t_steps_test * scale, 2))
        drift_test = np.zeros(shape=(self.t_steps_test * scale, 2))

        diffusion_train = np.zeros(shape=(self.t_steps_test * scale, 2))
        diffusion_test = np.zeros(shape=(self.t_steps_test * scale, 2))

        x_train[0, 1] = 1
        x_test[0, 1] = 1

        for i in range(0, self.t_steps_test * scale - 1):
            drift_train[i] = self.drift(x_train[i])
            drift_test[i] = self.drift(x_test[i])
            diffusion_train[i] = self.diffusion(x_train[i])
            diffusion_test[i] = self.diffusion(x_test[i])

            noise = np.random.normal(size=(2, ))
            x_train[i+1] = \
                x_train[i] + drift_train[i] * self.delta_t/scale + \
                np.sqrt(self.delta_t/scale) * diffusion_train[i] * noise
            x_test[i+1] = \
                x_test[i] + drift_test[i] * self.delta_t/scale + \
                np.sqrt(self.delta_t/scale) * diffusion_test[i] * noise

        x_train = np.tile(x_train[None, :self.t_steps * scale:scale],
                          (self.n_train, 1, 1))
        x_test = np.tile(x_test[None, ::scale], (self.n_test, 1, 1))

        drift_train = np.tile(drift_train[None, :self.t_steps * scale:scale],
                              (self.n_train, 1, 1))
        drift_test = np.tile(drift_test[None, ::scale], (self.n_test, 1, 1))

        diffusion_train = np.tile(
            diffusion_train[None, :self.t_steps * scale:scale],
            (self.n_train, 1, 1))
        diffusion_test = np.tile(diffusion_test[None, ::scale],
                                 (self.n_test, 1, 1))

        self.ys_train = jnp.array(x_train)
        self.ys_test = jnp.array(x_test)

        self.drifts_train = jnp.array(drift_train)
        self.drifts_test = jnp.array(drift_test)

        self.diffusions_train = jnp.array(diffusion_train)
        self.diffusions_test = jnp.array(diffusion_test)

        self.masks_train = jnp.ones(
            (self.n_train, self.t_steps, self.observed_dims), dtype=jnp.int32)
        self.masks_test = np.ones(
            (self.n_test, self.t_steps_test, self.observed_dims),
            dtype=np.int32)
        self.masks_test[:, self.t_steps:] = 0
        self.masks_test = jnp.array(self.masks_test)
        self.masks_train = np.ones(
            (self.n_train, self.t_steps, self.observed_dims), dtype=np.int32)
        self.masks_train[:, self.t_steps // 5:2 * self.t_steps // 5, 0] = 0
        self.masks_train[:, 3 * self.t_steps // 5:4 * self.t_steps // 5, 1] = 0

        self.masks_test = np.ones(
            (self.n_test, self.t_steps_test, self.observed_dims),
            dtype=np.int32)
        self.masks_test[:, self.t_steps_test // 5:2 * self.t_steps_test // 5,
                        0] = 0
        self.masks_test[:,
                        3 * self.t_steps_test // 5:4 * self.t_steps_test // 5,
                        1] = 0

        self.masks_train = jnp.array(self.masks_train)
        self.masks_test = jnp.array(self.masks_test)
Ejemplo n.º 29
0
 def testRmspropVector(self):
   def loss(x): return jnp.dot(x, x)
   x0 = jnp.ones(2)
   num_iters = 100
   step_size = 0.1
   self._CheckOptimizer(optimizers.rmsprop, loss, x0, num_iters, step_size)
Ejemplo n.º 30
0
 def model(data, labels):
     coefs = numpyro.sample("coefs",
                            dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
     offset = numpyro.sample("offset", dist.Uniform(-1, 1))
     logits = offset + jnp.sum(coefs * data, axis=-1)
     return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels)