Example #1
0
def get_summary_writers(workdir):
    current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
    log_dir = workdir + '/log/' + current_time
    train_log_dir = log_dir + '/train'
    eval_log_dir = log_dir + '/eval'
    train_summary_writer = tensorboard.SummaryWriter(train_log_dir)
    eval_summary_writer = tensorboard.SummaryWriter(eval_log_dir)
    return train_summary_writer, eval_summary_writer
Example #2
0
def main(_):
    master = jax.host_id() == 0
    # make sure TF does not allocate gpu memory
    tf.config.experimental.set_visible_devices([], 'GPU')

    # The pool is used to perform misc operations such as logging in async way.
    pool = multiprocessing.pool.ThreadPool()

    # load configs from a config json string
    hparams = FLAGS.config
    logging.info('=========== Hyperparameters ============')
    logging.info(hparams)

    if hparams.get('debug'):
        logging.warning('DEBUG MODE IS ENABLED!')

    # set tensorflow random seed
    tf.random.set_seed(jax.host_id() + hparams.rng_seed)
    experiment_dir = FLAGS.experiment_dir
    logging.info('Experiment directory: %s', experiment_dir)
    summary_writer = None

    if master and hparams.write_summary:
        tensorboard_dir = os.path.join(experiment_dir, 'tb_summaries')
        gfile.makedirs(tensorboard_dir)
        summary_writer = tensorboard.SummaryWriter(tensorboard_dir)

    run(hparams, experiment_dir, summary_writer)

    pool.close()
    pool.join()
Example #3
0
def train(train_ds, test_ds):
    """Train MNIST to completion."""
    rng = random.PRNGKey(0)

    batch_size = FLAGS.batch_size
    num_epochs = FLAGS.num_epochs
    model_dir = FLAGS.model_dir

    summary_writer = tensorboard.SummaryWriter(model_dir)

    model = create_model(rng)
    optimizer = create_optimizer(model, FLAGS.learning_rate, FLAGS.momentum)

    input_rng = onp.random.RandomState(0)

    for epoch in range(1, num_epochs + 1):
        optimizer, train_metrics = train_epoch(optimizer, train_ds, batch_size,
                                               epoch, input_rng)
        loss, accuracy = eval_model(optimizer.target, test_ds)
        logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch, loss,
                     accuracy * 100)
        summary_writer.scalar('train_loss', train_metrics['loss'], epoch)
        summary_writer.scalar('train_accuracy', train_metrics['accuracy'],
                              epoch)
        summary_writer.scalar('eval_loss', loss, epoch)
        summary_writer.scalar('eval_accuracy', accuracy, epoch)
    return optimizer
Example #4
0
def eval_once(run_configuration, checkpoint_path, optimizer=None):
    """Evaluates a single checkpoint on a single epoch of data."""
    config = run_configuration.config
    run_dir = run_configuration.run_dir
    adapter = run_configuration.adapter
    optimizer = optimizer or adapter.create_optimizer(run_configuration)
    dataset = run_configuration.dataset_info.dataset
    info = run_configuration.dataset_info.info

    eval_name = config.eval_name or 'eval'
    log_dir = os.path.join(run_dir, eval_name)

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(log_dir)

    # Restore checkpoint
    optimizer = checkpoint_utils.restore_checkpoint(checkpoint_path, optimizer)
    step = int(optimizer.state.step)

    # Replicate optimizer.
    optimizer = flax.jax_utils.replicate(optimizer)
    eval_step = adapter.make_eval_step()
    eval_step_parallel = jax.pmap(eval_step, axis_name='batch')

    # Perform evaluation
    tick = time.time()
    metrics_all = []

    example = None
    dataset_iter_raw = iter(dataset)
    dataset_iter = adapter.preprocess(dataset_iter_raw)
    for unused_eval_step, example in zip(range(config.eval_steps),
                                         dataset_iter):
        train_inputs = adapter.get_train_inputs(example)
        metrics, logits, state = eval_step_parallel(optimizer.target,
                                                    train_inputs)
        metrics_all.append(metrics)

    # Write results.
    metrics_all = common_utils.get_metrics(metrics_all)
    metrics_sums = jax.tree_map(jnp.sum, metrics_all)
    denominator = metrics_sums.pop('denominator')
    summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
    summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4)
    logging.info('eval @ train step: %d, loss: %.4f', step, summary['loss'])
    if jax.host_id() == 0:
        tock = time.time()
        steps_per_sec = len(metrics_all) / (tock - tick)
        examples_per_sec = denominator / (tock - tick)
        summary_writer.scalar('per-second/steps', steps_per_sec, step)
        summary_writer.scalar('per-second/examples', examples_per_sec, step)
        for key, val in summary.items():
            summary_writer.scalar(key, val, step)

        adapter.write_summaries(example, logits, summary_writer, info, step,
                                state)
        summary_writer.flush()
Example #5
0
    def __init__(self,
                 base_dir,
                 create_agent_fn,
                 create_environment_fn=gym_lib.create_gym_environment,
                 checkpoint_file_prefix='ckpt',
                 logging_file_prefix='log',
                 log_every_n=1,
                 num_iterations=200,
                 training_steps=250000,
                 evaluation_steps=125000,
                 max_steps_per_episode=1000,
                 clip_rewards=False):
        """Initialize the Runner object in charge of running a full experiment.

    Args:
      base_dir: str, the base directory to host all required sub-directories.
      create_agent_fn: A function that takes as argument an environment, and
        returns an agent.
      create_environment_fn: A function which receives a problem name and
        creates a Gym environment for that problem (e.g. an Atari 2600 game).
      checkpoint_file_prefix: str, the prefix to use for checkpoint files.
      logging_file_prefix: str, prefix to use for the log files.
      log_every_n: int, the frequency for writing logs.
      num_iterations: int, the iteration number threshold (must be greater than
        start_iteration).
      training_steps: int, the number of training steps to perform.
      evaluation_steps: int, the number of evaluation steps to perform.
      max_steps_per_episode: int, maximum number of steps after which an episode
        terminates.
      clip_rewards: bool, whether to clip rewards in [-1, 1].

    This constructor will take the following actions:
    - Initialize an environment.
    - Initialize a logger.
    - Initialize an agent.
    - Reload from the latest checkpoint, if available, and initialize the
      Checkpointer object.
    """
        assert base_dir is not None
        self._logging_file_prefix = logging_file_prefix
        self._log_every_n = log_every_n
        self._num_iterations = num_iterations
        self._training_steps = training_steps
        self._evaluation_steps = evaluation_steps
        self._max_steps_per_episode = max_steps_per_episode
        self._base_dir = base_dir
        self._clip_rewards = clip_rewards
        self._create_directories()
        self._summary_writer = tensorboard.SummaryWriter(base_dir)
        self._environment = create_environment_fn()
        self._agent = create_agent_fn(self._environment,
                                      summary_writer=self._summary_writer)
        self._initialize_checkpointer_and_maybe_resume(checkpoint_file_prefix)
Example #6
0
def train(module: models.ActorCritic, optimizer: flax.optim.base.Optimizer,
          config: ml_collections.ConfigDict, model_dir: str):
    """Main training loop.

  Args:
    module: the actor-critic model
    optimizer: optimizer for the actor-critic model
    config: object holding hyperparameters and the training information
    model_dir: path to dictionary where checkpoints and logging info are stored

  Returns:
    optimizer: the trained optimizer
  """
    game = config.game + 'NoFrameskip-v4'
    simulators = [
        agent.RemoteSimulator(game) for _ in range(config.num_agents)
    ]
    summary_writer = tensorboard.SummaryWriter(model_dir)
    summary_writer.hparams(dict(config))
    loop_steps = config.total_frames // (config.num_agents *
                                         config.actor_steps)
    log_frequency = 40
    checkpoint_frequency = 500

    for s in range(loop_steps):
        # Bookkeeping and testing.
        if s % log_frequency == 0:
            score = test_episodes.policy_test(1, module, optimizer.target,
                                              game)
            frames = s * config.num_agents * config.actor_steps
            summary_writer.scalar('game_score', score, frames)
            print(f'Step {s}:\nframes seen {frames}\nscore {score}\n\n')
        if s % checkpoint_frequency == 0:
            checkpoints.save_checkpoint(model_dir, optimizer, s)

        # Core training code.
        alpha = 1. - s / loop_steps if config.decaying_lr_and_clip_param else 1.
        all_experiences = get_experience(optimizer.target, module, simulators,
                                         config.actor_steps)
        trajectories = process_experience(all_experiences, config.actor_steps,
                                          config.num_agents, config.gamma,
                                          config.lambda_)
        lr = config.learning_rate * alpha
        clip_param = config.clip_param * alpha
        for e in range(config.num_epochs):
            permutation = onp.random.permutation(config.num_agents *
                                                 config.actor_steps)
            trajectories = tuple(map(lambda x: x[permutation], trajectories))
            optimizer, loss = train_step(module, optimizer, trajectories,
                                         clip_param, config.vf_coeff,
                                         config.entropy_coeff, lr,
                                         config.batch_size)
    return optimizer
Example #7
0
def train_agent(iterations, modeldir, logdir):
    """Train and convert the model."""
    summary_writer = tensorboard.SummaryWriter(logdir)

    rng = random.PRNGKey(0)
    rng, init_rng = random.split(rng)
    policygradient = PolicyGradient()
    params = policygradient.init(
        init_rng, jnp.ones([1, common.BOARD_SIZE,
                            common.BOARD_SIZE]))['params']
    optimizer = create_optimizer(model_params=params,
                                 learning_rate=LEARNING_RATE)

    # Main training loop
    progress_bar = tf.keras.utils.Progbar(iterations)
    for i in range(iterations):
        predict_fn = functools.partial(run_inference, optimizer.target)
        board_log, action_log, result_log = common.play_game(predict_fn)
        rewards = common.compute_rewards(result_log)
        summary_writer.scalar('game_length', len(board_log), i)
        optimizer = train_step(optimizer, board_log, action_log, rewards)

        summary_writer.flush()
        progress_bar.add(1)

    summary_writer.close()

    # Convert to tflite model
    model = PolicyGradient()
    jax_predict_fn = lambda input: model.apply({'params': optimizer.target},
                                               input)

    tf_predict = tf.function(
        jax2tf.convert(jax_predict_fn, enable_xla=False),
        input_signature=[
            tf.TensorSpec(shape=[1, common.BOARD_SIZE, common.BOARD_SIZE],
                          dtype=tf.float32,
                          name='input')
        ],
        autograph=False)

    converter = tf.lite.TFLiteConverter.from_concrete_functions(
        [tf_predict.get_concrete_function()], tf_predict)

    tflite_model = converter.convert()

    # Save the model
    with open(os.path.join(modeldir, 'planestrike.tflite'), 'wb') as f:
        f.write(tflite_model)

    print('TFLite model generated!')
def train_and_evaluate(config, workdir):
  """Execute model training and evaluation loop.

  Args:
      config: Hyperparameter configuration for training and evaluation.
      workdir: Directory where the tensorboard summaries are written to.
  """
  ## get random seed
  rng = jax.random.PRNGKey(0)
  ## Get data
  train_ds, test_ds = get_datasets("cifar10")

  ## Initializing model and infering dimensions of layers from one example batch
  model = models.ResNet18(num_classes=10)
  init_params = model.init(
      rng, jnp.ones((1, 32, 32, 3))
  )  # figure this shape out automatically ?
  params = init_params

  solver, solver_param_name = get_solver(
      FLAGS, config, loss_fun, losses=loss_fun)  # losses is not defined yet!
  params, state = solver.init(params)

  ## Path to dump results
  dumpath = create_dumpfile(config, solver_param_name, workdir, "cifar10")

  summary_writer = tensorboard.SummaryWriter(dumpath)
  summary_writer.hparams(dict(config))

  for epoch in range(1, config.num_epochs + 1):
    rng, _ = jax.random.split(rng)
    params, state = train_epoch(
        config, solver, params, state, train_ds, rng
    )
    test_loss, test_accuracy = eval_model(params, test_ds)
    train_loss, train_accuracy = eval_model(params, train_ds)
    print("eval epoch: %d, loss: %.4f, accuracy: %.2f", epoch, test_loss,
          test_accuracy * 100)
    print("train epoch: %d, train_loss: %.4f, train_accuracy: %.2f", epoch,
          train_loss, train_accuracy * 100)
    logging.info("eval epoch: %d, loss: %.4f, accuracy: %.2f", epoch, test_loss,
                 test_accuracy * 100)
    summary_writer.scalar("train_loss", train_loss, epoch)
    summary_writer.scalar("test_loss", loss, epoch)
    summary_writer.scalar("train_accuracy", train_accuracy, epoch)
    summary_writer.scalar("test_accuracy", test_accuracy, epoch)

  summary_writer.flush()
def train_and_evaluate(config, workdir):
    """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.
    workdir: Directory where the tensorboard summaries are written to.
  """
    train_ds, test_ds = get_datasets("mnist")
    # Get solver
    solver, solver_param_name = get_solver(FLAGS, config, loss_fun, losses)

    rng = jax.random.PRNGKey(0)
    rng, init_rng = jax.random.split(rng)

    init_params = CNN().init(init_rng, jnp.ones([1, 28, 28, 1]))["params"]
    params, state = solver.init(init_params)

    # Full path to dump resultss
    dumpath = create_dumpfile(config, solver_param_name, workdir, "mnist")

    summary_writer = tensorboard.SummaryWriter(dumpath)
    summary_writer.hparams(dict(config))

    # Run solver.
    for epoch in range(1, config.num_epochs + 1):
        rng, input_rng = jax.random.split(rng)

        params, state, train_metrics = train_epoch(config, solver, params,
                                                   state, train_ds, epoch,
                                                   input_rng)
        test_loss, test_accuracy = eval_model(params, test_ds)

        print("eval epoch: %d, loss: %.4f, accuracy: %.2f", epoch, test_loss,
              test_accuracy * 100)
        logging.info("eval epoch: %d, loss: %.4f, accuracy: %.2f", epoch,
                     test_loss, test_accuracy * 100)

        summary_writer.scalar("train_loss", train_metrics["loss"], epoch)
        summary_writer.scalar("train_accuracy", train_metrics["accuracy"],
                              epoch)
        summary_writer.scalar("eval_loss", test_loss, epoch)
        summary_writer.scalar("eval_accuracy", test_accuracy, epoch)

    summary_writer.flush()
Example #10
0
def train_and_evaluate(config: ml_collections.ConfigDict,
                       workdir: str) -> train_state.TrainState:
    """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.
    workdir: Directory where the tensorboard summaries are written to.

  Returns:
    The train state (which includes the `.params`).
  """
    train_ds, test_ds = get_datasets()
    rng = jax.random.PRNGKey(0)

    summary_writer = tensorboard.SummaryWriter(workdir)
    summary_writer.hparams(dict(config))

    rng, init_rng = jax.random.split(rng)
    cnn = CNN()
    params = cnn.init(init_rng, jnp.ones([1, 28, 28, 1]))['params']
    tx = optax.sgd(config.learning_rate, config.momentum)
    state = train_state.TrainState.create(apply_fn=cnn.apply,
                                          params=params,
                                          tx=tx)

    for epoch in range(1, config.num_epochs + 1):
        rng, input_rng = jax.random.split(rng)
        state, train_metrics = train_epoch(state, train_ds, config.batch_size,
                                           epoch, input_rng)
        loss, accuracy = eval_model(state.params, test_ds)

        logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch, loss,
                     accuracy * 100)

        summary_writer.scalar('train_loss', train_metrics['loss'], epoch)
        summary_writer.scalar('train_accuracy', train_metrics['accuracy'],
                              epoch)
        summary_writer.scalar('eval_loss', loss, epoch)
        summary_writer.scalar('eval_accuracy', accuracy, epoch)

    summary_writer.flush()
    return state
Example #11
0
    def setUp(self):
        super().setUp()

        self._batch_size = 128  # Note: Tests are run on GPU/TPU.
        self._batch_size_test = 128
        self._shuffle_buffer_size = 1024
        self._rng = jax.random.PRNGKey(42)
        self._input_shape = ((self._batch_size, 28, 28, 1), jnp.float32)
        self._num_classes = 10
        self._num_epochs = 1

        self._learning_rate_fn = lambda _: 0.01
        self._weight_decay = 0.0001
        self._momentum = 0.9
        self._rng = jax.random.PRNGKey(42)

        self._min_loss = jnp.finfo(float).eps
        self._max_loss = 2.0 * math.log(self._num_classes)

        self._dataset_name = 'MNIST'
        self._model_name = 'MNIST_CNN'

        self._summarywriter = tensorboard.SummaryWriter('/tmp/')

        self._dataset = dataset_factory.create_dataset(
            self._dataset_name,
            self._batch_size,
            self._batch_size_test,
            shuffle_buffer_size=self._shuffle_buffer_size)

        self._model, self._state = model_factory.create_model(
            self._model_name,
            self._rng, (self._input_shape, ),
            num_classes=self._num_classes)

        self._optimizer = flax.optim.Momentum(
            learning_rate=self._learning_rate_fn(0),
            beta=self._momentum,
            weight_decay=self._weight_decay)
Example #12
0
def train_and_evaluate(config: ml_collections.ConfigDict,
                       workdir: str) -> train_state.TrainState:
    """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.
    workdir: Directory where the tensorboard summaries are written to.

  Returns:
    The train state (which includes the `.params`).
  """
    train_ds, test_ds = get_datasets()
    rng = jax.random.PRNGKey(0)

    summary_writer = tensorboard.SummaryWriter(workdir)
    summary_writer.hparams(dict(config))

    rng, init_rng = jax.random.split(rng)
    state = create_train_state(init_rng, config)

    for epoch in range(1, config.num_epochs + 1):
        rng, input_rng = jax.random.split(rng)
        state, train_loss, train_accuracy = train_epoch(
            state, train_ds, config.batch_size, input_rng)
        _, test_loss, test_accuracy = apply_model(state, test_ds['image'],
                                                  test_ds['label'])

        print(
            'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f'
            % (epoch, train_loss, train_accuracy * 100, test_loss,
               test_accuracy * 100))

        summary_writer.scalar('train_loss', train_loss, epoch)
        summary_writer.scalar('train_accuracy', train_accuracy, epoch)
        summary_writer.scalar('test_loss', test_loss, epoch)
        summary_writer.scalar('test_accuracy', test_accuracy, epoch)

    summary_writer.flush()
    return state
Example #13
0
def train_and_evaluate(model_dir: str, num_epochs: int, batch_size: int,
                       learning_rate: float, momentum: float):
  """Execute model training and evaluation loop.

  Args:
    model_dir: Directory where the tensorboard summaries are written to.
    num_epochs: Number of epochs to cycle through the dataset before stopping.
    batch_size: Batch size of the input.
    learning_rate: Learning rate for the momentum optimizer.
    momentum: Momentum value for the momentum optimizer.

  Returns:
    The trained optimizer.
  """
  train_ds, test_ds = get_datasets()
  rng = random.PRNGKey(0)

  summary_writer = tensorboard.SummaryWriter(model_dir)

  rng, init_rng = random.split(rng)
  model = create_model(init_rng)
  optimizer = create_optimizer(model, learning_rate, momentum)

  for epoch in range(1, num_epochs + 1):
    rng, input_rng = random.split(rng)
    optimizer, train_metrics = train_epoch(
        optimizer, train_ds, batch_size, epoch, input_rng)
    loss, accuracy = eval_model(optimizer.target, test_ds)

    logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f',
                 epoch, loss, accuracy * 100)

    summary_writer.scalar('train_loss', train_metrics['loss'], epoch)
    summary_writer.scalar('train_accuracy', train_metrics['accuracy'], epoch)
    summary_writer.scalar('eval_loss', loss, epoch)
    summary_writer.scalar('eval_accuracy', accuracy, epoch)

  summary_writer.flush()
  return optimizer
Example #14
0
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
    """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.
    workdir: Directory where the tensorboard summaries are written to.

  Returns:
    The trained optimizer.
  """
    train_ds, test_ds = get_datasets()
    rng = jax.random.PRNGKey(0)

    summary_writer = tensorboard.SummaryWriter(workdir)
    summary_writer.hparams(dict(config))

    rng, init_rng = jax.random.split(rng)
    params = get_initial_params(init_rng)
    optimizer = create_optimizer(params, config.learning_rate, config.momentum)

    for epoch in range(1, config.num_epochs + 1):
        rng, input_rng = jax.random.split(rng)
        optimizer, train_metrics = train_epoch(optimizer, train_ds,
                                               config.batch_size, epoch,
                                               input_rng)
        loss, accuracy = eval_model(optimizer.target, test_ds)

        logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch, loss,
                     accuracy * 100)

        summary_writer.scalar('train_loss', train_metrics['loss'], epoch)
        summary_writer.scalar('train_accuracy', train_metrics['accuracy'],
                              epoch)
        summary_writer.scalar('eval_loss', loss, epoch)
        summary_writer.scalar('eval_accuracy', accuracy, epoch)

    summary_writer.flush()
    return optimizer
Example #15
0
def run_train(run_configuration):
    """Runs the training workflow."""
    config = run_configuration.config
    run_dir = run_configuration.run_dir
    adapter = run_configuration.adapter
    log_dir = os.path.join(run_dir, 'train')
    checkpoint_path = run_configuration.original_checkpoint_path

    dataset = run_configuration.dataset_info.dataset
    info = run_configuration.dataset_info.info

    random_seed = 0
    rng = jax.random.PRNGKey(random_seed)
    rng = jax.random.fold_in(rng, jax.host_id())
    rng, init_rng = jax.random.split(rng)
    dropout_rngs = jax.random.split(rng, jax.local_device_count())

    # Set up optimizer.
    optimizer = adapter.create_optimizer(run_configuration, rng=init_rng)

    # Set up train step.
    train_step = adapter.make_train_step()

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(log_dir)

    # Set up checkpointing.
    # TODO(dbieber): Set up phoenix.
    checkpoint_dir = checkpoint_utils.build_checkpoint_dir(run_dir)
    if checkpoint_path is None:
        checkpoint_path = checkpoint_utils.latest_checkpoint(checkpoint_dir)
    optimizer = checkpoint_utils.handle_restart_behavior(
        checkpoint_path, optimizer, config)

    start_step = int(optimizer.state.step)
    num_train_steps = config.train.total_steps

    # Replicate optimizer.
    optimizer = flax.jax_utils.replicate(optimizer)

    # Begin training loop.
    dataset_iter_raw = iter(dataset)
    dataset_iter = adapter.preprocess(dataset_iter_raw)

    summary_freq = config.logging.summary_freq
    metrics_all = []
    tick = time.time()
    for step, example in zip(range(start_step, num_train_steps), dataset_iter):
        train_inputs = adapter.get_train_inputs(example)
        optimizer, metrics, dropout_rngs, logits, state = train_step(
            optimizer, train_inputs, dropout_rngs)
        metrics_all.append(metrics)

        # Save a Checkpoint
        if ((step % config.logging.save_freq == 0 and step > 0)
                or step == num_train_steps - 1):
            if jax.host_id() == 0 and config.logging.save_freq:
                # Save unreplicated optimizer + model state.
                checkpoint_utils.save_checkpoint(
                    checkpoint_dir, jax_utils.unreplicate(optimizer), step)

        # Periodic metric handling.
        if summary_freq and step % summary_freq == 0 and step > 0:
            metrics_all = common_utils.get_metrics(metrics_all)
            lr = metrics_all.pop('learning_rate').mean()
            metrics_sums = jax.tree_map(jnp.sum, metrics_all)
            denominator = metrics_sums.pop('denominator')
            summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
            summary['learning_rate'] = lr
            # Calculate (clipped) perplexity after averaging log-perplexities:
            summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']),
                                             a_max=1.0e4)
            logging.info('train step: %d, loss: %.4f', step, summary['loss'])
            if jax.host_id() == 0:
                tock = time.time()
                steps_per_sec = summary_freq / (tock - tick)
                examples_per_sec = denominator / (tock - tick)
                tick = tock
                summary_writer.scalar('per-second/steps', steps_per_sec, step)
                summary_writer.scalar('per-second/examples', examples_per_sec,
                                      step)
                for key, val in summary.items():
                    summary_writer.scalar(key, val, step)

                adapter.write_summaries(example, logits, summary_writer, info,
                                        step, state)

                summary_writer.flush()
            # Reset metric accumulation for next evaluation cycle.
            metrics_all = []
Example #16
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.enable_v2_behavior()
    # make sure tf does not allocate gpu memory
    tf.config.experimental.set_visible_devices([], 'GPU')

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(FLAGS.model_dir)

    rng = random.PRNGKey(0)

    image_size = 224

    batch_size = FLAGS.batch_size
    if batch_size % jax.device_count() > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')
    local_batch_size = batch_size // jax.host_count()
    device_batch_size = batch_size // jax.device_count()

    platform = jax.local_devices()[0].platform

    if FLAGS.half_precision:
        if platform == 'tpu':
            model_dtype = jnp.bfloat16
            input_dtype = tf.bfloat16
        else:
            model_dtype = jnp.float16
            input_dtype = tf.float16
    else:
        model_dtype = jnp.float32
        input_dtype = tf.float32

    train_iter = create_input_iter(local_batch_size,
                                   image_size,
                                   input_dtype,
                                   train=True,
                                   cache=FLAGS.cache)
    eval_iter = create_input_iter(local_batch_size,
                                  image_size,
                                  input_dtype,
                                  train=False,
                                  cache=FLAGS.cache)

    num_epochs = FLAGS.num_epochs
    steps_per_epoch = input_pipeline.TRAIN_IMAGES // batch_size
    steps_per_eval = input_pipeline.EVAL_IMAGES // batch_size
    steps_per_checkpoint = steps_per_epoch * 10
    num_steps = steps_per_epoch * num_epochs

    base_learning_rate = FLAGS.learning_rate * batch_size / 256.
    base_learning_rate = base_learning_rate / FLAGS.loss_scaling

    model, model_state = create_model(rng, device_batch_size, image_size,
                                      model_dtype)
    optimizer = optim.Momentum(beta=FLAGS.momentum,
                               nesterov=True).create(model)
    state = TrainState(step=0, optimizer=optimizer, model_state=model_state)
    del model, model_state  # do not keep a copy of the initial model

    state = restore_checkpoint(state)
    step_offset = int(
        state.step)  # step_offset > 0 if restarting from checkpoint
    state = jax_utils.replicate(state)

    learning_rate_fn = create_learning_rate_fn(base_learning_rate,
                                               steps_per_epoch, num_epochs)

    p_train_step = jax.pmap(functools.partial(
        train_step, learning_rate_fn=learning_rate_fn),
                            axis_name='batch')
    p_eval_step = jax.pmap(eval_step, axis_name='batch')

    epoch_metrics = []
    t_loop_start = time.time()
    for step, batch in zip(range(step_offset, num_steps), train_iter):
        state, metrics = p_train_step(state, batch)
        epoch_metrics.append(metrics)
        if (step + 1) % steps_per_epoch == 0:
            epoch = step // steps_per_epoch
            epoch_metrics = common_utils.get_metrics(epoch_metrics)
            summary = jax.tree_map(lambda x: x.mean(), epoch_metrics)
            logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                         summary['loss'], summary['accuracy'] * 100)
            steps_per_sec = steps_per_epoch / (time.time() - t_loop_start)
            t_loop_start = time.time()
            if jax.host_id() == 0:
                for key, vals in epoch_metrics.items():
                    tag = 'train_%s' % key
                    for i, val in enumerate(vals):
                        summary_writer.scalar(tag, val,
                                              step - len(vals) + i + 1)
                summary_writer.scalar('steps per second', steps_per_sec, step)

            epoch_metrics = []
            eval_metrics = []

            # sync batch statistics across replicas
            state = sync_batch_stats(state)
            for _ in range(steps_per_eval):
                eval_batch = next(eval_iter)
                metrics = p_eval_step(state, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            summary = jax.tree_map(lambda x: x.mean(), eval_metrics)
            logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                         summary['loss'], summary['accuracy'] * 100)
            if jax.host_id() == 0:
                for key, val in eval_metrics.items():
                    tag = 'eval_%s' % key
                    summary_writer.scalar(tag, val.mean(), step)
                summary_writer.flush()
        if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps:
            state = sync_batch_stats(state)
            save_checkpoint(state)

    # Wait until computations are done before exiting
    jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
Example #17
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    if FLAGS.jax_backend_target:
        jax.config.FLAGS.jax_xla_backend = "tpu_driver"
        jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target

    # This seems to be necessary even when importing TF2?
    tf.enable_v2_behavior()

    # Number of local devices for this host.
    n_devices = jax.local_device_count()

    if jax.host_id() == 0:
        train_summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'train'))
        eval_summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'eval'))

    if FLAGS.batch_size % n_devices:
        raise ValueError(
            'Batch size must be divisible by the number of devices')

    vocab_path = FLAGS.vocab_path
    if vocab_path is None:
        vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model')

    # Load Dataset
    logging.info('Initializing dataset.')
    train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
        n_devices=n_devices,
        dataset_name=FLAGS.dataset_name,
        eval_dataset_name=FLAGS.eval_dataset_name,
        shard_idx=jax.host_id(),
        shard_count=jax.host_count(),
        data_dir=FLAGS.data_dir,
        vocab_path=vocab_path,
        batch_size=FLAGS.batch_size,
        max_length=FLAGS.max_target_length,
        max_eval_length=FLAGS.max_eval_target_length)
    train_iter = iter(train_ds)
    vocab_size = int(encoder.vocab_size())
    eos_token = 2  # Default Sentencepiece EOS token.

    def decode_tokens(toks):
        valid_toks = toks[:np.argmax(toks == eos_token) + 1].astype(np.int32)
        return encoder.detokenize(valid_toks).numpy().decode('utf-8')

    logging.info('Initializing model, optimizer, and step functions.')
    # Build Model and Optimizer
    transformer_kwargs = {
        'vocab_size': vocab_size,
        'output_vocab_size': vocab_size,
        'emb_dim': 1024,
        'num_heads': 16,
        'num_layers': 6,
        'qkv_dim': 1024,
        'mlp_dim': 4096,
        'max_len': max(FLAGS.max_target_length, FLAGS.max_eval_target_length),
        'share_embeddings': FLAGS.share_embeddings,
        'logits_via_embedding': FLAGS.logits_via_embedding,
    }

    start_step = 0
    rng = random.PRNGKey(FLAGS.random_seed)
    rng, init_rng = random.split(rng)
    input_shape = (FLAGS.batch_size, FLAGS.max_target_length)
    target_shape = (FLAGS.batch_size, FLAGS.max_target_length)
    model, cache_def = create_model(init_rng, input_shape, target_shape,
                                    transformer_kwargs)
    optimizer = create_optimizer(model, FLAGS.learning_rate,
                                 FLAGS.weight_decay)
    # We access model only from optimizer below via optimizer.target.
    del model

    if FLAGS.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = create_learning_rate_scheduler(
        base_learning_rate=FLAGS.learning_rate,
        warmup_steps=FLAGS.warmup_steps)

    p_train_step = jax.pmap(functools.partial(
        train_step,
        learning_rate_fn=learning_rate_fn,
        label_smoothing=FLAGS.label_smoothing,
        use_bfloat16=FLAGS.use_bfloat16),
                            axis_name='batch')
    p_eval_step = jax.pmap(functools.partial(
        eval_step,
        label_smoothing=FLAGS.label_smoothing,
        use_bfloat16=FLAGS.use_bfloat16),
                           axis_name='batch')
    p_pred_step = jax.pmap(
        functools.partial(predict_step, use_bfloat16=FLAGS.use_bfloat16),
        axis_name='batch',
        static_broadcasted_argnums=(3,
                                    4))  # eos token, max_length are constant

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    dropout_rngs = random.split(rng, n_devices)

    logging.info('Starting training loop.')
    metrics_all = []
    t_loop_start = time.time()
    for step, batch in zip(range(start_step, FLAGS.num_train_steps),
                           train_iter):
        # Shard data to devices and do a training step.
        batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access
        optimizer, metrics, dropout_rngs = p_train_step(
            optimizer, batch, dropout_rng=dropout_rngs)
        metrics_all.append(metrics)

        # Save a checkpoint on one host after every checkpoint_freq steps.
        if (FLAGS.save_checkpoints and step % FLAGS.checkpoint_freq == 0
                and step > 0 and jax.host_id() == 0):
            checkpoints.save_checkpoint(FLAGS.model_dir,
                                        jax_utils.unreplicate(optimizer), step)

        # Periodic metric handling.
        if step % FLAGS.eval_frequency != 0:
            continue

        logging.info('Gathering training metrics.')
        # Training Metrics
        metrics_all = common_utils.get_metrics(metrics_all)
        lr = metrics_all.pop('learning_rate').mean()
        metrics_sums = jax.tree_map(jnp.sum, metrics_all)
        denominator = metrics_sums.pop('denominator')
        summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
        summary['learning_rate'] = lr
        steps_per_eval = FLAGS.eval_frequency if step != 0 else 1
        steps_per_sec = steps_per_eval / (time.time() - t_loop_start)
        t_loop_start = time.time()
        if jax.host_id() == 0:
            train_summary_writer.scalar('steps per second', steps_per_sec,
                                        step)
            for key, val in summary.items():
                train_summary_writer.scalar(key, val, step)
            train_summary_writer.flush()
        metrics_all = []
        logging.info('train in step: %d, loss: %.4f', step, summary['loss'])

        # Eval Metrics
        logging.info('Gathering evaluation metrics.')
        t_eval_start = time.time()
        eval_metrics = []
        eval_iter = iter(eval_ds)
        for _, eval_batch in zip(range(FLAGS.num_eval_steps), eval_iter):
            eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
            eval_batch = common_utils.shard(eval_batch)
            metrics = p_eval_step(optimizer.target, eval_batch)
            eval_metrics.append(metrics)
        eval_metrics = common_utils.get_metrics(eval_metrics)
        eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
        eval_denominator = eval_metrics_sums.pop('denominator')
        eval_summary = jax.tree_map(
            lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
            eval_metrics_sums)
        if jax.host_id() == 0:
            for key, val in eval_summary.items():
                eval_summary_writer.scalar(key, val, step)
            eval_summary_writer.flush()
        logging.info('eval in step: %d, loss: %.4f', step,
                     eval_summary['loss'])
        logging.info('eval time: %.4f s step %d',
                     time.time() - t_eval_start, step)

        # Translation and BLEU Score.
        logging.info('Translating evaluation dataset.')
        t_inference_start = time.time()
        predict_iter = iter(predict_ds)
        sources, references, predictions = [], [], []
        for _, pred_batch in enumerate(predict_iter):
            pred_batch = jax.tree_map(lambda x: x._numpy(), pred_batch)  # pylint: disable=protected-access
            # Handle final odd-sized batch by padding instead of dropping it.
            cur_pred_batch_size = pred_batch['inputs'].shape[0]
            if cur_pred_batch_size % n_devices:
                padded_size = int(
                    np.ceil(cur_pred_batch_size / n_devices) * n_devices)
                pred_batch = jax.tree_map(
                    lambda x: pad_examples(x, padded_size), pred_batch)  # pylint: disable=cell-var-from-loop
            pred_batch = common_utils.shard(pred_batch)
            per_device_batchsize = pred_batch['inputs'].shape[1]
            cache_dtype = jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32
            cache = jax_utils.replicate(
                cache_def.initialize_cache(
                    (per_device_batchsize, FLAGS.max_predict_length),
                    dtype=cache_dtype))
            predicted = p_pred_step(pred_batch['inputs'], optimizer.target,
                                    cache, eos_token, FLAGS.max_predict_length)
            predicted = tohost(predicted)
            inputs = tohost(pred_batch['inputs'])
            targets = tohost(pred_batch['targets'])
            # Iterate through non-padding examples of batch.
            for i, s in enumerate(predicted[:cur_pred_batch_size]):
                sources.append(decode_tokens(inputs[i]))
                references.append(decode_tokens(targets[i]))
                predictions.append(decode_tokens(s))
        logging.info('Translation: %d predictions %d references %d sources.',
                     len(predictions), len(references), len(sources))
        logging.info('Translation time: %.4f s step %d.',
                     time.time() - t_inference_start, step)

        # Calculate BLEU score for translated eval corpus against reference.
        bleu_matches = bleu.bleu_partial(references, predictions)
        all_bleu_matches = per_host_sum_pmap(bleu_matches)
        bleu_score = bleu.complete_bleu(*all_bleu_matches)
        # Save translation samples for tensorboard.
        exemplars = ''
        for n in np.random.choice(np.arange(len(predictions)), 8):
            exemplars += f'{sources[n]}\n\n{references[n]}\n\n{predictions[n]}\n\n'
        if jax.host_id() == 0:
            eval_summary_writer.scalar('bleu', bleu_score, step)
            eval_summary_writer.text('samples', exemplars, step)
            eval_summary_writer.flush()
        logging.info('Translation BLEU Score %.4f', bleu_score)
Example #18
0
def train(
    model: models.ActorCritic,
    config: ml_collections.ConfigDict,
    model_dir: str):
  """Main training loop.

  Args:
    model: the actor-critic model
    config: object holding hyperparameters and the training information
    model_dir: path to dictionary where checkpoints and logging info are stored

  Returns:
    optimizer: the trained optimizer
  """
  
  game = config.game + 'NoFrameskip-v4'
  simulators = [agent.RemoteSimulator(game)
                for _ in range(config.num_agents)]
  summary_writer = tensorboard.SummaryWriter(model_dir)
  summary_writer.hparams(dict(config))
  loop_steps = config.total_frames // (config.num_agents * config.actor_steps)
  log_frequency = 40
  checkpoint_frequency = 500
  # train_step does multiple steps per call for better performance
  # compute number of steps per call here to convert between the number of
  # train steps and the inner number of optimizer steps
  iterations_per_step = (config.num_agents * config.actor_steps
      // config.batch_size)

  initial_params = get_initial_params(jax.random.PRNGKey(0), model)
  state = create_train_state(initial_params, model, config,
                             loop_steps * config.num_epochs * iterations_per_step)
  del initial_params
  state = checkpoints.restore_checkpoint(model_dir, state)
  # number of train iterations done by each train_step

  start_step = int(state.step) // config.num_epochs // iterations_per_step
  logging.info('Start training from step: %s', start_step)

  for step in range(start_step, loop_steps):
    # Bookkeeping and testing.
    if step % log_frequency == 0:
      score = test_episodes.policy_test(1, state.apply_fn, state.params, game)
      frames = step * config.num_agents * config.actor_steps
      summary_writer.scalar('game_score', score, frames)
      logging.info('Step %s:\nframes seen %s\nscore %s\n\n', step, frames, score)

    # Core training code.
    alpha = 1. - step / loop_steps if config.decaying_lr_and_clip_param else 1.
    all_experiences = get_experience(
        state, simulators, config.actor_steps)
    trajectories = process_experience(
        all_experiences, config.actor_steps, config.num_agents, config.gamma,
        config.lambda_)
    clip_param = config.clip_param * alpha
    for _ in range(config.num_epochs):
      permutation = np.random.permutation(
          config.num_agents * config.actor_steps)
      trajectories = tuple(x[permutation] for x in trajectories)
      state, _ = train_step(
          state, trajectories, config.batch_size,
          clip_param=clip_param,
          vf_coeff=config.vf_coeff,
          entropy_coeff=config.entropy_coeff)
    if (step + 1) % checkpoint_frequency == 0:
      checkpoints.save_checkpoint(model_dir, state, step + 1)
  return train_state
Example #19
0
def train(config, workdir):
  """Runs a training loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """

  # Create directories for experimental logs
  tf.io.gfile.makedirs(workdir)
  sample_dir = os.path.join(workdir, "samples")
  tf.io.gfile.makedirs(sample_dir)
  rng = jax.random.PRNGKey(config.seed)
  tb_dir = os.path.join(workdir, "tensorboard")
  tf.io.gfile.makedirs(tb_dir)
  if jax.host_id() == 0:
    writer = tensorboard.SummaryWriter(tb_dir)

  # Initialize model.
  rng, model_rng = jax.random.split(rng)
  model_name = config.model.name
  ncsn_def = mutils.get_model(model_name).partial(config=config)
  rng, run_rng = jax.random.split(rng)
  # Whether the generative model is conditioned on class labels
  class_conditional = "conditional" in config.training.loss.lower()
  with nn.stateful() as init_model_state:
    with nn.stochastic(run_rng):
      input_shape = (jax.local_device_count(), config.data.image_size,
                     config.data.image_size, 3)
      input_list = [(input_shape, jnp.float32), (input_shape[:1], jnp.int32)]
      if class_conditional:
        input_list.append(input_list[-1])
      _, initial_params = ncsn_def.init_by_shape(
          model_rng, input_list, train=True)
      ncsn = nn.Model(ncsn_def, initial_params)

  optimizer = losses.get_optimizer(config).create(ncsn)

  state = mutils.State(step=0, optimizer=optimizer, lr=config.optim.lr,
                       model_state=init_model_state,
                       ema_rate=config.model.ema_rate,
                       params_ema=initial_params,
                       rng=rng)  # pytype: disable=wrong-keyword-args

  del ncsn, init_model_state  # Do not keep a copy of the initial model.

  # Create checkpoints directory and the initial checkpoint
  checkpoint_dir = os.path.join(workdir, "checkpoints")
  ckpt = utils.Checkpoint(
      checkpoint_dir,
      max_to_keep=None)
  ckpt.restore_or_initialize(state)

  # Save intermediate checkpoints to resume training automatically
  checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta")
  ckpt_meta = utils.Checkpoint(
      checkpoint_meta_dir,
      max_to_keep=1)
  state = ckpt_meta.restore_or_initialize(state)
  initial_step = int(state.step)
  rng = state.rng

  # Build input pipeline.
  rng, ds_rng = jax.random.split(rng)
  train_ds, eval_ds, _ = datasets.get_dataset(ds_rng, config)
  train_iter = iter(train_ds)  # pytype: disable=wrong-arg-types
  eval_iter = iter(eval_ds)  # pytype: disable=wrong-arg-types
  scaler = datasets.get_data_scaler(config)  # data normalizer
  inverse_scaler = datasets.get_data_inverse_scaler(config)

  # Distribute training.
  optimize_fn = losses.optimization_manager(config)
  if config.training.loss.lower() == "ddpm":
    # Use score matching loss with DDPM-type perturbation.
    ddpm_params = mutils.get_ddpm_params()
    train_step = functools.partial(losses.ddpm_loss, ddpm_params=ddpm_params,
                                   train=True, optimize_fn=optimize_fn)
    eval_step = functools.partial(losses.ddpm_loss, ddpm_params=ddpm_params,
                                  train=False)
  else:
    # Use score matching loss with NCSN-type perturbation.
    sigmas = mutils.get_sigmas(config)
    # Whether to use a continuous distribution of noise levels
    continuous = "continuous" in config.training.loss.lower()
    train_step = functools.partial(
        losses.ncsn_loss,
        sigmas=sigmas,
        class_conditional=class_conditional,
        continuous=continuous,
        train=True,
        optimize_fn=optimize_fn,
        anneal_power=config.training.anneal_power)
    eval_step = functools.partial(
        losses.ncsn_loss,
        sigmas=sigmas,
        class_conditional=class_conditional,
        continuous=continuous,
        train=False,
        anneal_power=config.training.anneal_power)

  p_train_step = jax.pmap(train_step, axis_name="batch")
  p_eval_step = jax.pmap(eval_step, axis_name="batch")
  state = flax_utils.replicate(state)

  num_train_steps = config.training.n_iters

  logging.info("Starting training loop at step %d.", initial_step)
  rng = jax.random.fold_in(rng, jax.host_id())
  for step in range(initial_step, num_train_steps + 1):
    # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU
    # devices.

    # Convert data to JAX arrays. Use ._numpy() to avoid copy.
    batch = jax.tree_map(lambda x: scaler(x._numpy()), next(train_iter))  # pylint: disable=protected-access

    rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1)
    next_rng = jnp.asarray(next_rng)
    loss, state = p_train_step(next_rng, state, batch)
    loss = flax.jax_utils.unreplicate(loss)

    # Quick indication that training is happening.
    logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step)

    if jax.host_id() == 0 and step % 50 == 0:
      logging.info("step: %d, training_loss: %.5e", step, loss)
      writer.scalar("training_loss", loss, step)

    # Save a temporary checkpoint to resume training after pre-emption.
    if step % config.training.snapshot_freq_for_preemption == 0 and jax.host_id(
    ) == 0:
      saved_state = flax_utils.unreplicate(state)
      saved_state = saved_state.replace(rng=rng)
      ckpt_meta.save(saved_state)

    # Report the loss on an evaluation dataset.
    if step % 100 == 0:
      rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1)
      next_rng = jnp.asarray(next_rng)
      eval_batch = jax.tree_map(lambda x: scaler(x._numpy()), next(eval_iter))  # pylint: disable=protected-access
      eval_loss, _ = p_eval_step(next_rng, state, eval_batch)
      eval_loss = flax.jax_utils.unreplicate(eval_loss)
      if jax.host_id() == 0:
        logging.info("step: %d, eval_loss: %.5e", step, eval_loss)
        writer.scalar("eval_loss", eval_loss, step)

    # Save a checkpoint periodically and generate samples.
    if (step +
        1) % config.training.snapshot_freq == 0 or step == num_train_steps:
      # Save the checkpoint.
      if jax.host_id() == 0:
        saved_state = flax_utils.unreplicate(state)
        saved_state = saved_state.replace(rng=rng)
        ckpt.save(saved_state)

      # Generate and save samples
      if config.training.snapshot_sampling:
        rng, sample_rng = jax.random.split(rng)
        init_shape = tuple(train_ds.element_spec["image"].shape)
        samples = sampling.get_samples(sample_rng,
                                       config,
                                       flax_utils.unreplicate(state),
                                       init_shape,
                                       scaler,
                                       inverse_scaler,
                                       class_conditional=class_conditional)
        this_sample_dir = os.path.join(
            sample_dir, "iter_{}_host_{}".format(step, jax.host_id()))
        tf.io.gfile.makedirs(this_sample_dir)

        if config.sampling.final_only:  # Do not save intermediate samples
          sample = samples[-1]
          image_grid = sample.reshape((-1, *sample.shape[2:]))
          nrow = int(np.sqrt(image_grid.shape[0]))
          sample = np.clip(sample * 255, 0, 255).astype(np.uint8)
          with tf.io.gfile.GFile(
              os.path.join(this_sample_dir, "sample.np"), "wb") as fout:
            np.save(fout, sample)

          with tf.io.gfile.GFile(
              os.path.join(this_sample_dir, "sample.png"), "wb") as fout:
            utils.save_image(image_grid, fout, nrow=nrow, padding=2)
        else:  # Save all intermediate samples produced during sampling.
          for i, sample in enumerate(samples):
            image_grid = sample.reshape((-1, *sample.shape[2:]))
            nrow = int(np.sqrt(image_grid.shape[0]))
            sample = np.clip(sample * 255, 0, 255).astype(np.uint8)
            with tf.io.gfile.GFile(
                os.path.join(this_sample_dir, "sample_{}.np".format(i)),
                "wb") as fout:
              np.save(fout, sample)

            with tf.io.gfile.GFile(
                os.path.join(this_sample_dir, "sample_{}.png".format(i)),
                "wb") as fout:
              utils.save_image(image_grid, fout, nrow=nrow, padding=2)
Example #20
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.enable_v2_behavior()

    config = FLAGS.config
    logging.info('===========Config Dict============')
    logging.info(config)
    batch_size = config.batch_size
    learning_rate = config.learning_rate
    num_train_steps = config.num_train_steps
    num_eval_steps = config.num_eval_steps
    eval_freq = config.eval_frequency
    random_seed = config.random_seed
    model_type = config.model_type

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'summary'))
    else:
        summary_writer = None

    if batch_size % jax.device_count() > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')

    logging.info('Training on %s', FLAGS.task_name)

    if model_type in ['wideresnet', 'resnet', 'simple_cnn']:
        normalize = True
    else:  # transformer-based models
        normalize = False
    (train_ds, eval_ds, test_ds, num_classes, vocab_size,
     input_shape) = task_registry.TASK_DATA_DICT[FLAGS.task_name](
         n_devices=jax.local_device_count(),
         batch_size=batch_size,
         normalize=normalize)
    train_iter = iter(train_ds)
    model_kwargs = {}
    flatten_input = True

    if model_type in ['wideresnet', 'resnet', 'simple_cnn']:
        model_kwargs.update({
            'num_classes': num_classes,
        })
        flatten_input = False

    else:  # transformer models
        # we will flatten the input
        bs, h, w, c = input_shape
        assert c == 1
        input_shape = (bs, h * w * c)
        model_kwargs.update({
            'vocab_size': vocab_size,
            'max_len': input_shape[1],
            'classifier': True,
            'num_classes': num_classes,
        })

    model_kwargs.update(config.model)

    rng = random.PRNGKey(random_seed)
    rng = jax.random.fold_in(rng, jax.host_id())
    rng, init_rng = random.split(rng)
    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    dropout_rngs = random.split(rng, jax.local_device_count())

    model, state = get_model(init_rng, input_shape, model_type, model_kwargs)

    optimizer = create_optimizer(model, learning_rate, config.weight_decay)
    del model  # Don't keep a copy of the initial model.

    start_step = 0
    if config.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer, state = checkpoints.restore_checkpoint(
            FLAGS.model_dir, (optimizer, state))
        # Grab last step.
        start_step = int(optimizer.state.step)

    # Replicate optimizer and state
    optimizer = jax_utils.replicate(optimizer)
    state = jax_utils.replicate(state)

    learning_rate_fn = train_utils.create_learning_rate_scheduler(
        factors=config.factors,
        base_learning_rate=learning_rate,
        warmup_steps=config.warmup,
        steps_per_cycle=config.get('steps_per_cycle', None),
    )
    p_train_step = jax.pmap(functools.partial(
        train_step,
        learning_rate_fn=learning_rate_fn,
        num_classes=num_classes,
        grad_clip_norm=config.get('grad_clip_norm', None),
        flatten_input=flatten_input),
                            axis_name='batch')

    p_eval_step = jax.pmap(
        functools.partial(eval_step,
                          num_classes=num_classes,
                          flatten_input=flatten_input),
        axis_name='batch',
    )

    optimizer, state, step = train_loop(config, dropout_rngs, eval_ds,
                                        eval_freq, num_eval_steps,
                                        num_train_steps, optimizer, state,
                                        p_eval_step, p_train_step, start_step,
                                        train_iter, summary_writer)

    logging.info('Starting testing')
    logging.info('====================')
    test(optimizer, state, p_eval_step, step, test_ds, summary_writer,
         FLAGS.model_dir)
Example #21
0
def train_and_evaluate(config, workdir, vocab_filepath):
    """Runs a training and evaluation loop.

  Args:
    config: Model and training configuration.
    workdir: Working directory for checkpoints and Tensorboard summaries. If
      this contains a checkpoint, training will be resumed from the latest
      checkpoint.
    vocab_filepath: Absolute path to SentencePiece vocab model.

  Raises:
    ValueError: If training or eval batch sizes won't fit number of processes
      and devices, or config is underspecified.
  """
    n_processes = jax.process_count()  # Number of processes
    n_devices = jax.local_device_count()  # Number of local devices per process

    if config.train_batch_size % (n_processes * n_devices) > 0:
        raise ValueError(
            "Training batch size must be divisible by the total number of devices, "
            "but training batch size = %d, while total number of devices = %d "
            "(%d processes, each with %d devices)" %
            (config.train_batch_size, n_processes * n_devices, n_processes,
             n_devices))

    if config.eval_batch_size % (n_processes * n_devices) > 0:
        raise ValueError(
            "Eval batch size must be divisible by the total number of devices, "
            "but eval batch size = %d, while total number of devices = %d "
            "(%d processes, each with %d devices)" %
            (config.eval_batch_size, n_processes * n_devices, n_processes,
             n_devices))

    per_process_train_batch_size = config.train_batch_size // n_processes
    per_process_eval_batch_size = config.eval_batch_size // n_processes

    if jax.process_index() == 0:
        train_summary_writer = tensorboard.SummaryWriter(
            os.path.join(workdir, "train"))
        eval_summary_writer = tensorboard.SummaryWriter(
            os.path.join(workdir, "eval"))
    else:
        train_summary_writer = None
        eval_summary_writer = None

    rng = random.PRNGKey(config.seed)
    rng, init_rng = random.split(rng)

    ds_info = tfds.builder(config.dataset_name).info
    num_train_examples = ds_info.splits[tfds.Split.TRAIN].num_examples

    num_train_steps = int(num_train_examples * config.num_train_epochs //
                          config.train_batch_size)
    num_warmup_steps = int(config.warmup_proportion * num_train_steps)
    # Round up evaluation frequency to power of 10.
    eval_frequency = int(
        math.ceil(config.eval_proportion * num_train_steps / 10)) * 10

    is_regression_task = config.dataset_name == "glue/stsb"

    num_classes = (1 if is_regression_task else
                   ds_info.features["label"].num_classes)

    tokenizer = spm.SentencePieceProcessor()
    tokenizer.Load(vocab_filepath)
    with config.unlocked():
        config.vocab_size = tokenizer.GetPieceSize()

    frozen_config = ml_collections.FrozenConfigDict(config)
    model = models.SequenceClassificationModel(config=frozen_config,
                                               n_classes=num_classes)

    params = _init_params(model, init_rng, config)

    optimizer = _create_adam_optimizer(config.learning_rate, params)

    # In case current job restarts, ensure that we continue from where we left
    # off.
    optimizer = checkpoints.restore_checkpoint(workdir, optimizer)
    start_step = int(optimizer.state.step)

    # Otherwise, try to restore optimizer and model state from config checkpoint.
    if (start_step == 0 and "init_checkpoint_dir" in config
            and config.init_checkpoint_dir):
        optimizer = _restore_pretrained_model(optimizer, params, config)

    # We access model state only from optimizer via optimizer.target.
    del params

    optimizer = jax_utils.replicate(optimizer)

    if is_regression_task:
        compute_stats = functools.partial(_compute_regression_stats,
                                          model=model,
                                          pad_id=tokenizer.pad_id())
    else:
        compute_stats = functools.partial(_compute_classification_stats,
                                          model=model,
                                          pad_id=tokenizer.pad_id())

    learning_rate_fn = train_utils.create_learning_rate_scheduler(
        factors="constant * linear_warmup * linear_decay",
        base_learning_rate=config.learning_rate,
        warmup_steps=num_warmup_steps,
        decay_steps=num_train_steps - num_warmup_steps,
    )

    glue_inputs = functools.partial(input_pipeline.glue_inputs,
                                    dataset_name=config.dataset_name,
                                    max_seq_length=config.max_seq_length,
                                    tokenizer=tokenizer)
    train_ds = glue_inputs(split=tfds.Split.TRAIN,
                           batch_size=per_process_train_batch_size,
                           training=True)
    train_iter = iter(train_ds)

    if config.dataset_name == "glue/mnli":
        # MNLI contains two validation and test datasets.
        split_suffixes = ["_matched", "_mismatched"]
    else:
        split_suffixes = [""]

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    rngs = random.split(rng, n_devices)

    loss_and_metrics_fn = functools.partial(_compute_loss_and_metrics,
                                            model=model,
                                            pad_id=tokenizer.pad_id())
    p_train_step = jax.pmap(functools.partial(
        train_utils.train_step,
        loss_and_metrics_fn=loss_and_metrics_fn,
        learning_rate_fn=learning_rate_fn),
                            axis_name="batch")
    p_eval_step = jax.pmap(functools.partial(train_utils.eval_step,
                                             metric_fn=compute_stats),
                           axis_name="batch")
    eval_metrics_fn = _create_eval_metrics_fn(config.dataset_name,
                                              is_regression_task)

    train_metrics = []

    logging.info("Starting training loop.")
    logging.info("====================")

    for step in range(start_step, num_train_steps):
        with jax.profiler.StepTraceAnnotation("train", step_num=step):
            train_batch = next(train_iter)
            train_batch = common_utils.shard(train_batch)

            optimizer, train_step_metrics, rngs = p_train_step(optimizer,
                                                               train_batch,
                                                               rng=rngs)
            train_metrics.append(train_step_metrics)

        if ((step > 0 and config.save_checkpoints_steps
             and step % config.save_checkpoints_steps == 0)
                or step == num_train_steps - 1) and jax.process_index() == 0:
            # Save un-replicated optimizer and model state.
            checkpoints.save_checkpoint(workdir,
                                        jax_utils.unreplicate(optimizer),
                                        step,
                                        keep=2)

        # Periodic metric handling.
        if step % eval_frequency != 0 and step < num_train_steps - 1:
            continue

        logging.info("Gathering training metrics at step: %d", step)

        train_metrics = common_utils.get_metrics(train_metrics)
        train_summary = {
            "loss":
            jnp.sum(train_metrics["loss"]) /
            jnp.sum(train_metrics["num_labels"]),
            "learning_rate":
            learning_rate_fn(step)
        }
        if not is_regression_task:
            train_summary["accuracy"] = jnp.sum(
                train_metrics["correct_predictions"]) / jnp.sum(
                    train_metrics["num_labels"])

        if jax.process_index() == 0:
            assert train_summary_writer
            for key, val in train_summary.items():
                train_summary_writer.scalar(key, val, step)
            train_summary_writer.flush()
        # Reset metric accumulation for next evaluation cycle.
        train_metrics = []

        logging.info("Gathering validation metrics at step: %d", step)

        for split_suffix in split_suffixes:
            eval_ds = glue_inputs(split=tfds.Split.VALIDATION + split_suffix,
                                  batch_size=per_process_eval_batch_size,
                                  training=False)

            all_stats = []
            for _, eval_batch in zip(range(config.max_num_eval_steps),
                                     eval_ds):
                all_stats.append(
                    _evaluate(p_eval_step, optimizer.target, eval_batch,
                              n_devices))
            flat_stats = {}
            for k in all_stats[
                    0]:  # All batches of output stats are the same size
                flat_stats[k] = np.concatenate([stat[k] for stat in all_stats],
                                               axis=0)
            eval_summary = eval_metrics_fn(flat_stats)

            if jax.process_index() == 0:
                assert eval_summary_writer
                for key, val in eval_summary.items():
                    eval_summary_writer.scalar(f"{key}{split_suffix}", val,
                                               step)
                eval_summary_writer.flush()
def train_and_evaluate(config, workdir, vocab_filepath):
    """Runs a training and evaluation loop.

  Args:
    config: Model and training configuration.
    workdir: Working directory for checkpoints and TensorBoard summaries. If
      this contains a checkpoint, training will be resumed from the latest
      checkpoint.
    vocab_filepath: Absolute path to SentencePiece vocab model.

  Raises:
    ValueError: If training or eval batch sizes won't fit number of hosts and
      devices, or config is underspecified.
  """
    # Update config before config validation.
    with config.unlocked():
        # Numeric floating point type to use for model computations.
        config.dtype = jnp.float32

    train_utils.validate_config(config)

    if jax.process_index() == 0:
        train_summary_writer = tensorboard.SummaryWriter(
            os.path.join(workdir, "train"))
        eval_summary_writer = tensorboard.SummaryWriter(
            os.path.join(workdir, "eval"))
    else:
        train_summary_writer = None
        eval_summary_writer = None

    tokenizer = spm.SentencePieceProcessor()
    tokenizer.Load(vocab_filepath)
    tokenizer.SetEncodeExtraOptions("")
    # Note: [CLS] and [SEP] will be added by the data pipeline, not the tokenizer.

    with config.unlocked():
        config.vocab_size = tokenizer.GetPieceSize()
        config.pad_id = tokenizer.pad_id()

    config = ml_collections.FrozenConfigDict(config)

    model = models.PreTrainingModel(config=config)
    rng = random.PRNGKey(config.seed)
    rng, init_rng = random.split(rng)
    params = _init_params(model, init_rng, config)

    learning_rate_fn = train_utils.create_learning_rate_scheduler(
        factors="constant * linear_warmup * linear_decay",
        base_learning_rate=config.learning_rate,
        warmup_steps=config.num_warmup_steps,
        decay_steps=config.num_train_steps - config.num_warmup_steps,
    )

    tx = optax.adamw(learning_rate_fn,
                     b1=0.9,
                     b2=0.999,
                     eps=1e-6,
                     weight_decay=0.01)
    if config.clipped_grad_norm:
        tx = optax.chain(optax.clip_by_global_norm(config.clipped_grad_norm),
                         tx)

    # jit state creation to ensure arrays are created on same device as input
    # (i.e. CPU).
    state_cpu = jax.jit(
        functools.partial(FlaxTrainState.create,
                          apply_fn=model.apply,
                          params=params,
                          tx=tx))()

    # We access model params only via state.params
    del params

    if config.num_experts > 1:
        sharded_match_fn = core_utils.match_fn(r".*expert.*")
        not_sharded_match_fn = lambda name: not sharded_match_fn(name)
    else:
        sharded_match_fn = None
        not_sharded_match_fn = lambda name: True

    state, start_step = _restore_state_from_checkpoint(workdir, state_cpu,
                                                       sharded_match_fn,
                                                       not_sharded_match_fn,
                                                       config)
    train_ds, eval_ds = _init_train_and_eval_ds(tokenizer, config)
    train_iter = iter(train_ds)

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    rngs = random.split(rng, jax.local_device_count())

    loss_and_metrics_fn = functools.partial(
        _compute_loss_and_metrics,
        model=model,
        is_experts_model=config.num_experts > 1,
        auxiliary_loss_factor=config.auxiliary_loss_factor,
        router_z_loss_factor=config.router_z_loss_factor)
    train_step = functools.partial(
        train_utils.pmap_train_step,
        loss_and_metrics_fn=loss_and_metrics_fn,
        axis_name="batch",
        sharded_match_fn=sharded_match_fn,
        gradient_accum_steps=config.gradient_accum_steps)
    p_train_step = jax.pmap(train_step, axis_name="batch")

    eval_step = functools.partial(_compute_eval_stats, model=model)
    p_eval_step = jax.pmap(eval_step, axis_name="batch")

    seconds = 0.
    train_stats = []
    logging.info("Starting training loop.")
    logging.info("====================")

    for step in range(start_step, config.num_train_steps):
        with jax.profiler.StepTraceContext("train", step_num=step):
            train_batch = next(train_iter)
            train_batch = common_utils.shard(train_batch)

            tick = time.time()
            state, train_step_stats, rngs = p_train_step(state,
                                                         train_batch,
                                                         rng=rngs)
            if config.measure_step_speed:
                jax.tree_map(lambda opt: opt.block_until_ready(), state)
                tock = time.time()
                seconds += tock - tick

            train_stats.append(train_step_stats)

        if (step > 0 and config.save_checkpoints_steps
                and step % config.save_checkpoints_steps == 0):
            # We allow all hosts to potentially save checkpoints because some model
            # parameters are sharded across devices. Parameters replicated across
            # devices (i.e. not sharded) will only be checkpointed by host 0.
            unreplicated_state = jax.tree_map(
                np.array,
                core_utils.tree_unreplicate_by_name(state,
                                                    not_sharded_match_fn))
            checkpoints.save_checkpoint(workdir,
                                        unreplicated_state,
                                        sharded_match_fn,
                                        step,
                                        keep=config.checkpoints_to_keep)
            del unreplicated_state  # Only used for checkpointing.

        # Periodic metric handling.
        if step % config.eval_frequency != 0 and step > 0:
            continue

        logging.info("Gathering training metrics at step: %d", step)
        train_metrics = train_utils.collect_metrics(train_stats)
        train_summary = train_utils.compute_pretraining_metrics(train_metrics)
        train_summary["learning_rate"] = learning_rate_fn(step)
        if config.measure_step_speed:
            train_summary["steps_per_sec"] = (step - start_step + 1) / seconds

        if jax.process_index() == 0:
            assert train_summary_writer
            for key, val in train_summary.items():
                train_summary_writer.scalar(key, val, step)
            train_summary_writer.flush()
        # Reset metric accumulation for next training evaluation cycle.
        train_stats = []

        logging.info("Gathering evaluation metrics at step: %d", step)

        eval_stats = []
        for _, eval_batch in zip(range(config.max_num_eval_steps), eval_ds):
            eval_batch = common_utils.shard(eval_batch)
            eval_stats.append(p_eval_step(state.params, eval_batch))
        eval_metrics = train_utils.collect_metrics(eval_stats)
        eval_summary = train_utils.compute_pretraining_metrics(
            eval_metrics, record_grad_norm=False)

        if jax.process_index() == 0:
            assert eval_summary_writer
            for key, val in eval_summary.items():
                eval_summary_writer.scalar(key, val, step)
            eval_summary_writer.flush()
Example #23
0
def main(unused_argv):
    rng = random.PRNGKey(20200823)
    # Shift the numpy random seed by host_id() to shuffle data loaded by different
    # hosts.
    np.random.seed(20201473 + jax.host_id())

    if FLAGS.config is not None:
        utils.update_flags(FLAGS)
    if FLAGS.batch_size % jax.device_count() != 0:
        raise ValueError(
            "Batch size must be divisible by the number of devices.")
    if FLAGS.train_dir is None:
        raise ValueError("train_dir must be set. None set now.")
    if FLAGS.data_dir is None:
        raise ValueError("data_dir must be set. None set now.")
    dataset = datasets.get_dataset("train", FLAGS)
    test_dataset = datasets.get_dataset("test", FLAGS)

    rng, key = random.split(rng)
    model, variables = models.get_model(key, dataset.peek(), FLAGS)
    optimizer = flax.optim.Adam(FLAGS.lr_init).create(variables)
    state = utils.TrainState(optimizer=optimizer)
    del optimizer, variables

    learning_rate_fn = functools.partial(utils.learning_rate_decay,
                                         lr_init=FLAGS.lr_init,
                                         lr_final=FLAGS.lr_final,
                                         max_steps=FLAGS.max_steps,
                                         lr_delay_steps=FLAGS.lr_delay_steps,
                                         lr_delay_mult=FLAGS.lr_delay_mult)

    train_pstep = jax.pmap(functools.partial(train_step, model),
                           axis_name="batch",
                           in_axes=(0, 0, 0, None),
                           donate_argnums=(2, ))

    def render_fn(variables, key_0, key_1, rays):
        return jax.lax.all_gather(model.apply(variables, key_0, key_1, rays,
                                              FLAGS.randomized),
                                  axis_name="batch")

    render_pfn = jax.pmap(
        render_fn,
        in_axes=(None, None, None, 0),  # Only distribute the data input.
        donate_argnums=(3, ),
        axis_name="batch",
    )

    # Compiling to the CPU because it's faster and more accurate.
    ssim_fn = jax.jit(functools.partial(utils.compute_ssim, max_val=1.),
                      backend="cpu")

    if not utils.isdir(FLAGS.train_dir):
        utils.makedirs(FLAGS.train_dir)
    state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
    # Resume training a the step of the last checkpoint.
    init_step = state.optimizer.state.step + 1
    state = flax.jax_utils.replicate(state)

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(FLAGS.train_dir)

    # Prefetch_buffer_size = 3 x batch_size
    pdataset = flax.jax_utils.prefetch_to_device(dataset, 3)
    n_local_deices = jax.local_device_count()
    rng = rng + jax.host_id()  # Make random seed separate across hosts.
    keys = random.split(rng, n_local_deices)  # For pmapping RNG keys.
    gc.disable()  # Disable automatic garbage collection for efficiency.
    stats_trace = []
    reset_timer = True
    for step, batch in zip(range(init_step, FLAGS.max_steps + 1), pdataset):
        if reset_timer:
            t_loop_start = time.time()
            reset_timer = False
        lr = learning_rate_fn(step)
        state, stats, keys = train_pstep(keys, state, batch, lr)
        if jax.host_id() == 0:
            stats_trace.append(stats)
        if step % FLAGS.gc_every == 0:
            gc.collect()

        # Log training summaries. This is put behind a host_id check because in
        # multi-host evaluation, all hosts need to run inference even though we
        # only use host 0 to record results.
        if jax.host_id() == 0:
            if step % FLAGS.print_every == 0:
                summary_writer.scalar("train_loss", stats.loss[0], step)
                summary_writer.scalar("train_psnr", stats.psnr[0], step)
                summary_writer.scalar("train_loss_coarse", stats.loss_c[0],
                                      step)
                summary_writer.scalar("train_psnr_coarse", stats.psnr_c[0],
                                      step)
                summary_writer.scalar("weight_l2", stats.weight_l2[0], step)
                avg_loss = np.mean(
                    np.concatenate([s.loss for s in stats_trace]))
                avg_psnr = np.mean(
                    np.concatenate([s.psnr for s in stats_trace]))
                stats_trace = []
                summary_writer.scalar("train_avg_loss", avg_loss, step)
                summary_writer.scalar("train_avg_psnr", avg_psnr, step)
                summary_writer.scalar("learning_rate", lr, step)
                steps_per_sec = FLAGS.print_every / (time.time() -
                                                     t_loop_start)
                reset_timer = True
                rays_per_sec = FLAGS.batch_size * steps_per_sec
                summary_writer.scalar("train_steps_per_sec", steps_per_sec,
                                      step)
                summary_writer.scalar("train_rays_per_sec", rays_per_sec, step)
                precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1
                print(("{:" + "{:d}".format(precision) + "d}").format(step) +
                      f"/{FLAGS.max_steps:d}: " +
                      f"i_loss={stats.loss[0]:0.4f}, " +
                      f"avg_loss={avg_loss:0.4f}, " +
                      f"weight_l2={stats.weight_l2[0]:0.2e}, " +
                      f"lr={lr:0.2e}, " + f"{rays_per_sec:0.0f} rays/sec")
            if step % FLAGS.save_every == 0:
                state_to_save = jax.device_get(
                    jax.tree_map(lambda x: x[0], state))
                checkpoints.save_checkpoint(FLAGS.train_dir,
                                            state_to_save,
                                            int(step),
                                            keep=100)

        # Test-set evaluation.
        if FLAGS.render_every > 0 and step % FLAGS.render_every == 0:
            # We reuse the same random number generator from the optimization step
            # here on purpose so that the visualization matches what happened in
            # training.
            t_eval_start = time.time()
            eval_variables = jax.device_get(jax.tree_map(
                lambda x: x[0], state)).optimizer.target
            test_case = next(test_dataset)
            pred_color, pred_disp, pred_acc = utils.render_image(
                functools.partial(render_pfn, eval_variables),
                test_case["rays"],
                keys[0],
                FLAGS.dataset == "llff",
                chunk=FLAGS.chunk)

            # Log eval summaries on host 0.
            if jax.host_id() == 0:
                psnr = utils.compute_psnr(
                    ((pred_color - test_case["pixels"])**2).mean())
                ssim = ssim_fn(pred_color, test_case["pixels"])
                eval_time = time.time() - t_eval_start
                num_rays = jnp.prod(
                    jnp.array(test_case["rays"].directions.shape[:-1]))
                rays_per_sec = num_rays / eval_time
                summary_writer.scalar("test_rays_per_sec", rays_per_sec, step)
                print(
                    f"Eval {step}: {eval_time:0.3f}s., {rays_per_sec:0.0f} rays/sec"
                )
                summary_writer.scalar("test_psnr", psnr, step)
                summary_writer.scalar("test_ssim", ssim, step)
                summary_writer.image("test_pred_color", pred_color, step)
                summary_writer.image("test_pred_disp", pred_disp, step)
                summary_writer.image("test_pred_acc", pred_acc, step)
                summary_writer.image("test_target", test_case["pixels"], step)

    if FLAGS.max_steps % FLAGS.save_every != 0:
        state = jax.device_get(jax.tree_map(lambda x: x[0], state))
        checkpoints.save_checkpoint(FLAGS.train_dir,
                                    state,
                                    int(FLAGS.max_steps),
                                    keep=100)
Example #24
0
def main(unused_argv):
    rng = random.PRNGKey(20200823)
    # Shift the numpy random seed by host_id() to shuffle data loaded by different
    # hosts.
    np.random.seed(20201473 + jax.host_id())

    if FLAGS.config is not None:
        utils.update_flags(FLAGS)
    if FLAGS.batch_size % jax.device_count() != 0:
        raise ValueError(
            "Batch size must be divisible by the number of devices.")
    if FLAGS.train_dir is None:
        raise ValueError("train_dir must be set. None set now.")
    if FLAGS.data_dir is None:
        raise ValueError("data_dir must be set. None set now.")
    dataset = datasets.get_dataset("train", FLAGS)
    test_dataset = datasets.get_dataset("test", FLAGS)
    test_render_fn = jax.pmap(
        # Note rng_keys are useless in eval mode since there's no randomness.
        # pylint: disable=g-long-lambda
        lambda key_0, key_1, model, rays: jax.lax.all_gather(
            model(key_0, key_1, *rays), axis_name="batch"),
        in_axes=(None, None, None, 0),  # Only distribute the data input.
        donate_argnums=3,
        axis_name="batch",
    )
    rng, key = random.split(rng)
    init_model, init_state = models.get_model(key, dataset.peek(), FLAGS)
    optimizer_def = optim.Adam(FLAGS.lr_init)
    optimizer = optimizer_def.create(init_model)
    state = model_utils.TrainState(step=0,
                                   optimizer=optimizer,
                                   model_state=init_state)
    if not utils.isdir(FLAGS.train_dir):
        utils.makedirs(FLAGS.train_dir)
    state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
    offset = state.step + 1
    state = jax_utils.replicate(state)
    del init_model, init_state

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(FLAGS.train_dir)
    t_loop_start = time.time()
    learning_rate_fn = functools.partial(utils.learning_rate_decay,
                                         lr_init=FLAGS.lr_init,
                                         lr_final=FLAGS.lr_final,
                                         max_steps=FLAGS.max_steps,
                                         lr_delay_steps=FLAGS.lr_delay_steps,
                                         lr_delay_mult=FLAGS.lr_delay_mult)

    ptrain_step = jax.pmap(train_step,
                           axis_name="batch",
                           in_axes=(0, 0, 0, None),
                           donate_argnums=2)
    # Prefetch_buffer_size = 3 x batch_size
    pdataset = jax_utils.prefetch_to_device(dataset, 3)
    n_local_deices = jax.local_device_count()
    rng = rng + jax.host_id()  # Make random seed separate across hosts.
    keys = random.split(rng, n_local_deices)  # For pmapping RNG keys.
    gc.disable()  # Disable automatic garbage collection for efficiency.
    stats_trace = []
    for step, batch in zip(range(offset, FLAGS.max_steps + 1), pdataset):
        lr = learning_rate_fn(step)
        state, stats, keys = ptrain_step(keys, state, batch, lr)
        if jax.host_id() == 0:
            stats_trace.append(stats[0])
        if step % FLAGS.gc_every == 0:
            gc.collect()
        # --- Train logs start ---
        # Put the training time visualization before the host_id check as in
        # multi-host evaluation, all hosts need to run inference even though we
        # only use host 0 to record results.
        if FLAGS.render_every > 0 and step % FLAGS.render_every == 0:
            # We reuse the same random number generator from the optimization step
            # here on purpose so that the visualization matches what happened in
            # training.
            state_to_eval = jax.device_get(jax.tree_map(lambda x: x[0], state))
            test_case = next(test_dataset)
            pred_color, pred_disp, pred_acc = utils.render_image(
                state_to_eval,
                test_case["rays"],
                test_render_fn,
                keys[0],
                FLAGS.dataset == "llff",
                chunk=FLAGS.chunk)
            if jax.host_id() == 0:
                psnr = utils.compute_psnr(
                    ((pred_color - test_case["pixels"])**2).mean())
                summary_writer.scalar("test_psnr", psnr, step)
                summary_writer.image("test_pred_color", pred_color, step)
                summary_writer.image("test_pred_disp", pred_disp, step)
                summary_writer.image("test_pred_acc", pred_acc, step)
                summary_writer.image("test_target", test_case["pixels"], step)
        if jax.host_id() != 0:  # Only log via host 0.
            continue
        if step % FLAGS.print_every == 0:
            summary_writer.scalar("train_loss", stats[0].loss[0], step)
            summary_writer.scalar("train_psnr", stats[0].psnr[0], step)
            if len(stats) > 1:
                summary_writer.scalar("train_loss_coarse", stats[1].loss[0],
                                      step)
                summary_writer.scalar("train_psnr_coarse", stats[1].psnr[0],
                                      step)
            avg_loss = np.mean(np.concatenate([s.loss for s in stats_trace]))
            avg_psnr = np.mean(np.concatenate([s.psnr for s in stats_trace]))
            stats_trace = []
            summary_writer.scalar("train_avg_loss", avg_loss, step)
            summary_writer.scalar("train_avg_psnr", avg_psnr, step)
            summary_writer.scalar("learning_rate", lr, step)
            steps_per_sec = FLAGS.print_every / (time.time() - t_loop_start)
            t_loop_start = time.time()
            rays_per_sec = FLAGS.batch_size * steps_per_sec
            summary_writer.scalar("steps_per_sec", steps_per_sec, step)
            summary_writer.scalar("rays_per_sec", rays_per_sec, step)
            precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1
            print(("{:" + "{:d}".format(precision) + "d}").format(step) +
                  f"/{FLAGS.max_steps:d}: " +
                  f"i_loss={stats[0].loss[0]:0.5f}, " +
                  f"avg_loss={avg_loss:0.5f}, " + f"lr={lr:0.2e}, " +
                  f"{rays_per_sec:0.3f} rays/sec")
        if step % FLAGS.save_every == 0:
            state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state))
            checkpoints.save_checkpoint(FLAGS.train_dir,
                                        state_to_save,
                                        state_to_save.step,
                                        keep=100)
        # --- Train logs end ---

    if FLAGS.max_steps % FLAGS.save_every != 0:
        state = jax.device_get(jax.tree_map(lambda x: x[0], state))
        checkpoints.save_checkpoint(FLAGS.train_dir,
                                    state,
                                    int(state.step),
                                    keep=100)
Example #25
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  # Make sure tf does not allocate gpu memory.
  tf.config.experimental.set_visible_devices([], 'GPU')

  batch_size = FLAGS.batch_size
  learning_rate = FLAGS.learning_rate
  num_train_steps = FLAGS.num_train_steps
  eval_freq = FLAGS.eval_frequency
  random_seed = FLAGS.random_seed

  if not FLAGS.dev:
    raise app.UsageError('Please provide path to dev set.')
  if not FLAGS.train:
    raise app.UsageError('Please provide path to training set.')
  if batch_size % jax.device_count() > 0:
    raise ValueError('Batch size must be divisible by the number of devices')
  device_batch_size = batch_size // jax.device_count()

  if jax.process_index() == 0:
    train_summary_writer = tensorboard.SummaryWriter(
        os.path.join(FLAGS.model_dir, FLAGS.experiment + '_train'))
    eval_summary_writer = tensorboard.SummaryWriter(
        os.path.join(FLAGS.model_dir, FLAGS.experiment + '_eval'))

  # create the training and development dataset
  vocabs = input_pipeline.create_vocabs(FLAGS.train)
  config = models.TransformerConfig(
      vocab_size=len(vocabs['forms']),
      output_vocab_size=len(vocabs['xpos']),
      max_len=FLAGS.max_length)

  attributes_input = [input_pipeline.CoNLLAttributes.FORM]
  attributes_target = [input_pipeline.CoNLLAttributes.XPOS]
  train_ds = input_pipeline.sentence_dataset_dict(
      FLAGS.train,
      vocabs,
      attributes_input,
      attributes_target,
      batch_size=batch_size,
      bucket_size=config.max_len)
  train_iter = iter(train_ds)

  eval_ds = input_pipeline.sentence_dataset_dict(
      FLAGS.dev,
      vocabs,
      attributes_input,
      attributes_target,
      batch_size=batch_size,
      bucket_size=config.max_len,
      repeat=1)

  model = models.Transformer(config)

  rng = random.PRNGKey(random_seed)
  rng, init_rng = random.split(rng)

  # call a jitted initialization function to get the initial parameter tree
  @jax.jit
  def initialize_variables(init_rng):
    init_batch = jnp.ones((config.max_len, 1), jnp.float32)
    init_variables = model.init(init_rng, inputs=init_batch, train=False)
    return init_variables
  init_variables = initialize_variables(init_rng)

  optimizer_def = optim.Adam(learning_rate, beta1=0.9, beta2=0.98,
      eps=1e-9, weight_decay=1e-1)
  optimizer = optimizer_def.create(init_variables['params'])
  optimizer = jax_utils.replicate(optimizer)

  learning_rate_fn = create_learning_rate_scheduler(
      base_learning_rate=learning_rate)

  p_train_step = jax.pmap(
      functools.partial(train_step, model=model, learning_rate_fn=learning_rate_fn),
      axis_name='batch')

  def eval_step(params, batch):
    """Calculate evaluation metrics on a batch."""
    inputs, targets = batch['inputs'], batch['targets']
    weights = jnp.where(targets > 0, 1.0, 0.0)
    logits = model.apply({'params': params}, inputs=inputs, train=False)
    return compute_metrics(logits, targets, weights)

  p_eval_step = jax.pmap(eval_step, axis_name='batch')

  # We init the first set of dropout PRNG keys, but update it afterwards inside
  # the main pmap'd training update for performance.
  dropout_rngs = random.split(rng, jax.local_device_count())
  metrics_all = []
  tick = time.time()
  best_dev_score = 0
  for step, batch in zip(range(num_train_steps), train_iter):
    batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access

    optimizer, metrics, dropout_rngs = p_train_step(optimizer, batch, dropout_rng=dropout_rngs)
    metrics_all.append(metrics)

    if (step + 1) % eval_freq == 0:
      metrics_all = common_utils.get_metrics(metrics_all)
      lr = metrics_all.pop('learning_rate').mean()
      metrics_sums = jax.tree_map(jnp.sum, metrics_all)
      denominator = metrics_sums.pop('denominator')
      summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
      summary['learning_rate'] = lr
      logging.info('train in step: %d, loss: %.4f', step, summary['loss'])
      if jax.process_index() == 0:
        tock = time.time()
        steps_per_sec = eval_freq / (tock - tick)
        tick = tock
        train_summary_writer.scalar('steps per second', steps_per_sec, step)
        for key, val in summary.items():
          train_summary_writer.scalar(key, val, step)
        train_summary_writer.flush()

      metrics_all = []  # reset metric accumulation for next evaluation cycle.

      eval_metrics = []
      eval_iter = iter(eval_ds)

      for eval_batch in eval_iter:
        eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
        # Handle final odd-sized batch by padding instead of dropping it.
        cur_pred_batch_size = eval_batch['inputs'].shape[0]
        if cur_pred_batch_size != batch_size:
          # pad up to batch size
          eval_batch = jax.tree_map(
              lambda x: pad_examples(x, batch_size), eval_batch)
        eval_batch = common_utils.shard(eval_batch)

        metrics = p_eval_step(optimizer.target, eval_batch)
        eval_metrics.append(metrics)
      eval_metrics = common_utils.get_metrics(eval_metrics)
      eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
      eval_denominator = eval_metrics_sums.pop('denominator')
      eval_summary = jax.tree_map(
          lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
          eval_metrics_sums)

      logging.info('eval in step: %d, loss: %.4f, accuracy: %.4f', step,
                   eval_summary['loss'], eval_summary['accuracy'])

      if best_dev_score < eval_summary['accuracy']:
        best_dev_score = eval_summary['accuracy']
        # TODO: save model.
      eval_summary['best_dev_score'] = best_dev_score
      logging.info('best development model score %.4f', best_dev_score)
      if jax.process_index() == 0:
        for key, val in eval_summary.items():
          eval_summary_writer.scalar(key, val, step)
        eval_summary_writer.flush()
Example #26
0
def main(unused_argv):
  # Hide the GPUs and TPUs from TF so it does not reserve memory on them for
  # LPIPS computation or dataset loading.
  tf.config.experimental.set_visible_devices([], "GPU")
  tf.config.experimental.set_visible_devices([], "TPU")

  rng = random.PRNGKey(20200823)

  if FLAGS.config is not None:
    utils.update_flags(FLAGS)
  if FLAGS.train_dir is None:
    raise ValueError("train_dir must be set. None set now.")
  if FLAGS.data_dir is None:
    raise ValueError("data_dir must be set. None set now.")

  dataset = datasets.get_dataset("test", FLAGS)
  rng, key = random.split(rng)
  model, init_variables = models.get_model(key, dataset.peek(), FLAGS)
  optimizer = flax.optim.Adam(FLAGS.lr_init).create(init_variables)
  state = utils.TrainState(optimizer=optimizer)
  del optimizer, init_variables

  lpips_model = tf_hub.load(LPIPS_TFHUB_PATH)

  # Rendering is forced to be deterministic even if training was randomized, as
  # this eliminates "speckle" artifacts.
  def render_fn(variables, key_0, key_1, rays):
    return jax.lax.all_gather(
        model.apply(variables, key_0, key_1, rays, False), axis_name="batch")

  # pmap over only the data input.
  render_pfn = jax.pmap(
      render_fn,
      in_axes=(None, None, None, 0),
      donate_argnums=3,
      axis_name="batch",
  )

  # Compiling to the CPU because it's faster and more accurate.
  ssim_fn = jax.jit(
      functools.partial(utils.compute_ssim, max_val=1.), backend="cpu")

  last_step = 0
  out_dir = path.join(FLAGS.train_dir,
                      "path_renders" if FLAGS.render_path else "test_preds")
  if not FLAGS.eval_once:
    summary_writer = tensorboard.SummaryWriter(
        path.join(FLAGS.train_dir, "eval"))
  while True:
    state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
    step = int(state.optimizer.state.step)
    if step <= last_step:
      continue
    if FLAGS.save_output and (not utils.isdir(out_dir)):
      utils.makedirs(out_dir)
    psnr_values = []
    ssim_values = []
    lpips_values = []
    if not FLAGS.eval_once:
      showcase_index = np.random.randint(0, dataset.size)
    for idx in range(dataset.size):
      print(f"Evaluating {idx+1}/{dataset.size}")
      batch = next(dataset)
      pred_color, pred_disp, pred_acc = utils.render_image(
          functools.partial(render_pfn, state.optimizer.target),
          batch["rays"],
          rng,
          FLAGS.dataset == "llff",
          chunk=FLAGS.chunk)
      if jax.host_id() != 0:  # Only record via host 0.
        continue
      if not FLAGS.eval_once and idx == showcase_index:
        showcase_color = pred_color
        showcase_disp = pred_disp
        showcase_acc = pred_acc
        if not FLAGS.render_path:
          showcase_gt = batch["pixels"]
      if not FLAGS.render_path:
        psnr = utils.compute_psnr(((pred_color - batch["pixels"])**2).mean())
        ssim = ssim_fn(pred_color, batch["pixels"])
        lpips = compute_lpips(pred_color, batch["pixels"], lpips_model)
        print(f"PSNR = {psnr:.4f}, SSIM = {ssim:.4f}")
        psnr_values.append(float(psnr))
        ssim_values.append(float(ssim))
        lpips_values.append(float(lpips))
      if FLAGS.save_output:
        utils.save_img(pred_color, path.join(out_dir, "{:03d}.png".format(idx)))
        utils.save_img(pred_disp[Ellipsis, 0],
                       path.join(out_dir, "disp_{:03d}.png".format(idx)))
    if (not FLAGS.eval_once) and (jax.host_id() == 0):
      summary_writer.image("pred_color", showcase_color, step)
      summary_writer.image("pred_disp", showcase_disp, step)
      summary_writer.image("pred_acc", showcase_acc, step)
      if not FLAGS.render_path:
        summary_writer.scalar("psnr", np.mean(np.array(psnr_values)), step)
        summary_writer.scalar("ssim", np.mean(np.array(ssim_values)), step)
        summary_writer.scalar("lpips", np.mean(np.array(lpips_values)), step)
        summary_writer.image("target", showcase_gt, step)
    if FLAGS.save_output and (not FLAGS.render_path) and (jax.host_id() == 0):
      with utils.open_file(path.join(out_dir, f"psnrs_{step}.txt"), "w") as f:
        f.write(" ".join([str(v) for v in psnr_values]))
      with utils.open_file(path.join(out_dir, f"ssims_{step}.txt"), "w") as f:
        f.write(" ".join([str(v) for v in ssim_values]))
      with utils.open_file(path.join(out_dir, f"lpips_{step}.txt"), "w") as f:
        f.write(" ".join([str(v) for v in lpips_values]))
      with utils.open_file(path.join(out_dir, "psnr.txt"), "w") as f:
        f.write("{}".format(np.mean(np.array(psnr_values))))
      with utils.open_file(path.join(out_dir, "ssim.txt"), "w") as f:
        f.write("{}".format(np.mean(np.array(ssim_values))))
      with utils.open_file(path.join(out_dir, "lpips.txt"), "w") as f:
        f.write("{}".format(np.mean(np.array(lpips_values))))
    if FLAGS.eval_once:
      break
    if int(step) >= FLAGS.max_steps:
      break
    last_step = step
Example #27
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  tf.enable_v2_behavior()

  batch_size = FLAGS.batch_size
  learning_rate = FLAGS.learning_rate
  num_train_steps = FLAGS.num_train_steps
  num_eval_steps = FLAGS.num_eval_steps
  eval_freq = FLAGS.eval_frequency
  max_length = FLAGS.max_length
  random_seed = FLAGS.random_seed

  if not FLAGS.dev:
    raise app.UsageError('Please provide path to dev set.')
  if not FLAGS.train:
    raise app.UsageError('Please provide path to training set.')

  parameter_path = os.path.join(FLAGS.model_dir, FLAGS.experiment + '.params')
  if jax.host_id() == 0:
    train_summary_writer = tensorboard.SummaryWriter(
        os.path.join(FLAGS.model_dir, FLAGS.experiment + '_train'))
    eval_summary_writer = tensorboard.SummaryWriter(
        os.path.join(FLAGS.model_dir, FLAGS.experiment + '_eval'))

  if batch_size % jax.device_count() > 0:
    raise ValueError('Batch size must be divisible by the number of devices')
  device_batch_size = batch_size // jax.device_count()

  # create the training and development dataset
  vocabs = input_pipeline.create_vocabs(FLAGS.train)
  attributes_input = [input_pipeline.CoNLLAttributes.FORM]
  attributes_target = [input_pipeline.CoNLLAttributes.XPOS]
  train_ds = input_pipeline.sentence_dataset_dict(
      FLAGS.train,
      vocabs,
      attributes_input,
      attributes_target,
      batch_size=batch_size,
      bucket_size=max_length)

  eval_ds = input_pipeline.sentence_dataset_dict(
      FLAGS.dev,
      vocabs,
      attributes_input,
      attributes_target,
      batch_size=batch_size,
      bucket_size=max_length,
      repeat=1)
  train_iter = iter(train_ds)
  bs = device_batch_size * jax.device_count()

  rng = random.PRNGKey(random_seed)
  rng, init_rng = random.split(rng)
  input_shape = (bs, max_length)
  transformer_kwargs = {
      'vocab_size': len(vocabs['forms']),
      'output_vocab_size': len(vocabs['xpos']),
      'emb_dim': 512,
      'num_heads': 8,
      'num_layers': 6,
      'qkv_dim': 512,
      'mlp_dim': 2048,
      'max_len': max_length,
  }
  model = create_model(init_rng, tuple(input_shape), transformer_kwargs)

  optimizer = create_optimizer(model, learning_rate)
  del model  # don't keep a copy of the initial model
  learning_rate_fn = create_learning_rate_scheduler(
      base_learning_rate=learning_rate)

  p_train_step = jax.pmap(
      functools.partial(train_step, learning_rate_fn=learning_rate_fn),
      axis_name='batch')
  p_eval_step = jax.pmap(eval_step, axis_name='batch')

  # We init the first set of dropout PRNG keys, but update it afterwards inside
  # the main pmap'd training update for performance.
  dropout_rngs = random.split(rng, jax.local_device_count())

  metrics_all = []
  tick = time.time()
  best_dev_score = 0
  for step, batch in zip(range(num_train_steps), train_iter):
    batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access

    optimizer, metrics, dropout_rngs = p_train_step(
        optimizer, batch, dropout_rng=dropout_rngs)
    metrics_all.append(metrics)

    if (step + 1) % eval_freq == 0:
      metrics_all = common_utils.get_metrics(metrics_all)
      lr = metrics_all.pop('learning_rate').mean()
      metrics_sums = jax.tree_map(jnp.sum, metrics_all)
      denominator = metrics_sums.pop('denominator')
      summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
      summary['learning_rate'] = lr
      # Calculate (clipped) perplexity after averaging log-perplexities:
      summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4)
      logging.info('train in step: %d, loss: %.4f', step, summary['loss'])
      if jax.host_id() == 0:
        tock = time.time()
        steps_per_sec = eval_freq / (tock - tick)
        tick = tock
        train_summary_writer.scalar('steps per second', steps_per_sec, step)
        for key, val in summary.items():
          train_summary_writer.scalar(key, val, step)
        train_summary_writer.flush()
      # reset metric accumulation for next evaluation cycle.
      metrics_all = []

      eval_metrics = []
      eval_iter = iter(eval_ds)
      if num_eval_steps == -1:
        num_iter = itertools.repeat(1)
      else:
        num_iter = range(num_eval_steps)
      for _, eval_batch in zip(num_iter, eval_iter):
        eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
        # Handle final odd-sized batch by padding instead of dropping it.
        cur_pred_batch_size = eval_batch['inputs'].shape[0]
        if cur_pred_batch_size != batch_size:
          logging.info('Uneven batch size %d.', cur_pred_batch_size)
          eval_batch = jax.tree_map(
              lambda x: pad_examples(x, batch_size), eval_batch)
        eval_batch = common_utils.shard(eval_batch)

        metrics = p_eval_step(optimizer.target, eval_batch)
        eval_metrics.append(metrics)
      eval_metrics = common_utils.get_metrics(eval_metrics)
      eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
      eval_denominator = eval_metrics_sums.pop('denominator')
      eval_summary = jax.tree_map(
          lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
          eval_metrics_sums)

      # Calculate (clipped) perplexity after averaging log-perplexities:
      eval_summary['perplexity'] = jnp.clip(
          jnp.exp(eval_summary['loss']), a_max=1.0e4)
      logging.info('eval in step: %d, loss: %.4f, accuracy: %.4f', step,
                   eval_summary['loss'], eval_summary['accuracy'])

      if best_dev_score < eval_summary['accuracy']:
        best_dev_score = eval_summary['accuracy']
        # TODO: save model.
      eval_summary['best_dev_score'] = best_dev_score
      logging.info('best development model score %.4f', best_dev_score)
      if jax.host_id() == 0:
        for key, val in eval_summary.items():
          eval_summary_writer.scalar(key, val, step)
        eval_summary_writer.flush()
Example #28
0
def main(_):
    tf.enable_v2_behavior()

    tf.random.set_seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)
    random.seed(FLAGS.seed)

    # BOS special attention only makes sense if we are using relative attention
    # and it's not the baseline.
    if FLAGS.bos_special_attention and (not FLAGS.use_relative_attention
                                        or FLAGS.attention_mask_type
                                        == 'baseline'):
        raise ValueError(
            "bos_special_attention doesn't work when use_relative_attention={} and "
            'attention_mask_type={}'.format(FLAGS.use_relative_attention,
                                            FLAGS.attention_mask_type))

    if not gfile.isdir(FLAGS.save_dir):
        gfile.makedirs(FLAGS.save_dir)

    hparam_str_dict = json.loads(FLAGS.xm_parameters)
    hparam_str = ','.join([
        '%s=%s' % (shorten(k), str(hparam_str_dict[k]))
        for k in hparam_str_dict.keys()
    ])

    # Number of local devices for this host.
    n_devices = jax.local_device_count()

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.save_dir, 'tb', hparam_str))

    batch_size = FLAGS.per_device_batch_size * n_devices
    io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task,
                FLAGS.max_characters)
    predict_io_shape = (FLAGS.per_device_batch_size,
                        FLAGS.num_strings_per_task,
                        FLAGS.predict_max_characters)
    program_shape = (FLAGS.per_device_batch_size, FLAGS.max_program_length)

    # Setup DSL
    # ---------------------------------------------------------------------------

    # Build token tables.
    id_char_table = {i + 1: char for (i, char) in enumerate(dsl.CHARACTER)}
    char_id_table = {char: id for id, char in id_char_table.items()}
    id_token_table, token_id_table = dsl_tokens.build_token_tables()
    io_vocab_size = len(char_id_table) + 1  # For padding.
    program_vocab_size = len(token_id_table) + 1

    bos_token = token_id_table[dsl.BOS]
    eos_token = token_id_table[dsl.EOS]

    # Parse io and program token sequences (for eval).
    def decode_io(inputs, outputs):
        """Decode io examples tokens."""
        def decode_str(s):
            """Decode string tokens."""
            return ''.join([id_char_table[c_id] for c_id in s if c_id > 0])

        inps, outs = [], []
        for inp, out in zip(inputs, outputs):
            inps.append(decode_str(inp))
            outs.append(decode_str(out))
        return inps, outs

    def decode_program(program):
        """Decode program tokens."""
        program = program[:np.argmax(program == eos_token) + 1].astype(
            np.int32)
        program = program[program != bos_token]

        try:
            return dsl.decode_program(program.tolist(), id_token_table)
        except:  # pylint: disable=bare-except
            return None  # Program does not compile.

    # Load Dataset
    # ---------------------------------------------------------------------------
    logging.info('Initializing dataset.')
    if not FLAGS.dataset_filepattern:
        raise ValueError('Must specify filepattern to dataset.')

    # Training dataset.
    logging.info('Loading dataset from %s', FLAGS.dataset_filepattern)
    padded_shapes = (io_shape[1:], io_shape[1:], program_shape[1:])
    logging.info('padded_shapes: %s', padded_shapes)
    dataset = input_pipeline.create_dataset_from_tf_record(
        FLAGS.dataset_filepattern, token_id_table, char_id_table)
    dataset = dataset.padded_batch(batch_size,
                                   padded_shapes=padded_shapes,
                                   drop_remainder=True)
    # Split evaluation and training.
    eval_ds = dataset.take(FLAGS.num_eval_steps)
    # Decrease batch of predict dataset to handle beam search.
    predict_padded_shapes = (predict_io_shape[1:], predict_io_shape[1:],
                             program_shape[1:])
    logging.info('predict_padded_shapes: %s', predict_padded_shapes)
    predict_ds = eval_ds.unbatch().padded_batch(
        int(np.ceil(batch_size / 10)), padded_shapes=predict_padded_shapes)
    train_ds = dataset.skip(FLAGS.num_eval_steps)
    if FLAGS.train_set_batches > 0:
        train_ds = train_ds.take(FLAGS.train_set_batches)
    train_ds = train_ds.repeat()

    test_dataset = input_pipeline.create_dataset_from_tf_record(
        FLAGS.test_dataset_filepattern, token_id_table, char_id_table)
    test_dataset = test_dataset.padded_batch(
        batch_size, padded_shapes=predict_padded_shapes, drop_remainder=False)
    quick_test_dataset = (test_dataset.take(
        FLAGS.num_quick_test_steps).unbatch().padded_batch(
            int(np.ceil(batch_size / 10)),
            padded_shapes=predict_padded_shapes))
    final_test_dataset = (test_dataset.take(
        FLAGS.num_final_test_steps).unbatch().padded_batch(
            int(np.ceil(batch_size / 10)),
            padded_shapes=predict_padded_shapes))

    # Build Model and Optimizer
    # ---------------------------------------------------------------------------
    default_config = base_models.TransformerConfig(
        vocab_size=io_vocab_size, output_vocab_size=program_vocab_size)
    base_config = base_models.TransformerConfig(
        vocab_size=io_vocab_size,
        output_vocab_size=program_vocab_size,
        shift=True,
        emb_dim=FLAGS.embedding_dim,
        num_heads=FLAGS.num_heads,
        num_layers=FLAGS.num_layers,
        qkv_dim=FLAGS.embedding_dim,
        mlp_dim=FLAGS.hidden_dim,
        max_len=max(FLAGS.max_characters, FLAGS.max_program_length),
        dropout_rate=FLAGS.dropout_rate,
        attention_dropout_rate=FLAGS.attention_dropout_rate,
        use_relative_attention=FLAGS.use_relative_attention,
        deterministic=False,
        decode=False,
        bos_token=bos_token,
        num_input_relative_position_buckets=FLAGS.num_position_buckets,
        max_input_distance=min(FLAGS.max_distance,
                               default_config.max_input_distance),
        num_output_relative_position_buckets=FLAGS.num_position_buckets,
        max_output_distance=min(FLAGS.max_distance,
                                default_config.max_output_distance),
        num_input_cross_output_relative_position_buckets=(
            FLAGS.num_position_buckets),
        max_input_cross_output_distance=min(
            FLAGS.max_distance,
            default_config.max_input_cross_output_distance),
        num_program_relative_position_buckets=FLAGS.num_position_buckets,
        max_program_distance=min(FLAGS.max_distance,
                                 default_config.max_program_distance),
        num_program_cross_embed_relative_position_buckets=(
            FLAGS.num_position_buckets),
        max_program_cross_embed_distance=min(
            FLAGS.max_distance,
            default_config.max_program_cross_embed_distance),
        bidirectional_program_attention=FLAGS.bidirectional_program_attention)
    train_config = models.DecomposeAttentionTransformerConfig(
        base_config=base_config,
        attention_mask_type=FLAGS.attention_mask_type,
        bos_special_attention=FLAGS.bos_special_attention)
    eval_config = models.DecomposeAttentionTransformerConfig(
        base_config=base_config.replace(deterministic=True),
        attention_mask_type=FLAGS.attention_mask_type,
        bos_special_attention=FLAGS.bos_special_attention)
    predict_config = models.DecomposeAttentionTransformerConfig(
        base_config=base_config.replace(
            shift=False,
            deterministic=True,
            decode=not FLAGS.slow_decode,
            max_len=max(FLAGS.max_characters, FLAGS.max_program_length,
                        FLAGS.predict_max_characters)),
        attention_mask_type=FLAGS.attention_mask_type,
        bos_special_attention=FLAGS.bos_special_attention)

    rng = jax.random.PRNGKey(FLAGS.seed)
    rng = jax.random.fold_in(rng, jax.host_id())
    rng, init_rng = jax.random.split(rng)

    dropout_rng = jax.random.split(rng, jax.local_device_count())
    del rng

    m = models.DecomposeAttentionTransformer(eval_config)
    initial_variables = jax.jit(m.init)(init_rng,
                                        jnp.ones(io_shape, jnp.float32),
                                        jnp.ones(io_shape, jnp.float32),
                                        jnp.ones(program_shape, jnp.float32))

    optimizer_def = optim.Adam(FLAGS.lr,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=FLAGS.weight_decay)
    optimizer = optimizer_def.create(initial_variables['params'])

    del initial_variables  # Don't keep a copy of the initial model.

    start_step = 0
    if FLAGS.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(
            os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)
        logging.info('Found model checkpointed at step %d.', start_step)
        if FLAGS.finetune_start_step > 0:
            logging.info(
                'Checking that start_step (%s) == finetune_start_step (%s)',
                start_step, FLAGS.finetune_start_step)
            assert start_step >= FLAGS.finetune_start_step
            steps_to_skip = start_step - FLAGS.finetune_start_step
        else:
            steps_to_skip = start_step

        # TODO(kshi): It is likely that this code can lead to the job stalling for
        # 10+ hours when restarting from a checkpoint that had been trained a long
        # time, possibly because dataset skipping is slow.
        logging.info('Skipping %s steps...', steps_to_skip)
        train_ds = train_ds.skip(steps_to_skip)
        dummy_p_train_step = jax.pmap(
            lambda dropout_rng: jax.random.split(dropout_rng)[1])
        for _ in range(steps_to_skip):
            dropout_rng = dummy_p_train_step(dropout_rng)
        logging.info('Finished skipping steps')
        logging.info('Host %s has dropout_rng = %s', jax.host_id(),
                     dropout_rng)

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    # TODO(jxihong): Implement fast decoding.
    assert FLAGS.slow_decode, 'Fast decoding is not implemented yet.'

    if FLAGS.finetune_start_step <= 0:
        learning_rate_fn = create_learning_rate_scheduler(
            base_learning_rate=FLAGS.lr)
    else:
        # Constant LR for finetuning.
        learning_rate_fn = create_learning_rate_scheduler(
            base_learning_rate=FLAGS.lr, factors='constant')
    p_train_step = jax.pmap(functools.partial(
        train_step, learning_rate_fn=learning_rate_fn, config=train_config),
                            axis_name='batch')
    p_eval_step = jax.pmap(functools.partial(eval_step,
                                             eos_token=eos_token,
                                             config=eval_config),
                           axis_name='batch')
    p_init_cache = jax.pmap(functools.partial(
        initialize_cache,
        max_decode_len=FLAGS.max_program_length,
        config=predict_config),
                            axis_name='batch')
    p_pred_step = jax.pmap(functools.partial(
        predict_step,
        eos_token=eos_token,
        max_decode_len=FLAGS.max_program_length,
        config=predict_config,
        slow_decode=FLAGS.slow_decode),
                           axis_name='batch',
                           static_broadcasted_argnums=(4, ))

    # Main Train Loop
    # ---------------------------------------------------------------------------

    logging.info('Starting training!')
    metrics_all = []
    tick = time.time()
    train_iter = train_ds.as_numpy_iterator()
    for step in range(start_step, FLAGS.num_train_steps):
        inputs, outputs, programs = common_utils.shard(next(train_iter))

        optimizer, metrics, dropout_rng = p_train_step(optimizer,
                                                       inputs,
                                                       outputs,
                                                       programs,
                                                       dropout_rng=dropout_rng)
        metrics_all.append(metrics)
        is_last_step = step == FLAGS.num_train_steps - 1

        # Save a Checkpoint
        if (step % FLAGS.checkpoint_freq == 0 and step > 0) or is_last_step:
            if jax.host_id() == 0:
                # Save unreplicated optimizer + model state.
                checkpoints.save_checkpoint(
                    os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str),
                    jax_utils.unreplicate(optimizer), step)

        # Periodic metric handling.

        # Training Metrics
        if (step and step % FLAGS.log_freq == 0) or is_last_step:
            logging.info('Gathering training metrics.')
            metrics_all = common_utils.get_metrics(metrics_all)
            lr = metrics_all.pop('learning_rate').mean()
            metrics_sums = jax.tree_map(jnp.sum, metrics_all)
            denominator = metrics_sums.pop('denominator')
            summary = jax.tree_map(
                lambda x: x / denominator,  # pylint: disable=cell-var-from-loop
                metrics_sums)
            summary['learning_rate'] = lr
            # Calculate (clipped) perplexity after averaging log-perplexities:
            summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']),
                                             a_max=1.0e4)

            if jax.host_id() == 0:
                logging.info('Train in step: %d, loss: %.4f', step,
                             summary['loss'])
                tock = time.time()
                steps_per_sec = FLAGS.log_freq / (tock - tick)
                tick = tock
                summary_writer.scalar('train/steps per second', steps_per_sec,
                                      step)
                for key, val in summary.items():
                    summary_writer.scalar('train/' + key, val, step)
                summary_writer.flush()
            # Reset metric accumulation for next evaluation cycle.
            metrics_all = []

        # Evaluation Metrics
        if (step and step % FLAGS.eval_freq == 0) or is_last_step:
            logging.info('Gathering evaluation metrics.')
            t_evaluation_start = time.time()
            eval_metrics = []
            for batches in eval_ds.as_numpy_iterator():
                inputs, outputs, programs = common_utils.shard(batches)

                metrics = p_eval_step(optimizer.target, inputs, outputs,
                                      programs)
                eval_metrics.append(metrics)

            eval_metrics = common_utils.get_metrics(eval_metrics)
            eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
            eval_denominator = eval_metrics_sums.pop('denominator')
            eval_summary = jax.tree_map(
                lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
                eval_metrics_sums)

            if jax.host_id() == 0:
                logging.info('Evaluation time: %.4f s step %d, loss: %.4f.',
                             time.time() - t_evaluation_start, step,
                             eval_summary['loss'])
                for key, val in eval_summary.items():
                    summary_writer.scalar('eval/' + key, val, step)
                summary_writer.flush()

        # Beam search metrics.
        if (step and step % FLAGS.predict_freq == 0) or is_last_step:
            logging.info('Gathering beam search metrics.')
            test_ds = final_test_dataset if is_last_step else quick_test_dataset

            for dataset, predict_or_test in [(predict_ds, 'predict'),
                                             (test_ds, 'test')]:

                for beam_size in [1, 10]:
                    t_inference_start = time.time()
                    total_successes = 0
                    total_denominator = 0
                    pred_successes = collections.defaultdict(int)
                    pred_denominators = collections.defaultdict(int)

                    ios, targets, predictions, top_of_beams = [], [], [], []
                    for batches in dataset.as_numpy_iterator():
                        pred_batch = batches
                        # Handle final odd-sized batch by padding instead of dropping it.
                        cur_pred_batch_size = pred_batch[0].shape[0]
                        if cur_pred_batch_size % n_devices:
                            padded_size = int(
                                np.ceil(cur_pred_batch_size / n_devices) *
                                n_devices)
                            # pylint: disable=cell-var-from-loop
                            pred_batch = jax.tree_map(
                                lambda x: pad_examples(x, padded_size),
                                pred_batch)
                        inputs, outputs, programs = common_utils.shard(
                            pred_batch)

                        cache = (p_init_cache(inputs, outputs, programs)
                                 if not FLAGS.slow_decode else None)
                        predicted = p_pred_step(optimizer.target, inputs,
                                                outputs, cache, beam_size)
                        predicted = tohost(predicted)
                        inputs, outputs, programs = map(
                            tohost, (inputs, outputs, programs))

                        for i, beams in enumerate(predicted):
                            inps, outs = decode_io(inputs[i], outputs[i])
                            p, p_score = eval_predicted(
                                beams,
                                inps,
                                outs,
                                parse_beam_fn=decode_program)

                            # Split by length of program.
                            program = programs[i]
                            num_expressions = len(
                                decode_program(program).expressions)
                            pred_denominators[num_expressions] += 1
                            total_denominator += 1
                            if p_score >= len(inps):
                                pred_successes[num_expressions] += 1
                                total_successes += 1

                            ios.append(' ; '.join(map(str, zip(inps, outs))))
                            targets.append(
                                decode_program(programs[i]).to_string())
                            try:
                                predictions.append(p.to_string())
                            except:  # pylint: disable=bare-except
                                predictions.append('Did not compile')
                            logging.info('ios: %s', ios[-1])
                            logging.info('target: %s', targets[-1])
                            beams_log = []
                            for beam in beams:
                                try:
                                    beams_log.append(
                                        decode_program(beam).to_string())
                                except:  # pylint: disable=bare-except
                                    beams_log.append('Did not compile')
                            logging.info('predicted beam: %s',
                                         '\n'.join(beams_log))

                            top_of_beam = []
                            for index, beam in enumerate(beams[:-5:-1]):
                                try:
                                    decoded_program = decode_program(
                                        beam).to_string()
                                except:  # pylint: disable=bare-except
                                    decoded_program = 'Did not compile'
                                top_of_beam.append(
                                    'index: {}, decoded: {}, tokens: {}'.
                                    format(index, decoded_program, beam))
                            top_of_beams.append('\n\n'.join(top_of_beam))

                    all_total_successes, all_total_denominator = per_host_sum_pmap(
                        jax.tree_map(np.array,
                                     (total_successes, total_denominator)))
                    all_pred_successes, all_pred_denominators = per_host_sum_pmap(
                        jax.tree_map(np.array,
                                     (pred_successes, pred_denominators)))

                    # Record beam search results as text summaries.
                    message = []
                    for n in np.random.choice(np.arange(len(predictions)), 8):
                        text = (f'ios: {ios[n]}\n\ntarget: {targets[n]}\n\n'
                                f'predicted: {predictions[n]}\n\n'
                                f'top of beam:\n\n{top_of_beams[n]}\n\n')
                        message.append(text)

                    # Write to tensorboard.
                    if jax.host_id() == 0:
                        accuracy = 100 * all_total_successes / all_total_denominator
                        logging.info(
                            '%s results, step %d, beam size %d: %s / %s = %.2f%% (%.2f s)',
                            predict_or_test, step, beam_size,
                            all_total_successes, all_total_denominator,
                            accuracy,
                            time.time() - t_inference_start)
                        summary_writer.scalar(
                            '{}/beam-size-{}'.format(predict_or_test,
                                                     beam_size), accuracy,
                            step)

                        for length in sorted(all_pred_successes.keys()):
                            this_length_accuracy = (
                                100 * all_pred_successes[length] /
                                all_pred_denominators[length])
                            logging.info(
                                '  accuracy for length %s: %s / %s = %.2f%%',
                                length, all_pred_successes[length],
                                all_pred_denominators[length],
                                this_length_accuracy)
                            summary_writer.scalar(
                                '{}-by-length/beam-size-{}-length-{}'.format(
                                    predict_or_test, beam_size, length),
                                this_length_accuracy, step)

                        summary_writer.text(
                            '{}-samples-beam-{}'.format(
                                predict_or_test, beam_size),
                            '\n------\n'.join(message), step)
                        summary_writer.flush()
Example #29
0
from rigl.experimental.jax.datasets import dataset_factory
from rigl.experimental.jax.models import model_factory
from rigl.experimental.jax.training import training
from rigl.experimental.jax.utils import utils

  experiment_dir = path.join(FLAGS.experiment_dir, str(work_unit_id))

  logging.info('Saving experimental results to %s', experiment_dir)

  host_count = jax.host_count()
  local_device_count = jax.local_device_count()
  logging.info('Device count: %d, host count: %d, local device count: %d',
               jax.device_count(), host_count, local_device_count)

  if jax.host_id() == 0:
    summary_writer = tensorboard.SummaryWriter(experiment_dir)

  dataset = dataset_factory.create_dataset(
      FLAGS.dataset,
      FLAGS.batch_size,
      FLAGS.batch_size_test,
      shuffle_buffer_size=FLAGS.shuffle_buffer_size)

  logging.info('Training %s on the %s dataset...', FLAGS.model, FLAGS.dataset)

  rng = jax.random.PRNGKey(FLAGS.random_seed)

  input_shape = (1,) + dataset.shape
  base_model, _ = model_factory.create_model(
      FLAGS.model,
      rng, ((input_shape, jnp.float32),),
Example #30
0
def main(_):
    tf.enable_v2_behavior()

    tf.random.set_seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)
    random.seed(FLAGS.seed)

    if not gfile.isdir(FLAGS.save_dir):
        gfile.mkdir(FLAGS.save_dir)

    hparam_str_dict = dict(seed=FLAGS.seed, lr=FLAGS.lr)
    # Get hyperparmaters
    if FLAGS.xm_parameters:
        for key, value in json.loads(FLAGS.xm_parameters).items():
            if key not in hparam_str_dict:
                hparam_str_dict[key] = value

    hparam_str = ','.join([
        '%s=%s' % (k, str(hparam_str_dict[k]))
        for k in sorted(hparam_str_dict.keys())
    ])

    # Number of local devices for this host.
    n_devices = jax.local_device_count()

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.save_dir, 'tb', hparam_str))

    batch_size = FLAGS.per_device_batch_size * n_devices
    io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task,
                FLAGS.max_characters)
    program_shape = (FLAGS.per_device_batch_size, FLAGS.max_program_length)

    # Setup DSL
    # ---------------------------------------------------------------------------

    # Build token tables.
    id_char_table = {i + 1: char for (i, char) in enumerate(dsl.CHARACTER)}
    char_id_table = {char: id for id, char in id_char_table.items()}
    id_token_table, token_id_table = dsl_tokens.build_token_tables()
    io_vocab_size = len(char_id_table) + 1  # For padding.
    program_vocab_size = len(token_id_table) + 1

    bos_token = token_id_table[dsl.BOS]
    eos_token = token_id_table[dsl.EOS]

    def decode_io(inputs, outputs):
        """Decode io examples tokens."""
        def decode_str(s):
            """Decode string tokens."""
            return ''.join([id_char_table[c_id] for c_id in s if c_id > 0])

        io_string = ''
        inps, outs = [], []
        for inp, out in zip(inputs, outputs):
            inps.append(decode_str(inp))
            outs.append(decode_str(out))
            io_string += inps[-1] + ' < ' + outs[-1] + ' > '
        return inps, outs, io_string[:-3]  # Remove last separator.

    def decode_program(program):
        """Decode program tokens."""
        program = program[:np.argmax(program == eos_token) + 1].astype(
            np.int32)
        try:
            p = dsl.decode_program(program, id_token_table)
            return p, p.to_string()
        except:  # pylint: disable=bare-except
            return None, ''  # Program does not compile.

    # Load Dataset
    # ---------------------------------------------------------------------------
    logging.info('Initializing dataset.')
    if not FLAGS.dataset_filepattern:
        raise ValueError('Must specify filepattern to dataset.')

    # Training dataset.
    dataset = input_pipeline.create_dataset_from_tf_record(
        FLAGS.dataset_filepattern, token_id_table, char_id_table)
    dataset = dataset.padded_batch(batch_size,
                                   padded_shapes=(io_shape[1:], io_shape[1:],
                                                  program_shape[1:]),
                                   drop_remainder=True)
    # Split evaluation and training.
    eval_ds = dataset.take(FLAGS.num_eval_steps)
    # Decrease batch of predict dataset to handle beam search.
    predict_ds = eval_ds.unbatch().padded_batch(
        int(np.ceil(batch_size / 10)),
        padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:]))
    train_ds = dataset.skip(FLAGS.num_eval_steps).repeat()
    train_iter = train_ds.as_numpy_iterator()

    # Build Model and Optimizer
    # ---------------------------------------------------------------------------
    train_config = models.TransformerConfig(
        vocab_size=io_vocab_size,
        output_vocab_size=program_vocab_size,
        shift=True,
        emb_dim=FLAGS.embedding_dim,
        num_heads=FLAGS.num_heads,
        num_layers=FLAGS.num_layers,
        qkv_dim=FLAGS.embedding_dim,
        mlp_dim=FLAGS.hidden_dim,
        max_len=max(FLAGS.max_characters, FLAGS.max_program_length),
        use_relative_attention=FLAGS.use_relative_attention,
        deterministic=False,
        decode=False,
        bos_token=bos_token)
    eval_config = train_config.replace(deterministic=True)
    predict_config = train_config.replace(shift=False,
                                          deterministic=True,
                                          decode=True)

    rng = jax.random.PRNGKey(FLAGS.seed)
    rng = jax.random.fold_in(rng, jax.host_id())
    rng, init_rng = jax.random.split(rng)

    m = models.ProgramTransformer(eval_config)
    initial_variables = jax.jit(m.init)(init_rng,
                                        jnp.ones(io_shape, jnp.float32),
                                        jnp.ones(io_shape, jnp.float32),
                                        jnp.ones(program_shape, jnp.float32))

    optimizer_def = optim.Adam(FLAGS.lr,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=FLAGS.weight_decay)
    optimizer = optimizer_def.create(initial_variables['params'])

    del initial_variables  # Don't keep a copy of the initial model.

    start_step = 0
    if FLAGS.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(
            os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)
        logging.info('Found model checkpointed at step %d.', start_step)

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = train_lib.create_learning_rate_scheduler(
        base_learning_rate=FLAGS.lr)
    p_train_step = jax.pmap(functools.partial(
        train_lib.train_step,
        learning_rate_fn=learning_rate_fn,
        config=train_config),
                            axis_name='batch')
    p_eval_step = jax.pmap(functools.partial(train_lib.eval_step,
                                             config=eval_config),
                           axis_name='batch')
    p_init_cache = jax.pmap(functools.partial(
        train_lib.initialize_cache,
        max_decode_len=FLAGS.max_program_length,
        config=predict_config),
                            axis_name='batch')
    p_pred_step = jax.pmap(functools.partial(train_lib.predict_step,
                                             config=predict_config),
                           axis_name='batch',
                           static_broadcasted_argnums=(4, 5, 6))

    # Main Train Loop
    # ---------------------------------------------------------------------------
    train_rngs = jax.random.split(rng, jax.local_device_count())
    del rng

    metrics_all = []
    tick = time.time()
    for step in range(start_step, FLAGS.num_train_steps):
        inputs, outputs, programs = common_utils.shard(next(train_iter))

        optimizer, metrics, train_rngs = p_train_step(optimizer,
                                                      inputs,
                                                      outputs,
                                                      programs,
                                                      train_rng=train_rngs)
        metrics_all.append(metrics)

        # Save a Checkpoint
        if ((step % FLAGS.checkpoint_freq == 0 and step > 0)
                or step == FLAGS.num_train_steps - 1):
            if jax.host_id() == 0:
                # Save unreplicated optimizer + model state.
                checkpoints.save_checkpoint(
                    os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str),
                    jax_utils.unreplicate(optimizer), step)

        # Periodic metric handling.
        if not step or step % FLAGS.log_freq != 0:
            continue

        logging.info('Gathering training metrics.')
        # Training Metrics
        metrics_all = common_utils.get_metrics(metrics_all)
        lr = metrics_all.pop('learning_rate').mean()
        metrics_sums = jax.tree_map(jnp.sum, metrics_all)
        denominator = metrics_sums.pop('denominator')
        summary = jax.tree_map(
            lambda x: x / denominator,  # pylint: disable=cell-var-from-loop
            metrics_sums)
        summary['learning_rate'] = lr
        # Calculate (clipped) perplexity after averaging log-perplexities:
        summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4)

        if jax.host_id() == 0:
            logging.info('Train in step: %d, loss: %.4f', step,
                         summary['loss'])
            tock = time.time()
            steps_per_sec = FLAGS.log_freq / (tock - tick)
            tick = tock
            summary_writer.scalar('train/steps per second', steps_per_sec,
                                  step)
            for key, val in summary.items():
                summary_writer.scalar('train/' + key, val, step)
            summary_writer.flush()
        # Reset metric accumulation for next evaluation cycle.
        metrics_all = []

        # Evaluation Metrics
        logging.info('Gathering evaluation metrics.')
        t_evaluation_start = time.time()
        eval_metrics = []
        for batches in eval_ds.as_numpy_iterator():
            inputs, outputs, programs = common_utils.shard(batches)

            metrics = p_eval_step(optimizer.target, inputs, outputs, programs)
            eval_metrics.append(metrics)

        eval_metrics = common_utils.get_metrics(eval_metrics)
        eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
        eval_denominator = eval_metrics_sums.pop('denominator')
        eval_summary = jax.tree_map(
            lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
            eval_metrics_sums)

        if jax.host_id() == 0:
            logging.info('Evaluation time: %.4f s step %d, loss: %.4f.',
                         time.time() - t_evaluation_start, step,
                         eval_summary['loss'])
            for key, val in eval_summary.items():
                summary_writer.scalar('eval/' + key, val, step)
            summary_writer.flush()

        # Beam search metrics.
        logging.info('Gathering beam search metrics.')
        for beam_size in [10, 100]:
            t_inference_start = time.time()
            pred_acc = 0
            pred_denominator = 0

            ios, targets, predictions = [], [], []
            for batches in predict_ds.as_numpy_iterator():
                pred_batch = batches
                # Handle final odd-sized batch by padding instead of dropping it.
                cur_pred_batch_size = pred_batch[0].shape[0]
                if cur_pred_batch_size % n_devices:
                    padded_size = int(
                        np.ceil(cur_pred_batch_size / n_devices) * n_devices)
                    # pylint: disable=cell-var-from-loop
                    pred_batch = jax.tree_map(
                        lambda x: train_lib.pad_examples(x, padded_size),
                        pred_batch)
                inputs, outputs, programs = common_utils.shard(pred_batch)

                cache = p_init_cache(inputs, outputs, programs)
                predicted = p_pred_step(optimizer.target, inputs, outputs,
                                        cache, eos_token, programs.shape[-1],
                                        beam_size)
                predicted = train_lib.tohost(predicted)
                inputs, outputs, programs = map(train_lib.tohost,
                                                (inputs, outputs, programs))

                pred_denominator += programs.shape[0]
                for i, beams in enumerate(predicted):
                    inps, outs, io_string = decode_io(inputs[i], outputs[i])
                    p, p_score = train_lib.eval_predicted(
                        beams,
                        inps,
                        outs,
                        parse_beam_fn=lambda x: decode_program(x)[0])
                    if p_score >= len(inps):
                        pred_acc += 1
                    ios.append(io_string)
                    targets.append(decode_program(programs[i])[1])
                    predictions.append(p.to_string() if p else '')

            all_pred_acc, all_pred_denominator = train_lib.per_host_sum_pmap(
                jax.tree_map(np.array, (pred_acc, pred_denominator)))

            # Record beam search results as text summaries.
            message = []
            for n in np.random.choice(np.arange(len(predictions)), 8):
                text = (f'ios: {ios[n]}\n\ntarget: {targets[n]}\n\n'
                        f'predicted: {predictions[n]}\n\n')
                message.append(text)

            # Write to tensorboard.
            if jax.host_id() == 0:
                logging.info(
                    'Prediction time (beam %d): %.4f s step %d, score %.4f.',
                    beam_size,
                    time.time() - t_inference_start, step,
                    all_pred_acc / all_pred_denominator)
                summary_writer.scalar('predict/score-{}'.format(beam_size),
                                      all_pred_acc / all_pred_denominator,
                                      step)
                summary_writer.text('samples-{}'.format(beam_size),
                                    '\n------\n'.join(message), step)
                summary_writer.flush()