def main(unused_argv):
    # Build data and .
    print('Loading data.')
    x_train, y_train, x_test, y_test = datasets.mnist(permute_train=True)

    # Build the network
    init_fn, f = stax.serial(layers.Dense(2048), stax.Tanh, layers.Dense(10))

    key = random.PRNGKey(0)
    _, params = init_fn(key, (-1, 784))

    # Linearize the network about its initial parameters.
    f_lin = tangents.linearize(f, params)

    # Create and initialize an optimizer for both f and f_lin.
    opt_init, opt_apply = optimizers.momentum(FLAGS.learning_rate, 0.9)
    opt_apply = jit(opt_apply)

    state = opt_init(params)
    state_lin = opt_init(params)

    # Create a cross-entropy loss function.
    loss = lambda fx, y_hat: -np.mean(stax.logsoftmax(fx) * y_hat)

    # Specialize the loss function to compute gradients for both linearized and
    # full networks.
    grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))
    grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y)))

    # Train the network.
    print('Training.')
    print('Epoch\tLoss\tLinearized Loss')
    print('------------------------------------------')

    epoch = 0
    steps_per_epoch = 50000 // FLAGS.batch_size

    for i, (x, y) in enumerate(
            datasets.minibatch(x_train, y_train, FLAGS.batch_size,
                               FLAGS.train_epochs)):

        params = optimizers.get_params(state)
        state = opt_apply(i, grad_loss(params, x, y), state)

        params_lin = optimizers.get_params(state_lin)
        state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin)

        if i % steps_per_epoch == 0:
            print('{}\t{:.4f}\t{:.4f}'.format(epoch, loss(f(params, x), y),
                                              loss(f_lin(params_lin, x), y)))
            epoch += 1

    # Print out summary data comparing the linear / nonlinear model.
    x, y = x_train[:10000], y_train[:10000]
    util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss)
    util.print_summary('test', y_test, f(params, x_test),
                       f_lin(params_lin, x_test), loss)
Example #2
0
 def mapped_update(i, opt_state, batch):
     _, opt_update = optimizer(lr_fun)
     params = jax_opt.get_params(opt_state)
     grads = jax.grad(loss_fun)(params, batch, predict_fun)
     grads = jax.tree_util.tree_map(lambda g: jax.lax.psum(g, "batch"),
                                    grads)
     return opt_update(i, grads, opt_state)
 def private_update(rng, i, opt_state, batch):
   params = optimizers.get_params(opt_state)
   rng = random.fold_in(rng, i)  # get new key for new random numbers
   return opt_update(
       i,
       private_grad(params, batch, rng, FLAGS.l2_norm_clip,
                    FLAGS.noise_multiplier, FLAGS.batch_size), opt_state)
Example #4
0
 def update_fn(i, opt_state, rng, model_args=(), guide_args=()):
     model_init, guide_init = _seed(model, guide, rng)
     params = optimizers.get_params(opt_state)
     loss_val, grads = value_and_grad(loss)(params, model_init, guide_init, model_args, guide_args, kwargs)
     opt_state = optim_update(i, grads, opt_state)
     rng, = random.split(rng, 1)
     return loss_val, opt_state, rng
Example #5
0
def update_w_gc(i, opt_state, opt_update, x_bxt, f_bxt, f_mask_bxt, max_grad_norm, l2reg):
  """Update the parameters w/ gradient clipped, gradient descent updates.

  Arguments: 
    i: batch number
    opt_state: parameters plus optimizer state
    x_bxt: rnn inputs
    f_bxt: rnn targets
    f_mask_bxt: masks for when target is defined
    max_grad_norm: maximum norm value gradient is allowed to take
    l2reg: l2 regularization hyperparameter
  
  Returns: 
    opt_state tuple (as above) that includes updated parameters and optimzier 
      state.
  """
  params = optimizers.get_params(opt_state)
  unflatten = flatten_util.ravel_pytree(params)[1] # Requires shape

  def training_loss(params, x_bxt, f_bxt, l2reg):
    return loss(params, x_bxt, f_bxt, f_mask_bxt, l2reg)['total']
  
  grads = grad(training_loss)(params, x_bxt, f_bxt, l2reg)
  flat_grads = flatten(grads)
  grad_norm = np.sqrt(np.sum(flat_grads**2))
  normed_grads = np.where(grad_norm <= max_grad_norm, flat_grads,
                          flat_grads * (max_grad_norm / grad_norm))
  uf_grads = unflatten(normed_grads)
  return opt_update(i, uf_grads, opt_state)
Example #6
0
def main(args):
    # Generate some data.
    data = random.normal(PRNGKey(0), shape=(100, )) + 3.0

    # Construct an SVI object so we can do variational inference on our
    # model/guide pair.
    opt_init, opt_update = optimizers.adam(args.learning_rate)
    svi_init, svi_update, _ = svi(model, guide, elbo, opt_init, opt_update)
    rng = PRNGKey(0)
    opt_state = svi_init(rng, model_args=(data, ))

    # Training loop
    rng, = random.split(rng, 1)

    def body_fn(i, val):
        opt_state_, rng_ = val
        loss, opt_state_, rng_ = svi_update(i,
                                            opt_state_,
                                            rng_,
                                            model_args=(data, ))
        return opt_state_, rng_

    opt_state, _ = lax.fori_loop(0, args.num_steps, body_fn, (opt_state, rng))

    # Report the final values of the variational parameters
    # in the guide after training.
    params = optimizers.get_params(opt_state)
    for name, value in params.items():
        print("{} = {}".format(name, value))

    # For this simple (conjugate) model we know the exact posterior. In
    # particular we know that the variational distribution should be
    # centered near 3.0. So let's check this explicitly.
    assert np.abs(params["guide_loc"] - 3.0) < 0.1
Example #7
0
 def evaluate(opt_state, images):
   params = optimizers.get_params(opt_state)
   elbo_rng, data_rng, image_rng = random.split(test_rng, 3)
   binarized_test = random.bernoulli(data_rng, images)
   test_elbo = elbo(elbo_rng, params, binarized_test) / images.shape[0]
   sampled_images = image_sample(image_rng, params, nrow, ncol)
   return test_elbo, sampled_images
Example #8
0
def ppo_opt_step(i,
                 opt_state,
                 ppo_opt_update,
                 policy_net_apply,
                 old_policy_params,
                 value_net_apply,
                 value_net_params,
                 padded_observations,
                 padded_actions,
                 padded_rewards,
                 reward_mask,
                 gamma=0.99,
                 lambda_=0.95,
                 epsilon=0.1):
    """PPO optimizer step."""
    new_policy_params = optimizers.get_params(opt_state)
    g = grad(ppo_loss, argnums=1)(policy_net_apply,
                                  new_policy_params,
                                  old_policy_params,
                                  value_net_apply,
                                  value_net_params,
                                  padded_observations,
                                  padded_actions,
                                  padded_rewards,
                                  reward_mask,
                                  gamma=gamma,
                                  lambda_=lambda_,
                                  epsilon=epsilon)
    return ppo_opt_update(i, g, opt_state)
Example #9
0
def test_optim_multi_params():
    params = {'x': np.array([1., 1., 1.]), 'y': np.array([-1, -1., -1.])}
    opt_init, opt_update = optimizers.adam(step_size=1e-2)
    opt_state = opt_init(params)
    for i in range(1000):
        opt_state = step(i, opt_state, opt_update)
    for _, param in optimizers.get_params(opt_state).items():
        assert np.allclose(param, np.zeros(3))
Example #10
0
 def body_fun(i, loop_carry):
   (rng, opt_state, images) = loop_carry
   rng, elbo_rng, data_rng = random.split(rng, 3)
   batch = binarize_batch(data_rng, i, images)
   loss = lambda params: -elbo(elbo_rng, params, batch) / batch_size
   g = grad(loss)(optimizers.get_params(opt_state))
   loop_carry = rng, opt_update(i, g, opt_state), images
   return loop_carry
Example #11
0
 def mapped_update(i, opt_state, batch, rng):
     """This is a multi-device version of the update function above."""
     # We assume all tensors have the first dimension = num_devices.
     _, opt_update = optimizer(lr_fun)
     params = jax_opt.get_params(opt_state)
     grads = jax.grad(loss_fun)(params, batch, predict_fun, rng)
     grads = jax.tree_util.tree_map(lambda g: jax.lax.psum(g, "batch"),
                                    grads)
     return opt_update(i, grads, opt_state)
Example #12
0
def minimize(f, x, num_steps=10000, step_size=0.000001, mass=0.9):
    opt_init, opt_update = optimizers.momentum(step_size, mass)

    @jit
    def update(i, opt_state):
        x = optimizers.get_params(opt_state)
        return opt_update(i, grad(f)(x), opt_state)

    opt_state = opt_init(x)
    for i in xrange(num_steps):
        opt_state = update(i, opt_state)
    return optimizers.get_params(opt_state)
Example #13
0
def update_w_gc(i, opt_state, opt_update, x_bxt, f_bxt, max_grad_norm, l2reg):
  """Update the parameters w/ gradient clipped, gradient descent updates."""
  params = optimizers.get_params(opt_state)
  unflatten = flatten_util.ravel_pytree(params)[1] # Requires shape

  grads = grad(loss)(params, x_bxt, f_bxt, l2reg)
  flat_grads = flatten(grads)
  grad_norm = np.sqrt(np.sum(flat_grads**2))
  normed_grads = np.where(grad_norm <= max_grad_norm, flat_grads,
                          flat_grads * (max_grad_norm / grad_norm))
  uf_grads = unflatten(normed_grads)
  return opt_update(i, uf_grads, opt_state)
Example #14
0
    def update_w_gc(i, opt_state, lfads_hps, lfads_opt_hps, key, x_bxt,
                    kl_warmup):
        max_grad_norm = lfads_opt_hps['max_grad_norm']
        keep_rate = lfads_opt_hps['keep_rate']

        params = optimizers.get_params(opt_state)

        grads = grad(lfads.lfads_training_loss)(params, lfads_hps, key, x_bxt,
                                                kl_warmup, keep_rate)
        flat_grads = flatten_lfads(grads)
        grad_norm = np.sqrt(np.sum(flat_grads**2))
        normed_grads = np.where(grad_norm <= max_grad_norm, flat_grads,
                                flat_grads * (max_grad_norm / grad_norm))
        uf_grads = unflatten_lfads(normed_grads)
        return opt_update(i, uf_grads, opt_state)
Example #15
0
 def reconstruct_img(epoch):
     img = test_fetch(0, test_idx)[0][0]
     plt.imsave(os.path.join(RESULTS_DIR,
                             'original_epoch={}.png'.format(epoch)),
                img,
                cmap='gray')
     _, test_sample = binarize(rng, img)
     params = optimizers.get_params(opt_state)
     z_mean, z_var = encode(params['encoder'], test_sample.reshape([1, -1]))
     z = dist.norm(z_mean, z_var).rvs(random_state=rng)
     img_loc = decode(params['decoder'], z).reshape([28, 28])
     plt.imsave(os.path.join(RESULTS_DIR,
                             'recons_epoch={}.png'.format(epoch)),
                img_loc,
                cmap='gray')
def main(unused_argv):
    # Build data pipelines.
    print('Loading data.')
    x_train, y_train, x_test, y_test = \
        datasets.mnist(FLAGS.train_size, FLAGS.test_size)

    # Build the network
    init_fn, f = stax.serial(layers.Dense(4096), stax.Tanh, layers.Dense(10))

    key = random.PRNGKey(0)
    _, params = init_fn(key, (-1, 784))

    # Create and initialize an optimizer.
    opt_init, opt_apply = optimizers.sgd(FLAGS.learning_rate)
    state = opt_init(params)

    # Create an mse loss function and a gradient function.
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
    grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))

    # Create an MSE predictor to solve the NTK equation in function space.
    theta = tangents.ntk(f, batch_size=32)
    g_dd = theta(params, x_train)
    g_td = theta(params, x_test, x_train)
    predictor = tangents.analytic_mse_predictor(g_dd, y_train, g_td)

    # Get initial values of the network in function space.
    fx_train = f(params, x_train)
    fx_test = f(params, x_test)

    # Train the network.
    train_steps = int(FLAGS.train_time // FLAGS.learning_rate)
    print('Training for {} steps'.format(train_steps))

    for i in range(train_steps):
        params = optimizers.get_params(state)
        state = opt_apply(i, grad_loss(params, x_train, y_train), state)

    # Get predictions from analytic computation.
    print('Computing analytic prediction.')
    fx_train, fx_test = predictor(fx_train, fx_test, FLAGS.train_time)

    # Print out summary data comparing the linear / nonlinear model.
    util.print_summary('train', y_train, f(params, x_train), fx_train, loss)
    util.print_summary('test', y_test, f(params, x_test), fx_test, loss)
Example #17
0
def value_opt_step(i,
                   opt_state,
                   opt_update,
                   value_net_apply,
                   padded_observations,
                   padded_rewards,
                   reward_mask,
                   gamma=0.99):
    """Value optimizer step."""
    value_params = optimizers.get_params(opt_state)
    # Note this partial application here and argnums above in ppo_opt_step.
    g = grad(functools.partial(value_loss,
                               value_net_apply))(value_params,
                                                 padded_observations,
                                                 padded_rewards,
                                                 reward_mask,
                                                 gamma=gamma)
    return opt_update(i, g, opt_state)
Example #18
0
    def learn(opt_step, opt_state, params_Q_eval, params_Q_next):
        mini_batch = sample(memory, BATCH_SIZE)

        if opt_step % TAU == 0:
            params_Q_next = params_Q_eval.copy()

        input_states = np.stack([transition[0] for transition in mini_batch])
        next_states = np.stack([transition[3] for transition in mini_batch])

        predicted_Q = pred_Q(params_Q_eval, input_states)
        predicted_Q_next = pred_Q(params_Q_next, next_states)

        max_action = np.argmax(predicted_Q_next, axis=1)
        rewards = np.array([transition[2] for transition in mini_batch])

        Q_target = onp.array(predicted_Q)
        Q_target[:, max_action] = rewards + GAMMA * np.max(predicted_Q_next, axis=1)

        opt_state = step(opt_step, opt_state, (input_states, Q_target))
        params_Q_eval = optimizers.get_params(opt_state)

        return opt_state, params_Q_eval, params_Q_next
Example #19
0
def train_fn(data_dir=None, output_dir=None,
             model=gin.REQUIRED,
             dataset=gin.REQUIRED,
             train_steps=1000, eval_steps=10, eval_frequency=100):
  """Train the given model on the given dataset.

  Args:
    data_dir: Directory where the data is located.
    output_dir: Directory where to put the logs and checkpoints.
    model: The model to train (a function).
    dataset: The name of the dataset to train on.
    train_steps: for how many steps to train.
    eval_steps: for how many steps to do evaluation.
    eval_frequency: how often (every this many steps) to run evaluation.
  """
  (train_batches, eval_batches,
   input_name, input_shape) = input_pipeline.train_and_eval_batches(
       dataset, data_dir)
  train_stream = dataset_to_stream(train_batches, input_name)

  # Training loop.
  opt_init, opt_update = optimizer()
  model_init, model_predict = model()

  @jax.jit
  def update(i, opt_state, batch):
    params = optimizers.get_params(opt_state)
    return opt_update(i, jax.grad(loss)(
        params, batch, model_predict), opt_state)

  _, init_params = model_init([-1] + input_shape)
  step, train_sw, eval_sw = 0, None, None
  if output_dir is not None:
    _make_directory(output_dir)
    # Load parameters.
    loaded_params, loaded_step = load_params_and_step(output_dir)
    if loaded_params is not None:
      init_params = loaded_params
    if loaded_step is not None:
      step = loaded_step

    # Create summary writers.
    eval_sw = _make_summary_writer(os.path.join(output_dir, "eval_log"))
    train_sw = _make_summary_writer(os.path.join(output_dir, "train_log"))

  log("Starting training.")
  opt_state = opt_init(init_params)
  gin_config_saved = False
  while step < train_steps:
    # Training.
    start_time = time.time()
    for _ in range(eval_frequency):
      opt_state = update(step, opt_state, next(train_stream))
      step += 1
    epoch_time = time.time() - start_time
    log("Step {}, last {} steps in {:0.2f} sec".format(
        step, eval_frequency, epoch_time))

    # Save the model.
    params = optimizers.get_params(opt_state)
    save_params_and_step(params, step, output_dir)

    # Save the config if not saved yet.
    # Gin file only includes used parameters, so we save it at this point.
    if output_dir and not gin_config_saved:
      gin_config_saved = True
      config_path = os.path.join(output_dir, "config.gin")
      with gfile.GFile(config_path, "w") as f:
        f.write(gin.operative_config_str())

    # Evaluation.
    eval_stream = dataset_to_stream(eval_batches, input_name)
    eval_train_stream = dataset_to_stream(train_batches, input_name)
    train_acc, eval_acc, train_loss, eval_loss = 0.0, 0.0, 0.0, 0.0
    for _ in range(eval_steps):
      train_acc += accuracy(params, next(eval_train_stream), model_predict)
      eval_acc += accuracy(params, next(eval_stream), model_predict)
      train_loss += loss(params, next(eval_train_stream), model_predict)
      eval_loss += loss(params, next(eval_stream), model_predict)
    train_acc /= eval_steps
    eval_acc /= eval_steps
    train_loss /= eval_steps
    eval_loss /= eval_steps
    log("Train accuracy {:0.4f} loss {:0.8f}".format(train_acc, train_loss))
    if train_sw:
      train_sw.scalar("steps/s", epoch_time / eval_frequency, step=step)
      train_sw.scalar("accuracy", train_acc, step=step)
      train_sw.scalar("loss", train_loss, step=step)
    log("Eval  accuracy {:0.4f} loss {:0.8f}".format(eval_acc, eval_loss))
    if eval_sw:
      eval_sw.scalar("accuracy", eval_acc, step=step)
      train_sw.scalar("loss", eval_loss, step=step)
Example #20
0
 def body_fun(i, opt_state):
   elbo_rng, data_rng = random.split(random.fold_in(rng, i))
   batch = binarize_batch(data_rng, i, train_images)
   loss = lambda params: -elbo(elbo_rng, params, batch) / batch_size
   g = grad(loss)(optimizers.get_params(opt_state))
   return opt_update(i, g, opt_state)
Example #21
0
def train(output_dir,
          model=gin.REQUIRED,
          inputs=gin.REQUIRED,
          optimizer=trax_opt.adam,
          lr_schedule=lr.MultifactorSchedule,
          train_steps=1000,
          eval_steps=10,
          eval_frequency=100,
          run_debug_step=False):
    """Train the model on the inputs.

  Args:
    output_dir: Directory where to put the logs and checkpoints.
    model: The model to train as a callable returning 2 callables, an init_fun
      and apply_fun.
    inputs: callable returning trax.inputs.Inputs.
    optimizer: The optimizer as a callable taking a learning_rate callable and
      returning 2 callables, opt_init and opt_update.
    lr_schedule: A learning rate schedule as a function that takes history and
      returns a function from step to learning rate (a float).
    train_steps: int, total number of training steps.
    eval_steps: int, num of steps per evaluation. If None or 0, eval disabled.
    eval_frequency: int, how often to run evaluation (every eval_frequency
      steps). If None or 0, eval disabled.
    run_debug_step: bool, if True, will run the model and loss without @jit for
      one step.

  Returns:
    trax.State
  """
    rng = random.PRNGKey(0)
    gfile.makedirs(output_dir)
    # Create summary writers and history.
    train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train"))
    eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval"))

    inputs = inputs()

    # Setup optimizer and model
    state = restore_state(output_dir)
    history = state.history
    lr_fun = lr_schedule(history)
    opt_init, _ = optimizer(lr_fun)
    model_init, model_predict_original = model()

    # We need a model_predict that fills in the random generator if needed.
    def model_predict(x, y, **kwargs):
        """Same as model_predict_original but fill in rng if it isn't passed."""
        if "rng" in kwargs:
            return model_predict_original(x, y, **kwargs)
        return model_predict_original(x, y, rng=rng, **kwargs)

    # Setup state
    step = state.step or 0
    params_initializer = lambda: model_init([-1] + list(inputs.input_shape))[1]
    params = state.params or params_initializer()
    opt_state = opt_init(params)

    # jit model_predict and update so they're fast
    jit_model_predict = jax.jit(model_predict)  # for evaluation
    jit_update_fun = _jit_update_fun(model_predict, loss, optimizer, lr_fun)

    print()
    train_stream = inputs.train_stream()
    epoch_steps = itertools.chain(
        [
            1,  # first epoch only 1 step
            eval_frequency - 1
        ],
        itertools.repeat(eval_frequency))
    step_log(step, "Starting training")

    # Non-compiled debug step helps find problems in models easier.
    if run_debug_step:
        debug_loss = loss(params, next(train_stream), model_predict)
        step_log(step, "Debug step loss %.8f" % debug_loss)

    for epoch, epoch_steps in epochs(train_steps, epoch_steps):
        # Log separator
        print()

        # Timer
        start_time = time.time()

        for _ in range(epoch_steps):
            # Train
            opt_state = jit_update_fun(step, opt_state, next(train_stream))
            step += 1

            # LR log
            if step == 1 or step % 10 == 0:
                train_sw.scalar("training/learning rate",
                                lr_fun(step),
                                step=step)

        # Timer
        epoch_time = time.time() - start_time
        step_log(
            step,
            "Ran %d train steps in %0.2f secs" % (epoch_steps, epoch_time))
        if epoch_steps > 1:
            train_sw.scalar("training/steps per second",
                            epoch_steps / epoch_time,
                            step=step)

        # Evaluate
        params = jax_opt.get_params(opt_state)
        evaluate_train_and_eval(step=step,
                                inputs=inputs,
                                predict_fun=functools.partial(
                                    jit_model_predict, params),
                                eval_steps=eval_steps,
                                train_sw=train_sw,
                                eval_sw=eval_sw,
                                history=history)

        # Save state
        save_state(State(params=params, step=step, history=history),
                   output_dir)

        # Save Gin config
        # Gin only tracks the used parameters, so we save it after the first epoch.
        if epoch == 1:
            save_gin(output_dir, train_sw)

        # Update learning rate with new history
        old_lr_fun = lr_fun
        lr_fun = lr_schedule(history)
        if lr_fun != old_lr_fun:  # For performance, only jit if there is a change.
            jit_update_fun = _jit_update_fun(model_predict, loss, optimizer,
                                             lr_fun)

        # Flush summary writers
        train_sw.writer.flush()
        eval_sw.writer.flush()

    step_log(step, "Training done")
    return State(params=params, step=step, history=history)
Example #22
0
 def update(i, opt_state, batch):
     _, opt_update = optimizer(lr_fun)
     params = jax_opt.get_params(opt_state)
     return opt_update(i,
                       jax.grad(loss_fun)(params, batch, predict_fun),
                       opt_state)
Example #23
0
def train_fn(output_dir,
             data_dir,
             model=gin.REQUIRED,
             dataset=gin.REQUIRED,
             train_steps=1000,
             eval_steps=10,
             eval_frequency=100):
    """Train the given model on the given dataset.

  Args:
    output_dir: Directory where to put the logs and checkpoints.
    data_dir: Directory where the data is located.
    model: The model to train as a callable returning 2 callables, an init_fun
      and apply_fun.
    dataset: The name of the TFDS dataset to train on. To train on a T2T
      dataset, prefix the name with "t2t_".
    train_steps: int, total number of training steps.
    eval_steps: int, num of steps per evaluation.
    eval_frequency: int, how often to run evaluation (every eval_frequency
      steps).
  """
    gfile.makedirs(output_dir)

    # Make Inputs
    inputs = inputs_lib.make_inputs(dataset, data_dir)

    # Setup optimizer and model
    opt_init, opt_update = optimizer()
    model_init, model_predict = model()

    # Setup state
    state = restore_state(output_dir)
    step = state.step or 0
    params_initializer = lambda: model_init([-1] + inputs.input_shape)[1]
    opt_state = opt_init(state.params or params_initializer())

    # Create summary writers.
    train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train"))
    eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval"))

    # jit model_predict and update so they're fast
    jit_predict = jax.jit(model_predict)  # for evaluation

    @jax.jit
    def update(i, opt_state, batch):
        params = optimizers.get_params(opt_state)
        return opt_update(i,
                          jax.grad(loss)(params, batch, model_predict),
                          opt_state)

    print()
    step_log(step, "starting training")
    train_gen = inputs.train_fn()
    is_first_step = True
    epoch_steps = 1  # First evaluation after the first training step.
    while step < train_steps:
        print()

        # Train
        start_time = time.time()
        for _ in range(epoch_steps):
            opt_state = update(step, opt_state, next(train_gen))
            if step % 10 == 0:  # Log learning rate curve each 10 steps.
                train_sw.scalar("training/learning rate",
                                learning_rate(step),
                                step=step)
            step += 1
        epoch_time = time.time() - start_time
        step_log(
            step,
            "ran %d train steps in %0.2f secs" % (epoch_steps, epoch_time))

        # Save state
        params = optimizers.get_params(opt_state)
        save_state(State(params=params, step=step),
                   output_dir,
                   save_gin=is_first_step)

        # Evaluate
        step_log(step, "starting evaluation")
        train_metrics, eval_metrics = evaluate(
            inputs, functools.partial(jit_predict, params), eval_steps)
        log_metrics(train_metrics, train_sw, "train", step)
        log_metrics(eval_metrics, eval_sw, "eval ", step)

        # Log non-metric reports and flush.
        if not is_first_step:
            train_sw.scalar("training/steps per second",
                            epoch_steps / epoch_time,
                            step=step)
        train_sw.writer.flush()
        eval_sw.writer.flush()

        # After the first step, train for eval_frequency steps before evaluating
        epoch_steps = (eval_frequency - 1) if is_first_step else eval_frequency
        is_first_step = False

    print()
    step_log(step, "finished training")
Example #24
0
def optimize_lfads(init_params, lfads_hps, lfads_opt_hps, train_data,
                   eval_data):
    """Optimize the LFADS model and print batch based optimization data.

  Arguments:
    init_params: a dict of parameters to be trained
    lfads_hps: dict of lfads model HPs
    lfads_opt_hps: dict of optimization HPs
    train_data: nexamples x time x ndims np array of data for training
    eval_data: nexamples x time x ndims np array of data for evaluation

  Returns:
    a dictionary of trained parameters"""

    batch_size = lfads_hps['batch_size']
    num_batches = lfads_opt_hps['num_batches']
    print_every = lfads_opt_hps['print_every']

    # Build some functions used in optimization.
    kl_warmup_fun = get_kl_warmup_fun(lfads_opt_hps)
    decay_fun = optimizers.exponential_decay(lfads_opt_hps['step_size'],
                                             lfads_opt_hps['decay_steps'],
                                             lfads_opt_hps['decay_factor'])
    opt_init, opt_update = optimizers.adam(step_size=decay_fun,
                                           b1=lfads_opt_hps['adam_b1'],
                                           b2=lfads_opt_hps['adam_b2'],
                                           eps=lfads_opt_hps['adam_eps'])
    update_w_gc = get_update_w_gc_fun(init_params, opt_update)
    update_w_gc_jit = jit(update_w_gc, static_argnums=(2, 3))

    # Begin optimziation loop.
    all_tlosses = []
    all_elosses = []
    start_time = time.time()
    opt_state = opt_init(init_params)
    for bidx in range(num_batches):
        kl_warmup = kl_warmup_fun(bidx)
        didxs = onp.random.randint(0, train_data.shape[0], batch_size)
        x_bxt = train_data[didxs].astype(onp.float32)
        key = random.PRNGKey(onp.random.randint(0, utils.MAX_SEED_INT))
        opt_state = update_w_gc_jit(bidx, opt_state, lfads_hps, lfads_opt_hps,
                                    key, x_bxt, kl_warmup)

        if bidx % print_every == 0:
            params = optimizers.get_params(opt_state)

            # Training loss
            didxs = onp.random.randint(0, train_data.shape[0], batch_size)
            x_bxt = train_data[didxs].astype(onp.float32)
            key = random.PRNGKey(onp.random.randint(0, utils.MAX_SEED_INT))
            tlosses = lfads.lfads_losses_jit(params, lfads_hps, key, x_bxt,
                                             kl_warmup, 1.0)

            # Evaluation loss
            key = random.PRNGKey(onp.random.randint(0, utils.MAX_SEED_INT))
            didxs = onp.random.randint(0, eval_data.shape[0], batch_size)
            ex_bxt = eval_data[didxs].astype(onp.float32)
            # Commented out lfads_eval_losses_jit cuz freezing.
            elosses = lfads.lfads_losses_jit(params, lfads_hps, key, ex_bxt,
                                             kl_warmup, 1.0)
            # Saving, printing.
            all_tlosses.append(tlosses)
            all_elosses.append(elosses)
            batch_time = time.time() - start_time
            s = "Batch {} in {:0.2f} sec, Step size: {:0.5f}, \
              Training loss {:0.0f}, Eval loss {:0.0f}"

            print(
                s.format(bidx, batch_time, decay_fun(bidx), tlosses['total'],
                         elosses['total']))
            start_time = time.time()

            tlosses_thru_training = utils.merge_losses_dicts(all_tlosses)
            elosses_thru_training = utils.merge_losses_dicts(all_elosses)
            optimizer_details = {
                'tlosses': tlosses_thru_training,
                'elosses': elosses_thru_training
            }
    return optimizers.get_params(opt_state), optimizer_details
Example #25
0
def training_loop(env=None,
                  env_name="CartPole-v0",
                  epochs=EPOCHS,
                  batch_size=BATCH_TRAJECTORIES,
                  num_optimizer_steps=NUM_OPTIMIZER_STEPS,
                  print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP,
                  random_seed=None):
    """Runs the training loop for PPO, with fixed policy and value nets."""
    onp.random.seed(random_seed)

    value_losses = []
    ppo_objective = []
    average_rewards = []

    env = env if env is not None else gym.make(env_name)

    batch_observations_shape = (-1, ) + env.observation_space.shape

    assert isinstance(env.action_space, gym.spaces.Discrete)
    num_actions = env.action_space.n

    rng_key = jax_random.PRNGKey(0)
    ((policy_net_params, policy_net_apply),
     (value_net_params, value_net_apply)) = initialize_policy_and_value_nets(
         rng_key, num_actions, batch_observations_shape)

    (ppo_opt_state,
     ppo_opt_update), (value_opt_state,
                       value_opt_update) = initialize_optimizers(
                           policy_net_params, value_net_params)

    for i in range(epochs):
        t = time.time()
        t0 = t
        trajs = collect_trajectories(
            env,
            policy_net_apply,
            policy_net_params,
            num_trajectories=batch_size,
            policy=POLICY,
            epsilon=(10.0 / (i + 10.0)))  # this is a different epsilon.

        avg_reward = float(sum(np.sum(traj[2]) for traj in trajs)) / len(trajs)
        average_rewards.append(avg_reward)

        logging.debug("Average sum rewards [%0.2f]", avg_reward)
        logging.debug("Collecting trajectories took %0.2f msec.", get_time(t))
        logging.debug("Average Trajectory size [%0.2f]",
                      float(sum(len(traj[0]) for traj in trajs)) / len(trajs))

        t = time.time()
        (_, reward_mask, padded_observations, padded_actions,
         padded_rewards) = pad_trajectories(trajs, boundary=20)

        logging.debug("Padding trajectories took %0.2f msec.", get_time(t))
        logging.debug("Padded Actions' shape [%s]", str(padded_actions.shape))

        # Linear annealing from 0.1 to 0.0
        epsilon = 0.1 if epochs == 1 else 0.1 * (1.0 - (i / (epochs - 1)))

        t = time.time()
        cur_value_loss = value_loss(value_net_apply,
                                    value_net_params,
                                    padded_observations,
                                    padded_rewards,
                                    reward_mask,
                                    gamma=GAMMA)

        logging.debug("Calculating value loss took %0.2f msec.", get_time(t))
        value_losses.append(cur_value_loss)

        t = time.time()
        cur_ppo_loss = ppo_loss(policy_net_apply,
                                policy_net_params,
                                policy_net_params,
                                value_net_apply,
                                value_net_params,
                                padded_observations,
                                padded_actions,
                                padded_rewards,
                                reward_mask,
                                gamma=GAMMA,
                                lambda_=LAMBDA,
                                epsilon=epsilon)
        # ppo_loss = 11.00110011
        logging.debug("Calculating PPO loss took %0.2f msec.", get_time(t))
        ppo_objective.append(-cur_ppo_loss)

        # Run optimizers.
        logging.debug("PPO Optimization")
        t1 = time.time()

        for j in range(num_optimizer_steps):
            t = time.time()
            # Update the optimizer state.
            ppo_opt_state = ppo_opt_step(j,
                                         ppo_opt_state,
                                         ppo_opt_update,
                                         policy_net_apply,
                                         policy_net_params,
                                         value_net_apply,
                                         value_net_params,
                                         padded_observations,
                                         padded_actions,
                                         padded_rewards,
                                         reward_mask,
                                         gamma=GAMMA,
                                         lambda_=LAMBDA,
                                         epsilon=epsilon)
            t2 = time.time()
            # Get the new params.
            new_policy_net_params = optimizers.get_params(ppo_opt_state)
            if ((j + 1) % print_every_optimizer_steps
                    == 0) or (j == num_optimizer_steps - 1):
                new_ppo_loss = ppo_loss(policy_net_apply,
                                        new_policy_net_params,
                                        policy_net_params,
                                        value_net_apply,
                                        value_net_params,
                                        padded_observations,
                                        padded_actions,
                                        padded_rewards,
                                        reward_mask,
                                        gamma=GAMMA,
                                        lambda_=LAMBDA,
                                        epsilon=epsilon)
                logging.debug("One PPO grad desc took: %0.2f msec",
                              get_time(t, t2))
                logging.debug("PPO loss [%10.2f] -> [%10.2f]", cur_ppo_loss,
                              new_ppo_loss)
            # Update the params.
            policy_net_params = new_policy_net_params

        logging.debug("Total PPO loss reduction [%0.2f]%%",
                      (100 *
                       (cur_ppo_loss - new_ppo_loss) / np.abs(cur_ppo_loss)))

        logging.debug("Value Optimization")

        for j in range(num_optimizer_steps):
            t = time.time()
            value_opt_state = value_opt_step(j,
                                             value_opt_state,
                                             value_opt_update,
                                             value_net_apply,
                                             padded_observations,
                                             padded_rewards,
                                             reward_mask,
                                             gamma=GAMMA)
            t2 = time.time()
            value_net_params = optimizers.get_params(value_opt_state)
            if ((j + 1) % print_every_optimizer_steps
                    == 0) or (j == num_optimizer_steps - 1):
                new_value_loss = value_loss(value_net_apply,
                                            value_net_params,
                                            padded_observations,
                                            padded_rewards,
                                            reward_mask,
                                            gamma=GAMMA)
                logging.debug("One value grad desc took: %0.2f msec",
                              get_time(t, t2))
                logging.debug("Value loss [%10.2f] -> [%10.2f]",
                              cur_value_loss, new_value_loss)
        logging.debug(
            "Total value loss reduction [%0.2f]%%",
            (100 * (cur_value_loss - new_value_loss) / np.abs(cur_value_loss)))

        logging.debug("Grad desc took %0.2f msec", get_time(t1))

        # Set the optimized params to new params.
        policy_net_params = optimizers.get_params(ppo_opt_state)
        value_net_params = optimizers.get_params(value_opt_state)

        logging.info(
            "Epoch [% 6d], average reward [%10.2f], ppo loss [%10.2f], "
            "value loss [%10.2f], took [%10.2f msec]", i, avg_reward,
            new_ppo_loss, new_value_loss, get_time(t0))

    logging.debug("value_losses: %s", np.stack(value_losses))
    logging.debug("ppo_objective: %s", np.stack(ppo_objective))
    logging.debug("average_rewards: %s", average_rewards)

    return ((policy_net_params, value_net_params), average_rewards,
            np.stack(value_losses), np.stack(ppo_objective))
Example #26
0
      perm = rng.permutation(num_train)
      for i in range(num_batches):
        batch_idx = perm[i * batch_size:(i + 1) * batch_size]
        yield train_images[batch_idx], train_labels[batch_idx]
  batches = data_stream()

  opt_init, opt_update = optimizers.momentum(step_size, mass=momentum_mass)

  @jit
  def update(i, opt_state, batch):
    params = optimizers.get_params(opt_state)
    return opt_update(i, grad(loss)(params, batch), opt_state)

  _, init_params = init_random_params(rng, (-1, 28 * 28))
  opt_state = opt_init(init_params)
  itercount = itertools.count()

  print("\nStarting training...")
  for epoch in range(num_epochs):
    start_time = time.time()
    for _ in range(num_batches):
      opt_state = update(next(itercount), opt_state, next(batches))
    epoch_time = time.time() - start_time

    params = optimizers.get_params(opt_state)
    train_acc = accuracy(params, (train_images, train_labels))
    test_acc = accuracy(params, (test_images, test_labels))
    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
    print("Training set accuracy {}".format(train_acc))
    print("Test set accuracy {}".format(test_acc))
Example #27
0
 def update(i, opt_state, batch):
   params = optimizers.get_params(opt_state)
   return opt_update(i, jax.grad(loss)(
       params, batch, model_predict), opt_state)
Example #28
0
    def testNTKGDPrediction(self, shape, out_logits):

        key = random.PRNGKey(1)

        key, split = random.split(key)
        data_train = random.normal(split, shape)

        key, split = random.split(key)
        label_ids = random.randint(split, (shape[0], ), 0, out_logits)
        data_labels = np.eye(out_logits)[label_ids]

        key, split = random.split(key)
        data_test = random.normal(split, shape)

        key, w_split, b_split = random.split(key, 3)
        params = (random.normal(w_split, (shape[-1], out_logits)),
                  random.normal(b_split, (out_logits, )))

        def f(params, x):
            w, b = params
            return np.dot(x, w) / shape[-1] + b

        loss = lambda y, y_hat: 0.5 * np.mean((y - y_hat)**2)
        grad_loss = grad(lambda params, x: loss(f(params, x), data_labels))

        theta = tangents.ntk(f)
        g_dd = theta(params, data_train)
        g_td = theta(params, data_test, data_train)

        predictor = tangents.gradient_descent_predictor(
            g_dd, data_labels, loss, g_td)

        step_size = 1.0
        train_time = 100.0
        steps = int(train_time / step_size)

        opt_init, opt_update = opt.sgd(step_size)
        opt_state = opt_init(params)

        fx_initial_train = f(params, data_train)
        fx_initial_test = f(params, data_test)

        fx_pred_train, fx_pred_test = predictor(fx_initial_train,
                                                fx_initial_test, 0.0)

        # NOTE(schsam): I think at the moment stax always generates 32-bit results
        # since the weights are explicitly cast to float32.
        self.assertAllClose(fx_initial_train, fx_pred_train, False)
        self.assertAllClose(fx_initial_test, fx_pred_test, False)

        for i in range(steps):
            params = opt.get_params(opt_state)
            opt_state = opt_update(i, grad_loss(params, data_train), opt_state)

        params = opt.get_params(opt_state)
        fx_train = f(params, data_train)
        fx_test = f(params, data_test)

        fx_pred_train, fx_pred_test = predictor(fx_initial_train,
                                                fx_initial_test, train_time)

        # Put errors in units of RMS distance of the function values during
        # optimization.
        fx_disp_train = np.sqrt(np.mean((fx_train - fx_initial_train)**2))
        fx_disp_test = np.sqrt(np.mean((fx_test - fx_initial_test)**2))

        fx_error_train = (fx_train - fx_pred_train) / fx_disp_train
        fx_error_test = (fx_test - fx_pred_test) / fx_disp_test

        self.assertAllClose(fx_error_train, np.zeros_like(fx_error_train),
                            False, 0.1, 0.1)
        self.assertAllClose(fx_error_test, np.zeros_like(fx_error_test), False,
                            0.1, 0.1)
Example #29
0
 def update(i, opt_state, batch):
   params = optimizers.get_params(opt_state)
   return opt_update(i, grad(loss)(params, batch), opt_state)
Example #30
0
    def testNTKMomentumPrediction(self, shape, out_logits):

        key = random.PRNGKey(1)

        key, split = random.split(key)
        data_train = random.normal(split, shape)

        key, split = random.split(key)
        label_ids = random.randint(split, (shape[0], ), 0, out_logits)
        data_labels = np.eye(out_logits)[label_ids]

        key, split = random.split(key)
        data_test = random.normal(split, shape)

        key, w_split, b_split = random.split(key, 3)
        params = (random.normal(w_split, (shape[-1], out_logits)),
                  random.normal(b_split, (out_logits, )))

        def f(params, x):
            w, b = params
            return np.dot(x, w) / shape[-1] + b

        loss = lambda y, y_hat: 0.5 * np.mean((y - y_hat)**2)
        grad_loss = grad(lambda params, x: loss(f(params, x), data_labels))

        theta = tangents.ntk(f)
        g_dd = theta(params, data_train)
        g_td = theta(params, data_test, data_train)

        step_size = 1.0
        train_time = 100.0
        steps = int(train_time / np.sqrt(step_size))

        init_fn, predict_fn, get_fn = tangents.momentum_predictor(
            g_dd, data_labels, loss, step_size, g_td)

        opt_init, opt_update = momentum(step_size)
        opt_state = opt_init(params)

        fx_initial_train = f(params, data_train)
        fx_initial_test = f(params, data_test)

        lin_state = init_fn(fx_initial_train, fx_initial_test)

        for i in range(steps):
            params = opt.get_params(opt_state)
            opt_state = opt_update(i, grad_loss(params, data_train), opt_state)

        params = opt.get_params(opt_state)
        fx_train = f(params, data_train)
        fx_test = f(params, data_test)

        lin_state = predict_fn(lin_state, train_time)

        fx_pred_train, fx_pred_test = get_fn(lin_state)

        # Put errors in units of RMS distance of the function values during
        # optimization.
        fx_disp_train = np.sqrt(np.mean((fx_train - fx_initial_train)**2))
        fx_disp_test = np.sqrt(np.mean((fx_test - fx_initial_test)**2))

        fx_error_train = (fx_train - fx_pred_train) / fx_disp_train
        fx_error_test = (fx_test - fx_pred_test) / fx_disp_test

        self.assertAllClose(fx_error_train, np.zeros_like(fx_error_train),
                            False, 0.1, 0.1)
        self.assertAllClose(fx_error_test, np.zeros_like(fx_error_test), False,
                            0.1, 0.1)