def test_batch(self): """Test that batch layer is indeed ignored. Code taken from: https://github.com/google/flax/issues/932 """ key = jax.random.PRNGKey(0) x = jnp.ones((5, 4, 4, 3)) y = jax.random.uniform(key, (5, 4, 4, 7)) foo_vars = flax.core.unfreeze(Foo(filters=7, train=True).init(key, x)) tx = optax.masked(optax.adam(1e-7), create_weight_decay_mask()) @self.variant def train_step(params, x, y): y1, new_batch_stats = Foo( filters=7, train=True).apply( params, x, mutable=['batch_stats']) return jnp.abs(y - y1).sum(), new_batch_stats state = self.variant(tx.init)(foo_vars['params']) grads, _ = jax.grad(train_step, has_aux=True)(foo_vars, x, y) updates, state = self.variant(tx.update)(dict(grads['params']), state) chex.assert_trees_all_close(updates['BatchNorm_0'], grads['params']['BatchNorm_0'])
def train_with_bc(make_demonstrations: Callable[[int], Iterator[types.Transition]], networks: networks_lib.FeedForwardNetwork, loss: losses.Loss, num_steps: int = 100000) -> networks_lib.Params: """Trains the given network with BC and returns the params. Args: make_demonstrations: A function (batch_size) -> iterator with demonstrations to be imitated. networks: Network taking (params, obs, is_training, key) as input loss: BC loss to use. num_steps: number of training steps Returns: The trained network params. """ demonstration_iterator = make_demonstrations(256) learner = learning.BCLearner(network=networks, random_key=jax.random.PRNGKey(0), loss_fn=loss, demonstrations=demonstration_iterator, optimizer=optax.adam(1e-4), num_sgd_steps_per_step=1) # Train the agent for _ in range(num_steps): learner.step() return learner.get_variables(['policy'])[0]
def __init__(self, name, param_store=None, tensorboard_dir=None): env = make_env(name, tensorboard_dir) # function approximator self.q = coax.Q(forward_pass, env) self.q_targ = self.q.copy() # tracer and updater self.q_updater = coax.td_learning.QLearning(self.q, q_targ=self.q_targ, optimizer=optax.adam(3e-4)) # schedule for beta parameter used in PrioritizedReplayBuffer self.buffer_beta = coax.utils.StepwiseLinearFunction((0, 0.4), (1000000, 1)) super().__init__( env=env, param_store=param_store, pi=coax.BoltzmannPolicy(self.q, temperature=0.015), tracer=coax.reward_tracing.NStep(n=1, gamma=0.99), buffer=(coax.experience_replay.PrioritizedReplayBuffer( capacity=1000000, alpha=0.6) if param_store is None else None), buffer_warmup=50000, name=name)
def make_learner( self, random_key: networks_lib.PRNGKey, networks: impala_networks.IMPALANetworks, dataset: Iterator[reverb.ReplaySample], logger_fn: loggers.LoggerFactory, environment_spec: specs.EnvironmentSpec, replay_client: Optional[reverb.Client] = None, counter: Optional[counting.Counter] = None, ) -> core.Learner: del environment_spec, replay_client optimizer = optax.chain( optax.clip_by_global_norm(self._config.max_gradient_norm), optax.adam( self._config.learning_rate, b1=self._config.adam_momentum_decay, b2=self._config.adam_variance_decay, eps=self._config.adam_eps, eps_root=self._config.adam_eps_root)) return learning.IMPALALearner( networks=networks, iterator=dataset, optimizer=optimizer, random_key=random_key, discount=self._config.discount, entropy_cost=self._config.entropy_cost, baseline_cost=self._config.baseline_cost, max_abs_reward=self._config.max_abs_reward, counter=counter, logger=logger_fn('learner'), )
def test_optimizer_epoch(self): optax_op = optax.adam(1e-3) lr_schedule = lambda step, epoch: epoch optimizer = elegy.Optimizer(optax_op, lr_schedule=lr_schedule, steps_per_epoch=2) params = np.random.uniform((3, 4)) grads = np.random.uniform((3, 4)) rng = elegy.RNGSeq(42) optimizer_states = optimizer.init( rng=rng, net_params=params, ) assert jnp.allclose(optimizer.current_lr(optimizer_states), 0) params, optimizer_states = optimizer.apply(params, grads, optimizer_states, rng) assert jnp.allclose(optimizer.current_lr(optimizer_states), 0) params, optimizer_states = optimizer.apply(params, grads, optimizer_states, rng) assert jnp.allclose(optimizer.current_lr(optimizer_states), 1) params, optimizer_states = optimizer.apply(params, grads, optimizer_states, rng) assert jnp.allclose(optimizer.current_lr(optimizer_states), 1) params, optimizer_states = optimizer.apply(params, grads, optimizer_states, rng)
def sparsify_basis(Q,lr=1e-2): #(n,r) """ Convenience function to attempt to sparsify a given basis by applying an orthogonal transformation W, Q' = QW where Q' has only 1s, 0s and -1s. Notably this method does not have the same convergence gauruntees of krylov_constraint_solve and can fail (even silently). Intended to be used only for visualization purposes, use at your own risk. """ W = np.random.randn(Q.shape[-1],Q.shape[-1]) W,_ = np.linalg.qr(W) W = device_put(W.astype(jnp.float32)) opt_init,opt_update = optax.adam(lr)#optax.sgd(1e2,.9)#optax.adam(lr)#optax.sgd(3e-3,.9)#optax.adam(lr) opt_update = jit(opt_update) opt_state = opt_init(W) # init stats def loss(W): return jnp.abs([email protected]).mean() + .1*(jnp.abs([email protected](W.shape[0]))).mean()+.01*jax.numpy.linalg.slogdet(W)[1]**2 loss_and_grad = jit(jax.value_and_grad(loss)) for i in tqdm(range(3000),desc=f'sparsifying basis'): lossval, grad = loss_and_grad(W) updates, opt_state = opt_update(grad, opt_state, W) W = optax.apply_updates(W, updates) #W,_ = np.linalg.qr(W) if lossval>1e2 and i>100: # Solve diverged due to too high learning rate logging.warning(f"basis sparsification diverged, trying lower learning rate {lr/3:.2e}") return sparsify_basis(Q,lr=lr/3) Q = np.copy([email protected]) Q[np.abs(Q)<1e-2]=0 Q[np.abs(Q)>1e-2] /= np.abs(Q[np.abs(Q)>1e-2]) A = Q@(1+np.arange(Q.shape[-1])) if len(np.unique(np.abs(A)))!=Q.shape[-1]+1 and len(np.unique(np.abs(A)))!=Q.shape[-1]: logging.error(f"Basis elems did not separate: found only {len(np.unique(np.abs(A)))}/{Q.shape[-1]}") #raise ConvergenceError(f"Basis elems did not separate: found only {len(np.unique(A))}/{Q.shape[-1]}") return Q
def main(): # Create the dataset. train_dataset, vocab_size = load(batch_size, sequence_length) # Set up the model, loss, and updater. forward_fn = build_forward_fn(vocab_size, d_model, num_heads, num_layers, dropout_rate) forward_fn = hk.transform(forward_fn) loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size) optimizer = optax.chain(optax.clip_by_global_norm(grad_clip_value), optax.adam(learning_rate, b1=0.9, b2=0.99)) updater = GradientUpdater(forward_fn.init, loss_fn, optimizer) # Initialize parameters. logging.info('Initializing parameters...') rng = jax.random.PRNGKey(428) data = next(train_dataset) state = updater.init(rng, data) logging.info('Starting train loop...') prev_time = time.time() for step in range(MAX_STEPS): data = next(train_dataset) state, metrics = updater.update(state, data)
def train( # pylint: disable=invalid-name Phi, Psi, num_epochs, learning_rate, key, estimator, alpha, optimizer, use_l2_reg, reg_coeff, use_penalty, j, num_rows, skipsize=1): """Training function.""" Phis = [Phi] # pylint: disable=invalid-name grads = [] if optimizer == 'sgd': optim = optax.sgd(learning_rate) elif optimizer == 'adam': optim = optax.adam(learning_rate) opt_state = optim.init(Phi) for i in tqdm(range(num_epochs)): key, subkey = jax.random.split(key) Phi, opt_state, grad = estimates.nabla_phi_analytical( Phi, Psi, subkey, optim, opt_state, estimator, alpha, use_l2_reg, reg_coeff, use_penalty, j, num_rows) Phis.append(Phi) grads.append(grad) if i % skipsize == 0: Phis.append(Phi) grads.append(grad) return jnp.stack(Phis), jnp.stack(grads)
def make_learner( self, random_key: networks_lib.PRNGKey, networks: networks_lib.FeedForwardNetwork, dataset: Iterator[reverb.ReplaySample], logger_fn: loggers.LoggerFactory, environment_spec: Optional[specs.EnvironmentSpec], replay_client: Optional[reverb.Client] = None, counter: Optional[counting.Counter] = None, ) -> core.Learner: del environment_spec return learning_lib.SGDLearner( network=networks, random_key=random_key, optimizer=optax.adam(self._config.learning_rate, eps=self._config.adam_eps), target_update_period=self._config.target_update_period, data_iterator=dataset, loss_fn=self._loss_fn, replay_client=replay_client, replay_table_name=self._config.replay_table_name, counter=counter, num_sgd_steps_per_step=self._config.num_sgd_steps_per_step, logger=logger_fn('learner'))
def test_beta_bernoulli(elbo): data = jnp.array([1.0] * 8 + [0.0] * 2) def model(data): f = numpyro.sample("beta", dist.Beta(1.0, 1.0)) numpyro.sample("obs", dist.Bernoulli(f), obs=data) def guide(data): alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive) beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive) numpyro.sample("beta", dist.Beta(alpha_q, beta_q)) adam = optax.adam(0.05) svi = SVI(model, guide, adam, elbo) svi_state = svi.init(random.PRNGKey(1), data) assert_allclose( svi.optim.get_params(svi_state.optim_state)["alpha_q"], 0.0) def body_fn(i, val): svi_state, _ = svi.update(val, data) return svi_state svi_state = fori_loop(0, 2000, body_fn, svi_state) params = svi.get_params(svi_state) assert_allclose( params["alpha_q"] / (params["alpha_q"] + params["beta_q"]), 0.8, atol=0.05, rtol=0.05, )
def test_fitting_surrogate_posterior_stateless(self): if not JAX_MODE: self.skipTest('Requires optax.') import optax # pylint: disable=g-import-not-at-top prior_dist = self.make_prior_dist() observations = self.get_observations(prior_dist) init_fn, build_surrogate_posterior_fn = ( tfp.experimental.vi.build_asvi_surrogate_posterior_stateless( prior=prior_dist)) target_log_prob = self.get_target_log_prob(observations, prior_dist) def loss_fn(*params, seed=None): surrogate_posterior = build_surrogate_posterior_fn(*params) zs, q_lp = surrogate_posterior.experimental_sample_and_log_prob( 10, seed=seed) return tf.reduce_mean(q_lp - target_log_prob(*zs), axis=0) # Test vi fit surrogate posterior works optimized_params, _ = tfp.math.minimize_stateless( loss_fn, init=init_fn(seed=test_util.test_seed()), num_steps=5, # Don't optimize to completion. optimizer=optax.adam(0.1), seed=test_util.test_seed(sampler_type='stateless')) surrogate_posterior = build_surrogate_posterior_fn(optimized_params) surrogate_posterior.sample( 100, seed=test_util.test_seed(sampler_type='stateless'))
def test_batch_overfit(train_dataset): vocab_size, d_model, num_heads, num_layers = 100, 32, 8, 1 dropout_rate, grad_clip_value, learning_rate = 0.01, 0.25, 2e-2 max_iter = 100 # Set up the model, loss, and updater. forward_fn = build_forward_fn(vocab_size, d_model, num_heads, num_layers, dropout_rate, max_iter) forward_fn = hk.transform(forward_fn) loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size) optimizer = optax.chain(optax.clip_by_global_norm(grad_clip_value), optax.adam(learning_rate, b1=0.9, b2=0.99)) updater = Updater(forward_fn.init, loss_fn, optimizer) # Initialize parameters. rng = jax.random.PRNGKey(428) data = next(train_dataset) state = updater.init(rng, data) for step in range(100): data = next(train_dataset) state, metrics = updater.update(state, data) assert metrics['loss'] < 0.1 assert metrics['step'] == 99
def make_learner( self, random_key: networks_lib.PRNGKey, networks: r2d2_networks.R2D2Networks, dataset: Iterator[r2d2_learning.R2D2ReplaySample], logger_fn: loggers.LoggerFactory, environment_spec: specs.EnvironmentSpec, replay_client: Optional[reverb.Client] = None, counter: Optional[counting.Counter] = None, ) -> core.Learner: del environment_spec # The learner updates the parameters (and initializes them). return r2d2_learning.R2D2Learner( unroll=networks.unroll, initial_state=networks.initial_state, batch_size=self._batch_size_per_device, random_key=random_key, burn_in_length=self._config.burn_in_length, discount=self._config.discount, importance_sampling_exponent=( self._config.importance_sampling_exponent), max_priority_weight=self._config.max_priority_weight, target_update_period=self._config.target_update_period, iterator=dataset, optimizer=optax.adam(self._config.learning_rate), bootstrap_n=self._config.bootstrap_n, tx_pair=self._config.tx_pair, clip_rewards=self._config.clip_rewards, replay_client=replay_client, counter=counter, logger=logger_fn('learner'))
def main(batch_size: int = 64, k: int = 5, debug: bool = False): noise = np.float32(np.random.normal(size=(3000, 1))) # random noise y_train = np.float32(np.random.uniform(-10.5, 10.5, (1, 3000))).T X_train = np.float32( np.sin(0.75 * y_train) * 7.0 + y_train * 0.5 + noise * 1.0) X_train = X_train / np.abs(X_train.max()) y_train = y_train / np.abs(y_train.max()) visualize_data(X_train, y_train) model = elegy.Model(module=MixtureModel(k=k), loss=MixtureNLL(), optimizer=optax.adam(3e-4)) model.summary(X_train[:batch_size], depth=1) model.fit( x=X_train, y=y_train, epochs=500, batch_size=batch_size, shuffle=True, ) visualize_model(X_train, y_train, model, k)
def optimize_club(num_steps: int): """Solves the karte club problem by optimizing the assignments of students.""" network = hk.without_apply_rng(hk.transform(network_definition)) zacharys_karate_club = get_zacharys_karate_club() labels = get_ground_truth_assignments_for_zacharys_karate_club() params = network.init(jax.random.PRNGKey(42), zacharys_karate_club) @jax.jit def prediction_loss(params): decoded_nodes = network.apply(params, zacharys_karate_club) # We interpret the decoded nodes as a pair of logits for each node. log_prob = jax.nn.log_softmax(decoded_nodes) # The only two assignments we know a-priori are those of Mr. Hi (Node 0) # and John A (Node 33). return -(log_prob[0, 0] + log_prob[33, 1]) opt_init, opt_update = optax.adam(1e-2) opt_state = opt_init(params) @jax.jit def update(params, opt_state): g = jax.grad(prediction_loss)(params) updates, opt_state = opt_update(g, opt_state) return optax.apply_updates(params, updates), opt_state @jax.jit def accuracy(params): decoded_nodes = network.apply(params, zacharys_karate_club) return jnp.mean(jnp.argmax(decoded_nodes, axis=1) == labels) for step in range(num_steps): logging.info("step %r accuracy %r", step, accuracy(params).item()) params, opt_state = update(params, opt_state)
def subspace_sampler(key, loglikelihood, logprior, params_init_tree, build_sampler, data, batch_size, subspace_dim, nsamples, opt=optax.adam(learning_rate=0.1), nsteps_full=0, nsteps_sub=0, projection_matrix=None, use_cv=True, pbar=True): subspace_key, sample_key = split(key) if nsteps_full > 0 or nsteps_sub > 0: # Find good control variate / starting point in subspace params_tree, params_sub, log_post_trace, subspace_fns = subspace_optimizer( subspace_key, loglikelihood, logprior, params_init_tree, data, batch_size, subspace_dim, nsteps_full, nsteps_sub, opt, pbar=pbar) else: params_sub = jax.random.normal(subspace_key, (subspace_dim,)) params_init_flat, _ = jax.flatten_util.ravel_pytree(params_init_tree) full_dim = len(params_init_flat) if projection_matrix is None: projection_matrix = generate_random_basis(subspace_key, full_dim, subspace_dim) subspace_fns = make_subspace_fns(loglikelihood, logprior, params_init_tree, projection_matrix) loglik_sub, logprior_sub, subspace_to_pytree_fn = subspace_fns if use_cv: sampler_sub = build_sampler(loglikelihood=loglik_sub, logprior=logprior_sub, data=data, batch_size=batch_size, centering_value=params_sub, pbar=pbar) else: sampler_sub = build_sampler(loglikelihood=loglik_sub, logprior=logprior_sub, data=data, batch_size=batch_size, pbar=pbar) params_sub_samples = sampler_sub(sample_key, nsamples, params_sub) params_tree_samples = vmap(subspace_to_pytree_fn)(params_sub_samples) return params_tree_samples, params_sub_samples, subspace_fns
def __init__(self, f, f_targ=None, optimizer=None, loss_function=None, policy_regularizer=None): self._f = f self._f_targ = f if f_targ is None else f_targ self.loss_function = huber if loss_function is None else loss_function if not isinstance(policy_regularizer, (Regularizer, type(None))): raise TypeError( f"policy_regularizer must be a Regularizer, got: {type(policy_regularizer)}" ) self.policy_regularizer = policy_regularizer # optimizer self._optimizer = optax.adam(1e-3) if optimizer is None else optimizer self._optimizer_state = self.optimizer.init(self._f.params) def apply_grads_func(opt, opt_state, params, grads): updates, new_opt_state = opt.update(grads, opt_state, params) new_params = optax.apply_updates(params, updates) return new_opt_state, new_params self._apply_grads_func = jit(apply_grads_func, static_argnums=0)
def subspace_optimizer(key, loglikelihood, logprior, params_init_tree, data, batch_size, subspace_dim, nwarmup, nsteps, opt=optax.adam(learning_rate=0.1), projection_matrix=None, pbar=True): opt_key, subspace_key, sub_init_key, sub_opt_key = split(key, 4) # Find good anchor in full space during warmup phase if nwarmup > 0: optimizer = build_optax_optimizer(opt, loglikelihood, logprior, data, batch_size, pbar) params_init_tree, _ = optimizer(opt_key, nwarmup, params_init_tree) # Make Random subspace if projection_matrix is None: params_init_flat, _ = jax.flatten_util.ravel_pytree(params_init_tree) full_dim = len(params_init_flat) projection_matrix = generate_random_basis(subspace_key, full_dim, subspace_dim) # TODO: add SVD loglik_sub, logprior_sub, subspace_to_pytree_fn = make_subspace_fns( loglikelihood, logprior, params_init_tree, projection_matrix) subspace_fns = (loglik_sub, logprior_sub, subspace_to_pytree_fn) # Do subspace optimization starting from rnd location params_subspace = jax.random.normal(sub_init_key, (subspace_dim,)) optimizer_sub = build_optax_optimizer(opt, loglik_sub, logprior_sub, data, batch_size, pbar) params_subspace, log_post_trace = optimizer_sub(sub_opt_key, nsteps, params_subspace) params_tree = subspace_to_pytree_fn(params_subspace) return params_tree, params_subspace, log_post_trace, subspace_fns
def main(_): # Create the dataset. train_dataset, vocab_size = dataset.load(FLAGS.batch_size, FLAGS.sequence_length) # Set up the model, loss, and updater. forward_fn = build_forward_fn(vocab_size, FLAGS.d_model, FLAGS.num_heads, FLAGS.num_layers, FLAGS.dropout_rate) forward_fn = hk.transform(forward_fn) loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size) optimizer = optax.chain(optax.clip_by_global_norm(FLAGS.grad_clip_value), optax.adam(FLAGS.learning_rate, b1=0.9, b2=0.99)) updater = Updater(forward_fn.init, loss_fn, optimizer) updater = CheckpointingUpdater(updater, FLAGS.checkpoint_dir) # Initialize parameters. logging.info('Initializing parameters...') rng = jax.random.PRNGKey(428) data = next(train_dataset) state = updater.init(rng, data) logging.info('Starting train loop...') prev_time = time.time() for step in range(MAX_STEPS): data = next(train_dataset) state, metrics = updater.update(state, data) # We use JAX runahead to mask data preprocessing and JAX dispatch overheads. # Using values from state/metrics too often will block the runahead and can # cause these overheads to become more prominent. if step % LOG_EVERY == 0: steps_per_sec = LOG_EVERY / (time.time() - prev_time) prev_time = time.time() metrics.update({'steps_per_sec': steps_per_sec}) logging.info({k: float(v) for k, v in metrics.items()})
def train_model(workdir): """Train for a fixed number of steps and decode during training.""" key = jax.random.PRNGKey(0) key, init_key = jax.random.split(key) model = Seq2seq(teacher_force=False, hidden_size=FLAGS.hidden_size) params = get_initial_params(model, init_key) tx = optax.adam(FLAGS.learning_rate) state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx) writer = metric_writers.create_default_writer(workdir) for step in range(FLAGS.num_train_steps): key, lstm_key = jax.random.split(key) batch = get_batch(FLAGS.batch_size) state, metrics = train_step(state, batch, lstm_key) if step % FLAGS.decode_frequency == 0: writer.write_scalars(step, metrics) key, lstm_key = jax.random.split(key) batch = get_batch(5) decode_batch(state.params, batch, lstm_key) return state
def adam(learning_rate: ScalarOrSchedule, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-8, eps_root: float = 0.0) -> Optimizer: """The classic Adam optimiser. Adam is an SGD variant with learning rate adaptation. The `learning_rate` used for each weight is computed from estimates of first- and second-order moments of the gradients (using suitable exponential moving averages). References: [Kingma et al, 2014](https://arxiv.org/abs/1412.6980) Args: learning_rate: This is a fixed global scaling factor. b1: The exponential decay rate to track the first moment of past gradients. b2: The exponential decay rate to track the second moment of past gradients. eps: A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling. eps_root: A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example when computing (meta-)gradients through Adam. Returns: The corresponding `Optimizer`. """ return create_optimizer_from_optax( optax.adam( learning_rate=learning_rate, b1=b1, b2=b2, eps=eps, eps_root=eps_root))
def main(_): optimizer = optax.adam(FLAGS.learning_rate) @jax.jit def update(params: hk.Params, prng_key: PRNGKey, opt_state: OptState, batch: Batch) -> Tuple[hk.Params, OptState]: """Single SGD update step.""" grads = jax.grad(loss_fn)(params, prng_key, batch) updates, new_opt_state = optimizer.update(grads, opt_state) new_params = optax.apply_updates(params, updates) return new_params, new_opt_state prng_seq = hk.PRNGSequence(42) params = log_prob.init(next(prng_seq), np.zeros((1, *MNIST_IMAGE_SHAPE))) opt_state = optimizer.init(params) train_ds = load_dataset(tfds.Split.TRAIN, FLAGS.batch_size) valid_ds = load_dataset(tfds.Split.TEST, FLAGS.batch_size) for step in range(FLAGS.training_steps): params, opt_state = update(params, next(prng_seq), opt_state, next(train_ds)) if step % FLAGS.eval_frequency == 0: val_loss = eval_fn(params, next(valid_ds)) logging.info("STEP: %5d; Validation loss: %.3f", step, val_loss)
def test_basic_variational_fitting_stateless(self): if not JAX_MODE: self.skipTest('Uses `optax` for stateless optimization') import optax # pylint: disable=g-import-not-at-top batch_shape = [2, 3] num_timesteps = 5 num_inits = 10 observed_time_series = self._build_tensor( np.random.randn(*(batch_shape + [num_timesteps]))) model = self._build_model(observed_time_series) seed = test_util.test_seed(sampler_type='stateless') init_seed, fit_seed = tfp.random.split_seed(seed, n=2) init_fn, build_surrogate_fn = ( tfp.sts.build_factored_surrogate_posterior_stateless( model, batch_shape=num_inits)) jd = model.joint_distribution( observed_time_series=observed_time_series) _, loss_curve = tfp.vi.fit_surrogate_posterior_stateless( jd.log_prob, build_surrogate_posterior_fn=build_surrogate_fn, initial_parameters=init_fn(init_seed), sample_size=3, num_steps=10, optimizer=optax.adam(1e-1), jit_compile=True, seed=fit_seed) self.assertLess(np.mean(loss_curve[-1]), np.mean(loss_curve[0]))
def test_sdp_dual_simple_no_crash(self, model_type): verif_instance = test_utils.make_toy_verif_instance(seed=0, target_label=1, label=2, nn=model_type) kwargs = { 'key': jax.random.PRNGKey(0), 'opt': optax.adam(1e-3), 'num_steps': 10, 'eval_every': 5, 'verbose': False, 'use_exact_eig_eval': False, 'use_exact_eig_train': False, 'n_iter_lanczos': 5, 'kappa_reg_weight': 1e-5, 'kappa_zero_after': 8, 'device_type': None, } verif_instance = utils.make_sdp_verif_instance(verif_instance) # Check all kwargs work. dual_val, _ = sdp_verify.solve_sdp_dual_simple(verif_instance, **kwargs) assert isinstance(dual_val, float) # Check code runs without kwargs. dual_val, _ = sdp_verify.solve_sdp_dual_simple(verif_instance, num_steps=5) assert isinstance(dual_val, float)
def create_train_state(config, rng, init_samples): """Creates the training state.""" model = create_model(config) params = model.init(rng, *init_samples) tx = optax.adam(learning_rate=config.learning_rate) return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)
def main(_): # Create an environment and grab the spec. environment = bc_utils.make_environment() environment_spec = specs.make_environment_spec(environment) # Unwrap the environment to get the demonstrations. dataset = bc_utils.make_demonstrations(environment.environment, FLAGS.batch_size) dataset = dataset.as_numpy_iterator() # Create the networks to optimize. network = bc_utils.make_network(environment_spec) key = jax.random.PRNGKey(FLAGS.seed) key, key1 = jax.random.split(key, 2) def logp_fn(logits, actions): logits_actions = jnp.sum(jax.nn.one_hot(actions, logits.shape[-1]) * logits, axis=-1) logits_actions = logits_actions - special.logsumexp(logits, axis=-1) return logits_actions loss_fn = bc.logp(logp_fn=logp_fn) learner = bc.BCLearner(network=network, random_key=key1, loss_fn=loss_fn, optimizer=optax.adam(FLAGS.learning_rate), demonstrations=dataset, num_sgd_steps_per_step=1) def evaluator_network(params: hk.Params, key: jnp.DeviceArray, observation: jnp.DeviceArray) -> jnp.DeviceArray: dist_params = network.apply(params, observation) return rlax.epsilon_greedy(FLAGS.evaluation_epsilon).sample( key, dist_params) actor_core = actor_core_lib.batched_feed_forward_to_actor_core( evaluator_network) variable_client = variable_utils.VariableClient(learner, 'policy', device='cpu') evaluator = actors.GenericActor(actor_core, key, variable_client, backend='cpu') eval_loop = acme.EnvironmentLoop(environment=environment, actor=evaluator, logger=loggers.TerminalLogger( 'evaluation', time_delta=0.)) # Run the environment loop. while True: for _ in range(FLAGS.evaluate_every): learner.step() eval_loop.run(FLAGS.evaluation_episodes)
def test_chunking(self, relaxer): batch_size = 3 input_size = 2 hidden_size = 5 final_size = 4 input_shape = (batch_size, input_size) hidden_lay_weight_shape = (input_size, hidden_size) final_lay_weight_shape = (hidden_size, final_size) inp_lb, inp_ub = test_utils.sample_bounds(jax.random.PRNGKey(0), input_shape, minval=-1., maxval=1.) inp_bound = jax_verify.IntervalBound(inp_lb, inp_ub) hidden_lay_weight = jax.random.uniform(jax.random.PRNGKey(1), hidden_lay_weight_shape) final_lay_weight = jax.random.uniform(jax.random.PRNGKey(2), final_lay_weight_shape) def model_fun(inp): hidden = inp @ hidden_lay_weight act = jax.nn.relu(hidden) final = act @ final_lay_weight return final if isinstance(relaxer, linear_bound_utils.ParameterizedLinearBoundsRelaxer): concretizing_transform = ( backward_crown.OptimizingLinearBoundBackwardTransform( relaxer, backward_crown.CONCRETIZE_ARGS_PRIMITIVE, optax.adam(1.e-3), num_opt_steps=10)) else: concretizing_transform = backward_crown.LinearBoundBackwardTransform( relaxer, backward_crown.CONCRETIZE_ARGS_PRIMITIVE) chunked_concretizer = backward_crown.ChunkedBackwardConcretizer( concretizing_transform, max_chunk_size=16) unchunked_concretizer = backward_crown.ChunkedBackwardConcretizer( concretizing_transform, max_chunk_size=0) chunked_algorithm = bound_utils.BackwardConcretizingAlgorithm( chunked_concretizer) unchunked_algorithm = bound_utils.BackwardConcretizingAlgorithm( unchunked_concretizer) chunked_bound, _ = bound_propagation.bound_propagation( chunked_algorithm, model_fun, inp_bound) unchunked_bound, _ = bound_propagation.bound_propagation( unchunked_algorithm, model_fun, inp_bound) np.testing.assert_array_almost_equal(chunked_bound.lower, unchunked_bound.lower) np.testing.assert_array_almost_equal(chunked_bound.upper, unchunked_bound.upper)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: networks_lib.FeedForwardNetwork, config: DQNConfig, ): """Initialize the agent.""" # Data is communicated via reverb replay. reverb_replay = replay.make_reverb_prioritized_nstep_replay( environment_spec=environment_spec, n_step=config.n_step, batch_size=config.batch_size, max_replay_size=config.max_replay_size, min_replay_size=config.min_replay_size, priority_exponent=config.priority_exponent, discount=config.discount, ) self._server = reverb_replay.server optimizer = optax.chain( optax.clip_by_global_norm(config.max_gradient_norm), optax.adam(config.learning_rate), ) key_learner, key_actor = jax.random.split( jax.random.PRNGKey(config.seed)) # The learner updates the parameters (and initializes them). learner = learning.DQNLearner( network=network, obs_spec=environment_spec.observations, random_key=key_learner, optimizer=optimizer, discount=config.discount, importance_sampling_exponent=config.importance_sampling_exponent, target_update_period=config.target_update_period, iterator=reverb_replay.data_iterator, replay_client=reverb_replay.client, ) # The actor selects actions according to the policy. def policy(params: networks_lib.Params, key: jnp.ndarray, observation: jnp.ndarray) -> jnp.ndarray: action_values = network.apply(params, observation) return rlax.epsilon_greedy(config.epsilon).sample( key, action_values) actor = actors.FeedForwardActor( policy=policy, rng=hk.PRNGSequence(key_actor), variable_client=variable_utils.VariableClient(learner, ''), adder=reverb_replay.adder) super().__init__( actor=actor, learner=learner, min_observations=max(config.batch_size, config.min_replay_size), observations_per_step=config.batch_size / config.samples_per_insert, )
def main(logdir: str = "runs", steps_per_epoch: tp.Optional[int] = None, epochs: int = 10, batch_size: int = 32): platform = jax.local_devices()[0].platform ndevices = len(jax.devices()) print('devices ', jax.devices()) print('platform ', platform) current_time = datetime.now().strftime("%b%d_%H-%M-%S") logdir = os.path.join(logdir, current_time) dataset = load_dataset("mnist") dataset.set_format("np") X_train = dataset["train"]["image"][..., None] y_train = dataset["train"]["label"] X_test = dataset["test"]["image"][..., None] y_test = dataset["test"]["label"] accuracies = {} # we run distributed=False twice to remove any initial warmup costs for distributed in [False, False, True]: print(f'Distributed training = {distributed}') start_time = time.time() model = eg.Model(module=CNN(), loss=eg.losses.Crossentropy(), metrics=eg.metrics.Accuracy(), optimizer=optax.adam(1e-3), seed=42) if distributed: model = model.distributed() bs = batch_size #int(batch_size / ndevices) else: bs = batch_size #model.summary(X_train[:64], depth=1) history = model.fit(inputs=X_train, labels=y_train, epochs=epochs, steps_per_epoch=steps_per_epoch, batch_size=bs, validation_data=(X_test, y_test), shuffle=True, verbose=3) ev = model.evaluate(x=X_test, y=y_test, verbose=1) print('eval ', ev) accuracies[distributed] = ev['accuracy'] end_time = time.time() print(f'time taken ', {end_time - start_time}) print(accuracies)
def create_train_state(key, rng, batch_size, learning_rate): init_data = jnp.ones([batch_size, 28, 28, 1], jnp.float32) state = train_state.TrainState.create( apply_fn=self.model().apply, params=self.model().init(key, init_data, rng)['params'], tx=optax.adam(learning_rate), ) return state