def main(unused_argv): # Build data and . print('Loading data.') x_train, y_train, x_test, y_test = datasets.mnist(permute_train=True) # Build the network init_fn, f = stax.serial(layers.Dense(2048), stax.Tanh, layers.Dense(10)) key = random.PRNGKey(0) _, params = init_fn(key, (-1, 784)) # Linearize the network about its initial parameters. f_lin = tangents.linearize(f, params) # Create and initialize an optimizer for both f and f_lin. opt_init, opt_apply = optimizers.momentum(FLAGS.learning_rate, 0.9) opt_apply = jit(opt_apply) state = opt_init(params) state_lin = opt_init(params) # Create a cross-entropy loss function. loss = lambda fx, y_hat: -np.mean(stax.logsoftmax(fx) * y_hat) # Specialize the loss function to compute gradients for both linearized and # full networks. grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y))) grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y))) # Train the network. print('Training.') print('Epoch\tLoss\tLinearized Loss') print('------------------------------------------') epoch = 0 steps_per_epoch = 50000 // FLAGS.batch_size for i, (x, y) in enumerate( datasets.minibatch(x_train, y_train, FLAGS.batch_size, FLAGS.train_epochs)): params = optimizers.get_params(state) state = opt_apply(i, grad_loss(params, x, y), state) params_lin = optimizers.get_params(state_lin) state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin) if i % steps_per_epoch == 0: print('{}\t{:.4f}\t{:.4f}'.format(epoch, loss(f(params, x), y), loss(f_lin(params_lin, x), y))) epoch += 1 # Print out summary data comparing the linear / nonlinear model. x, y = x_train[:10000], y_train[:10000] util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss) util.print_summary('test', y_test, f(params, x_test), f_lin(params_lin, x_test), loss)
def mapped_update(i, opt_state, batch): _, opt_update = optimizer(lr_fun) params = jax_opt.get_params(opt_state) grads = jax.grad(loss_fun)(params, batch, predict_fun) grads = jax.tree_util.tree_map(lambda g: jax.lax.psum(g, "batch"), grads) return opt_update(i, grads, opt_state)
def private_update(rng, i, opt_state, batch): params = optimizers.get_params(opt_state) rng = random.fold_in(rng, i) # get new key for new random numbers return opt_update( i, private_grad(params, batch, rng, FLAGS.l2_norm_clip, FLAGS.noise_multiplier, FLAGS.batch_size), opt_state)
def update_fn(i, opt_state, rng, model_args=(), guide_args=()): model_init, guide_init = _seed(model, guide, rng) params = optimizers.get_params(opt_state) loss_val, grads = value_and_grad(loss)(params, model_init, guide_init, model_args, guide_args, kwargs) opt_state = optim_update(i, grads, opt_state) rng, = random.split(rng, 1) return loss_val, opt_state, rng
def update_w_gc(i, opt_state, opt_update, x_bxt, f_bxt, f_mask_bxt, max_grad_norm, l2reg): """Update the parameters w/ gradient clipped, gradient descent updates. Arguments: i: batch number opt_state: parameters plus optimizer state x_bxt: rnn inputs f_bxt: rnn targets f_mask_bxt: masks for when target is defined max_grad_norm: maximum norm value gradient is allowed to take l2reg: l2 regularization hyperparameter Returns: opt_state tuple (as above) that includes updated parameters and optimzier state. """ params = optimizers.get_params(opt_state) unflatten = flatten_util.ravel_pytree(params)[1] # Requires shape def training_loss(params, x_bxt, f_bxt, l2reg): return loss(params, x_bxt, f_bxt, f_mask_bxt, l2reg)['total'] grads = grad(training_loss)(params, x_bxt, f_bxt, l2reg) flat_grads = flatten(grads) grad_norm = np.sqrt(np.sum(flat_grads**2)) normed_grads = np.where(grad_norm <= max_grad_norm, flat_grads, flat_grads * (max_grad_norm / grad_norm)) uf_grads = unflatten(normed_grads) return opt_update(i, uf_grads, opt_state)
def main(args): # Generate some data. data = random.normal(PRNGKey(0), shape=(100, )) + 3.0 # Construct an SVI object so we can do variational inference on our # model/guide pair. opt_init, opt_update = optimizers.adam(args.learning_rate) svi_init, svi_update, _ = svi(model, guide, elbo, opt_init, opt_update) rng = PRNGKey(0) opt_state = svi_init(rng, model_args=(data, )) # Training loop rng, = random.split(rng, 1) def body_fn(i, val): opt_state_, rng_ = val loss, opt_state_, rng_ = svi_update(i, opt_state_, rng_, model_args=(data, )) return opt_state_, rng_ opt_state, _ = lax.fori_loop(0, args.num_steps, body_fn, (opt_state, rng)) # Report the final values of the variational parameters # in the guide after training. params = optimizers.get_params(opt_state) for name, value in params.items(): print("{} = {}".format(name, value)) # For this simple (conjugate) model we know the exact posterior. In # particular we know that the variational distribution should be # centered near 3.0. So let's check this explicitly. assert np.abs(params["guide_loc"] - 3.0) < 0.1
def evaluate(opt_state, images): params = optimizers.get_params(opt_state) elbo_rng, data_rng, image_rng = random.split(test_rng, 3) binarized_test = random.bernoulli(data_rng, images) test_elbo = elbo(elbo_rng, params, binarized_test) / images.shape[0] sampled_images = image_sample(image_rng, params, nrow, ncol) return test_elbo, sampled_images
def ppo_opt_step(i, opt_state, ppo_opt_update, policy_net_apply, old_policy_params, value_net_apply, value_net_params, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=0.99, lambda_=0.95, epsilon=0.1): """PPO optimizer step.""" new_policy_params = optimizers.get_params(opt_state) g = grad(ppo_loss, argnums=1)(policy_net_apply, new_policy_params, old_policy_params, value_net_apply, value_net_params, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon) return ppo_opt_update(i, g, opt_state)
def test_optim_multi_params(): params = {'x': np.array([1., 1., 1.]), 'y': np.array([-1, -1., -1.])} opt_init, opt_update = optimizers.adam(step_size=1e-2) opt_state = opt_init(params) for i in range(1000): opt_state = step(i, opt_state, opt_update) for _, param in optimizers.get_params(opt_state).items(): assert np.allclose(param, np.zeros(3))
def body_fun(i, loop_carry): (rng, opt_state, images) = loop_carry rng, elbo_rng, data_rng = random.split(rng, 3) batch = binarize_batch(data_rng, i, images) loss = lambda params: -elbo(elbo_rng, params, batch) / batch_size g = grad(loss)(optimizers.get_params(opt_state)) loop_carry = rng, opt_update(i, g, opt_state), images return loop_carry
def mapped_update(i, opt_state, batch, rng): """This is a multi-device version of the update function above.""" # We assume all tensors have the first dimension = num_devices. _, opt_update = optimizer(lr_fun) params = jax_opt.get_params(opt_state) grads = jax.grad(loss_fun)(params, batch, predict_fun, rng) grads = jax.tree_util.tree_map(lambda g: jax.lax.psum(g, "batch"), grads) return opt_update(i, grads, opt_state)
def minimize(f, x, num_steps=10000, step_size=0.000001, mass=0.9): opt_init, opt_update = optimizers.momentum(step_size, mass) @jit def update(i, opt_state): x = optimizers.get_params(opt_state) return opt_update(i, grad(f)(x), opt_state) opt_state = opt_init(x) for i in xrange(num_steps): opt_state = update(i, opt_state) return optimizers.get_params(opt_state)
def update_w_gc(i, opt_state, opt_update, x_bxt, f_bxt, max_grad_norm, l2reg): """Update the parameters w/ gradient clipped, gradient descent updates.""" params = optimizers.get_params(opt_state) unflatten = flatten_util.ravel_pytree(params)[1] # Requires shape grads = grad(loss)(params, x_bxt, f_bxt, l2reg) flat_grads = flatten(grads) grad_norm = np.sqrt(np.sum(flat_grads**2)) normed_grads = np.where(grad_norm <= max_grad_norm, flat_grads, flat_grads * (max_grad_norm / grad_norm)) uf_grads = unflatten(normed_grads) return opt_update(i, uf_grads, opt_state)
def update_w_gc(i, opt_state, lfads_hps, lfads_opt_hps, key, x_bxt, kl_warmup): max_grad_norm = lfads_opt_hps['max_grad_norm'] keep_rate = lfads_opt_hps['keep_rate'] params = optimizers.get_params(opt_state) grads = grad(lfads.lfads_training_loss)(params, lfads_hps, key, x_bxt, kl_warmup, keep_rate) flat_grads = flatten_lfads(grads) grad_norm = np.sqrt(np.sum(flat_grads**2)) normed_grads = np.where(grad_norm <= max_grad_norm, flat_grads, flat_grads * (max_grad_norm / grad_norm)) uf_grads = unflatten_lfads(normed_grads) return opt_update(i, uf_grads, opt_state)
def reconstruct_img(epoch): img = test_fetch(0, test_idx)[0][0] plt.imsave(os.path.join(RESULTS_DIR, 'original_epoch={}.png'.format(epoch)), img, cmap='gray') _, test_sample = binarize(rng, img) params = optimizers.get_params(opt_state) z_mean, z_var = encode(params['encoder'], test_sample.reshape([1, -1])) z = dist.norm(z_mean, z_var).rvs(random_state=rng) img_loc = decode(params['decoder'], z).reshape([28, 28]) plt.imsave(os.path.join(RESULTS_DIR, 'recons_epoch={}.png'.format(epoch)), img_loc, cmap='gray')
def main(unused_argv): # Build data pipelines. print('Loading data.') x_train, y_train, x_test, y_test = \ datasets.mnist(FLAGS.train_size, FLAGS.test_size) # Build the network init_fn, f = stax.serial(layers.Dense(4096), stax.Tanh, layers.Dense(10)) key = random.PRNGKey(0) _, params = init_fn(key, (-1, 784)) # Create and initialize an optimizer. opt_init, opt_apply = optimizers.sgd(FLAGS.learning_rate) state = opt_init(params) # Create an mse loss function and a gradient function. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2) grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y))) # Create an MSE predictor to solve the NTK equation in function space. theta = tangents.ntk(f, batch_size=32) g_dd = theta(params, x_train) g_td = theta(params, x_test, x_train) predictor = tangents.analytic_mse_predictor(g_dd, y_train, g_td) # Get initial values of the network in function space. fx_train = f(params, x_train) fx_test = f(params, x_test) # Train the network. train_steps = int(FLAGS.train_time // FLAGS.learning_rate) print('Training for {} steps'.format(train_steps)) for i in range(train_steps): params = optimizers.get_params(state) state = opt_apply(i, grad_loss(params, x_train, y_train), state) # Get predictions from analytic computation. print('Computing analytic prediction.') fx_train, fx_test = predictor(fx_train, fx_test, FLAGS.train_time) # Print out summary data comparing the linear / nonlinear model. util.print_summary('train', y_train, f(params, x_train), fx_train, loss) util.print_summary('test', y_test, f(params, x_test), fx_test, loss)
def value_opt_step(i, opt_state, opt_update, value_net_apply, padded_observations, padded_rewards, reward_mask, gamma=0.99): """Value optimizer step.""" value_params = optimizers.get_params(opt_state) # Note this partial application here and argnums above in ppo_opt_step. g = grad(functools.partial(value_loss, value_net_apply))(value_params, padded_observations, padded_rewards, reward_mask, gamma=gamma) return opt_update(i, g, opt_state)
def learn(opt_step, opt_state, params_Q_eval, params_Q_next): mini_batch = sample(memory, BATCH_SIZE) if opt_step % TAU == 0: params_Q_next = params_Q_eval.copy() input_states = np.stack([transition[0] for transition in mini_batch]) next_states = np.stack([transition[3] for transition in mini_batch]) predicted_Q = pred_Q(params_Q_eval, input_states) predicted_Q_next = pred_Q(params_Q_next, next_states) max_action = np.argmax(predicted_Q_next, axis=1) rewards = np.array([transition[2] for transition in mini_batch]) Q_target = onp.array(predicted_Q) Q_target[:, max_action] = rewards + GAMMA * np.max(predicted_Q_next, axis=1) opt_state = step(opt_step, opt_state, (input_states, Q_target)) params_Q_eval = optimizers.get_params(opt_state) return opt_state, params_Q_eval, params_Q_next
def train_fn(data_dir=None, output_dir=None, model=gin.REQUIRED, dataset=gin.REQUIRED, train_steps=1000, eval_steps=10, eval_frequency=100): """Train the given model on the given dataset. Args: data_dir: Directory where the data is located. output_dir: Directory where to put the logs and checkpoints. model: The model to train (a function). dataset: The name of the dataset to train on. train_steps: for how many steps to train. eval_steps: for how many steps to do evaluation. eval_frequency: how often (every this many steps) to run evaluation. """ (train_batches, eval_batches, input_name, input_shape) = input_pipeline.train_and_eval_batches( dataset, data_dir) train_stream = dataset_to_stream(train_batches, input_name) # Training loop. opt_init, opt_update = optimizer() model_init, model_predict = model() @jax.jit def update(i, opt_state, batch): params = optimizers.get_params(opt_state) return opt_update(i, jax.grad(loss)( params, batch, model_predict), opt_state) _, init_params = model_init([-1] + input_shape) step, train_sw, eval_sw = 0, None, None if output_dir is not None: _make_directory(output_dir) # Load parameters. loaded_params, loaded_step = load_params_and_step(output_dir) if loaded_params is not None: init_params = loaded_params if loaded_step is not None: step = loaded_step # Create summary writers. eval_sw = _make_summary_writer(os.path.join(output_dir, "eval_log")) train_sw = _make_summary_writer(os.path.join(output_dir, "train_log")) log("Starting training.") opt_state = opt_init(init_params) gin_config_saved = False while step < train_steps: # Training. start_time = time.time() for _ in range(eval_frequency): opt_state = update(step, opt_state, next(train_stream)) step += 1 epoch_time = time.time() - start_time log("Step {}, last {} steps in {:0.2f} sec".format( step, eval_frequency, epoch_time)) # Save the model. params = optimizers.get_params(opt_state) save_params_and_step(params, step, output_dir) # Save the config if not saved yet. # Gin file only includes used parameters, so we save it at this point. if output_dir and not gin_config_saved: gin_config_saved = True config_path = os.path.join(output_dir, "config.gin") with gfile.GFile(config_path, "w") as f: f.write(gin.operative_config_str()) # Evaluation. eval_stream = dataset_to_stream(eval_batches, input_name) eval_train_stream = dataset_to_stream(train_batches, input_name) train_acc, eval_acc, train_loss, eval_loss = 0.0, 0.0, 0.0, 0.0 for _ in range(eval_steps): train_acc += accuracy(params, next(eval_train_stream), model_predict) eval_acc += accuracy(params, next(eval_stream), model_predict) train_loss += loss(params, next(eval_train_stream), model_predict) eval_loss += loss(params, next(eval_stream), model_predict) train_acc /= eval_steps eval_acc /= eval_steps train_loss /= eval_steps eval_loss /= eval_steps log("Train accuracy {:0.4f} loss {:0.8f}".format(train_acc, train_loss)) if train_sw: train_sw.scalar("steps/s", epoch_time / eval_frequency, step=step) train_sw.scalar("accuracy", train_acc, step=step) train_sw.scalar("loss", train_loss, step=step) log("Eval accuracy {:0.4f} loss {:0.8f}".format(eval_acc, eval_loss)) if eval_sw: eval_sw.scalar("accuracy", eval_acc, step=step) train_sw.scalar("loss", eval_loss, step=step)
def body_fun(i, opt_state): elbo_rng, data_rng = random.split(random.fold_in(rng, i)) batch = binarize_batch(data_rng, i, train_images) loss = lambda params: -elbo(elbo_rng, params, batch) / batch_size g = grad(loss)(optimizers.get_params(opt_state)) return opt_update(i, g, opt_state)
def train(output_dir, model=gin.REQUIRED, inputs=gin.REQUIRED, optimizer=trax_opt.adam, lr_schedule=lr.MultifactorSchedule, train_steps=1000, eval_steps=10, eval_frequency=100, run_debug_step=False): """Train the model on the inputs. Args: output_dir: Directory where to put the logs and checkpoints. model: The model to train as a callable returning 2 callables, an init_fun and apply_fun. inputs: callable returning trax.inputs.Inputs. optimizer: The optimizer as a callable taking a learning_rate callable and returning 2 callables, opt_init and opt_update. lr_schedule: A learning rate schedule as a function that takes history and returns a function from step to learning rate (a float). train_steps: int, total number of training steps. eval_steps: int, num of steps per evaluation. If None or 0, eval disabled. eval_frequency: int, how often to run evaluation (every eval_frequency steps). If None or 0, eval disabled. run_debug_step: bool, if True, will run the model and loss without @jit for one step. Returns: trax.State """ rng = random.PRNGKey(0) gfile.makedirs(output_dir) # Create summary writers and history. train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train")) eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval")) inputs = inputs() # Setup optimizer and model state = restore_state(output_dir) history = state.history lr_fun = lr_schedule(history) opt_init, _ = optimizer(lr_fun) model_init, model_predict_original = model() # We need a model_predict that fills in the random generator if needed. def model_predict(x, y, **kwargs): """Same as model_predict_original but fill in rng if it isn't passed.""" if "rng" in kwargs: return model_predict_original(x, y, **kwargs) return model_predict_original(x, y, rng=rng, **kwargs) # Setup state step = state.step or 0 params_initializer = lambda: model_init([-1] + list(inputs.input_shape))[1] params = state.params or params_initializer() opt_state = opt_init(params) # jit model_predict and update so they're fast jit_model_predict = jax.jit(model_predict) # for evaluation jit_update_fun = _jit_update_fun(model_predict, loss, optimizer, lr_fun) print() train_stream = inputs.train_stream() epoch_steps = itertools.chain( [ 1, # first epoch only 1 step eval_frequency - 1 ], itertools.repeat(eval_frequency)) step_log(step, "Starting training") # Non-compiled debug step helps find problems in models easier. if run_debug_step: debug_loss = loss(params, next(train_stream), model_predict) step_log(step, "Debug step loss %.8f" % debug_loss) for epoch, epoch_steps in epochs(train_steps, epoch_steps): # Log separator print() # Timer start_time = time.time() for _ in range(epoch_steps): # Train opt_state = jit_update_fun(step, opt_state, next(train_stream)) step += 1 # LR log if step == 1 or step % 10 == 0: train_sw.scalar("training/learning rate", lr_fun(step), step=step) # Timer epoch_time = time.time() - start_time step_log( step, "Ran %d train steps in %0.2f secs" % (epoch_steps, epoch_time)) if epoch_steps > 1: train_sw.scalar("training/steps per second", epoch_steps / epoch_time, step=step) # Evaluate params = jax_opt.get_params(opt_state) evaluate_train_and_eval(step=step, inputs=inputs, predict_fun=functools.partial( jit_model_predict, params), eval_steps=eval_steps, train_sw=train_sw, eval_sw=eval_sw, history=history) # Save state save_state(State(params=params, step=step, history=history), output_dir) # Save Gin config # Gin only tracks the used parameters, so we save it after the first epoch. if epoch == 1: save_gin(output_dir, train_sw) # Update learning rate with new history old_lr_fun = lr_fun lr_fun = lr_schedule(history) if lr_fun != old_lr_fun: # For performance, only jit if there is a change. jit_update_fun = _jit_update_fun(model_predict, loss, optimizer, lr_fun) # Flush summary writers train_sw.writer.flush() eval_sw.writer.flush() step_log(step, "Training done") return State(params=params, step=step, history=history)
def update(i, opt_state, batch): _, opt_update = optimizer(lr_fun) params = jax_opt.get_params(opt_state) return opt_update(i, jax.grad(loss_fun)(params, batch, predict_fun), opt_state)
def train_fn(output_dir, data_dir, model=gin.REQUIRED, dataset=gin.REQUIRED, train_steps=1000, eval_steps=10, eval_frequency=100): """Train the given model on the given dataset. Args: output_dir: Directory where to put the logs and checkpoints. data_dir: Directory where the data is located. model: The model to train as a callable returning 2 callables, an init_fun and apply_fun. dataset: The name of the TFDS dataset to train on. To train on a T2T dataset, prefix the name with "t2t_". train_steps: int, total number of training steps. eval_steps: int, num of steps per evaluation. eval_frequency: int, how often to run evaluation (every eval_frequency steps). """ gfile.makedirs(output_dir) # Make Inputs inputs = inputs_lib.make_inputs(dataset, data_dir) # Setup optimizer and model opt_init, opt_update = optimizer() model_init, model_predict = model() # Setup state state = restore_state(output_dir) step = state.step or 0 params_initializer = lambda: model_init([-1] + inputs.input_shape)[1] opt_state = opt_init(state.params or params_initializer()) # Create summary writers. train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train")) eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval")) # jit model_predict and update so they're fast jit_predict = jax.jit(model_predict) # for evaluation @jax.jit def update(i, opt_state, batch): params = optimizers.get_params(opt_state) return opt_update(i, jax.grad(loss)(params, batch, model_predict), opt_state) print() step_log(step, "starting training") train_gen = inputs.train_fn() is_first_step = True epoch_steps = 1 # First evaluation after the first training step. while step < train_steps: print() # Train start_time = time.time() for _ in range(epoch_steps): opt_state = update(step, opt_state, next(train_gen)) if step % 10 == 0: # Log learning rate curve each 10 steps. train_sw.scalar("training/learning rate", learning_rate(step), step=step) step += 1 epoch_time = time.time() - start_time step_log( step, "ran %d train steps in %0.2f secs" % (epoch_steps, epoch_time)) # Save state params = optimizers.get_params(opt_state) save_state(State(params=params, step=step), output_dir, save_gin=is_first_step) # Evaluate step_log(step, "starting evaluation") train_metrics, eval_metrics = evaluate( inputs, functools.partial(jit_predict, params), eval_steps) log_metrics(train_metrics, train_sw, "train", step) log_metrics(eval_metrics, eval_sw, "eval ", step) # Log non-metric reports and flush. if not is_first_step: train_sw.scalar("training/steps per second", epoch_steps / epoch_time, step=step) train_sw.writer.flush() eval_sw.writer.flush() # After the first step, train for eval_frequency steps before evaluating epoch_steps = (eval_frequency - 1) if is_first_step else eval_frequency is_first_step = False print() step_log(step, "finished training")
def optimize_lfads(init_params, lfads_hps, lfads_opt_hps, train_data, eval_data): """Optimize the LFADS model and print batch based optimization data. Arguments: init_params: a dict of parameters to be trained lfads_hps: dict of lfads model HPs lfads_opt_hps: dict of optimization HPs train_data: nexamples x time x ndims np array of data for training eval_data: nexamples x time x ndims np array of data for evaluation Returns: a dictionary of trained parameters""" batch_size = lfads_hps['batch_size'] num_batches = lfads_opt_hps['num_batches'] print_every = lfads_opt_hps['print_every'] # Build some functions used in optimization. kl_warmup_fun = get_kl_warmup_fun(lfads_opt_hps) decay_fun = optimizers.exponential_decay(lfads_opt_hps['step_size'], lfads_opt_hps['decay_steps'], lfads_opt_hps['decay_factor']) opt_init, opt_update = optimizers.adam(step_size=decay_fun, b1=lfads_opt_hps['adam_b1'], b2=lfads_opt_hps['adam_b2'], eps=lfads_opt_hps['adam_eps']) update_w_gc = get_update_w_gc_fun(init_params, opt_update) update_w_gc_jit = jit(update_w_gc, static_argnums=(2, 3)) # Begin optimziation loop. all_tlosses = [] all_elosses = [] start_time = time.time() opt_state = opt_init(init_params) for bidx in range(num_batches): kl_warmup = kl_warmup_fun(bidx) didxs = onp.random.randint(0, train_data.shape[0], batch_size) x_bxt = train_data[didxs].astype(onp.float32) key = random.PRNGKey(onp.random.randint(0, utils.MAX_SEED_INT)) opt_state = update_w_gc_jit(bidx, opt_state, lfads_hps, lfads_opt_hps, key, x_bxt, kl_warmup) if bidx % print_every == 0: params = optimizers.get_params(opt_state) # Training loss didxs = onp.random.randint(0, train_data.shape[0], batch_size) x_bxt = train_data[didxs].astype(onp.float32) key = random.PRNGKey(onp.random.randint(0, utils.MAX_SEED_INT)) tlosses = lfads.lfads_losses_jit(params, lfads_hps, key, x_bxt, kl_warmup, 1.0) # Evaluation loss key = random.PRNGKey(onp.random.randint(0, utils.MAX_SEED_INT)) didxs = onp.random.randint(0, eval_data.shape[0], batch_size) ex_bxt = eval_data[didxs].astype(onp.float32) # Commented out lfads_eval_losses_jit cuz freezing. elosses = lfads.lfads_losses_jit(params, lfads_hps, key, ex_bxt, kl_warmup, 1.0) # Saving, printing. all_tlosses.append(tlosses) all_elosses.append(elosses) batch_time = time.time() - start_time s = "Batch {} in {:0.2f} sec, Step size: {:0.5f}, \ Training loss {:0.0f}, Eval loss {:0.0f}" print( s.format(bidx, batch_time, decay_fun(bidx), tlosses['total'], elosses['total'])) start_time = time.time() tlosses_thru_training = utils.merge_losses_dicts(all_tlosses) elosses_thru_training = utils.merge_losses_dicts(all_elosses) optimizer_details = { 'tlosses': tlosses_thru_training, 'elosses': elosses_thru_training } return optimizers.get_params(opt_state), optimizer_details
def training_loop(env=None, env_name="CartPole-v0", epochs=EPOCHS, batch_size=BATCH_TRAJECTORIES, num_optimizer_steps=NUM_OPTIMIZER_STEPS, print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP, random_seed=None): """Runs the training loop for PPO, with fixed policy and value nets.""" onp.random.seed(random_seed) value_losses = [] ppo_objective = [] average_rewards = [] env = env if env is not None else gym.make(env_name) batch_observations_shape = (-1, ) + env.observation_space.shape assert isinstance(env.action_space, gym.spaces.Discrete) num_actions = env.action_space.n rng_key = jax_random.PRNGKey(0) ((policy_net_params, policy_net_apply), (value_net_params, value_net_apply)) = initialize_policy_and_value_nets( rng_key, num_actions, batch_observations_shape) (ppo_opt_state, ppo_opt_update), (value_opt_state, value_opt_update) = initialize_optimizers( policy_net_params, value_net_params) for i in range(epochs): t = time.time() t0 = t trajs = collect_trajectories( env, policy_net_apply, policy_net_params, num_trajectories=batch_size, policy=POLICY, epsilon=(10.0 / (i + 10.0))) # this is a different epsilon. avg_reward = float(sum(np.sum(traj[2]) for traj in trajs)) / len(trajs) average_rewards.append(avg_reward) logging.debug("Average sum rewards [%0.2f]", avg_reward) logging.debug("Collecting trajectories took %0.2f msec.", get_time(t)) logging.debug("Average Trajectory size [%0.2f]", float(sum(len(traj[0]) for traj in trajs)) / len(trajs)) t = time.time() (_, reward_mask, padded_observations, padded_actions, padded_rewards) = pad_trajectories(trajs, boundary=20) logging.debug("Padding trajectories took %0.2f msec.", get_time(t)) logging.debug("Padded Actions' shape [%s]", str(padded_actions.shape)) # Linear annealing from 0.1 to 0.0 epsilon = 0.1 if epochs == 1 else 0.1 * (1.0 - (i / (epochs - 1))) t = time.time() cur_value_loss = value_loss(value_net_apply, value_net_params, padded_observations, padded_rewards, reward_mask, gamma=GAMMA) logging.debug("Calculating value loss took %0.2f msec.", get_time(t)) value_losses.append(cur_value_loss) t = time.time() cur_ppo_loss = ppo_loss(policy_net_apply, policy_net_params, policy_net_params, value_net_apply, value_net_params, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=GAMMA, lambda_=LAMBDA, epsilon=epsilon) # ppo_loss = 11.00110011 logging.debug("Calculating PPO loss took %0.2f msec.", get_time(t)) ppo_objective.append(-cur_ppo_loss) # Run optimizers. logging.debug("PPO Optimization") t1 = time.time() for j in range(num_optimizer_steps): t = time.time() # Update the optimizer state. ppo_opt_state = ppo_opt_step(j, ppo_opt_state, ppo_opt_update, policy_net_apply, policy_net_params, value_net_apply, value_net_params, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=GAMMA, lambda_=LAMBDA, epsilon=epsilon) t2 = time.time() # Get the new params. new_policy_net_params = optimizers.get_params(ppo_opt_state) if ((j + 1) % print_every_optimizer_steps == 0) or (j == num_optimizer_steps - 1): new_ppo_loss = ppo_loss(policy_net_apply, new_policy_net_params, policy_net_params, value_net_apply, value_net_params, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=GAMMA, lambda_=LAMBDA, epsilon=epsilon) logging.debug("One PPO grad desc took: %0.2f msec", get_time(t, t2)) logging.debug("PPO loss [%10.2f] -> [%10.2f]", cur_ppo_loss, new_ppo_loss) # Update the params. policy_net_params = new_policy_net_params logging.debug("Total PPO loss reduction [%0.2f]%%", (100 * (cur_ppo_loss - new_ppo_loss) / np.abs(cur_ppo_loss))) logging.debug("Value Optimization") for j in range(num_optimizer_steps): t = time.time() value_opt_state = value_opt_step(j, value_opt_state, value_opt_update, value_net_apply, padded_observations, padded_rewards, reward_mask, gamma=GAMMA) t2 = time.time() value_net_params = optimizers.get_params(value_opt_state) if ((j + 1) % print_every_optimizer_steps == 0) or (j == num_optimizer_steps - 1): new_value_loss = value_loss(value_net_apply, value_net_params, padded_observations, padded_rewards, reward_mask, gamma=GAMMA) logging.debug("One value grad desc took: %0.2f msec", get_time(t, t2)) logging.debug("Value loss [%10.2f] -> [%10.2f]", cur_value_loss, new_value_loss) logging.debug( "Total value loss reduction [%0.2f]%%", (100 * (cur_value_loss - new_value_loss) / np.abs(cur_value_loss))) logging.debug("Grad desc took %0.2f msec", get_time(t1)) # Set the optimized params to new params. policy_net_params = optimizers.get_params(ppo_opt_state) value_net_params = optimizers.get_params(value_opt_state) logging.info( "Epoch [% 6d], average reward [%10.2f], ppo loss [%10.2f], " "value loss [%10.2f], took [%10.2f msec]", i, avg_reward, new_ppo_loss, new_value_loss, get_time(t0)) logging.debug("value_losses: %s", np.stack(value_losses)) logging.debug("ppo_objective: %s", np.stack(ppo_objective)) logging.debug("average_rewards: %s", average_rewards) return ((policy_net_params, value_net_params), average_rewards, np.stack(value_losses), np.stack(ppo_objective))
perm = rng.permutation(num_train) for i in range(num_batches): batch_idx = perm[i * batch_size:(i + 1) * batch_size] yield train_images[batch_idx], train_labels[batch_idx] batches = data_stream() opt_init, opt_update = optimizers.momentum(step_size, mass=momentum_mass) @jit def update(i, opt_state, batch): params = optimizers.get_params(opt_state) return opt_update(i, grad(loss)(params, batch), opt_state) _, init_params = init_random_params(rng, (-1, 28 * 28)) opt_state = opt_init(init_params) itercount = itertools.count() print("\nStarting training...") for epoch in range(num_epochs): start_time = time.time() for _ in range(num_batches): opt_state = update(next(itercount), opt_state, next(batches)) epoch_time = time.time() - start_time params = optimizers.get_params(opt_state) train_acc = accuracy(params, (train_images, train_labels)) test_acc = accuracy(params, (test_images, test_labels)) print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time)) print("Training set accuracy {}".format(train_acc)) print("Test set accuracy {}".format(test_acc))
def update(i, opt_state, batch): params = optimizers.get_params(opt_state) return opt_update(i, jax.grad(loss)( params, batch, model_predict), opt_state)
def testNTKGDPrediction(self, shape, out_logits): key = random.PRNGKey(1) key, split = random.split(key) data_train = random.normal(split, shape) key, split = random.split(key) label_ids = random.randint(split, (shape[0], ), 0, out_logits) data_labels = np.eye(out_logits)[label_ids] key, split = random.split(key) data_test = random.normal(split, shape) key, w_split, b_split = random.split(key, 3) params = (random.normal(w_split, (shape[-1], out_logits)), random.normal(b_split, (out_logits, ))) def f(params, x): w, b = params return np.dot(x, w) / shape[-1] + b loss = lambda y, y_hat: 0.5 * np.mean((y - y_hat)**2) grad_loss = grad(lambda params, x: loss(f(params, x), data_labels)) theta = tangents.ntk(f) g_dd = theta(params, data_train) g_td = theta(params, data_test, data_train) predictor = tangents.gradient_descent_predictor( g_dd, data_labels, loss, g_td) step_size = 1.0 train_time = 100.0 steps = int(train_time / step_size) opt_init, opt_update = opt.sgd(step_size) opt_state = opt_init(params) fx_initial_train = f(params, data_train) fx_initial_test = f(params, data_test) fx_pred_train, fx_pred_test = predictor(fx_initial_train, fx_initial_test, 0.0) # NOTE(schsam): I think at the moment stax always generates 32-bit results # since the weights are explicitly cast to float32. self.assertAllClose(fx_initial_train, fx_pred_train, False) self.assertAllClose(fx_initial_test, fx_pred_test, False) for i in range(steps): params = opt.get_params(opt_state) opt_state = opt_update(i, grad_loss(params, data_train), opt_state) params = opt.get_params(opt_state) fx_train = f(params, data_train) fx_test = f(params, data_test) fx_pred_train, fx_pred_test = predictor(fx_initial_train, fx_initial_test, train_time) # Put errors in units of RMS distance of the function values during # optimization. fx_disp_train = np.sqrt(np.mean((fx_train - fx_initial_train)**2)) fx_disp_test = np.sqrt(np.mean((fx_test - fx_initial_test)**2)) fx_error_train = (fx_train - fx_pred_train) / fx_disp_train fx_error_test = (fx_test - fx_pred_test) / fx_disp_test self.assertAllClose(fx_error_train, np.zeros_like(fx_error_train), False, 0.1, 0.1) self.assertAllClose(fx_error_test, np.zeros_like(fx_error_test), False, 0.1, 0.1)
def update(i, opt_state, batch): params = optimizers.get_params(opt_state) return opt_update(i, grad(loss)(params, batch), opt_state)
def testNTKMomentumPrediction(self, shape, out_logits): key = random.PRNGKey(1) key, split = random.split(key) data_train = random.normal(split, shape) key, split = random.split(key) label_ids = random.randint(split, (shape[0], ), 0, out_logits) data_labels = np.eye(out_logits)[label_ids] key, split = random.split(key) data_test = random.normal(split, shape) key, w_split, b_split = random.split(key, 3) params = (random.normal(w_split, (shape[-1], out_logits)), random.normal(b_split, (out_logits, ))) def f(params, x): w, b = params return np.dot(x, w) / shape[-1] + b loss = lambda y, y_hat: 0.5 * np.mean((y - y_hat)**2) grad_loss = grad(lambda params, x: loss(f(params, x), data_labels)) theta = tangents.ntk(f) g_dd = theta(params, data_train) g_td = theta(params, data_test, data_train) step_size = 1.0 train_time = 100.0 steps = int(train_time / np.sqrt(step_size)) init_fn, predict_fn, get_fn = tangents.momentum_predictor( g_dd, data_labels, loss, step_size, g_td) opt_init, opt_update = momentum(step_size) opt_state = opt_init(params) fx_initial_train = f(params, data_train) fx_initial_test = f(params, data_test) lin_state = init_fn(fx_initial_train, fx_initial_test) for i in range(steps): params = opt.get_params(opt_state) opt_state = opt_update(i, grad_loss(params, data_train), opt_state) params = opt.get_params(opt_state) fx_train = f(params, data_train) fx_test = f(params, data_test) lin_state = predict_fn(lin_state, train_time) fx_pred_train, fx_pred_test = get_fn(lin_state) # Put errors in units of RMS distance of the function values during # optimization. fx_disp_train = np.sqrt(np.mean((fx_train - fx_initial_train)**2)) fx_disp_test = np.sqrt(np.mean((fx_test - fx_initial_test)**2)) fx_error_train = (fx_train - fx_pred_train) / fx_disp_train fx_error_test = (fx_test - fx_pred_test) / fx_disp_test self.assertAllClose(fx_error_train, np.zeros_like(fx_error_train), False, 0.1, 0.1) self.assertAllClose(fx_error_test, np.zeros_like(fx_error_test), False, 0.1, 0.1)