Exemple #1
0
def train_and_evaluate(config, work_dir, try_checkpoint=True):
    """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.
    work_dir: Directory where the tensorboard summaries are written to.
    try_checkpoint: Should try to load checkpoint (usually enabled, practical
        for debugging purposes to disable).

  Returns:
    The train state (which includes the `.params`).
  """
    # Init rng key.
    msg = f'Running with seed {config.seed}.'
    logging.info(msg)
    rng = jax.random.PRNGKey(config.seed)
    data_rng, rng = jax.random.split(rng)
    is_first_host = jax.process_index() == 0

    train_ds, test_ds, shape, num_classes = datasets.get_dataset(
        config, data_rng)

    # config.mask_shape = mask_shape
    config.data_shape = shape
    config.num_classes = num_classes

    writer = metric_writers.create_default_writer(
        work_dir, just_logging=jax.process_index() > 0)
    rng, init_rng = jax.random.split(rng)

    # Create output directory for saving samples.
    output_path = work_dir
    tf.io.gfile.makedirs(output_path)

    model, variables = model_setup(init_rng, config)

    # From now on we want different rng across hosts:
    rng = jax.random.fold_in(rng, jax.process_index())

    tx = optax.adam(config.learning_rate,
                    b1=0.9,
                    b2=config.beta2,
                    eps=1e-08,
                    eps_root=0.0)
    state = custom_train_state.TrainState.create(params=variables['params'],
                                                 tx=tx)

    if try_checkpoint:
        state, start_epoch = checkpoint.restore_from_path(work_dir, state)
        if start_epoch is None:
            start_epoch = 1
    else:
        # For debugging we start at zero, so we immediately do detailed eval.
        start_epoch = 0

    if is_first_host and start_epoch == 1:
        config_dict = dict(config)
        writer.write_hparams(config_dict)

    if is_first_host and start_epoch in (0, 1):
        # Dump config file to work dir for easy model loading.
        config_path = os.path.join(work_dir, 'config')
        with tf.io.gfile.GFile(config_path, 'wb') as fp:
            pickle.dump(config, fp)

    test_rng, train_rng = jax.random.split(rng)

    kl_tracker_train = util_fns.KLTracker(num_steps=model.num_steps)
    kl_history = []

    p_train_step = jax.pmap(functools.partial(train_step,
                                              model=model,
                                              config=config),
                            axis_name='batch',
                            in_axes=(None, 0, 0),
                            out_axes=(0, 0, None),
                            donate_argnums=(2, ))

    # The only axes that are broadcasted are the in- and output rng key ones. The
    # rng is the first arg, and the last return value.
    p_eval_step = jax.pmap(functools.partial(eval_step, model=model),
                           axis_name='batch',
                           in_axes=(None, 0, 0),
                           out_axes=(0, None))

    # Replicate state.
    state = flax.jax_utils.replicate(state)

    with metric_writers.ensure_flushes(writer):
        for epoch in range(start_epoch, config.num_epochs + 1):
            # Train part.
            state, train_metrics, train_rng = train_epoch(
                p_train_step, state, train_ds, config.batch_size, epoch,
                train_rng, kl_tracker_train)

            # Val part.
            eval_metrics, test_rng = eval_model(p_eval_step, test_rng, state,
                                                test_ds, epoch)

            # Metric logging.
            if is_first_host:
                log_standard_metrics(writer, train_metrics, eval_metrics,
                                     epoch)

            kl_values = kl_tracker_train.get_kl_per_t()
            kl_history.append(np.array(kl_values))

            # Prune to avoid too much memory consumption.
            kl_history = kl_history[-50:]

            if epoch == 15 or epoch % config.detailed_eval_every == 0:
                if is_first_host:
                    loss_components_path = os.path.join(
                        work_dir, 'loss_components')
                    with tf.io.gfile.GFile(loss_components_path, 'wb') as fp:
                        pickle.dump(kl_history[-1], fp)

                test_rng = extensive_eval(config, test_rng, writer,
                                          output_path, model, state,
                                          kl_history, test_ds, epoch)

            # Save to checkpoint.
            if is_first_host and epoch % config.save_every == 0:
                # Save to epoch + 1 since current epoch has just been completed.
                logging.info('saving checkpoint')
                checkpoint.save_checkpoint(
                    work_dir,
                    state=flax.jax_utils.unreplicate(state),
                    step=epoch + 1,
                    keep=2)
                logging.info('finished saving checkpoint')

        return state
Exemple #2
0
def train_and_evaluate(config, work_dir, try_checkpoint=True):
    """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.
    work_dir: Directory where the tensorboard summaries are written to.
    try_checkpoint: Should try to load checkpoint (usually enabled, practical
        for debugging purposes to disable).

  Returns:
    The train state (which includes the `.params`).
  """
    # Init rng key.
    rng = jax.random.PRNGKey(config.seed)
    data_rng, rng = jax.random.split(rng)
    is_first_host = jax.process_index() == 0

    if config.dataset.name.endswith('speech_commands09'):
        ds, ds_metadata = input_pipeline_sc09.get_dataset(data_rng, config)
    else:
        raise ValueError(f'Unknown dataset {config.dataset.name}.')

    # Immediately create infinite iterators.
    it = jax.tree_map(util_fns.get_iterator, ds)

    # TODO(agritsenko): Can we fix the ugly nested dicts?
    config.data_shape = ds_metadata['train']['shape']['inputs'][2:]
    config.num_classes = ds_metadata['train']['num_classes']
    config.sample_rate = ds_metadata['train']['sample_rate']

    writer = metric_writers.create_default_writer(
        work_dir, just_logging=jax.process_index() > 0)
    rng, init_rng = jax.random.split(rng)

    model, variables = model_setup(init_rng, config)

    # From now on we want different rng across hosts:
    rng = jax.random.fold_in(rng, jax.process_index())

    def tx_fn(lr):
        return optax.adamw(lr,
                           b1=0.9,
                           b2=config.beta2,
                           eps=1e-08,
                           eps_root=0.0,
                           weight_decay=config.weight_decay)

    state = language_train_state.TrainState.create(params=variables['params'],
                                                   tx_fn=tx_fn)

    start_step = None
    if try_checkpoint:
        state, start_step = checkpoint.restore_from_path(work_dir, state)
    start_step = start_step or 0

    # Use different rngs for train & eval.
    rng_train, rng_eval, rng_sample = jax.random.split(rng, 3)

    kl_tracker = util_fns.KLTracker(num_steps=model.num_steps)
    kl_history = []

    learning_rate_fn = train_utils.create_learning_rate_scheduler(
        **config.learning_rate)
    p_train_step = jax.pmap(functools.partial(
        train_step,
        config=config,
        learning_rate_fn=learning_rate_fn,
        model=model),
                            axis_name='batch',
                            in_axes=(None, 0, 0),
                            out_axes=(0, 0, None),
                            donate_argnums=(2, ))

    # The only axes that are broadcasted are the in- and output rng key ones. The
    # rng is the first arg, and the last return value.
    p_eval_step = jax.pmap(functools.partial(eval_step, model=model),
                           axis_name='batch',
                           in_axes=(None, 0, 0),
                           out_axes=(0, 0, None))

    # Training length.
    logging.info('Training will start from step %d', start_step)

    # Replicate state.
    state = flax.jax_utils.replicate(state)

    # Setup hooks.
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_train_steps, writer=writer)
    if is_first_host:
        hooks += [
            report_progress,
            periodic_actions.Profile(logdir=work_dir, num_profile_steps=5)
        ]

    with metric_writers.ensure_flushes(writer):
        batch_metrics = []
        for step in range(start_step, config.num_train_steps):
            logging.log_first_n(logging.INFO, f'Train step: {step}', 5)
            with jax.profiler.StepTraceAnnotation('train', step_num=step):
                state, metrics, rng_train = p_train_step(
                    rng_train, next(it['train']), state)
            batch_metrics.append(metrics)

            # Cycle though hooks.
            for h in hooks:
                h(step)

            is_last_step = step == config.num_train_steps - 1

            if (step % config.log_every_steps == 0) or is_last_step:
                with report_progress.timed('training_metrics'):
                    ################### Process batch metrics ############################
                    batch_metrics = jax.device_get(
                        flax.jax_utils.unreplicate(batch_metrics))

                    if 't_batch' in metrics:
                        # TODO(agritsenko): Factor out into a separate function.
                        # This processes the loss per t, although two nested for-loops
                        # (counting the one inside kl_tracker), it actually does not hurt
                        # timing performance meaningfully.
                        batch_t = [
                            metrics['t_batch'].reshape(-1)
                            for metrics in batch_metrics
                        ]
                        batch_nelbo_per_t = [
                            metrics['nelbo_per_t_batch'].reshape(-1)
                            for metrics in batch_metrics
                        ]
                        for t, nelbo_per_t in zip(batch_t, batch_nelbo_per_t):
                            kl_tracker.update(t, nelbo_per_t)

                    ################### Process batch metrics ############################
                    metrics = {
                        key:
                        np.mean([metrics[key] for metrics in batch_metrics])
                        for key in batch_metrics[0] if 'batch' not in key
                    }

                    # Metric logging.
                    if is_first_host:
                        log_standard_metrics(writer,
                                             step,
                                             train_metrics=metrics)
                    batch_metrics = []

            if config.eval_every_steps and (
                (step % config.eval_every_steps == 0) or is_last_step):
                with report_progress.timed('eval'):
                    ####################### Run evaluation ###############################
                    metrics, rng_eval = eval_model(
                        p_eval_step, rng_eval, state, it['eval'],
                        (ds_metadata['eval']['num_batches'] *
                         config.get('num_eval_passes', 1)))

                    # Metric logging.
                    if is_first_host:
                        log_standard_metrics(writer,
                                             step,
                                             eval_metrics=metrics)

                # Track KL (unrelated to the eval, but nice to not do every step).
                kl_values = kl_tracker.get_kl_per_t()
                kl_history.append(np.array(kl_values))
                kl_history = kl_history[-50:]

            if config.sample_every_steps and (
                (step % config.sample_every_steps == 0) or is_last_step):
                with report_progress.timed('sample'):
                    ######################### Run sampling ###############################
                    chain = model.sample(jax.random.fold_in(rng_sample, step),
                                         state.ema_params,
                                         config.sample_batch_size,
                                         chain_out_size=config.get(
                                             'chain_out_size',
                                             model.num_stages))

                    if is_first_host:
                        chain = jax.device_get(chain)
                        long_sample = np.reshape(chain[-1],
                                                 (1, -1, 1)).astype(np.float32)
                        long_sample = (2. *
                                       long_sample) / config.num_classes - 1.
                        writer.write_audios(step, {'samples': long_sample},
                                            sample_rate=config.sample_rate)

            ######################### Checkpointing #################################
            if is_first_host and config.checkpoint_every_steps and (
                (step % config.checkpoint_every_steps == 0) or is_last_step):
                logging.info('Saving checkpoint: step %d', step)
                with report_progress.timed('checkpoint'):
                    checkpoint.save_checkpoint(
                        work_dir,
                        state=flax.jax_utils.unreplicate(state),
                        step=step)
                logging.info('Finished saving checkpoint: step %d', step)

        return state
Exemple #3
0
def train_and_evaluate(config, workdir):
  """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
  is_first_process = jax.process_index() == 0
  tf.io.gfile.makedirs(workdir)

  # Load Dataset
  # ---------------------------------------------------------------------------
  logging.info('Initializing dataset.')
  train_ds, eval_ds, test_ds, encoder = input_pipeline.get_datasets(
      config)
  config.seq_length = 250
  vocab_size = int(encoder.vocab_size())
  config.num_classes = vocab_size
  config.data_shape = (config.seq_length, 1)

  logging.info('Training with vocab size %d', vocab_size)

  def decode_tokens(toks):
    return encoder.detokenize(toks)

  start_step = 0
  rng = jax.random.PRNGKey(config.seed)
  rng, init_rng = jax.random.split(rng)
  config.per_device_batch_size = config.batch_size // jax.process_count()

  logging.info('Initializing model, optimizer, and step functions.')
  # Build Model and Optimizer
  # ---------------------------------------------------------------------------
  model, initial_variables = model_setup(init_rng, config)

  # Instead of passing the optimizer fns directly, we use a fn that returns
  # the optimizer given a learning rate.
  def tx_fn(lr):
    return optax.adamw(
        lr, b1=0.9, b2=0.99, eps=1e-08, eps_root=0.0,
        weight_decay=config.weight_decay)

  state = language_train_state.TrainState.create(
      params=initial_variables['params'], tx_fn=tx_fn)

  # We access model params only from state below via state.params.
  del initial_variables

  if config.restore_checkpoints:
    # Restore unreplicated model state from last checkpoint.
    state = checkpoints.restore_checkpoint(workdir, state)
    # Grab last step.
    start_step = int(state.step)

  writer = metric_writers.create_default_writer(
      workdir, just_logging=not is_first_process)
  if start_step == 0:
    config_dict = dict(config)
    writer.write_hparams(config_dict)

  if is_first_process and start_step == 0:
    # Dump config file to work dir for easy model loading.
    config_path = os.path.join(workdir, 'config')
    with tf.io.gfile.GFile(config_path, 'wb') as fp:
      pickle.dump(config, fp)

  print('Using state', type(state))
  # Replicate state.
  state = jax_utils.replicate(state)

  learning_rate_fn = create_learning_rate_scheduler(
      factors=config.lr_factors,
      base_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps)

  # Compile multidevice versions of train/eval/predict step fn.
  p_train_step = jax.pmap(
      functools.partial(
          train_step,
          model=model,
          learning_rate_fn=learning_rate_fn,
          clip_grad=config.clip_grad,
          ema_momentum=config.get('ema_momentum', 0.999)),
      axis_name='batch',
      in_axes=(0, 0),
      donate_argnums=(0,))
  p_eval_step = jax.pmap(
      functools.partial(
          eval_step, model=model),
      axis_name='batch')

  # Main Train Loop
  # ---------------------------------------------------------------------------

  # We init the first set of train PRNG keys, but update it afterwards inside
  # the main pmap'd training update for performance.
  rng = jax.random.fold_in(rng, jax.process_index())
  rng1, rng2, rng3, extensive_eval_rngs, sample_rng = jax.random.split(rng, 5)
  train_rngs = jax.random.split(rng1, jax.local_device_count())
  eval_rngs = jax.random.split(rng2, jax.local_device_count())
  test_rngs = jax.random.split(rng3, jax.local_device_count())
  del rng, rng1, rng2, rng3

  logging.info('Starting training loop.')
  hooks = []
  report_progress = periodic_actions.ReportProgress(
      num_train_steps=config.num_train_steps, writer=writer)
  if is_first_process:
    hooks += [
        report_progress,
        periodic_actions.Profile(logdir=workdir, num_profile_steps=5)
    ]
  train_metrics = []

  # Iterator that does epoch-wise indefinite iteration.
  def iterate_train(train_ds):
    epoch = 1
    while True:
      msg = f'Starting epoch {epoch}'
      logging.info(msg)
      for batch in train_ds:
        yield batch
      epoch += 1

  train_iter = iterate_train(train_ds)

  kl_tracker_train = util_fns.KLTracker(num_steps=model.num_steps)
  kl_history = []

  with metric_writers.ensure_flushes(writer):
    step = start_step
    for step in range(start_step, config.num_train_steps):
      is_last_step = step == config.num_train_steps - 1

      # Shard data to devices and do a training step.
      with jax.profiler.StepTraceAnnotation('train', step_num=step):
        batch = common_utils.shard(jax.tree_map(np.asarray, next(train_iter)))
        state, metrics = p_train_step(
            state, batch, rng=train_rngs)
        train_metrics.append(metrics)

      # Quick indication that training is happening.
      logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step)
      for h in hooks:
        h(step)

      # Periodic metric handling.
      if step > 0 and (step % config.eval_every_steps == 0 or is_last_step):
        with report_progress.timed('training_metrics'):
          logging.info('Gathering training metrics.')
          train_metrics = common_utils.get_metrics(train_metrics)

          # First handle loss terms per step.
          t_batch = train_metrics.pop('t_batch')
          nelbo_per_t_batch = train_metrics.pop('nelbo_per_t_batch')
          kl_tracker_train.update(
              t_batch.reshape(-1), nelbo_per_t_batch.reshape(-1))
          kl_values = kl_tracker_train.get_kl_per_t()
          kl_history.append(np.array(kl_values))
          kl_history = kl_history[-100:]  # Keep last 100 items only.

          # Handle remaining `standard` metrics
          summary = jax.tree_map(jnp.mean, train_metrics)
          summary = {'train_' + k: v for k, v in summary.items()}
          writer.write_scalars(step, summary)
          train_metrics = []

        with report_progress.timed('eval'):
          eval_results, eval_rngs = evaluate(
              p_eval_step=p_eval_step,
              params=state.ema_params,
              eval_ds=eval_ds,
              rng=eval_rngs)
          writer.write_scalars(
              step, {'eval_' + k: v for k, v in eval_results.items()})

          test_results, test_rngs = evaluate(
              p_eval_step=p_eval_step,
              params=state.ema_params,
              eval_ds=test_ds,
              rng=test_rngs)
          writer.write_scalars(
              step, {'test_' + k: v for k, v in test_results.items()})

        if step == 1000 or (step > 0 and
                            step % config.detailed_eval_every_steps == 0):
          if is_first_process:
            loss_components_path = os.path.join(workdir, 'loss_components')
            with tf.io.gfile.GFile(loss_components_path, 'wb') as fp:
              pickle.dump(kl_history[-1], fp)

          extensive_eval_rngs = extensive_eval(
              config, extensive_eval_rngs, writer, workdir,
              model, state, kl_history, test_ds, step,
              decode_tokens)

        with report_progress.timed('generate_text'):
          generate_prediction(sample_rng, config, model, state, writer,
                              decode_tokens, step)

      # Save a checkpoint on one host after every checkpoint_freq steps.
      save_checkpoint = (
          step > 0 and
          (step % config.checkpoint_every_steps == 0 or is_last_step))
      if config.save_checkpoints and save_checkpoint and is_first_process:
        with report_progress.timed('checkpoint'):
          checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(state),
                                      step, overwrite=True)