Ejemplo n.º 1
0
def vf(r, g, eta, rhoc, rho, Nkn=0.0):
    """terminal velocity.

    Args:
        r: particle size (cm)
        g: gravity (cm/s2)
        eta: dynamic viscosity (g/s/cm)
        rhoc: condensate density (g/cm3)
        rho: atmosphere density (g/cm3)
        Nkn: Knudsen number

    Return:
        terminal velocity (cm/s)    

    Example:

        >>> #terminal velocity at T=300K, for Earth atmosphere/gravity.
        >>> g=980.
        >>> drho=1.0
        >>> rho=1.29*1.e-3 #g/cm3
        >>> vfactor,Tr=vc.calc_vfactor(atm="Air")
        >>> eta=vc.eta_Rosner(300.0,vfactor)
        >>> r=jnp.logspace(-5,0,70)
        >>> vf(r,g,eta,drho,rho) #terminal velocity (cm/s)
    """
    drho = rhoc-rho
    ND = Ndavies(r, g, eta, drho, rho)
    cond = [ND < 42.877543, (ND >= 42.877543) *
            (ND < 119643.38), ND >= 119643.38]
    choice = [vf_stokes(r, g, eta, drho, Nkn), vf_midNre(
        r, g, eta, drho, rho), vf_largeNre(r, g, eta, drho, rho)]
    vft = jnp.select(cond, choice)
    return vft
Ejemplo n.º 2
0
def train(init_random_params,
          x_train,
          xt_train,
          x_test,
          xt_test,
          log_dir=None):
    run_name = datetime.now().strftime("%d_%m_%Y.%H_%M")

    rng = jax.random.PRNGKey(0)
    _, init_params = init_random_params(rng, (-1, 1))

    batch_size = 100
    test_every = 10
    num_batches = 1500

    train_losses = []
    test_losses = []

    # adam w learn rate decay
    opt_init, opt_update, get_params = optimizers.adam(lambda t: jnp.select([
        t < batch_size * (num_batches // 3), t < batch_size *
        (2 * num_batches // 3), t > batch_size * (2 * num_batches // 3)
    ], [1e-3, 3e-4, 1e-4]))
    opt_state = opt_init(init_params)

    @jax.jit
    def update_derivative(i, opt_state, batch):
        params = get_params(opt_state)
        grad = jax.grad(loss)(params, batch, None)
        return opt_update(i,
                          jax.grad(loss)(params, batch, None), opt_state), grad

    grads = []
    for iteration in range(batch_size * num_batches + 1):
        if iteration % batch_size == 0:
            params = get_params(opt_state)
            train_loss = loss(params, (x_train, xt_train))
            train_losses.append(train_loss)
            test_loss = loss(params, (x_test, xt_test))
            test_losses.append(test_loss)
            if iteration % (batch_size * test_every) == 0:
                print(
                    f"iteration={iteration}, train_loss={train_loss:.6f}, test_loss={test_loss:.6f}"
                )
        opt_state, grad = update_derivative(iteration, opt_state,
                                            (x_train, xt_train))
        grads.append(grad)

    params = get_params(opt_state)

    if log_dir is not None:
        with open(f'{log_dir}/new_lnn_model_{run_name}.pickle',
                  'wb') as handle:
            pickle.dump(params, handle, protocol=pickle.HIGHEST_PROTOCOL)

        plot_loss(train_losses, test_losses, model_name='lnn', log_dir=log_dir)
Ejemplo n.º 3
0
def train(args, model, data):
    global opt_update, get_params, nn_forward_fn
    (nn_forward_fn, init_params) = model
    data = {
        k: jax.device_put(v) if type(v) is jnp.ndarray else v
        for k, v in data.items()
    }
    time.sleep(2)

    # choose our loss function
    if args.model == 'gln':
        loss = gln_loss
    elif args.model == 'baseline_nn':
        loss = baseline_loss
    else:
        raise ValueError

    @jax.jit
    def update_derivative(i, opt_state, batch):
        params = get_params(opt_state)
        return opt_update(i, jax.grad(loss)(params, batch, None), opt_state)

    # make an optimizer
    opt_init, opt_update, get_params = optimizers.adam(lambda t: jnp.select([
        t < args.batch_size * (args.num_batches // 3), t < args.batch_size *
        (2 * args.num_batches // 3), t > args.batch_size *
        (2 * args.num_batches // 3)
    ], [args.learn_rate, args.learn_rate / 10, args.learn_rate / 100]))
    opt_state = opt_init(init_params)

    train_losses, test_losses = [], []
    for iteration in range(args.batch_size * args.num_batches + 1):
        if iteration % args.batch_size == 0:
            params = get_params(opt_state)
            train_loss = loss(params, (data['x'], data['dx']))
            train_losses.append(train_loss)
            test_loss = loss(params, (data['test_x'], data['test_dx']))
            test_losses.append(test_loss)
            if iteration % (args.batch_size * args.test_every) == 0:
                print(
                    f"iteration={iteration}, train_loss={train_loss:.6f}, test_loss={test_loss:.6f}"
                )
        opt_state = update_derivative(iteration, opt_state,
                                      (data['x'], data['dx']))

    params = get_params(opt_state)
    return params, train_losses, test_losses
Ejemplo n.º 4
0
 def ppf(cf):
     x = norm.ppf((cf + 0.5) / (1 << post_prec), mean, stdd)
     # Binary search is faster than using the actual gaussian cdf for the
     # precisions we typically use, however the cdf is O(1) whereas search
     # is O(precision), so for high precision cdf will be faster.
     idxs = jnp.digitize(x, std_gaussian_bins(prior_prec)) - 1
     # This loop works around an issue which is extremely rare when we use
     # float64 everywhere but is common if we work with float32: due to the
     # finite precision of floating point arithmetic, norm.[cdf,ppf] are not
     # perfectly inverse to each other.
     idxs_ = lax.while_loop(
         lambda idxs: ~jnp.all((cdf(idxs) <= cf) & (cf < cdf(idxs + 1))),
         lambda idxs: jnp.select(
             [cf < cdf(idxs), cf >= cdf(idxs + 1)],
             [idxs - 1,       idxs + 1           ], idxs),
         idxs)
     return idxs_
Ejemplo n.º 5
0
from jaxmeta.data import tensor_grid
from dataset import Batch_Generator

# (x, t) in [-1, 1] x [0, 0.02]
domain = jnp.array([[-1.0, 0.0],
					# [1.0, 0.25]])
					[1.0, 0.02]])
epsilon = 1e-12
# epsilon = 0.7

data_file = "problem2_2_snapshot_epsilon_1e-12.mat"
# data_file = "problem2_2_snapshot_epsilon_0.7.mat"

# initial conditions
u0_fn = lambda x, t: jnp.select([x <= 0, x > 0], 
								[2.0, 1.0])
v0_fn = lambda x, t: jnp.zeros_like(x)

# boundary conditions
ul_fn = ur_fn = u0_fn
vl_fn = vr_fn = v0_fn

# dataset for initial, boundary conditions, collocation points
dataset_Dirichlet = namedtuple("dataset_Dirichlet", ["x", "t", "u", "v"])
dataset_Collocation = namedtuple("dataset_Collocation", ["x", "t"])

def generate_dataset(n_i, n_b, n_cx, n_ct, n_dx, n_dt):
	x_i = jnp.linspace(*domain[:, 0], n_i).reshape((-1, 1))
	t_i = jnp.zeros_like(x_i)
	u_i = u0_fn(x_i, t_i)
	v_i = v0_fn(x_i, t_i)
Ejemplo n.º 6
0
def select(condlist, choicelist, default=0):
  condlist = [c.value if isinstance(c, JaxArray) else c for c in condlist]
  choicelist = [c.value if isinstance(c, JaxArray) else c for c in choicelist]
  return JaxArray(jnp.select(condlist, choicelist, default=default))
Ejemplo n.º 7
0
def logistic_logpmf(img, means, inv_scales):
    centered = img - means
    top = -jnp.logaddexp(0, (centered - 1 / 255) * inv_scales)
    bottom = -jnp.logaddexp(0, -(centered + 1 / 255) * inv_scales)
    mid = log1mexp(inv_scales / 127.5) + top + bottom
    return jnp.select([img == -1, img == 1], [bottom, top], mid)
Ejemplo n.º 8
0
def train(args, model, data, rng):
    global opt_update, get_params, nn_forward_fn
    global best_params, best_loss
    best_params = None
    best_loss = np.inf
    best_small_loss = np.inf
    (nn_forward_fn, init_params) = model
    data = {k: jax.device_put(v) for k, v in data.items()}

    loss = make_loss(args)
    opt_init, opt_update, get_params = optimizers.adam(lambda t: jnp.select([
        t < args.num_epochs // 2, t >= args.num_epochs // 2
    ], [args.lr, args.lr2]))
    opt_state = opt_init(init_params)

    @jax.jit
    def update_derivative(i, opt_state, batch, l2reg):
        params = get_params(opt_state)
        return opt_update(i,
                          jax.grad(loss, 0)(params, batch, l2reg),
                          opt_state), params

    train_losses, test_losses = [], []

    for iteration in range(args.num_epochs):
        rand_idx = jax.random.randint(rng, (args.batch_size, ), 0,
                                      len(data['x']))
        rng += 1

        batch = (data['x'][rand_idx], data['dx'][rand_idx])
        opt_state, params = update_derivative(iteration, opt_state, batch,
                                              args.l2reg)
        small_loss = loss(params, batch, 0.0)

        new_small_loss = False
        if small_loss < best_small_loss:
            best_small_loss = small_loss
            new_small_loss = True

        if new_small_loss or (iteration % 1000
                              == 0) or (iteration < 1000
                                        and iteration % 100 == 0):
            params = get_params(opt_state)
            train_loss = loss(params,
                              (data['x'], data['dx']), 0.0) / len(data['x'])
            train_losses.append(train_loss)
            test_loss = loss(params, (data['test_x'], data['test_dx']),
                             0.0) / len(data['test_x'])
            test_losses.append(test_loss)

            if test_loss < best_loss:
                best_loss = test_loss
                best_params = params

            if jnp.isnan(test_loss).sum():
                break

            print(
                f"iteration={iteration}, train_loss={train_loss:.6f}, test_loss={test_loss:.6f}"
            )

    params = get_params(opt_state)
    return params, train_losses, test_losses, best_loss