def full_pes_grad(key, params, data, targets, K, sigma, N): theta, unflatten_fn = flatten_util.ravel_pytree(params) estimator = gradient_estimators.MultiParticleEstimator( key=key, theta_shape=theta.shape, n_chunks=1, n_particles_per_chunk=N, K=K, T=None, sigma=sigma, method='lockstep', estimator_type=args.estimate, init_state_fn=init_state_fn, unroll_fn=unroll, ) T = len(data) t = 0 gradient_estimate = jax.tree_map(lambda x: jnp.zeros(x.shape), theta) while t < T: grad_pes_term = estimator.grad_estimate(theta, update_state=True) gradient_estimate += grad_pes_term t += K return gradient_estimate
theta = jnp.concatenate(theta_vals) outer_opt = optax.adam(args.outer_lr, b1=args.outer_b1, b2=args.outer_b2, eps=args.outer_eps) outer_opt_state = outer_opt.init(theta) key = jax.random.PRNGKey(args.seed) estimator = gradient_estimators.MultiParticleEstimator( key=key, theta_shape=theta.shape, n_chunks=args.n_chunks, n_particles_per_chunk=args.n_per_chunk, K=args.K, T=args.T, sigma=args.sigma, method='lockstep', telescoping=args.telescoping, estimator_type=args.estimate, init_state_fn=init_state_fn, unroll_fn=unroll, ) start_time = time.time() total_inner_iterations = 0 total_inner_iterations_including_N = 0 # Meta-optimization loop for outer_iteration in range(args.outer_iterations): outer_grad = estimator.grad_estimate(theta)
# Get a single fixed sequence for which to compute the gradient data, targets = get_batch(train_data, 0, T) print('data.shape = {}, targets.shape = {}'.format(data.shape, targets.shape)) hidden = init_hidden(batch_size, nhid, nlayers) theta, _ = flatten_util.ravel_pytree(params) # Compute the ground-truth gradient estimate using vanilla ES # --------------------------------------------------------------------------- es_estimator = gradient_estimators.MultiParticleEstimator( key=key, theta_shape=theta.shape, n_chunks=1, n_particles_per_chunk=5000, K=T, # Take the full sequence as one unroll for vanilla ES T=T, sigma=sigma, method='lockstep', estimator_type='es', init_state_fn=init_state_fn, unroll_fn=unroll, ) base_grad = es_estimator.grad_estimate(theta, update_state=False) flat_base_grad, _ = flatten_util.ravel_pytree(base_grad) total_grad_norm = jnp.linalg.norm(flat_base_grad)**2 print('Finished computing base grad') sys.stdout.flush() # --------------------------------------------------------------------------- pes_var_dict = {}