Exemplo n.º 1
0
def full_pes_grad(key, params, data, targets, K, sigma, N):
    theta, unflatten_fn = flatten_util.ravel_pytree(params)

    estimator = gradient_estimators.MultiParticleEstimator(
        key=key,
        theta_shape=theta.shape,
        n_chunks=1,
        n_particles_per_chunk=N,
        K=K,
        T=None,
        sigma=sigma,
        method='lockstep',
        estimator_type=args.estimate,
        init_state_fn=init_state_fn,
        unroll_fn=unroll,
    )

    T = len(data)
    t = 0
    gradient_estimate = jax.tree_map(lambda x: jnp.zeros(x.shape), theta)
    while t < T:
        grad_pes_term = estimator.grad_estimate(theta, update_state=True)
        gradient_estimate += grad_pes_term
        t += K
    return gradient_estimate
Exemplo n.º 2
0
theta = jnp.concatenate(theta_vals)

outer_opt = optax.adam(args.outer_lr,
                       b1=args.outer_b1,
                       b2=args.outer_b2,
                       eps=args.outer_eps)
outer_opt_state = outer_opt.init(theta)

key = jax.random.PRNGKey(args.seed)
estimator = gradient_estimators.MultiParticleEstimator(
    key=key,
    theta_shape=theta.shape,
    n_chunks=args.n_chunks,
    n_particles_per_chunk=args.n_per_chunk,
    K=args.K,
    T=args.T,
    sigma=args.sigma,
    method='lockstep',
    telescoping=args.telescoping,
    estimator_type=args.estimate,
    init_state_fn=init_state_fn,
    unroll_fn=unroll,
)

start_time = time.time()
total_inner_iterations = 0
total_inner_iterations_including_N = 0

# Meta-optimization loop
for outer_iteration in range(args.outer_iterations):
    outer_grad = estimator.grad_estimate(theta)
Exemplo n.º 3
0
# Get a single fixed sequence for which to compute the gradient
data, targets = get_batch(train_data, 0, T)
print('data.shape = {}, targets.shape = {}'.format(data.shape, targets.shape))

hidden = init_hidden(batch_size, nhid, nlayers)
theta, _ = flatten_util.ravel_pytree(params)

# Compute the ground-truth gradient estimate using vanilla ES
# ---------------------------------------------------------------------------
es_estimator = gradient_estimators.MultiParticleEstimator(
    key=key,
    theta_shape=theta.shape,
    n_chunks=1,
    n_particles_per_chunk=5000,
    K=T,  # Take the full sequence as one unroll for vanilla ES
    T=T,
    sigma=sigma,
    method='lockstep',
    estimator_type='es',
    init_state_fn=init_state_fn,
    unroll_fn=unroll,
)

base_grad = es_estimator.grad_estimate(theta, update_state=False)
flat_base_grad, _ = flatten_util.ravel_pytree(base_grad)
total_grad_norm = jnp.linalg.norm(flat_base_grad)**2
print('Finished computing base grad')
sys.stdout.flush()
# ---------------------------------------------------------------------------

pes_var_dict = {}