예제 #1
0
xR = float(N - 5)

#Define the parameters for motion of the left wall
xL_f = [6.32]

neurons1 = 100
neurons2 = 100
neurons3 = 100
#layer2_sizes=[5,7,10,15,20]
LR = 2e-3  #learning rate for the optimizer
episodes = 1

final_fidelity_neurons2 = []
for i in range(len(xL_f)):

    net_init, net_apply = stax.serial(Dense(neurons1),
                                      Relu, Dense(neurons2), Relu,
                                      Dense(neurons3), Relu, Dense(1), Sigmoid)

    # Initialize parameters, not committing to a batch shape

    rng = random.PRNGKey(1)
    in_shape = (
        -1,
        1,
    )
    out_shape, net_params = net_init(rng, in_shape)
    net_params = [[a.astype(np.float64) for a in d]
                  for d in net_params]  # convert data type
    inputs = np.array([[k] for k in tlist[1:len(tlist)] / T_total])
    inputs, targets = batch
    preds = predict(params, inputs)
    return -np.mean(np.sum(preds * targets, axis=1))


def accuracy(params, batch):
    inputs, targets = batch
    target_class = np.argmax(targets, axis=1)
    predicted_class = np.argmax(predict(params, inputs), axis=1)
    return np.mean(predicted_class == target_class)


init_random_params, predict = stax.serial(
    Conv(10, (5, 5), (1, 1)), Activator,
    MaxPool((4, 4)), Flatten,
    Dense(10), LogSoftmax)

if __name__ == "__main__":
    rng = random.PRNGKey(0)
    
    step_size = 0.001
    num_epochs = 10
    batch_size = 128
    momentum_mass = 0.9

    # input shape for CNN
    input_shape = (-1, 28, 28, 1)
    
    # training/test split
    (train_images, train_labels), (test_images, test_labels) = mnist_data.tiny_mnist(flatten=False)
    num_train = train_images.shape[0]
예제 #3
0

def loss(params, batch):
    inputs, targets = batch
    preds = predict(params, inputs)
    return -jnp.mean(jnp.sum(preds * targets, axis=1))


def accuracy(params, batch):
    inputs, targets = batch
    target_class = jnp.argmax(targets, axis=1)
    predicted_class = jnp.argmax(predict(params, inputs), axis=1)
    return jnp.mean(predicted_class == target_class)


init_random_params, predict = stax.serial(Dense(1024), Relu, Dense(1024), Relu,
                                          Dense(10), LogSoftmax)

if __name__ == "__main__":
    rng = random.PRNGKey(0)

    step_size = 0.001
    num_epochs = 10
    batch_size = 128
    momentum_mass = 0.9

    train_images, train_labels, test_images, test_labels = datasets.mnist()
    num_train = train_images.shape[0]
    num_complete_batches, leftover = divmod(num_train, batch_size)
    num_batches = num_complete_batches + bool(leftover)
예제 #4
0
from jax.experimental.stax import Dense, Dropout, Tanh, Relu, randn
from jax.experimental import optimizers
import pandas as pd
import lifelike.losses as losses
from lifelike import Model
from lifelike.callbacks import *
from datasets.loaders import *


x_train, t_train, e_train = get_generated_churn_dataset()

model = Model([Dense(8), Relu, Dense(12), Relu, Dense(16), Relu])

model.compile(
    optimizer=optimizers.adam,
    optimizer_kwargs={"step_size": optimizers.exponential_decay(0.001, 1, 0.9995)},
    weight_l2=0.00,
    smoothing_l2=100.,
    loss=losses.NonParametric()
)

print(model)

model.fit(
    x_train,
    t_train,
    e_train,
    epochs=10000,
    batch_size=10000,
    validation_split=0.1,
    callbacks=[
예제 #5
0
# train_images, labels, _, _ = mnist(permute_train=True, resize=True)
# del _
# inputs = train_images[:data_size]
#
# del train_images

# u, s, v_t = onp.linalg.svd(inputs, full_matrices=False)
# I = np.eye(v_t.shape[-1])
# I_add = npr.normal(0.0, 0.002, size=I.shape)
# noisy_I = I + I_add

init_fun, predict_fun = stax.serial(
    HomotopyDropout(rate=0.0),
    # Dense(4, b_init=zeros),
    Dense(4, W_init=orthogonal(scale=1.0), b_init=zeros),
    Sigmoid,
    # Dense(4, b_init=zeros),
    Dense(out_dim=input_shape[-1], W_init=orthogonal(scale=1.0), b_init=zeros),
)
_, key = random.split(random.PRNGKey(0))


class DataTopologyAE(AbstractProblem):
    def __init__(self):
        self.HPARAMS_PATH = "hparams.json"

    @staticmethod
    def objective(params, bparam, batch) -> float:
        x, _ = batch
        x = np.reshape(x, (x.shape[0], -1))
예제 #6
0
def main():
    total_secs = 10.0
    gamma = 0.9
    rng = random.PRNGKey(0)

    ### Set up the problem/environment
    # xdot = Ax + Bu
    # u = - Kx
    # cost = xQx + uRu + 2xNu

    A = jp.eye(2)
    B = jp.eye(2)
    Q = jp.eye(2)
    R = jp.eye(2)
    N = jp.zeros((2, 2))

    # rngA, rngB, rngQ, rngR, rng = random.split(rng, 5)
    # # A = random.normal(rngA, (2, 2))
    # A = -1 * random_psd(rngA, 2)
    # B = random.normal(rngB, (2, 2))
    # Q = random_psd(rngQ, 2) + 0.1 * jp.eye(2)
    # R = random_psd(rngR, 2) + 0.1 * jp.eye(2)
    # N = jp.zeros((2, 2))

    # x_dim, u_dim = B.shape

    dynamics_fn = lambda x, u: A @ x + B @ u
    cost_fn = lambda x, u: x.T @ Q @ x + u.T @ R @ u + 2 * x.T @ N @ u

    ### Solve the Riccatti equation to get the infinite-horizon optimal solution.
    K, _, _ = control.lqr(A, B, Q, R, N)
    K = jp.array(K)

    t0 = time.time()
    rng_eval, rng = random.split(rng)
    x0_eval = random.normal(rng_eval, (1000, 2))
    opt_all_costs = vmap(lambda x0: policy_integrate_cost(
        dynamics_fn, cost_fn, lambda _, x: -K @ x, gamma)
                         (None, x0, total_secs))(x0_eval)
    opt_cost = jp.mean(opt_all_costs)
    print(f"opt_cost = {opt_cost} in {time.time() - t0}s")

    ### Set up the learned policy model.
    policy_init, policy = stax.serial(
        Dense(64),
        Relu,
        Dense(64),
        Relu,
        Dense(2),
    )
    # policy_init, policy = DenseNoBias(2)

    rng_init_params, rng = random.split(rng)
    _, init_policy_params = policy_init(rng_init_params, (2, ))

    cost_and_grad = jit(
        value_and_grad(
            policy_integrate_cost(dynamics_fn, cost_fn, policy, gamma)))
    opt = make_optimizer(optimizers.adam(1e-3))(init_policy_params)

    def multiple_steps(num_steps):
        """Return a jit-able function that runs `num_steps` iterations."""
        def body(_, stuff):
            rng, _, opt = stuff
            rng_x0, rng = random.split(rng)
            x0 = random.normal(rng_x0, (2, ))
            cost, g = cost_and_grad(opt.value, x0, total_secs)

            # Gradient clipping
            # g = tree_map(lambda x: jp.clip(x, -10, 10), g)
            # g = optimizers.clip_grads(g, 64)

            return rng, cost, opt.update(g)

        return lambda rng, opt: lax.fori_loop(0, num_steps, body,
                                              (rng, jp.zeros(()), opt))

    multi_steps = 1
    run = jit(multiple_steps(multi_steps))

    ### Main optimization loop.
    costs = []
    for i in range(25000):
        t0 = time.time()
        rng, cost, opt = run(rng, opt)
        print(f"Episode {(i + 1) * multi_steps}:")
        print(f"    excess cost = {cost - opt_cost}")
        print(f"    elapsed = {time.time() - t0}")
        costs.append(float(cost))

    print(f"Opt solution cost from starting point: {opt_cost}")
    # print(f"Gradient at opt solution: {opt_g}")

    # Print the identified and optimal policy. Note that layers multiply multipy
    # on the right instead of the left so we need a transpose.
    print(f"Est solution parameters: {opt.value}")
    print(f"Opt solution parameters: {K.T}")

    est_all_costs = vmap(
        lambda x0: policy_integrate_cost(dynamics_fn, cost_fn, policy, gamma)
        (opt.value, x0, total_secs))(x0_eval)

    ### Scatter plot of learned policy performance vs optimal policy performance.
    plt.figure()
    plt.scatter(est_all_costs, opt_all_costs)
    plt.plot([-100, 100], [-100, 100], color="gray")
    plt.xlim(0, jp.max(est_all_costs))
    plt.ylim(0, jp.max(opt_all_costs))
    plt.xlabel("Learned policy cost")
    plt.ylabel("Optimal cost")
    plt.title("Performance relative to the direct LQR solution")

    ### Plot performance per iteration, incl. average optimal policy performance.
    plt.figure()
    plt.plot(costs)
    plt.axhline(opt_cost, linestyle="--", color="gray")
    plt.yscale("log")
    plt.xlabel("Iteration")
    plt.ylabel(f"Cost (T = {total_secs}s)")
    plt.legend(["Learned policy", "Direct LQR solution"])
    plt.title("ODE control of LQR problem")

    ### Example rollout plots (learned policy vs optimal policy).
    x0 = jp.array([1.0, 2.0])
    framerate = 30
    timesteps = jp.linspace(0, total_secs, num=int(total_secs * framerate))
    est_policy_rollout_states = ode.odeint(
        lambda x, _: dynamics_fn(x, policy(opt.value, x)), y0=x0, t=timesteps)
    est_policy_rollout_controls = vmap(lambda x: policy(opt.value, x))(
        est_policy_rollout_states)

    opt_policy_rollout_states = ode.odeint(lambda x, _: dynamics_fn(x, -K @ x),
                                           y0=x0,
                                           t=timesteps)
    opt_policy_rollout_controls = vmap(lambda x: -K @ x)(
        opt_policy_rollout_states)

    plt.figure()
    plt.plot(est_policy_rollout_states[:, 0],
             est_policy_rollout_states[:, 1],
             marker='.')
    plt.plot(opt_policy_rollout_states[:, 0],
             opt_policy_rollout_states[:, 1],
             marker='.')
    plt.xlabel("x_1")
    plt.ylabel("x_2")
    plt.legend(["Learned policy", "Direct LQR solution"])
    plt.title("Phase space trajectory")

    plt.figure()
    plt.plot(timesteps, jp.sqrt(jp.sum(est_policy_rollout_controls**2,
                                       axis=-1)))
    plt.plot(timesteps, jp.sqrt(jp.sum(opt_policy_rollout_controls**2,
                                       axis=-1)))
    plt.xlabel("time")
    plt.ylabel("control input (L2 norm)")
    plt.legend(["Learned policy", "Direct LQR solution"])
    plt.title("Policy control over time")

    ### Plot quiver field showing dynamics under learned policy.
    plot_policy_dynamics(dynamics_fn, cost_fn, lambda x: policy(opt.value, x))

    plt.show()
예제 #7
0
from jax.experimental.stax import Dense, Relu
from jax.experimental.optimizers import l2_norm, momentum, adam, sgd
from jax.nn.initializers import glorot_normal, normal
from jax.nn import relu, softplus, log_sigmoid, sigmoid
from collections import namedtuple
import itertools
from tqdm import tqdm
import logging
from absl import app, flags

from common import INPUTS, LABELS
from common import load_bibtex, data_stream, evaluate, compute_f1, compute_accuracy

logger = logging.getLogger(__file__)

init_mlp, apply_mlp = stax.serial(Dense(150), Relu, Dense(200), Relu,
                                  Dense(LABELS))


def sigmoid_cross_entropy(x, y):
    pos_logprob = log_sigmoid(x) * y
    neg_logprob = -softplus(x) * (1 - y)
    return np.sum(pos_logprob + neg_logprob, axis=1)


def cross_entropy_loss(params, x, y, lamb=0.001):
    neglogprob = -np.mean(sigmoid_cross_entropy(-apply_mlp(params, x), y))
    return neglogprob + lamb * l2_norm(params)


@jit
def loss(params, batch):
    inputs, targets = batch
    preds = predict(params, inputs)
    return -jnp.mean(jnp.sum(preds * targets, axis=1))


def accuracy(params, batch):
    inputs, targets = batch
    target_class = jnp.argmax(targets, axis=1)
    predicted_class = jnp.argmax(predict(params, inputs), axis=1)
    return jnp.mean(predicted_class == target_class)


init_random_params, predict = stax.serial(
    Dense(config.num_hidden), Relu,
    Dense(config.num_hidden), Relu,
    Dense(10), LogSoftmax)

if __name__ == "__main__":

    tensorboard = config.tensorboard

    rng = random.PRNGKey(0)

    num_epochs = 10

    train_images, train_labels, test_images, test_labels = datasets.mnist()
    num_train = train_images.shape[0]
    num_complete_batches, leftover = divmod(num_train, config.batch_size)
    num_batches = num_complete_batches + bool(leftover)
예제 #9
0
def loss(params, batch):
    inputs, targets = batch
    preds = predict(params, inputs)
    return -np.mean(np.sum(preds * targets, axis=1))


def accuracy(params, batch):
    inputs, targets = batch
    target_class = np.argmax(targets, axis=1)
    predicted_class = np.argmax(predict(params, inputs), axis=1)
    return np.mean(predicted_class == target_class)


init_random_params, predict = stax.serial(Conv(10, (5, 5), (1, 1)), Activator,
                                          MaxPool((4, 4)), Flatten, Dense(24),
                                          LogSoftmax)

if __name__ == "__main__":
    rng = random.PRNGKey(0)

    step_size = 0.001
    num_epochs = 10
    batch_size = 128
    momentum_mass = 0.9

    # input shape for CNN
    input_shape = (-1, 28, 28, 1)

    # training/test split
    (train_images,
예제 #10
0
 def MyDense(*args, **kwargs):
     return Dense(*args, **kwargs)
예제 #11
0
from research.estop.pendulum.env import viz_pendulum_rollout
from research.estop.utils import Scalarify
from research.statistax import Deterministic, Normal
from research.utils import make_optimizer
from research.estop import mdp

tau = 1e-4
buffer_size = 2**15
batch_size = 64
num_eval_rollouts = 128
opt_init = make_optimizer(optimizers.adam(step_size=1e-3))
init_noise = Normal(jp.array(0.0), jp.array(0.0))
noise = lambda _1, _2: Normal(jp.array(0.0), jp.array(0.5))

actor_init, actor = stax.serial(
    Dense(64),
    Relu,
    Dense(1),
    Tanh,
    stax.elementwise(lambda x: config.max_torque * x),
)

critic_init, critic = stax.serial(
    FanInConcat(),
    Dense(64),
    Relu,
    Dense(64),
    Relu,
    Dense(1),
    stax.elementwise(lambda x: x + 1.0 / (1 - config.gamma)),
    Scalarify,
예제 #12
0
def Mpl():
    """MPL used for the Chain experiment in:
    https://arxiv.org/abs/2102.12425"""
    return serial(Dense(128), Relu)
예제 #13
0
    one_hots,
    right_pad,
    validate_mLSTM1900_params,
)

# setup logger
logger = logging.getLogger("evotuning")
logger.setLevel(logging.DEBUG)
fh = logging.FileHandler("evotuning.log")
fh.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s :: %(levelname)s :: %(message)s")
fh.setFormatter(formatter)
logger.addHandler(fh)

# setup model
model_layers = (mLSTM1900(), mLSTM1900_HiddenStates(), Dense(25), Softmax)
init_fun, predict = serial(*model_layers)


@jit
def evotune_loss(params, inputs, targets):
    logging.debug(f"Input shape: {inputs.shape}")
    logging.debug(f"Output shape: {targets.shape}")
    predictions = vmap(partial(predict, params))(inputs)

    return _neg_cross_entropy_loss(targets, predictions)


def avg_loss(xs: List[np.ndarray], ys: List[np.ndarray], params) -> float:
    """
    Return average loss of a set of parameters,
예제 #14
0
#
# del train_images

# u, s, v_t = onp.linalg.svd(inputs, full_matrices=False)
# I = np.eye(v_t.shape[-1])
# I_add = npr.normal(0.0, 0.002, size=I.shape)
# noisy_I = I + I_add

init_fun, conv_net = stax.serial(
    Conv(32, (5, 5), (2, 2), padding="SAME"),
    BatchNorm(),
    Relu,
    Conv(10, (3, 3), (2, 2), padding="SAME"),
    Relu,
    Flatten,
    Dense(num_classes),
    LogSoftmax,
)
_, key = random.split(random.PRNGKey(0))


class DataTopologyAE(AbstractProblem):
    def __init__(self):
        self.HPARAMS_PATH = "hparams.json"

    @staticmethod
    def objective(params, bparam, batch) -> float:
        x, _ = batch
        x = np.reshape(x, (x.shape[0], -1))
        logits = predict_fun(params, x, bparam=bparam[0], rng=key)
        keep = random.bernoulli(key, bparam[0], x.shape)
예제 #15
0
import jax.numpy as np
from jax import random, jacrev, jacfwd, vjp, jvp, linearize, jit
from jax.experimental import stax
from jax.experimental.stax import Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax

from functools import partial

# Use stax to set up network initialization and evaluation functions
net_init, net_apply = stax.serial(
    Dense(64),
    Relu,
    Dense(10),
    LogSoftmax,
)

# Initialize parameters, not committing to a batch shape
rng = random.PRNGKey(0)
in_shape = (-1, 32)
out_shape, net_params = net_init(rng, in_shape)

# Apply network to dummy inputs
inputs = np.zeros((1, 32))
# predictions = net_apply(net_params, inputs)
# print ("pred: ", predictions)


def net_apply_reverse(inputs, net_params):
    return net_apply(net_params, inputs)


@jit
예제 #16
0
def synth_batches():
    while True:
        images = npr.rand(*input_shape).astype("float32")
        yield images


batches = synth_batches()
inputs = next(batches)
u, s, v_t = onp.linalg.svd(inputs, full_matrices=False)
I = np.eye(v_t.shape[-1])
I_add = npr.normal(0.0, 0.002, size=I.shape)
noisy_I = I + I_add

encoder_init, encode = stax.serial(
    Dense(512),
    Relu,
    Dense(512),
    Relu,
    FanOut(2),
    stax.parallel(Dense(10), stax.serial(Dense(10), Softplus)),
)

decoder_init, decode = stax.serial(
    Dense(512),
    Relu,
    Dense(512),
    Relu,
    Dense(8),
)
예제 #17
0
  train_images = train_byte_images.astype(np.float32) / 255
  train_int_labels = raw_data['train']['label']
  train_labels = one_hot(train_int_labels, 10)
  test_images = raw_data['test']['image'].astype(np.float32) / 255
  test_labels = one_hot(raw_data['test']['label'], 10)
  return dict(train_images=train_images, train_labels=train_labels,
              train_byte_images=train_byte_images, 
              train_int_labels=train_int_labels,
              test_images=test_images, test_labels=test_labels,
              test_byte_images=raw_data['test']['image'],
              test_int_labels=raw_data['test']['label'])


init_random_params, predict = stax.serial(
    Flatten,
    Dense(512), Relu,
    Dense(256), Relu,
    Dense(10), LogSoftmax)
mnist_data = load_mnist()


def subset_train(seed, subset_ratio):
  jrng = random.PRNGKey(seed)
  
  step_size = 0.1
  num_epochs = 10
  batch_size = 128
  momentum_mass = 0.9

  num_train_total = mnist_data['train_images'].shape[0]
  num_train = int(num_train_total * subset_ratio)
예제 #18
0
import pickle
from examples.torch_data import get_data

npr.seed(7)
orth_init_cont = True
input_shape = (30000, 36)


def accuracy(params, batch):
    inputs, targets = batch
    target_class = np.argmax(targets, axis=-1)
    predicted_class = np.argmax(predict_fun(params, inputs), axis=-1)
    return np.mean(predicted_class == target_class)


init_fun, predict_fun = Dense(out_dim=10, W_init=normal(), b_init=normal())


class ModelContClassifier(AbstractProblem):
    def __init__(self):
        self.HPARAMS_PATH = "hparams.json"

    @staticmethod
    def objective(params, batch) -> float:
        x, targets = batch
        logits = predict_fun(params, x)
        logits = logits - logsumexp(logits, axis=1, keepdims=True)
        loss = -np.mean(np.sum(logits * targets, axis=1))
        loss += 5e-6 * (l2_norm(params))  #+ l2_norm(bparam))
        return loss
예제 #19
0
def elbo(key, params, images):
    enc_params, dec_params = params
    mu, logvar = encode(enc_params, images)
    z = sample_z(key, mu, logvar)
    logits = decode(dec_params, z)
    elbo = bernoulli_llh(logits, images) - gaussian_kl(mu, logvar)
    return elbo, logits


if __name__ == "__main__":
    data = MNIST()
    latent_dim = 10

    encoder_init, encode = stax.serial(
        Dense(512),
        Relu,
        Dense(512),
        Relu,
        FanOut(2),
        stax.parallel(Dense(latent_dim), Dense(latent_dim)),
    )

    decoder_init, decode = stax.serial(
        Dense(512), Relu, Dense(512), Relu, Dense(data.num_pixels)
    )

    step_size = 1e-3
    num_epochs = 100
    batch_size = 128
    return train_ds, test_ds



key = PRNGKey(42)
data_key, key = split(key)

n_train, n_test = 5000, 1000
train_ds, test_ds = load_mnist(data_key, n_train, n_test)

n_features = train_ds["X"].shape[1]
n_classes = 10

init_random_params, predict = stax.serial(
    Dense(n_features), Relu,
    Dense(50), Relu,
    Dense(n_classes), LogSoftmax)

init_key, key = split(key)
_, params_tree_init = init_random_params(init_key, input_shape=(-1, n_features))

# Do one step of SGD in full parameter space to get good initial value (“anchor”)
potential_key, key = split(key)
l2_regularizer, batch_size = 1., 512
objective = sub.make_potential(potential_key, predict, train_ds, batch_size, l2_regularizer)

losses = jnp.array([])
learning_rate = 1e-3
optimizer = optax.adam(learning_rate)
n_steps = 300
예제 #21
0
    descend_and_update(nn_loss_xy)


X, Y, X_test = get_data(seed=5)
key = PRNGKey(0)
x_tr, x_te = split(X, key)
x_tr = x_tr[:, [1]]
x_te = x_te[:, [1]]
y_tr, y_te = split(Y, key)

n_layers = 3
n_neurons = 3

nn_init_fn, nn_apply_fn = stax.serial(
    *chain(*[(Tanh, Dense(n_neurons)) for _ in range(n_layers)]),
    Dense(1),
)

out_shape, init_params = nn_init_fn(PRNGKey(9), x_tr.shape[1:])

n_train = 400
lr = 0.125
momentum = 0.9
nn_loss = partial(make_loss, nn_apply_fn)
nn_loss_xy = partial(nn_loss, x=x_tr, y=y_tr)
params, memo = train_opt(nn_loss_xy, n_train, init_params, lr, momentum)

x = x_te
y = y_te
xlin = jnp.linspace(x.min() - 1, x.max() + 1, 50).reshape(-1, 1)
예제 #22
0
def main():
    num_iter = 50000
    # Most people run 1000 steps and the OpenAI gym pendulum is 0.05s per step.
    # The max torque that can be applied is also 2 in their setup.
    T = 1000
    time_delta = 0.05
    max_torque = 2.0
    rng = random.PRNGKey(0)

    dynamics = pendulum_dynamics(
        mass=1.0,
        length=1.0,
        gravity=9.8,
        friction=0.0,
    )

    policy_init, policy_nn = stax.serial(
        Dense(64),
        Tanh,
        Dense(64),
        Tanh,
        Dense(1),
        Tanh,
        stax.elementwise(lambda x: max_torque * x),
    )

    # Should it matter whether theta is wrapped into [0, 2pi]?
    policy = lambda params, x: policy_nn(
        params,
        jnp.array([x[0] % (2 * jnp.pi), x[1],
                   jnp.cos(x[0]),
                   jnp.sin(x[0])]))

    def loss(policy_params, x0):
        x = x0
        acc_cost = 0.0
        for _ in range(T):
            u = policy(policy_params, x)
            x += time_delta * dynamics(x, u)
            acc_cost += time_delta * cost(x, u)
        return acc_cost

    rng_init_params, rng = random.split(rng)
    _, init_policy_params = policy_init(rng_init_params, (4, ))
    opt = make_optimizer(optimizers.adam(1e-3))(init_policy_params)
    loss_and_grad = jit(value_and_grad(loss))

    loss_per_iter = []
    elapsed_per_iter = []
    x0s = vmap(sample_x0)(random.split(rng, num_iter))
    for i in range(num_iter):
        t0 = time.time()
        loss, g = loss_and_grad(opt.value, x0s[i])
        opt = opt.update(g)
        elapsed = time.time() - t0

        loss_per_iter.append(loss)
        elapsed_per_iter.append(elapsed)

        print(f"Episode {i}")
        print(f"    loss = {loss}")
        print(f"    elapsed = {elapsed}")

    blt.remember({
        "loss_per_iter": loss_per_iter,
        "elapsed_per_iter": elapsed_per_iter,
        "final_params": opt.value
    })

    plt.figure()
    plt.plot(loss_per_iter)
    plt.yscale("log")
    plt.title("ODE control of an inverted pendulum")
    plt.xlabel("Iteration")
    plt.ylabel(f"Policy cost (T = {total_secs}s)")

    # Viz
    num_viz_rollouts = 50
    framerate = 30
    timesteps = jnp.linspace(0,
                             int(T * time_delta),
                             num=int(T * time_delta * framerate))
    rollout = lambda x0: ode.odeint(
        lambda x, _: dynamics(x, policy(opt.value, x)), y0=x0, t=timesteps)

    plt.figure()
    states = rollout(jnp.zeros(2))
    plt.plot(states[:, 0], states[:, 1], marker=".")
    plt.xlabel("theta")
    plt.ylabel("theta dot")
    plt.title("Swing up trajectory")

    plt.figure()
    states = vmap(rollout)(x0s[:num_viz_rollouts])
    for i in range(num_viz_rollouts):
        plt.plot(states[i, :, 0], states[i, :, 1], marker='.', alpha=0.5)
    plt.xlabel("theta")
    plt.ylabel("theta dot")
    plt.title("Phase space trajectory")

    plot_control_contour(lambda x: policy(opt.value, x))
    plot_policy_dynamics(dynamics, lambda x: policy(opt.value, x))

    blt.show()
예제 #23
0
파일: mnist_vae.py 프로젝트: zhouj/jax
  """Sample images from the generative model."""
  _, dec_params = params
  code_rng, img_rng = random.split(rng)
  logits = decode(dec_params, random.normal(code_rng, (nrow * ncol, 10)))
  sampled_images = random.bernoulli(img_rng, np.logaddexp(0., logits))
  return image_grid(nrow, ncol, sampled_images, (28, 28))

def image_grid(nrow, ncol, imagevecs, imshape):
  """Reshape a stack of image vectors into an image grid for plotting."""
  images = iter(imagevecs.reshape((-1,) + imshape))
  return np.vstack([np.hstack([next(images).T for _ in range(ncol)][::-1])
                    for _ in range(nrow)]).T


encoder_init, encode = stax.serial(
    Dense(512), Relu,
    Dense(512), Relu,
    FanOut(2),
    stax.parallel(Dense(10), stax.serial(Dense(10), Softplus)),
)

decoder_init, decode = stax.serial(
    Dense(512), Relu,
    Dense(512), Relu,
    Dense(28 * 28),
)


if __name__ == "__main__":
  step_size = 0.001
  num_epochs = 100
예제 #24
0
def get_rossi_dataset():
    from lifelines.datasets import load_rossi

    df = load_rossi()

    T_train = df.pop("week").values
    E_train = df.pop("arrest").values
    X_train = df.values

    return X_train, T_train, E_train


x_train, t_train, e_train = get_rossi_dataset()

model = Model([Dense(18), Relu])

model.compile(
    optimizer=optimizers.adam,
    optimizer_kwargs={
        "step_size": optimizers.exponential_decay(0.01, 10, 0.999)
    },
    loss=losses.NonParametric(),
)

model.fit(x_train, t_train, e_train, epochs=2, batch_size=32)

print(model.predict_survival_function(x_train[0], np.arange(0, 10)))

dump(model, "testsavefile")
model = load("testsavefile")
예제 #25
0
 def model(x, y):
     nn = numpyro.module("nn", Dense(1), (10,))
     mu = nn(x).squeeze(-1)
     sigma = numpyro.sample("sigma", dist.HalfNormal(1))
     numpyro.sample("y", dist.Normal(mu, sigma), obs=y)
예제 #26
0
    leaves_of_params = tree_leaves(params)
    return sum(tree_map(lambda p: jnp.sum(jax.scipy.stats.norm.logpdf(p, scale=l2_regularizer)), leaves_of_params))


key = PRNGKey(42)
data_key, init_key, opt_key, sample_key, warmstart_key = split(key, 5)

n_train, n_test = 20000, 1000
train_ds, test_ds = load_mnist(data_key, n_train, n_test)
data = (train_ds["X"], train_ds["y"])
n_features = train_ds["X"].shape[1]
n_classes = 10

# model
init_random_params, predict = stax.serial(
    Dense(200), Relu,
    Dense(50), Relu,
    Dense(n_classes), LogSoftmax)

_, params_init_tree = init_random_params(init_key, input_shape=(-1, n_features))

leaves = tree_leaves(params_init_tree)
n = 0
for i in range(len(leaves)):
    sh = leaves[i].shape
    n += np.prod(sh)
    print(f"size of parameters in leaf {i} is {sh}")
print("total nparams", n)

l2_regularizer = 0.1
batch_size = 512
예제 #27
0
def logistic_regression(rng, dim):
    """Logistic regression."""
    init_params, forward = stax.serial(Dense(1), Sigmoid)
    temp, rng = random.split(rng)
    params = init_params(temp, (-1, dim))[1]
    return params, forward
예제 #28
0
    xL[-1] = xL_f
    vL = (xL_f - xL_0) / (onp.sum(params) * dt) * params
    for i in range(len(vL) - 1):
        xL[i + 1] = xL[i] + vL[i] * dt

    return xL


f_vL = partial(velocity_to_position, xL_0=xL_0, xL_f=xL_f, dt=dt)
f_coarse_velocity_xL = coarse_to_fine(f_interpolate, f_vL)

# define neural network
neurons1 = 100
#neurons2 = 20
net_init, net_apply = stax.serial(
    Dense(neurons1),
    Relu,
    #Dense(neurons2), Relu,
    Dense(1),
    Sigmoid)
rng = random.PRNGKey(1)
in_shape = (
    -1,
    1,
)
out_shape, net_params = net_init(rng, in_shape)
inputs = np.array([[k] for k in tlist[1:len(tlist)] / T_total])


# neural network profile
def nn_profile(params, inputs, xL_0, xL_f):
예제 #29
0
from examples import datasets


def loss(params, batch):
  inputs, targets = batch
  preds = predict(params, inputs)
  return -np.mean(preds * targets)

def accuracy(params, batch):
  inputs, targets = batch
  target_class = np.argmax(targets, axis=1)
  predicted_class = np.argmax(predict(params, inputs), axis=1)
  return np.mean(predicted_class == target_class)

init_random_params, predict = stax.serial(
    Dense(1024), Relu,
    Dense(1024), Relu,
    Dense(10), LogSoftmax)

if __name__ == "__main__":
  rng = random.PRNGKey(0)

  step_size = 0.001
  num_epochs = 10
  batch_size = 128
  momentum_mass = 0.9

  train_images, train_labels, test_images, test_labels = datasets.mnist()
  num_train = train_images.shape[0]
  num_complete_batches, leftover = divmod(num_train, batch_size)
  num_batches = num_complete_batches + bool(leftover)
예제 #30
0
npr.seed(7)
orth_init_cont = True
input_shape = (30000, 36)


def accuracy(params, bparams, batch):
    inputs, targets = batch
    target_class = np.argmax(targets, axis=-1)
    predicted_class = np.argmax(predict_fun(params, inputs), axis=-1)
    return np.mean(predicted_class == target_class)


if orth_init_cont:
    init_fun, predict_fun = stax.serial(
        HomotopyDense(out_dim=18, W_init=orthogonal(), b_init=zeros),
        Dense(out_dim=10, W_init=orthogonal(), b_init=zeros), LogSoftmax)
    #init_fun, predict_fun = Dense(out_dim=10, W_init=normal(), b_init=normal())
else:
    # baseline network
    init_fun, predict_fun = stax.serial(Dense(out_dim=18), Relu,
                                        Dense(out_dim=10), LogSoftmax)


class ModelContClassifier(AbstractProblem):
    def __init__(self):
        self.HPARAMS_PATH = "hparams.json"

    @staticmethod
    def objective(params, bparam, batch) -> float:
        x, targets = batch
        x = np.reshape(x, (x.shape[0], -1))