Exemple #1
0
def train_and_evaluate(config, work_dir, try_checkpoint=True):
    """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.
    work_dir: Directory where the tensorboard summaries are written to.
    try_checkpoint: Should try to load checkpoint (usually enabled, practical
        for debugging purposes to disable).

  Returns:
    The train state (which includes the `.params`).
  """
    # Init rng key.
    msg = f'Running with seed {config.seed}.'
    logging.info(msg)
    rng = jax.random.PRNGKey(config.seed)
    data_rng, rng = jax.random.split(rng)
    is_first_host = jax.process_index() == 0

    train_ds, test_ds, shape, num_classes = datasets.get_dataset(
        config, data_rng)

    # config.mask_shape = mask_shape
    config.data_shape = shape
    config.num_classes = num_classes

    writer = metric_writers.create_default_writer(
        work_dir, just_logging=jax.process_index() > 0)
    rng, init_rng = jax.random.split(rng)

    # Create output directory for saving samples.
    output_path = work_dir
    tf.io.gfile.makedirs(output_path)

    model, variables = model_setup(init_rng, config)

    # From now on we want different rng across hosts:
    rng = jax.random.fold_in(rng, jax.process_index())

    tx = optax.adam(config.learning_rate,
                    b1=0.9,
                    b2=config.beta2,
                    eps=1e-08,
                    eps_root=0.0)
    state = custom_train_state.TrainState.create(params=variables['params'],
                                                 tx=tx)

    if try_checkpoint:
        state, start_epoch = checkpoint.restore_from_path(work_dir, state)
        if start_epoch is None:
            start_epoch = 1
    else:
        # For debugging we start at zero, so we immediately do detailed eval.
        start_epoch = 0

    if is_first_host and start_epoch == 1:
        config_dict = dict(config)
        writer.write_hparams(config_dict)

    if is_first_host and start_epoch in (0, 1):
        # Dump config file to work dir for easy model loading.
        config_path = os.path.join(work_dir, 'config')
        with tf.io.gfile.GFile(config_path, 'wb') as fp:
            pickle.dump(config, fp)

    test_rng, train_rng = jax.random.split(rng)

    kl_tracker_train = util_fns.KLTracker(num_steps=model.num_steps)
    kl_history = []

    p_train_step = jax.pmap(functools.partial(train_step,
                                              model=model,
                                              config=config),
                            axis_name='batch',
                            in_axes=(None, 0, 0),
                            out_axes=(0, 0, None),
                            donate_argnums=(2, ))

    # The only axes that are broadcasted are the in- and output rng key ones. The
    # rng is the first arg, and the last return value.
    p_eval_step = jax.pmap(functools.partial(eval_step, model=model),
                           axis_name='batch',
                           in_axes=(None, 0, 0),
                           out_axes=(0, None))

    # Replicate state.
    state = flax.jax_utils.replicate(state)

    with metric_writers.ensure_flushes(writer):
        for epoch in range(start_epoch, config.num_epochs + 1):
            # Train part.
            state, train_metrics, train_rng = train_epoch(
                p_train_step, state, train_ds, config.batch_size, epoch,
                train_rng, kl_tracker_train)

            # Val part.
            eval_metrics, test_rng = eval_model(p_eval_step, test_rng, state,
                                                test_ds, epoch)

            # Metric logging.
            if is_first_host:
                log_standard_metrics(writer, train_metrics, eval_metrics,
                                     epoch)

            kl_values = kl_tracker_train.get_kl_per_t()
            kl_history.append(np.array(kl_values))

            # Prune to avoid too much memory consumption.
            kl_history = kl_history[-50:]

            if epoch == 15 or epoch % config.detailed_eval_every == 0:
                if is_first_host:
                    loss_components_path = os.path.join(
                        work_dir, 'loss_components')
                    with tf.io.gfile.GFile(loss_components_path, 'wb') as fp:
                        pickle.dump(kl_history[-1], fp)

                test_rng = extensive_eval(config, test_rng, writer,
                                          output_path, model, state,
                                          kl_history, test_ds, epoch)

            # Save to checkpoint.
            if is_first_host and epoch % config.save_every == 0:
                # Save to epoch + 1 since current epoch has just been completed.
                logging.info('saving checkpoint')
                checkpoint.save_checkpoint(
                    work_dir,
                    state=flax.jax_utils.unreplicate(state),
                    step=epoch + 1,
                    keep=2)
                logging.info('finished saving checkpoint')

        return state
def generate(config: ml_collections.ConfigDict):
    """Generates memories."""
    # Establish host information
    local_device_count = jax.local_device_count()
    device_count = jax.device_count()
    process_count = jax.process_count()
    process_index = jax.process_index()

    task = memory_generation_task.MemoryGenerationTask
    model_config = ml_collections.FrozenConfigDict(config.model_config)
    model = task.build_model(model_config)
    p_predict_step = jax.pmap(functools.partial(
        task.make_prediction_fn(config),
        model_config,
    ),
                              axis_name='batch')
    rng = jax.random.PRNGKey(config.seed)

    # Initialization needs to be pmapped because models use collective ops.
    # Create dummy input
    dummy_input = {
        key: jnp.tile(value, (local_device_count, ) + (1, ) * value.ndim)
        for key, value in task.dummy_input(config).items()
    }

    rng, init_rng = jax.random.split(rng)
    init_rng = jax.random.split(init_rng, local_device_count)

    logging.info('Initializing model.')
    initial_variables = jax.pmap(model.init,
                                 'batch',
                                 static_broadcasted_argnums=2)(init_rng,
                                                               dummy_input,
                                                               True)
    logging.info('Finished initializing model.')
    initial_variables = initial_variables.unfreeze()

    if config.load_weights is not None:
        logging.info('Loading model weights from file')
        loaded_variables = task.load_weights(config)
        unexpected, missing = checkpoint_utils.merge_nested_dicts(
            initial_variables, loaded_variables)
        logging.info('*** Unexpected features: ***')
        for feature_name in unexpected:
            logging.info('\t%s', feature_name)
        # In the prediction mode we don't allow any features to be missing
        # pylint: disable=g-explicit-length-test
        if len(missing) > 0:
            raise ValueError('Missing features: %s' % ','.join(missing))

    # model_params = jax_utils.unreplicate(initial_variables['params'])
    model_params = initial_variables['params']
    model_vars = {
        key: value
        for key, value in initial_variables.items() if key != 'params'
    }
    # We access model params only from train state.
    del initial_variables

    writer = metric_writers.create_default_writer(
        config.output_dir, just_logging=process_index > 0)

    max_length = config.get('max_length_with_entity_tokens',
                            model_config.encoder_config.max_length)

    num_total_memories = math.ceil(config.num_total_memories / process_count)
    memory_saver = memory_generation_task.MemorySaver(
        num_total_memories=num_total_memories,
        memory_dim=config.memory_dim,
        max_length=max_length,
        max_mentions_per_sample=config.max_mentions_per_sample,
        memory_key_dim=config.get('memory_key_dim'))
    n_samples = 0
    data_iter = get_data_iterator(config)

    logging.info('Start memory generation.')
    with metric_writers.ensure_flushes(writer):
        for step, batch in enumerate(data_iter):
            batch = jax.tree_map(jnp.asarray, batch)
            predictions = p_predict_step(
                model_params,
                model_vars,
                batch,
            )
            predictions = jax.device_get(predictions)
            memory_saver.add_memories(batch, predictions)
            n_devices, batch_size, _ = batch['text_ids'].shape
            logging.log_first_n(
                logging.INFO, 'Process %d / %d: '
                'Finished generating step %d, local devices %d, batch size %d',
                5, process_index, process_count, step, n_devices, batch_size)

            n_samples += device_count * config.per_device_batch_size
            if (step % config.log_every_steps == 0
                    or memory_saver.get_num_memories() >= num_total_memories):
                writer.write_scalars(
                    step,
                    dict(n_memories=memory_saver.get_num_memories(),
                         n_samples=n_samples))

            if memory_saver.get_num_memories() >= num_total_memories:
                break

    logging.info('Process %d / %d: Finished generating memories: %d out of %d',
                 process_index, process_count, memory_saver.get_num_memories(),
                 num_total_memories)

    start_time = time.time()
    logging.info('Process %d / %d: Start saving generated memories to files.',
                 process_index, process_count)
    memory_saver.save(config.output_dir,
                      num_shards=config.num_shards,
                      stride=process_count,
                      offset=process_index,
                      shard_size_divisible=config.shard_size_divisible)

    logging.info(
        'Process %d / %d: Finished saving generated memories to files in %.2f seconds',
        process_index, process_count,
        time.time() - start_time)
Exemple #3
0
    def train_and_evaluate(self, workdir):
        """Runs a training and evaluation loop.

    Args:
      workdir: Working directory for checkpoints and TF summaries. If this
        contains checkpoint training will be resumed from the latest checkpoint.
    """

        tf.io.gfile.makedirs(workdir)
        config = self.config
        substeps = config.training.substeps

        # Learning rate schedule.
        num_train_steps = config.training.num_train_steps
        logging.info('num_train_steps=%d', num_train_steps)

        # Get train state
        state = self._train_state

        # Set up checkpointing of the model and the input pipeline.
        checkpoint_dir = os.path.join(workdir, 'checkpoints')
        ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir, max_to_keep=5)
        state = ckpt.restore_or_initialize(state)
        initial_step = int(state.step)

        # Distribute training.
        state = flax_utils.replicate(state)

        writer = metric_writers.create_default_writer(
            workdir, just_logging=jax.process_index() > 0)
        if initial_step == 0:
            writer.write_hparams(dict(config))

        logging.info('Starting training loop at step %d.', initial_step)
        hooks = []
        report_progress = periodic_actions.ReportProgress(
            num_train_steps=num_train_steps, writer=writer)
        if jax.process_index() == 0:
            hooks += [
                report_progress,
                periodic_actions.Profile(num_profile_steps=5, logdir=workdir)
            ]
        step = initial_step
        with metric_writers.ensure_flushes(writer):
            while step < num_train_steps:
                # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU
                # devices.
                is_last_step = step + substeps >= num_train_steps

                with jax.profiler.StepTraceAnnotation('train', step_num=step):
                    inputs = jax.tree_map(np.asarray, next(self._train_iter))
                    state, outputs = self._update_func(state, inputs)

                # Quick indication that training is happening.
                logging.log_first_n(logging.INFO, 'Finished training step %d.',
                                    5, step)
                for h in hooks:
                    h(step)

                new_step = int(state.step[0])
                assert new_step == step + substeps
                step = new_step

                is_eval = step % config.logs.eval_full_every_steps == 0 or is_last_step
                if step % config.logs.log_loss_every_steps == 0 and not is_eval:

                    def avg_over_substeps(x):
                        assert x.shape[0] == substeps
                        return float(x.mean(axis=0))

                    # Extract scalars and images.
                    outputs = flax_utils.unreplicate(outputs)
                    outputs = jax.tree_map(avg_over_substeps, outputs)
                    scalars = outputs['scalars']
                    writer.write_scalars(step, scalars)

                if is_eval:
                    with report_progress.timed('eval_full'):
                        outputs = self._eval_epoch(params=state.ema_params)
                        outputs = flax_utils.unreplicate(outputs)
                        scalars = outputs['scalars']
                        writer.write_scalars(step, scalars)

                if step % config.logs.checkpoint_every_steps == 0 or is_last_step:
                    with report_progress.timed('checkpoint'):
                        ckpt.save(flax_utils.unreplicate(state))

        logging.info('Finishing training at step %d', num_train_steps)
Exemple #4
0
def train_and_evaluate(config, workdir):
  """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
  is_first_process = jax.process_index() == 0
  tf.io.gfile.makedirs(workdir)

  # Load Dataset
  # ---------------------------------------------------------------------------
  logging.info('Initializing dataset.')
  train_ds, eval_ds, test_ds, encoder = input_pipeline.get_datasets(
      config)
  config.seq_length = 250
  vocab_size = int(encoder.vocab_size())
  config.num_classes = vocab_size
  config.data_shape = (config.seq_length, 1)

  logging.info('Training with vocab size %d', vocab_size)

  def decode_tokens(toks):
    return encoder.detokenize(toks)

  start_step = 0
  rng = jax.random.PRNGKey(config.seed)
  rng, init_rng = jax.random.split(rng)
  config.per_device_batch_size = config.batch_size // jax.process_count()

  logging.info('Initializing model, optimizer, and step functions.')
  # Build Model and Optimizer
  # ---------------------------------------------------------------------------
  model, initial_variables = model_setup(init_rng, config)

  # Instead of passing the optimizer fns directly, we use a fn that returns
  # the optimizer given a learning rate.
  def tx_fn(lr):
    return optax.adamw(
        lr, b1=0.9, b2=0.99, eps=1e-08, eps_root=0.0,
        weight_decay=config.weight_decay)

  state = language_train_state.TrainState.create(
      params=initial_variables['params'], tx_fn=tx_fn)

  # We access model params only from state below via state.params.
  del initial_variables

  if config.restore_checkpoints:
    # Restore unreplicated model state from last checkpoint.
    state = checkpoints.restore_checkpoint(workdir, state)
    # Grab last step.
    start_step = int(state.step)

  writer = metric_writers.create_default_writer(
      workdir, just_logging=not is_first_process)
  if start_step == 0:
    config_dict = dict(config)
    writer.write_hparams(config_dict)

  if is_first_process and start_step == 0:
    # Dump config file to work dir for easy model loading.
    config_path = os.path.join(workdir, 'config')
    with tf.io.gfile.GFile(config_path, 'wb') as fp:
      pickle.dump(config, fp)

  print('Using state', type(state))
  # Replicate state.
  state = jax_utils.replicate(state)

  learning_rate_fn = create_learning_rate_scheduler(
      factors=config.lr_factors,
      base_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps)

  # Compile multidevice versions of train/eval/predict step fn.
  p_train_step = jax.pmap(
      functools.partial(
          train_step,
          model=model,
          learning_rate_fn=learning_rate_fn,
          clip_grad=config.clip_grad,
          ema_momentum=config.get('ema_momentum', 0.999)),
      axis_name='batch',
      in_axes=(0, 0),
      donate_argnums=(0,))
  p_eval_step = jax.pmap(
      functools.partial(
          eval_step, model=model),
      axis_name='batch')

  # Main Train Loop
  # ---------------------------------------------------------------------------

  # We init the first set of train PRNG keys, but update it afterwards inside
  # the main pmap'd training update for performance.
  rng = jax.random.fold_in(rng, jax.process_index())
  rng1, rng2, rng3, extensive_eval_rngs, sample_rng = jax.random.split(rng, 5)
  train_rngs = jax.random.split(rng1, jax.local_device_count())
  eval_rngs = jax.random.split(rng2, jax.local_device_count())
  test_rngs = jax.random.split(rng3, jax.local_device_count())
  del rng, rng1, rng2, rng3

  logging.info('Starting training loop.')
  hooks = []
  report_progress = periodic_actions.ReportProgress(
      num_train_steps=config.num_train_steps, writer=writer)
  if is_first_process:
    hooks += [
        report_progress,
        periodic_actions.Profile(logdir=workdir, num_profile_steps=5)
    ]
  train_metrics = []

  # Iterator that does epoch-wise indefinite iteration.
  def iterate_train(train_ds):
    epoch = 1
    while True:
      msg = f'Starting epoch {epoch}'
      logging.info(msg)
      for batch in train_ds:
        yield batch
      epoch += 1

  train_iter = iterate_train(train_ds)

  kl_tracker_train = util_fns.KLTracker(num_steps=model.num_steps)
  kl_history = []

  with metric_writers.ensure_flushes(writer):
    step = start_step
    for step in range(start_step, config.num_train_steps):
      is_last_step = step == config.num_train_steps - 1

      # Shard data to devices and do a training step.
      with jax.profiler.StepTraceAnnotation('train', step_num=step):
        batch = common_utils.shard(jax.tree_map(np.asarray, next(train_iter)))
        state, metrics = p_train_step(
            state, batch, rng=train_rngs)
        train_metrics.append(metrics)

      # Quick indication that training is happening.
      logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step)
      for h in hooks:
        h(step)

      # Periodic metric handling.
      if step > 0 and (step % config.eval_every_steps == 0 or is_last_step):
        with report_progress.timed('training_metrics'):
          logging.info('Gathering training metrics.')
          train_metrics = common_utils.get_metrics(train_metrics)

          # First handle loss terms per step.
          t_batch = train_metrics.pop('t_batch')
          nelbo_per_t_batch = train_metrics.pop('nelbo_per_t_batch')
          kl_tracker_train.update(
              t_batch.reshape(-1), nelbo_per_t_batch.reshape(-1))
          kl_values = kl_tracker_train.get_kl_per_t()
          kl_history.append(np.array(kl_values))
          kl_history = kl_history[-100:]  # Keep last 100 items only.

          # Handle remaining `standard` metrics
          summary = jax.tree_map(jnp.mean, train_metrics)
          summary = {'train_' + k: v for k, v in summary.items()}
          writer.write_scalars(step, summary)
          train_metrics = []

        with report_progress.timed('eval'):
          eval_results, eval_rngs = evaluate(
              p_eval_step=p_eval_step,
              params=state.ema_params,
              eval_ds=eval_ds,
              rng=eval_rngs)
          writer.write_scalars(
              step, {'eval_' + k: v for k, v in eval_results.items()})

          test_results, test_rngs = evaluate(
              p_eval_step=p_eval_step,
              params=state.ema_params,
              eval_ds=test_ds,
              rng=test_rngs)
          writer.write_scalars(
              step, {'test_' + k: v for k, v in test_results.items()})

        if step == 1000 or (step > 0 and
                            step % config.detailed_eval_every_steps == 0):
          if is_first_process:
            loss_components_path = os.path.join(workdir, 'loss_components')
            with tf.io.gfile.GFile(loss_components_path, 'wb') as fp:
              pickle.dump(kl_history[-1], fp)

          extensive_eval_rngs = extensive_eval(
              config, extensive_eval_rngs, writer, workdir,
              model, state, kl_history, test_ds, step,
              decode_tokens)

        with report_progress.timed('generate_text'):
          generate_prediction(sample_rng, config, model, state, writer,
                              decode_tokens, step)

      # Save a checkpoint on one host after every checkpoint_freq steps.
      save_checkpoint = (
          step > 0 and
          (step % config.checkpoint_every_steps == 0 or is_last_step))
      if config.save_checkpoints and save_checkpoint and is_first_process:
        with report_progress.timed('checkpoint'):
          checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(state),
                                      step, overwrite=True)
Exemple #5
0
def train_and_evaluate(config, workdir):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
    logging.info('Starting training at %s', workdir)
    tf.io.gfile.makedirs(workdir)
    if jax.process_index() == 0:
        with tf.io.gfile.GFile(os.path.join(workdir, 'config.json'), 'w') as f:
            json.dump(config.to_dict(), f, indent=2)
    rng = jax.random.PRNGKey(config.seed)

    # Build input pipeline.
    rng, data_rng = jax.random.split(rng)
    data_rng = jax.random.fold_in(data_rng, jax.process_index())
    train_ds, eval_ds = input_pipeline.create_datasets(config.dataset,
                                                       data_rng)
    train_iter = iter(train_ds)

    test_ds = []
    for split in config.dataset.test_splits:
        ds = input_pipeline.create_val_dataset(
            config.dataset, split, config.dataset.test_per_device_batch_size,
            config.dataset.test_pad_last_batch)
        test_ds.append(ds)

    # Learning rate schedule.
    num_train_steps = config.num_train_steps
    if num_train_steps == -1:
        num_train_steps = train_ds.cardinality().numpy()
    steps_per_epoch = num_train_steps // config.dataset.num_epochs
    logging.info('num_train_steps=%d, steps_per_epoch=%d', num_train_steps,
                 steps_per_epoch)
    learning_rate_fn = functools.partial(
        train_utils.get_learning_rate,
        base_learning_rate=config.learning_rate,
        num_train_steps=num_train_steps,
        schedule_type=config.learning_rate_schedule,
        warmup_proportion=config.warmup_proportion,
        step_boundaries=config.learning_rate_step_boundaries)

    # Initialize model.
    inputs = train_utils.get_init_inputs(train_ds)
    rng, model_rng = jax.random.split(rng)
    eval_config = models.TransformerConfig(**config.model.to_dict())
    train_config = eval_config.replace(deterministic=False)
    model = models.Model(eval_config)
    state = train_utils.create_train_state(model,
                                           config,
                                           model_rng,
                                           inputs=inputs)

    # Set up checkpointing of the model and the input pipeline.
    checkpoint_dir = os.path.join(workdir, 'checkpoints')
    ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir, max_to_keep=3)
    state = ckpt.restore_or_initialize(state)
    initial_step = int(state.step) + 1

    # Distribute training.
    state = flax_utils.replicate(state)
    p_train_step = jax.pmap(functools.partial(
        train_step,
        config=train_config,
        learning_rate_fn=learning_rate_fn,
        grad_clip=config.grad_clip),
                            axis_name='batch',
                            donate_argnums=(0, ))
    p_eval_step = jax.pmap(functools.partial(eval_step, config=eval_config),
                           axis_name='batch')

    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.process_index() > 0)
    if initial_step == 1:
        writer.write_hparams(train_utils.flatten_config(config))

    logging.info('Starting training loop at step %d.', initial_step)
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=num_train_steps, writer=writer)
    if jax.process_index() == 0:
        hooks += [
            report_progress,
            periodic_actions.Profile(
                num_profile_steps=config.num_profile_steps, logdir=workdir)
        ]

    rng, train_rngs = jax.random.split(rng)
    train_rngs = jax.random.fold_in(train_rngs, jax.process_index())
    train_rngs = jax.random.split(train_rngs, jax.local_device_count())

    train_metrics = []
    with metric_writers.ensure_flushes(writer):
        for step in range(initial_step, num_train_steps + 1):
            is_last_step = step == num_train_steps
            with jax.profiler.StepTraceContext('train', step_num=step):
                batch = jax.tree_map(np.asarray, next(train_iter))
                state, metrics = p_train_step(batch=batch,
                                              rng=train_rngs,
                                              state=state)
                train_metrics.append(metrics)

            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, 'Finished training step %d.', 5,
                                step)
            for h in hooks:
                h(step)

            if config.log_loss_every_steps > 0 and (
                    step % config.log_loss_every_steps == 0 or is_last_step):
                train_metrics = common_utils.get_metrics(train_metrics)
                lr = train_metrics.pop('learning_rate').mean()
                train_summary = train_utils.metrics_summary(
                    train_metrics, 'train')
                train_summary['learning_rate'] = lr
                writer.write_scalars(step, train_summary)
                train_metrics = []

            if config.eval_every_steps > 0 and (step % config.eval_every_steps
                                                == 0 or is_last_step):
                with report_progress.timed('eval'):
                    eval_summary = evaluate(p_eval_step, state, eval_ds,
                                            config.num_eval_steps)
                writer.write_scalars(step, eval_summary)

            if config.checkpoint_every_steps > 0 and (
                    step % config.checkpoint_every_steps == 0 or is_last_step):
                with report_progress.timed('checkpoint'):
                    ckpt.save(flax_utils.unreplicate(state))
                logging.info('Checkpoint saved to %s', checkpoint_dir)

    logging.info('Finishing training at step %d', num_train_steps)
Exemple #6
0
def train_and_evaluate(config, work_dir, try_checkpoint=True):
    """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.
    work_dir: Directory where the tensorboard summaries are written to.
    try_checkpoint: Should try to load checkpoint (usually enabled, practical
        for debugging purposes to disable).

  Returns:
    The train state (which includes the `.params`).
  """
    # Init rng key.
    rng = jax.random.PRNGKey(config.seed)
    data_rng, rng = jax.random.split(rng)
    is_first_host = jax.process_index() == 0

    if config.dataset.name.endswith('speech_commands09'):
        ds, ds_metadata = input_pipeline_sc09.get_dataset(data_rng, config)
    else:
        raise ValueError(f'Unknown dataset {config.dataset.name}.')

    # Immediately create infinite iterators.
    it = jax.tree_map(util_fns.get_iterator, ds)

    # TODO(agritsenko): Can we fix the ugly nested dicts?
    config.data_shape = ds_metadata['train']['shape']['inputs'][2:]
    config.num_classes = ds_metadata['train']['num_classes']
    config.sample_rate = ds_metadata['train']['sample_rate']

    writer = metric_writers.create_default_writer(
        work_dir, just_logging=jax.process_index() > 0)
    rng, init_rng = jax.random.split(rng)

    model, variables = model_setup(init_rng, config)

    # From now on we want different rng across hosts:
    rng = jax.random.fold_in(rng, jax.process_index())

    def tx_fn(lr):
        return optax.adamw(lr,
                           b1=0.9,
                           b2=config.beta2,
                           eps=1e-08,
                           eps_root=0.0,
                           weight_decay=config.weight_decay)

    state = language_train_state.TrainState.create(params=variables['params'],
                                                   tx_fn=tx_fn)

    start_step = None
    if try_checkpoint:
        state, start_step = checkpoint.restore_from_path(work_dir, state)
    start_step = start_step or 0

    # Use different rngs for train & eval.
    rng_train, rng_eval, rng_sample = jax.random.split(rng, 3)

    kl_tracker = util_fns.KLTracker(num_steps=model.num_steps)
    kl_history = []

    learning_rate_fn = train_utils.create_learning_rate_scheduler(
        **config.learning_rate)
    p_train_step = jax.pmap(functools.partial(
        train_step,
        config=config,
        learning_rate_fn=learning_rate_fn,
        model=model),
                            axis_name='batch',
                            in_axes=(None, 0, 0),
                            out_axes=(0, 0, None),
                            donate_argnums=(2, ))

    # The only axes that are broadcasted are the in- and output rng key ones. The
    # rng is the first arg, and the last return value.
    p_eval_step = jax.pmap(functools.partial(eval_step, model=model),
                           axis_name='batch',
                           in_axes=(None, 0, 0),
                           out_axes=(0, 0, None))

    # Training length.
    logging.info('Training will start from step %d', start_step)

    # Replicate state.
    state = flax.jax_utils.replicate(state)

    # Setup hooks.
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_train_steps, writer=writer)
    if is_first_host:
        hooks += [
            report_progress,
            periodic_actions.Profile(logdir=work_dir, num_profile_steps=5)
        ]

    with metric_writers.ensure_flushes(writer):
        batch_metrics = []
        for step in range(start_step, config.num_train_steps):
            logging.log_first_n(logging.INFO, f'Train step: {step}', 5)
            with jax.profiler.StepTraceAnnotation('train', step_num=step):
                state, metrics, rng_train = p_train_step(
                    rng_train, next(it['train']), state)
            batch_metrics.append(metrics)

            # Cycle though hooks.
            for h in hooks:
                h(step)

            is_last_step = step == config.num_train_steps - 1

            if (step % config.log_every_steps == 0) or is_last_step:
                with report_progress.timed('training_metrics'):
                    ################### Process batch metrics ############################
                    batch_metrics = jax.device_get(
                        flax.jax_utils.unreplicate(batch_metrics))

                    if 't_batch' in metrics:
                        # TODO(agritsenko): Factor out into a separate function.
                        # This processes the loss per t, although two nested for-loops
                        # (counting the one inside kl_tracker), it actually does not hurt
                        # timing performance meaningfully.
                        batch_t = [
                            metrics['t_batch'].reshape(-1)
                            for metrics in batch_metrics
                        ]
                        batch_nelbo_per_t = [
                            metrics['nelbo_per_t_batch'].reshape(-1)
                            for metrics in batch_metrics
                        ]
                        for t, nelbo_per_t in zip(batch_t, batch_nelbo_per_t):
                            kl_tracker.update(t, nelbo_per_t)

                    ################### Process batch metrics ############################
                    metrics = {
                        key:
                        np.mean([metrics[key] for metrics in batch_metrics])
                        for key in batch_metrics[0] if 'batch' not in key
                    }

                    # Metric logging.
                    if is_first_host:
                        log_standard_metrics(writer,
                                             step,
                                             train_metrics=metrics)
                    batch_metrics = []

            if config.eval_every_steps and (
                (step % config.eval_every_steps == 0) or is_last_step):
                with report_progress.timed('eval'):
                    ####################### Run evaluation ###############################
                    metrics, rng_eval = eval_model(
                        p_eval_step, rng_eval, state, it['eval'],
                        (ds_metadata['eval']['num_batches'] *
                         config.get('num_eval_passes', 1)))

                    # Metric logging.
                    if is_first_host:
                        log_standard_metrics(writer,
                                             step,
                                             eval_metrics=metrics)

                # Track KL (unrelated to the eval, but nice to not do every step).
                kl_values = kl_tracker.get_kl_per_t()
                kl_history.append(np.array(kl_values))
                kl_history = kl_history[-50:]

            if config.sample_every_steps and (
                (step % config.sample_every_steps == 0) or is_last_step):
                with report_progress.timed('sample'):
                    ######################### Run sampling ###############################
                    chain = model.sample(jax.random.fold_in(rng_sample, step),
                                         state.ema_params,
                                         config.sample_batch_size,
                                         chain_out_size=config.get(
                                             'chain_out_size',
                                             model.num_stages))

                    if is_first_host:
                        chain = jax.device_get(chain)
                        long_sample = np.reshape(chain[-1],
                                                 (1, -1, 1)).astype(np.float32)
                        long_sample = (2. *
                                       long_sample) / config.num_classes - 1.
                        writer.write_audios(step, {'samples': long_sample},
                                            sample_rate=config.sample_rate)

            ######################### Checkpointing #################################
            if is_first_host and config.checkpoint_every_steps and (
                (step % config.checkpoint_every_steps == 0) or is_last_step):
                logging.info('Saving checkpoint: step %d', step)
                with report_progress.timed('checkpoint'):
                    checkpoint.save_checkpoint(
                        work_dir,
                        state=flax.jax_utils.unreplicate(state),
                        step=step)
                logging.info('Finished saving checkpoint: step %d', step)

        return state
def train_and_evaluate(config, workdir):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
    if config.dataset.batch_size % jax.device_count() != 0:
        raise ValueError(
            "Batch size must be divisible by the number of devices.")

    tf.io.gfile.makedirs(workdir)
    # Deterministic training.
    rng = jax.random.PRNGKey(config.seed)
    # Shift the numpy random seed by process_index() to shuffle data loaded
    # by different hosts
    np.random.seed(20201473 + jax.process_index())

    #----------------------------------------------------------------------------
    # Build input pipeline.
    rng, data_rng = jax.random.split(rng)
    data_rng = jax.random.fold_in(data_rng, jax.process_index())
    config.dataset.data_dir = os.path.join(config.dataset.base_dir,
                                           config.dataset.scene)
    train_ds, eval_ds = datasets.create_dataset(config)
    example_batch = train_ds.peek()

    #----------------------------------------------------------------------------
    # Learning rate schedule.
    num_train_steps = config.train.max_steps
    if num_train_steps == -1:
        num_train_steps = train_ds.size()
    steps_per_epoch = num_train_steps // config.train.num_epochs
    logging.info("num_train_steps=%d, steps_per_epoch=%d", num_train_steps,
                 steps_per_epoch)

    learning_rate_fn = train_utils.create_learning_rate_fn(config)

    #----------------------------------------------------------------------------
    # Initialize model.
    rng, model_rng = jax.random.split(rng)
    model, state = models.create_train_state(
        config,
        model_rng,
        learning_rate_fn=learning_rate_fn,
        example_batch=example_batch,
    )

    #----------------------------------------------------------------------------
    # Set up checkpointing of the model and the input pipeline.
    state = checkpoints.restore_checkpoint(workdir, state)
    initial_step = int(state.step) + 1

    #----------------------------------------------------------------------------
    # Distribute training.
    state = flax_utils.replicate(state)
    p_train_step = jax.pmap(
        functools.partial(
            train_step,
            model=model,
            learning_rate_fn=learning_rate_fn,
            weight_decay=config.train.weight_decay,
            config=config,
        ),
        axis_name="batch",
    )

    # Get distributed rendering function
    render_pfn = render_utils.get_render_function(
        model=model,
        config=config,
        randomized=False,  # No randomization for evaluation.
    )

    #----------------------------------------------------------------------------
    # Prepare Metric Writers
    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.process_index() > 0)
    if initial_step == 1:
        writer.write_hparams(dict(config))

    logging.info("Starting training loop at step %d.", initial_step)
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=num_train_steps, writer=writer)
    if jax.process_index() == 0:
        hooks += [
            report_progress,
        ]
    train_metrics = None

    # Prefetch_buffer_size = 6 x batch_size
    ptrain_ds = flax.jax_utils.prefetch_to_device(train_ds, 6)
    n_local_devices = jax.local_device_count()
    rng = rng + jax.process_index()  # Make random seed separate across hosts.
    keys = jax.random.split(rng, n_local_devices)  # For pmapping RNG keys.

    with metric_writers.ensure_flushes(writer):
        for step in range(initial_step, num_train_steps + 1):
            # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU
            # devices.
            is_last_step = step == num_train_steps
            with jax.profiler.StepTraceAnnotation("train", step_num=step):
                batch = next(ptrain_ds)
                state, metrics_update, keys = p_train_step(rng=keys,
                                                           state=state,
                                                           batch=batch)
                metric_update = flax_utils.unreplicate(metrics_update)
                train_metrics = (metric_update if train_metrics is None else
                                 train_metrics.merge(metric_update))
            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, "Finished training step %d.", 5,
                                step)
            for h in hooks:
                h(step)

            if step % config.train.log_loss_every_steps == 0 or is_last_step:
                writer.write_scalars(step, train_metrics.compute())
                train_metrics = None

            if step % config.train.render_every_steps == 0 or is_last_step:
                test_batch = next(eval_ds)
                test_pixels = model_utils.uint2float(
                    test_batch.target_view.rgb)  # extract for evaluation
                with report_progress.timed("eval"):
                    pred_color, pred_disp, pred_acc = eval_step(
                        state, keys[0], test_batch, render_pfn, config)
                #------------------------------------------------------------------
                # Log metrics and images for host 0
                #------------------------------------------------------------------
                if jax.process_index() == 0:
                    psnr = model_utils.compute_psnr(
                        ((pred_color - test_pixels)**2).mean())
                    ssim = skmetrics.structural_similarity(
                        pred_color.astype(np.float32),
                        test_pixels.astype(np.float32),
                        win_size=11,
                        multichannel=True,
                        gaussian_weight=True)
                    writer.write_scalars(
                        step, {
                            "train_eval/test_psnr": psnr,
                            "train_eval/test_ssim": ssim,
                        })
                    writer.write_images(
                        step, {
                            "test_pred_color": pred_color[None, :],
                            "test_target": test_pixels[None, :]
                        })
                    if pred_disp is not None:
                        writer.write_images(
                            step, {"test_pred_disp": pred_disp[None, :]})
                    if pred_acc is not None:
                        writer.write_images(
                            step, {"test_pred_acc": pred_acc[None, :]})
                #------------------------------------------------------------------

            if (jax.process_index()
                    == 0) and (step % config.train.checkpoint_every_steps == 0
                               or is_last_step):
                # Write final metrics to file
                with file_utils.open_file(
                        os.path.join(workdir, "train_logs.json"), "w") as f:
                    log_dict = metric_update.compute()
                    for k, v in log_dict.items():
                        log_dict[k] = v.item()
                    f.write(json.dumps(log_dict))
                with report_progress.timed("checkpoint"):
                    state_to_save = jax.device_get(
                        jax.tree_map(lambda x: x[0], state))
                    checkpoints.save_checkpoint(workdir,
                                                state_to_save,
                                                step,
                                                keep=100)

    logging.info("Finishing training at step %d", num_train_steps)
Exemple #8
0
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
    tf.io.gfile.makedirs(workdir)

    # Number of local devices for this host.
    n_devices = jax.local_device_count()

    if config.batch_size % n_devices:
        raise ValueError(
            "Batch size must be divisible by the number of devices")

    vocab_path = config.vocab_path
    if vocab_path is None:
        vocab_path = os.path.join(workdir, "sentencepiece_model")
    tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

    # Load Dataset
    # ---------------------------------------------------------------------------
    logging.info("Initializing dataset.")
    train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
        n_devices=n_devices,
        dataset_name=config.dataset_name,
        eval_dataset_name=config.eval_dataset_name,
        shard_idx=jax.host_id(),
        shard_count=jax.host_count(),
        vocab_path=vocab_path,
        target_vocab_size=config.vocab_size,
        batch_size=config.batch_size,
        max_corpus_chars=config.max_corpus_chars,
        max_length=config.max_target_length,
        max_eval_length=config.max_eval_target_length)
    train_iter = iter(train_ds)
    vocab_size = int(encoder.vocab_size())
    eos_id = decode.EOS_ID  # Default Sentencepiece EOS token.

    def decode_tokens(toks):
        valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
        return encoder.detokenize(valid_toks).numpy().decode("utf-8")

    if config.num_predict_steps > 0:
        predict_ds = predict_ds.take(config.num_predict_steps)

    logging.info("Initializing model, optimizer, and step functions.")

    # Build Model and Optimizer
    # ---------------------------------------------------------------------------
    train_config = models.TransformerConfig(
        vocab_size=vocab_size,
        output_vocab_size=vocab_size,
        share_embeddings=config.share_embeddings,
        logits_via_embedding=config.logits_via_embedding,
        dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32,
        emb_dim=config.emb_dim,
        num_heads=config.num_heads,
        num_layers=config.num_layers,
        qkv_dim=config.qkv_dim,
        mlp_dim=config.mlp_dim,
        max_len=max(config.max_target_length, config.max_eval_target_length),
        dropout_rate=config.dropout_rate,
        attention_dropout_rate=config.attention_dropout_rate,
        deterministic=False,
        decode=False,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6))
    eval_config = train_config.replace(deterministic=True)
    predict_config = train_config.replace(deterministic=True, decode=True)

    start_step = 0
    rng = random.PRNGKey(config.seed)
    rng, init_rng = random.split(rng)
    input_shape = (config.batch_size, config.max_target_length)
    target_shape = (config.batch_size, config.max_target_length)

    m = models.Transformer(eval_config)
    initial_variables = jax.jit(m.init)(init_rng,
                                        jnp.ones(input_shape, jnp.float32),
                                        jnp.ones(target_shape, jnp.float32))

    # apply an optimizer to this tree
    optimizer_def = optim.Adam(config.learning_rate,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=config.weight_decay)
    optimizer = optimizer_def.create(initial_variables["params"])

    # We access model params only from optimizer below via optimizer.target.
    del initial_variables

    if config.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(workdir, optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)

    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.host_id() > 0)
    if start_step == 1:
        writer.write_hparams(dict(config))

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = create_learning_rate_scheduler(
        base_learning_rate=config.learning_rate,
        warmup_steps=config.warmup_steps)

    # compile multidevice versions of train/eval/predict step and cache init fn.
    p_train_step = jax.pmap(functools.partial(
        train_step,
        config=train_config,
        learning_rate_fn=learning_rate_fn,
        label_smoothing=config.label_smoothing),
                            axis_name="batch",
                            donate_argnums=(0, ))  # pytype: disable=wrong-arg-types
    p_eval_step = jax.pmap(functools.partial(
        eval_step, config=eval_config, label_smoothing=config.label_smoothing),
                           axis_name="batch")
    p_init_cache = jax.pmap(functools.partial(
        initialize_cache,
        max_decode_len=config.max_predict_length,
        config=predict_config),
                            axis_name="batch")
    p_pred_step = jax.pmap(
        functools.partial(predict_step,
                          config=predict_config,
                          beam_size=config.beam_size),
        axis_name="batch",
        static_broadcasted_argnums=(3,
                                    4))  # eos token, max_length are constant

    # Main Train Loop
    # ---------------------------------------------------------------------------

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap"d training update for performance.
    dropout_rngs = random.split(rng, n_devices)

    logging.info("Starting training loop.")
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_train_steps, writer=writer)
    if jax.host_id() == 0:
        hooks += [
            report_progress,
            periodic_actions.Profile(num_profile_steps=5)
        ]
    metrics_all = []
    with metric_writers.ensure_flushes(writer):
        for step, batch in zip(range(start_step, config.num_train_steps),
                               train_iter):
            # Shard data to devices and do a training step.
            batch = common_utils.shard(
                jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access
            optimizer, metrics, dropout_rngs = p_train_step(
                optimizer, batch, dropout_rng=dropout_rngs)
            metrics_all.append(metrics)

            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, "Finished training step %d.", 5,
                                step)
            for h in hooks:
                h(step)

            # Save a checkpoint on one host after every checkpoint_freq steps.
            if (config.save_checkpoints and step % config.checkpoint_freq == 0
                    and step > 0 and jax.host_id() == 0):
                checkpoints.save_checkpoint(workdir,
                                            jax_utils.unreplicate(optimizer),
                                            step)

            # Periodic metric handling.
            if step % config.eval_frequency != 0 and step > 0:
                continue

            # Training Metrics
            logging.info("Gathering training metrics.")
            metrics_all = common_utils.get_metrics(metrics_all)
            lr = metrics_all.pop("learning_rate").mean()
            metrics_sums = jax.tree_map(jnp.sum, metrics_all)
            denominator = metrics_sums.pop("denominator")
            summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
            summary["learning_rate"] = lr
            summary = {"train_" + k: v for k, v in summary.items()}
            writer.write_scalars(step, summary)
            metrics_all = []

            # Eval Metrics
            logging.info("Gathering evaluation metrics.")
            eval_metrics = []
            eval_iter = iter(eval_ds)
            for _, eval_batch in zip(range(config.num_eval_steps), eval_iter):
                eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
                eval_batch = common_utils.shard(eval_batch)
                metrics = p_eval_step(optimizer.target, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
            eval_denominator = eval_metrics_sums.pop("denominator")
            eval_summary = jax.tree_map(
                lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
                eval_metrics_sums)
            eval_summary = {"eval_" + k: v for k, v in eval_summary.items()}
            writer.write_scalars(step, eval_summary)

            # Translation and BLEU Score.
            logging.info("Translating evaluation dataset.")
            t_inference_start = time.time()
            sources, references, predictions = [], [], []
            for pred_batch in predict_ds:
                pred_batch = jax.tree_map(lambda x: x._numpy(), pred_batch)  # pylint: disable=protected-access
                # Handle final odd-sized batch by padding instead of dropping it.
                cur_pred_batch_size = pred_batch["inputs"].shape[0]
                if cur_pred_batch_size % n_devices:
                    padded_size = int(
                        np.ceil(cur_pred_batch_size / n_devices) * n_devices)
                    pred_batch = jax.tree_map(
                        lambda x: pad_examples(x, padded_size),  # pylint: disable=cell-var-from-loop
                        pred_batch)
                pred_batch = common_utils.shard(pred_batch)
                cache = p_init_cache(pred_batch["inputs"])
                predicted = p_pred_step(pred_batch["inputs"], optimizer.target,
                                        cache, eos_id,
                                        config.max_predict_length)
                predicted = tohost(predicted)
                inputs = tohost(pred_batch["inputs"])
                targets = tohost(pred_batch["targets"])
                # Iterate through non-padding examples of batch.
                for i, s in enumerate(predicted[:cur_pred_batch_size]):
                    sources.append(decode_tokens(inputs[i]))
                    references.append(decode_tokens(targets[i]))
                    predictions.append(decode_tokens(s))
            logging.info(
                "Translation: %d predictions %d references %d sources.",
                len(predictions), len(references), len(sources))
            logging.info("Translation time: %.4f s step %d.",
                         time.time() - t_inference_start, step)

            # Calculate BLEU score for translated eval corpus against reference.
            bleu_matches = bleu.bleu_partial(references, predictions)
            all_bleu_matches = per_host_sum_pmap(bleu_matches)
            bleu_score = bleu.complete_bleu(*all_bleu_matches)
            # Save translation samples for tensorboard.
            exemplars = ""
            for n in np.random.choice(np.arange(len(predictions)), 8):
                exemplars += f"{sources[n]}\n\n{references[n]}\n\n{predictions[n]}\n\n"
            writer.write_scalars(step, {"bleu": bleu_score})
            writer.write_texts(step, {"samples": exemplars})
Exemple #9
0
def train_and_evaluate(config, workdir):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
    tf.io.gfile.makedirs(workdir)

    rng = jax.random.PRNGKey(config.seed)

    # Build input pipeline.
    rng, data_rng = jax.random.split(rng)
    data_rng = jax.random.fold_in(data_rng, jax.host_id())
    splits = input_pipeline.create_datasets(config, data_rng)
    num_classes = splits.info.features["label"].num_classes
    train_iter = iter(splits.train)  # pytype: disable=wrong-arg-types

    # Learning rate schedule.
    num_train_steps = config.num_train_steps
    if num_train_steps == -1:
        num_train_steps = splits.train.cardinality().numpy()
    steps_per_epoch = num_train_steps // config.num_epochs
    logging.info("num_train_steps=%d, steps_per_epoch=%d", num_train_steps,
                 steps_per_epoch)
    # We treat the learning rate in the config as the learning rate for batch size
    # 32 but scale it according to our batch size.
    global_batch_size = config.per_device_batch_size * jax.device_count()
    base_learning_rate = config.learning_rate * global_batch_size / 32.0
    learning_rate_fn = functools.partial(get_learning_rate,
                                         base_learning_rate=base_learning_rate,
                                         steps_per_epoch=steps_per_epoch,
                                         num_epochs=config.num_epochs,
                                         warmup_epochs=config.warmup_epochs)

    # Initialize model.
    rng, model_rng = jax.random.split(rng)
    model, state = create_train_state(
        config,
        model_rng,
        input_shape=splits.train.element_spec["input"].shape[1:],
        num_classes=num_classes)

    # Set up checkpointing of the model and the input pipeline.
    checkpoint_dir = os.path.join(workdir, "checkpoints")
    ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir,
                                          {"train_iter": train_iter},
                                          max_to_keep=2)
    state = ckpt.restore_or_initialize(state)
    initial_step = int(state.step) + 1

    # Count number of trainable parameters. This must be done before replicating
    # the state to avoid double-counting replicated parameters.
    param_count = sum(p.size for p in jax.tree_leaves(state.optimizer.target))

    # Distribute training over local devices.
    state = flax_utils.replicate(state)

    p_train_step = jax.pmap(functools.partial(
        train_step,
        model=model,
        learning_rate_fn=learning_rate_fn,
        weight_decay=config.weight_decay),
                            axis_name=_PMAP_AXIS_NAME)

    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.host_id() > 0)
    if initial_step == 1:
        writer.write_hparams(dict(config))
        # Log the number of trainable params.
        writer.write_scalars(initial_step, {"param_count": param_count})

    logging.info("Starting training loop at step %d.", initial_step)
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=num_train_steps, writer=writer)
    if jax.host_id() == 0:
        hooks += [
            report_progress,
            periodic_actions.Profile(num_profile_steps=5, logdir=workdir)
        ]
    train_metrics = None
    with metric_writers.ensure_flushes(writer):
        for step in range(initial_step, num_train_steps + 1):
            # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU
            # devices.
            is_last_step = step == num_train_steps

            with jax.profiler.StepTraceContext("train", step_num=step):
                batch = jax.tree_map(np.asarray, next(train_iter))
                state, metrics_update = p_train_step(state=state, batch=batch)
                metric_update = flax_utils.unreplicate(metrics_update)
                train_metrics = (metric_update if train_metrics is None else
                                 train_metrics.merge(metric_update))

            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, "Finished training step %d.", 5,
                                step)
            for h in hooks:
                h(step)

            if step % config.log_loss_every_steps == 0 or is_last_step:
                writer.write_scalars(step, train_metrics.compute())
                train_metrics = None

            # When combining train and eval, we do not evaluate while training.
            if ((step % config.eval_every_steps == 0 or is_last_step)
                    and not config.combine_train_val_and_eval_on_test):
                with report_progress.timed("eval"):
                    eval_metrics = evaluate(model, state, splits.validation,
                                            config.num_eval_steps)
                writer.write_scalars(step, eval_metrics.compute())

            if step % config.checkpoint_every_steps == 0 or is_last_step:
                with report_progress.timed("checkpoint"):
                    ckpt.save(flax_utils.unreplicate(state))

            if is_last_step and config.combine_train_val_and_eval_on_test:
                # Evaluate a single time on the test set when requested.
                with report_progress.timed("test"):
                    test_metrics = evaluate(model, state, splits.test,
                                            config.num_eval_steps)
                writer.write_scalars(step, test_metrics.compute())

    logging.info("Finishing training at step %d", num_train_steps)
Exemple #10
0
def train(base_dir, config):
    """Train function."""
    print(config)
    chkpt_manager = checkpoint.Checkpoint(str(base_dir / 'train'))

    writer = create_default_writer()

    # Initialize dataset
    key = jax.random.PRNGKey(config.seed)
    key, subkey = jax.random.split(key)
    ds = dataset.get_dataset(config, subkey, num_tasks=config.num_tasks)
    ds_iter = iter(ds)

    key, subkey = jax.random.split(key)
    encoder = MLPEncoder(**config.encoder)

    train_config = config.train.to_dict()
    train_method = train_config.pop('method')

    module_config = train_config.pop('module')
    module_class = module_config.pop('name')

    module = globals().get(module_class)(encoder, **module_config)
    train_step = globals().get(f'train_step_{train_method}')
    train_step = functools.partial(train_step, **train_config)

    params = module.init(subkey, next(ds_iter)[0])
    lr = optax.cosine_decay_schedule(config.learning_rate,
                                     config.num_train_steps)
    optim = optax.chain(optax.adam(lr),
                        # optax.adaptive_grad_clip(0.15)
                        )

    state = TrainState.create(apply_fn=module.apply, params=params, tx=optim)
    state = chkpt_manager.restore_or_initialize(state)

    # Hooks
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_train_steps, writer=writer)
    hooks = [
        report_progress,
        periodic_actions.Profile(num_profile_steps=5, logdir=str(base_dir))
    ]

    def handle_preemption(signal_number, _):
        logging.info('Received signal %d, saving checkpoint.', signal_number)
        with report_progress.timed('checkpointing'):
            chkpt_manager.save(state)
        logging.info('Finished saving checkpoint.')

    signal.signal(signal.SIGTERM, handle_preemption)

    metrics = TrainMetrics.empty()
    with metric_writers.ensure_flushes(writer):
        for step in tqdm.tqdm(range(state.step, config.num_train_steps)):
            with jax.profiler.StepTraceAnnotation('train', step_num=step):
                states, targets = next(ds_iter)
                state, metrics = train_step(state, metrics, states, targets)

            logging.log_first_n(logging.INFO, 'Finished training step %d', 5,
                                step)

            if step % config.log_metrics_every == 0:
                writer.write_scalars(step, metrics.compute())
                metrics = TrainMetrics.empty()

            # if step % config.log_eval_metrics_every == 0 and isinstance(
            #     ds, dataset.MDPDataset):
            #   eval_metrics = evaluate_mdp(state, ds.aux_task_matrix, config)
            #   writer.write_scalars(step, eval_metrics.compute())

            for hook in hooks:
                hook(step)

    chkpt_manager.save(state)
    return state
        jax.random.normal(  # charlinel(why benefit of np?)
            weight_key, (d, num_tasks),
            dtype=jnp.float64))

    assert optimizer == 'sgd', 'Non-sgd not yet supported.'

    writer = metric_writers.create_default_writer(logdir=str(workdir), )

    hooks = [
        periodic_actions.PeriodicCallback(
            every_steps=5_000,
            callback_fn=lambda step, t: chkpt_manager.save((step, Phi)))
    ]

    # Perform num_epochs gradient steps.
    with metric_writers.ensure_flushes(writer):
        for step in etqdm.tqdm(range(initial_step + 1, num_epochs + 1),
                               initial=initial_step,
                               total=num_epochs):
            # Draw one or many source states to update, and its task.
            source_states, key = utils.draw_states(num_states, main_batch_size,
                                                   key)
            task, key = utils.draw_states(num_tasks, 1, key)  # bad Marc!

            # Use the source states to update our estimate of the feature norm.
            # Do this pre-LISSA, avoid a bad first gradient.
            if method == 'lissa' and estimate_feature_norm:
                max_norm = utils.compute_max_feature_norm(
                    Phi[source_states, :])
                estimated_feature_norm += 0.01 * (max_norm -
                                                  estimated_feature_norm)
Exemple #12
0
def evaluate(base_dir, config, *, train_state):
    """Eval function."""
    chkpt_manager = checkpoint.Checkpoint(str(base_dir / 'eval'))

    writer = create_default_writer()

    key = jax.random.PRNGKey(config.eval.seed)
    model_init_key, ds_key = jax.random.split(key)

    linear_module = LinearModule(config.eval.num_tasks)
    params = linear_module.init(model_init_key,
                                jnp.zeros((config.encoder.embedding_dim, )))
    lr = optax.cosine_decay_schedule(config.eval.learning_rate,
                                     config.num_eval_steps)
    optim = optax.adam(lr)

    ds = dataset.get_dataset(config, ds_key, num_tasks=config.eval.num_tasks)
    ds_iter = iter(ds)

    state = TrainState.create(apply_fn=linear_module.apply,
                              params=params,
                              tx=optim)
    state = chkpt_manager.restore_or_initialize(state)

    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_eval_steps, writer=writer)
    hooks = [
        report_progress,
        periodic_actions.Profile(num_profile_steps=5, logdir=str(base_dir))
    ]

    def handle_preemption(signal_number, _):
        logging.info('Received signal %d, saving checkpoint.', signal_number)
        with report_progress.timed('checkpointing'):
            chkpt_manager.save(state)
        logging.info('Finished saving checkpoint.')

    signal.signal(signal.SIGTERM, handle_preemption)

    metrics = EvalMetrics.empty()
    with metric_writers.ensure_flushes(writer):
        for step in tqdm.tqdm(range(state.step, config.num_eval_steps)):
            with jax.profiler.StepTraceAnnotation('eval', step_num=step):
                states, targets = next(ds_iter)
                state, metrics = evaluate_step(train_state, state, metrics,
                                               states, targets)

            if step % config.log_metrics_every == 0:
                writer.write_scalars(step, metrics.compute())
                metrics = EvalMetrics.empty()

            for hook in hooks:
                hook(step)

        # Finally, evaluate on the true(ish) test aux task matrix.
        states, targets = dataset.EvalDataset(config, ds_key).get_batch()

        @jax.jit
        def loss_fn():
            outputs = train_state.apply_fn(train_state.params, states)
            phis = outputs.phi
            predictions = jax.vmap(state.apply_fn,
                                   in_axes=(None, 0))(state.params, phis)
            return jnp.mean(optax.l2_loss(predictions, targets))

        test_loss = loss_fn()
        writer.write_scalars(config.num_eval_steps + 1,
                             {'test_loss': test_loss})
Exemple #13
0
def train(config: ml_collections.ConfigDict):
  """Run training."""

  # Establish host information
  local_device_count = jax.local_device_count()
  host_count = jax.process_count()
  host_id = jax.process_index()

  task = task_registry.get_registered_task(config.task_name)

  start_step = 0
  rng = jax.random.PRNGKey(config.seed)

  model_config = ml_collections.FrozenConfigDict(config.model_config)
  model = task.build_model(model_config)

  # Initialization needs to be pmapped because models use collective ops.
  # Create dummy input
  dummy_input = {
      key: jnp.tile(value, (local_device_count,) + (1,) * value.ndim)
      for key, value in task.dummy_input(config).items()
  }

  rng, init_rng = jax.random.split(rng)
  init_rng = jax.random.split(init_rng, local_device_count)

  logging.info('Initializing model.')
  initial_variables = jax.pmap(
      model.init, 'batch', static_broadcasted_argnums=2)(init_rng, dummy_input,
                                                         True)
  logging.info('Finished initializing model.')
  initial_variables = initial_variables.unfreeze()

  if config.load_weights is not None:
    logging.info('Loading model weights from file')
    loaded_variables = task.load_weights(config)
    unexpected, missing = checkpoint_utils.merge_nested_dicts(
        initial_variables, loaded_variables)
    logging.info('*** Unexpected features: ***')
    for feature_name in unexpected:
      logging.info('\t%s', feature_name)
    logging.info('*** Missing features: ***')
    for feature_name in missing:
      logging.info('\t%s', feature_name)

  model_vars = {
      key: value for key, value in initial_variables.items() if key != 'params'
  }

  learning_rate_fn = optim_utils.create_learning_rate_scheduler(
      learning_rate=config.learning_rate,
      warmup=config.warmup,
      warmup_steps=config.get('warmup_steps', None),
      linear_decay=config.linear_decay,
      max_steps=config.num_train_steps,
      decay_minimum_factor=config.get('decay_minimum_factor', None),
  )

  if config.weight_decay_exclude is not None:
    decay_mask = optim_utils.create_dict_mask(initial_variables['params'],
                                              config.weight_decay_exclude)
  else:
    decay_mask = None
  tx = optax.adamw(
      learning_rate=learning_rate_fn,
      weight_decay=config.weight_decay,
      b1=0.9,
      b2=0.999,
      eps=1e-6,
      mask=decay_mask)
  if config.grad_clip is not None:
    tx = optax.chain(tx, optax.clip_by_global_norm(config.grad_clip))

  ignore_k_nans = config.get('ignore_k_nans')
  if ignore_k_nans is not None:
    tx = optax.apply_if_finite(tx, max_consecutive_errors=ignore_k_nans)

  loss_fn = task.make_loss_fn(config)
  train_state = ts.TrainState.create(
      apply_fn=loss_fn,
      params=jax_utils.unreplicate(initial_variables['params']),
      tx=tx,
  )

  # We access model params only from train state.
  del initial_variables

  # Restore unreplicated train state from last checkpoint
  train_state = checkpoints.restore_checkpoint(config.model_dir, train_state)
  # Grab last step.
  start_step = int(train_state.step)

  writer = metric_writers.create_default_writer(
      config.model_dir, just_logging=jax.process_index() > 0)
  if start_step == 0:
    writer.write_hparams(config.to_dict())

  dropout_rngs = jax.random.split(rng, local_device_count)

  del rng

  # Load datasets
  logging.info('Loading dataset.')

  # Make sure we don't re-use same data if we load weights or checkpoint
  seed = config.seed + start_step
  if config.load_weights:
    seed = seed + hash(config.load_weights)

  name_to_features = task.get_name_to_features(config)
  preprocess_fn = task.make_preprocess_fn(config)
  collater_fn = task.make_collater_fn(config)

  train_data = data_utils.load_multi_dataset(
      datasets_config=config.train_data,
      name_to_features=name_to_features,
      preprocess_fn=preprocess_fn,
      collater_fn=collater_fn,
      is_training=True,
      per_device_batch_size=config.per_device_batch_size,
      local_device_count=local_device_count,
      host_count=host_count,
      host_id=host_id,
      seed=config.seed,
  )
  train_iter = iter(train_data)

  pad_eval = config.get('pad_eval', False)
  if pad_eval:
    logging.info('Eval data is padded such that none of samples are dropped.')
  else:
    logging.warn('Eval data is NOT padded -- some samples might be dropped.')

  eval_data = data_utils.load_multi_dataset(
      datasets_config=config.eval_data,
      name_to_features=name_to_features,
      preprocess_fn=preprocess_fn,
      collater_fn=collater_fn,
      is_training=False,
      per_device_batch_size=config.per_device_batch_size,
      local_device_count=local_device_count,
      host_count=host_count,
      host_id=host_id,
      seed=config.seed,
      pad_eval=pad_eval,
  )
  eval_data = list(eval_data)
  logging.info('Loaded %d samples for evaluation.', len(eval_data))

  # Setup postprocessing_fn for saving samples occasionally.
  if config.get('save_samples_every_steps') is not None:
    if config.get('save_samples_every_steps') % config.eval_every_steps != 0:
      raise ValueError(
          '`eval_every_steps` must divide `save_samples_every_steps`.')
    postprocessing_fn = task.make_output_postprocess_fn(config)

  # Training loop
  logging.info('Starting training.')

  # Replicate train state.
  train_state = jax_utils.replicate(train_state)

  # compile multidevice versions of train/eval/predict step
  p_train_step = jax.pmap(
      functools.partial(
          train_step,
          model_config=model_config,
      ),
      axis_name='batch',
      donate_argnums=(0,),
  )  # pytype: disable=wrong-arg-types
  p_eval_step = jax.pmap(
      functools.partial(
          eval_step,
          model_config=model_config,
      ),
      axis_name='batch')

  hooks = []
  report_progress = periodic_actions.ReportProgress(
      num_train_steps=config.num_train_steps, writer=writer)

  if jax.process_index() == 0:
    hooks += [
        report_progress,
        periodic_actions.Profile(logdir=config.model_dir, num_profile_steps=5)
    ]
  train_metrics = []
  with metric_writers.ensure_flushes(writer):
    for step in range(start_step, config.num_train_steps):
      is_last_step = step == config.num_train_steps - 1

      # Shard data to devices and perform a training step.
      with jax.profiler.StepTraceAnnotation('train', step_num=step):
        batch = jax.tree_map(jnp.asarray, train_iter.get_next())
        train_state, metrics = p_train_step(
            train_state,
            model_vars,
            batch,
            dropout_rngs,
        )
        train_metrics.append(metrics)

      # Quick indication that training is happening.
      logging.log_first_n(logging.INFO, 'Finished training step %d', 5, step)
      for h in hooks:
        h(step)

        # Periodic metric handling.
      if step % config.eval_every_steps == 0 or is_last_step:
        with report_progress.timed('training_metrics'):
          logging.info('Gathering training metrics.')
          train_metrics = common_utils.get_metrics(train_metrics)
          metrics_sums = jax.tree_map(jnp.sum, train_metrics)
          summary = metric_utils.process_metrics(metrics_sums, prefix='train')
          summary['learning_rate'] = learning_rate_fn(step)

          writer.write_scalars(step, summary)
          train_metrics = []

          with report_progress.timed('eval'):
            eval_results, eval_auxiliary = evaluate(
                eval_step_fn=p_eval_step,
                train_state=train_state,
                model_vars=model_vars,
                eval_data=eval_data,
            )
            writer.write_scalars(step, eval_results)

            if config.get('save_samples_every_steps') is not None:
              with report_progress.timed('save_samples'):
                if config.get('save_first_batch_only', 'True'):
                  postprocessing_input = [eval_auxiliary[0]]
                eval_processed = [
                    postprocessing_fn(batch, auxiliary_output)
                    for batch, auxiliary_output in eval_auxiliary
                ]
                data_utils.save_samples_to_json(eval_processed, config, step)

      # Save a checkpoint on one host after every checkpoint_freq steps.
      save_checkpoint = (
          step % config.checkpoint_every_steps == 0 or is_last_step)
      if (config.save_checkpoints and save_checkpoint and
          jax.process_index() == 0):
        with report_progress.timed('checkpoint'):
          logging.info('Saving checkpoint at step %s', step)
          checkpoints.save_checkpoint(
              config.model_dir,
              jax_utils.unreplicate(train_state),
              step,
              keep=config.get('keep_checkpoints', 1),
              keep_every_n_steps=config.get('keep_checkpoint_every_steps'),
          )

      save_model = (
          config.save_every_steps and
          (step % config.save_every_steps == 0 or is_last_step) and step != 0)
      if (save_model and jax.process_index() == 0):
        with report_progress.timed('checkpoint'):
          logging.info('Saving weights at step %s', step)
          save_path = os.path.join(config.model_dir, 'weights',
                                   'step' + str(step))
          # By default, save only encoder weights
          weights = jax_utils.unreplicate(train_state).params['encoder']
          checkpoint_utils.save_weights(save_path, weights)
Exemple #14
0
  def run_train(self, experiment_dir, work_unit_dir,
                rng):
    """Training loop with fixed number of steps and checkpoint every steps."""
    del experiment_dir  # unused
    tf.io.gfile.makedirs(work_unit_dir)

    config = self.config

    total_bs = config.train.batch_size
    assert total_bs % jax.device_count() == 0, (
        f'num total devices {jax.device_count()} must divide the batch size '
        f'{total_bs}')
    device_bs = total_bs // jax.device_count()
    logging.info('total_bs=%d device_bs=%d', total_bs, device_bs)

    # Logging setup
    writer = metric_writers.create_default_writer(
        work_unit_dir, just_logging=jax.host_id() > 0)
    if jax.host_id() == 0:
      utils.write_config_json(config, os.path.join(work_unit_dir,
                                                   'config.json'))

    # Build input pipeline
    logging.info('Substeps per training step: %d', config.train.substeps)
    train_ds = self.dataset.get_tf_dataset(
        split='train',
        batch_shape=(
            jax.local_device_count(),  # for pmap
            config.train.substeps,  # for lax.scan over multiple substeps
            device_bs,  # batch size per device
        ),
        global_rng=jax.random.PRNGKey(config.seed),
        repeat=True,
        shuffle=True,
        augment=True,
        shard_id=jax.host_id(),
        num_shards=jax.host_count())
    train_iter = utils.numpy_iter(train_ds)
    eval_ds = self.dataset.get_tf_dataset(
        split='eval',
        batch_shape=(jax.local_device_count(), device_bs),
        global_rng=jax.random.PRNGKey(config.seed),
        repeat=True,
        shuffle=True,
        augment=False,
        shard_id=jax.host_id(),
        num_shards=jax.host_count())
    eval_iter = utils.numpy_iter(eval_ds)

    samples_shape = (device_bs, *self.dataset.data_shape)

    self.p_gen_samples = utils.dist(
        functools.partial(self._gen_samples, samples_shape=samples_shape),
        accumulate='concat',
        axis_name='batch')

    # Set up model and training state
    state = jax.device_get(self.make_init_state())
    checkpoint_dir = os.path.join(work_unit_dir, 'checkpoints')
    state = checkpoints.restore_checkpoint(checkpoint_dir, state)
    initial_step = int(state.step)
    state = flax.jax_utils.replicate(state)

    # Training step
    train_step = functools.partial(self.step_fn, next(rng), True)
    train_step = functools.partial(jax.lax.scan, train_step)  # for substeps
    train_step = jax.pmap(train_step, axis_name='batch', donate_argnums=(0,))

    # Eval step (does not modify parameters; no substeps)
    eval_base_rng = next(rng)

    # Training loop
    logging.info('Entering training loop at step %i', initial_step)
    utils.assert_synced(state)
    last_log_time = last_ckpt_time = time.time()
    prev_step = initial_step

    with metric_writers.ensure_flushes(writer):
      for batch in train_iter:

        state, metrics = train_step(state, batch)
        new_step = int(state.step[0])
        assert new_step == prev_step + config.train.substeps

        # Quick indication that training is happening.
        logging.log_first_n(logging.INFO, 'Finished training step %d', 5,
                            new_step)
        # Log metrics
        if new_step % config.train.log_loss_every_steps == 0:
          # Unreplicate metrics, average over substeps, and cast to python float
          metrics = jax.device_get(flax.jax_utils.unreplicate(metrics))

          def avg_over_substeps(x):
            assert x.shape[0] == config.train.substeps
            return float(x.mean(axis=0))

          metrics = jax.tree_map(avg_over_substeps, metrics)
          metrics['train/steps_per_sec'] = float(
              config.train.log_loss_every_steps / (time.time() - last_log_time))
          writer.write_scalars(new_step, metrics)
          last_log_time = time.time()

        # Eval
        should_eval = new_step % config.train.eval_every_steps == 0
        if prev_step == 0 or should_eval:
          # Samples

          samples_to_log = {
              'eval/samples':
                  self.get_model_samples(
                      params=state.ema_params, rng=next(rng))
          }

          if samples_to_log:
            assert all(v.shape == (total_bs, *self.dataset.data_shape)
                       for v in samples_to_log.values())
            # tf.summary.image asks for a batch, so insert a new axis
            writer.write_images(
                new_step, {
                    k: utils.np_tile_imgs(v.astype('uint8'))[None, :, :, :]
                    for k, v in samples_to_log.items()
                })

          # Eval metrics
          if config.train.get('calc_eval_metrics', True):
            eval_metrics = self._calc_eval_metrics(
                state=state,
                eval_iter=eval_iter,
                eval_steps=config.train.get('eval_number_steps',
                                            self.dataset.num_eval // total_bs),
                eval_base_rng=eval_base_rng,
                total_bs=total_bs)
            if eval_metrics is not None:
              writer.write_scalars(new_step, eval_metrics)

        # Checkpointing: only if checkpoint_every_secs is not None.
        if config.train.checkpoint_every_secs is not None:
          should_ckpt = (
              time.time() - last_ckpt_time >=
              config.train.checkpoint_every_secs)
          should_ckpt = (
              prev_step == 0 or new_step == config.train.num_train_steps or
              should_ckpt)
        else:
          should_ckpt = False

        if should_ckpt and jax.host_id() == 0:
          checkpoints.save_checkpoint(
              checkpoint_dir,
              flax.jax_utils.unreplicate(state),
              step=new_step,
              keep=3)
          last_ckpt_time = time.time()

        # Keep extra checkpoints without removal. Training does not resume
        # from these checkpoints.
        if (('retain_checkpoint_every_steps' in config.train) and
            ((new_step % config.train.retain_checkpoint_every_steps == 0) or
             (new_step == config.train.num_train_steps)) and
            (jax.host_id() == 0)):
          # Below, overwrite=True because training might resume from a
          # checkpoint from an earlier step than the latest retained checkpoint,
          # causing the latest retained checkpoint to be overwritten.
          checkpoints.save_checkpoint(
              os.path.join(work_unit_dir, 'retained_checkpoints'),
              flax.jax_utils.unreplicate(state),
              step=new_step,
              keep=int(1e10),
              overwrite=True)

        prev_step = new_step
        if new_step == config.train.num_train_steps:
          logging.info('Finished training for %d iterations.', new_step)
          break
Exemple #15
0
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
  """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
  tf.io.gfile.makedirs(workdir)

  vocab_path = config.vocab_path
  if vocab_path is None:
    vocab_path = os.path.join(workdir, "sentencepiece_model")
    config.vocab_path = vocab_path
  tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

  # Load Dataset
  # ---------------------------------------------------------------------------
  logging.info("Initializing dataset.")
  train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
      n_devices=jax.local_device_count(),
      config=config,
      reverse_translation=config.reverse_translation,
      vocab_path=vocab_path)

  train_iter = iter(train_ds)
  vocab_size = int(encoder.vocab_size())
  eos_id = decode.EOS_ID  # Default Sentencepiece EOS token.

  def decode_tokens(toks):
    valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
    return encoder.detokenize(valid_toks).numpy().decode("utf-8")

  if config.num_predict_steps > 0:
    predict_ds = predict_ds.take(config.num_predict_steps)

  logging.info("Initializing model, optimizer, and step functions.")

  # Build Model and Optimizer
  # ---------------------------------------------------------------------------
  train_config = models.TransformerConfig(
      vocab_size=vocab_size,
      output_vocab_size=vocab_size,
      share_embeddings=config.share_embeddings,
      logits_via_embedding=config.logits_via_embedding,
      dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32,
      emb_dim=config.emb_dim,
      num_heads=config.num_heads,
      num_layers=config.num_layers,
      qkv_dim=config.qkv_dim,
      mlp_dim=config.mlp_dim,
      max_len=max(config.max_target_length, config.max_eval_target_length),
      dropout_rate=config.dropout_rate,
      attention_dropout_rate=config.attention_dropout_rate,
      deterministic=False,
      decode=False,
      kernel_init=nn.initializers.xavier_uniform(),
      bias_init=nn.initializers.normal(stddev=1e-6))
  eval_config = train_config.replace(deterministic=True)
  predict_config = train_config.replace(deterministic=True, decode=True)

  start_step = 0
  rng = jax.random.PRNGKey(config.seed)
  rng, init_rng = jax.random.split(rng)
  input_shape = (config.per_device_batch_size, config.max_target_length)
  target_shape = (config.per_device_batch_size, config.max_target_length)

  m = models.Transformer(eval_config)
  initial_variables = jax.jit(m.init)(init_rng,
                                      jnp.ones(input_shape, jnp.float32),
                                      jnp.ones(target_shape, jnp.float32))

  # apply an optimizer to this tree
  optimizer_def = optim.Adam(
      config.learning_rate,
      beta1=0.9,
      beta2=0.98,
      eps=1e-9,
      weight_decay=config.weight_decay)
  optimizer = optimizer_def.create(initial_variables["params"])

  # We access model params only from optimizer below via optimizer.target.
  del initial_variables

  if config.restore_checkpoints:
    # Restore unreplicated optimizer + model state from last checkpoint.
    optimizer = checkpoints.restore_checkpoint(workdir, optimizer)
    # Grab last step.
    start_step = int(optimizer.state.step)

  writer = metric_writers.create_default_writer(
      workdir, just_logging=jax.host_id() > 0)
  if start_step == 0:
    writer.write_hparams(dict(config))

  # Replicate optimizer.
  optimizer = jax_utils.replicate(optimizer)

  learning_rate_fn = create_learning_rate_scheduler(
      base_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps)

  # compile multidevice versions of train/eval/predict step and cache init fn.
  p_train_step = jax.pmap(
      functools.partial(
          train_step,
          config=train_config,
          learning_rate_fn=learning_rate_fn,
          label_smoothing=config.label_smoothing),
      axis_name="batch",
      donate_argnums=(0,))  # pytype: disable=wrong-arg-types
  p_eval_step = jax.pmap(
      functools.partial(
          eval_step, config=eval_config),
      axis_name="batch")
  p_init_cache = jax.pmap(
      functools.partial(
          initialize_cache,
          max_decode_len=config.max_predict_length,
          config=predict_config),
      axis_name="batch")
  p_pred_step = jax.pmap(
      functools.partial(
          predict_step, config=predict_config, beam_size=config.beam_size),
      axis_name="batch",
      static_broadcasted_argnums=(3, 4))  # eos token, max_length are constant

  # Main Train Loop
  # ---------------------------------------------------------------------------

  # We init the first set of dropout PRNG keys, but update it afterwards inside
  # the main pmap"d training update for performance.
  dropout_rngs = jax.random.split(rng, jax.local_device_count())
  del rng

  logging.info("Starting training loop.")
  hooks = []
  report_progress = periodic_actions.ReportProgress(
      num_train_steps=config.num_train_steps, writer=writer)
  if jax.host_id() == 0:
    hooks += [
        report_progress,
        periodic_actions.Profile(logdir=workdir, num_profile_steps=5)
    ]
  train_metrics = []
  with metric_writers.ensure_flushes(writer):
    for step in range(start_step, config.num_train_steps):
      is_last_step = step == config.num_train_steps - 1

      # Shard data to devices and do a training step.
      with jax.profiler.StepTraceAnnotation("train", step_num=step):
        batch = common_utils.shard(jax.tree_map(np.asarray, next(train_iter)))
        optimizer, metrics = p_train_step(
            optimizer, batch, dropout_rng=dropout_rngs)
        train_metrics.append(metrics)

      # Quick indication that training is happening.
      logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step)
      for h in hooks:
        h(step)

      # Periodic metric handling.
      if step % config.eval_every_steps == 0 or is_last_step:
        with report_progress.timed("training_metrics"):
          logging.info("Gathering training metrics.")
          train_metrics = common_utils.get_metrics(train_metrics)
          lr = train_metrics.pop("learning_rate").mean()
          metrics_sums = jax.tree_map(jnp.sum, train_metrics)
          denominator = metrics_sums.pop("denominator")
          summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
          summary["learning_rate"] = lr
          summary = {"train_" + k: v for k, v in summary.items()}
          writer.write_scalars(step, summary)
          train_metrics = []

        with report_progress.timed("eval"):
          eval_results = evaluate(
              p_eval_step=p_eval_step,
              target=optimizer.target,
              eval_ds=eval_ds,
              num_eval_steps=config.num_eval_steps)
          writer.write_scalars(
              step, {"eval_" + k: v for k, v in eval_results.items()})

        with report_progress.timed("translate_and_bleu"):
          exemplars, bleu_score = translate_and_calculate_bleu(
              p_pred_step=p_pred_step,
              p_init_cache=p_init_cache,
              target=optimizer.target,
              predict_ds=predict_ds,
              decode_tokens=decode_tokens,
              max_predict_length=config.max_predict_length)
          writer.write_scalars(step, {"bleu": bleu_score})
          writer.write_texts(step, {"samples": exemplars})

      # Save a checkpoint on one host after every checkpoint_freq steps.
      save_checkpoint = (step % config.checkpoint_every_steps == 0 or
                         is_last_step)
      if config.save_checkpoints and save_checkpoint and jax.host_id() == 0:
        with report_progress.timed("checkpoint"):
          checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(optimizer),
                                      step)
Exemple #16
0
def train(*,
          workdir,
          compute_phi,
          compute_psi,
          params,
          optimal_subspace,
          num_epochs,
          learning_rate,
          key,
          method,
          lissa_kappa,
          optimizer,
          covariance_batch_size,
          main_batch_size,
          weight_batch_size,
          d,
          num_tasks,
          compute_feature_norm_on_oracle_states,
          sample_states,
          eval_states,
          use_tabular_gradient=True):
    """Training function.

  For lissa, the total number of samples is
  2 x covariance_batch_size + main_batch_size + 2 x weight_batch_size.

  Args:
    workdir: Work directory, where we'll save logs.
    compute_phi: A function that takes params and states and returns
      a matrix of phis.
    compute_psi: A function that takes an array of states and an array
      of tasks and returns Psi[states, tasks].
    params: Parameters used as the first argument for compute_phi.
    optimal_subspace: Top-d left singular vectors of Psi.
    num_epochs: How many gradient steps to perform. (Not really epochs)
    learning_rate: The step size parameter for sgd.
    key: The jax prng key.
    method: 'naive', 'lissa', or 'oracle'.
    lissa_kappa: The parameter of the lissa method, if used.
    optimizer: Which optimizer to use. Only 'sgd' is supported.
    covariance_batch_size: the 'J' parameter. For the naive method, this is how
      many states we sample to construct the inverse. For the lissa method,
      ditto -- these are also "iterations".
    main_batch_size: How many states to update at once.
    weight_batch_size: How many states to construct the weight vector.
    d: The dimension of the representation.
    num_tasks: The total number of tasks.
    compute_feature_norm_on_oracle_states: If True, computes the feature norm
      using the oracle states (all the states in synthetic experiments).
      Otherwise, computes the norm using the sampled batch.
      Only applies to LISSA.
    sample_states: A function that takes an rng key and a number of states
      to sample, and returns a tuple containing
      (a vector of sampled states, an updated rng key).
    eval_states: An array of states to use to compute metrics on.
      This will be used to compute Phi = compute_phi(params, eval_states).
    use_tabular_gradient: If true, the train step will calculate the
      gradient using the tabular calculation. Otherwise, it will use a
      jax.vjp to backpropagate the gradient.
  """
    # Create an explicit weight vector (needed for explicit method only).
    if method == 'explicit':
        key, weight_key = jax.random.split(key)
        explicit_weight_matrix = jax.random.normal(weight_key, (d, num_tasks),
                                                   dtype=jnp.float32)
        params['explicit_weight_matrix'] = explicit_weight_matrix

    if optimizer == 'sgd':
        optimizer = optax.sgd(learning_rate)
    elif optimizer == 'adam':
        optimizer = optax.adam(learning_rate)
    else:
        raise ValueError(f'Unknown optimizer {optimizer}.')
    optimizer_state = optimizer.init(params)

    chkpt_manager = checkpoint.Checkpoint(base_directory=_WORKDIR.value)
    initial_step, params, optimizer_state = chkpt_manager.restore_or_initialize(
        (0, params, optimizer_state))

    writer = metric_writers.create_default_writer(logdir=str(workdir), )

    # Checkpointing and logging too much can use a lot of disk space.
    # Therefore, we don't want to checkpoint more than 10 times an experiment,
    # or keep more than 1k Phis per experiment.
    checkpoint_period = max(num_epochs // 10, 100_000)
    log_period = max(1_000, num_epochs // 1_000)

    def _checkpoint_callback(step, t, params, optimizer_state):
        del t  # Unused.
        chkpt_manager.save((step, params, optimizer_state))

    hooks = [
        periodic_actions.PeriodicCallback(every_steps=checkpoint_period,
                                          callback_fn=_checkpoint_callback)
    ]

    fixed_train_kwargs = {
        'compute_phi':
        compute_phi,
        'compute_psi':
        compute_psi,
        'optimizer':
        optimizer,
        'method':
        method,
        # In the tabular case, the eval_states are all the states.
        'oracle_states':
        eval_states,
        'lissa_kappa':
        lissa_kappa,
        'main_batch_size':
        main_batch_size,
        'covariance_batch_size':
        covariance_batch_size,
        'weight_batch_size':
        weight_batch_size,
        'd':
        d,
        'num_tasks':
        num_tasks,
        'compute_feature_norm_on_oracle_states':
        (compute_feature_norm_on_oracle_states),
        'sample_states':
        sample_states,
        'use_tabular_gradient':
        use_tabular_gradient,
    }
    variable_kwargs = {
        'params': params,
        'optimizer_state': optimizer_state,
        'key': key,
    }

    @jax.jit
    def _eval_step(phi_params):
        eval_phi = compute_phi(phi_params, eval_states)
        eval_psi = compute_psi(eval_states)  # pytype: disable=wrong-arg-count

        metrics = compute_metrics(eval_phi, optimal_subspace)
        metrics |= {'frob_norm': utils.outer_objective_mc(eval_phi, eval_psi)}
        return metrics

    # Perform num_epochs gradient steps.
    with metric_writers.ensure_flushes(writer):
        for step in etqdm.tqdm(range(initial_step + 1, num_epochs + 1),
                               initial=initial_step,
                               total=num_epochs):

            variable_kwargs = _train_step(**fixed_train_kwargs,
                                          **variable_kwargs)

            if step % log_period == 0:
                metrics = _eval_step(variable_kwargs['params']['phi_params'])
                writer.write_scalars(step, metrics)

            for hook in hooks:
                hook(step,
                     params=variable_kwargs['params'],
                     optimizer_state=variable_kwargs['optimizer_state'])

    writer.flush()
Exemple #17
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # Make sure tf does not allocate gpu memory.
    tf.config.experimental.set_visible_devices([], 'GPU')

    if FLAGS.jax_backend_target:
        jax.config.FLAGS.jax_xla_backend = 'tpu_driver'
        jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target

    # Number of local devices for this host.
    n_devices = jax.local_device_count()

    if jax.process_index() == 0:
        tf.io.gfile.makedirs(FLAGS.model_dir)

    if FLAGS.batch_size % n_devices:
        raise ValueError(
            'Batch size must be divisible by the number of devices')

    vocab_path = FLAGS.vocab_path
    if vocab_path is None:
        vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model')
    tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

    # Load Dataset
    # ---------------------------------------------------------------------------
    logging.info('Initializing dataset.')
    train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
        dataset_name=FLAGS.dataset_name,
        eval_dataset_name=FLAGS.eval_dataset_name,
        shard_idx=jax.process_index(),
        shard_count=jax.process_count(),
        data_dir=FLAGS.data_dir,
        vocab_path=vocab_path,
        target_vocab_size=FLAGS.vocab_size,
        batch_size=FLAGS.batch_size,
        max_length=FLAGS.max_target_length,
        max_eval_length=FLAGS.max_eval_target_length,
        paracrawl_size=FLAGS.paracrawl_size,
        is_scores_path=FLAGS.is_scores_path,
        num_to_keep=FLAGS.data_selection_size,
        pseudo_path=FLAGS.pseudo_path,
        repeat_count=FLAGS.repeat_count,
        newscommentary_size=FLAGS.newscommentary_size,
        split_tokenizer=FLAGS.split_tokenizer)

    if FLAGS.aux_eval_dataset:
        aux_datasets = []
        aux_names = FLAGS.aux_eval_dataset.split(',')
        for name in aux_names:
            _, aux_eval_ds, _, _ = input_pipeline.get_wmt_datasets(
                dataset_name=name,
                eval_dataset_name=None,
                shard_idx=jax.process_index(),
                shard_count=jax.process_count(),
                data_dir=FLAGS.data_dir,
                vocab_path=vocab_path,
                target_vocab_size=FLAGS.vocab_size,
                batch_size=FLAGS.batch_size,
                max_length=FLAGS.max_target_length,
                max_eval_length=FLAGS.max_eval_target_length,
                paracrawl_size=FLAGS.paracrawl_size,
                is_scores_path=FLAGS.is_scores_path,
                num_to_keep=FLAGS.data_selection_size,
                pseudo_path=FLAGS.pseudo_path,
                repeat_count=FLAGS.repeat_count,
                newscommentary_size=FLAGS.newscommentary_size)
            aux_datasets.append(aux_eval_ds)

    train_iter = iter(train_ds)
    vocab_size = int(encoder.vocab_size())
    eos_id = decode.EOS_ID  # Default Sentencepiece EOS token.

    def decode_tokens(toks):
        valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
        return encoder.detokenize(valid_toks).numpy().decode('utf-8')

    logging.info('Initializing model, optimizer, and step functions.')

    # Build Model and Optimizer
    # ---------------------------------------------------------------------------
    train_config = models.TransformerConfig(
        vocab_size=vocab_size,
        output_vocab_size=vocab_size,
        share_embeddings=FLAGS.share_embeddings,
        logits_via_embedding=FLAGS.logits_via_embedding,
        dtype=jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32,
        emb_dim=FLAGS.emb_dim,
        num_heads=FLAGS.num_heads,
        num_layers=FLAGS.num_layers,
        qkv_dim=FLAGS.qkv_dim,
        mlp_dim=FLAGS.mlp_dim,
        max_len=max(FLAGS.max_target_length, FLAGS.max_eval_target_length),
        dropout_rate=FLAGS.dropout_rate,
        attention_dropout_rate=FLAGS.attention_dropout_rate,
        deterministic=False,
        decode=False,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6))
    eval_config = train_config.replace(deterministic=True)
    predict_config = train_config.replace(deterministic=True, decode=True)

    start_step = 0
    rng = jax.random.PRNGKey(FLAGS.random_seed)
    rng, init_rng = jax.random.split(rng)
    # It's possible that is supposed to be per device batch size
    input_shape = (FLAGS.batch_size, FLAGS.max_target_length)
    target_shape = (FLAGS.batch_size, FLAGS.max_target_length)

    m = models.Transformer(eval_config)
    initial_variables = jax.jit(m.init)(init_rng,
                                        jnp.ones(input_shape, jnp.float32),
                                        jnp.ones(target_shape, jnp.float32))

    # apply an optimizer to this tree
    optimizer_def = optim.Adam(FLAGS.learning_rate,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=FLAGS.weight_decay)
    optimizer = optimizer_def.create(initial_variables['params'])

    # We access model params only from optimizer below via optimizer.target.
    del initial_variables

    if FLAGS.restore_checkpoints:
        logging.info('Restoring checkpoint.')
        # If we have a pretrained model, use that. Else, just continue where leftoff
        model_path = FLAGS.pretrained_model_dir if FLAGS.pretrained_model_dir else FLAGS.model_dir
        optimizer = checkpoints.restore_checkpoint(model_path, optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)

    writer = metric_writers.create_default_writer(
        FLAGS.model_dir, just_logging=jax.process_index() > 0)

    flag_key = [
        k for k in FLAGS.flags_by_module_dict().keys() if 'wmt.par' in k
    ]
    if flag_key:
        flag_key = flag_key[0]
        local_flags = {
            f.name: f.value
            for f in FLAGS.flags_by_module_dict()[flag_key]
        }
        writer.write_hparams(local_flags)

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = common.create_learning_rate_scheduler(
        base_learning_rate=FLAGS.learning_rate,
        warmup_steps=FLAGS.warmup_steps,
        steps_per_cycle=FLAGS.steps_per_cycle,
        init_step=start_step,
        finetune_lr=FLAGS.finetune_lr)

    # compile multidevice versions of train/eval/predict step and cache init fn.
    p_train_step = jax.pmap(functools.partial(
        train_util.train_step,
        config=train_config,
        learning_rate_fn=learning_rate_fn,
        label_smoothing=FLAGS.label_smoothing),
                            axis_name='batch',
                            donate_argnums=(0, ))  # pytype: disable=wrong-arg-types
    p_eval_step = jax.pmap(functools.partial(train_util.eval_step,
                                             config=eval_config),
                           axis_name='batch')
    p_init_cache = jax.pmap(functools.partial(
        train_util.initialize_cache,
        max_decode_len=FLAGS.max_predict_length,
        config=predict_config),
                            axis_name='batch')
    p_pred_step = jax.pmap(
        functools.partial(train_util.predict_step,
                          config=predict_config,
                          beam_size=FLAGS.beam_size),
        axis_name='batch',
        static_broadcasted_argnums=(3,
                                    4))  # eos token, max_length are constant

    # Main Train Loop
    # ---------------------------------------------------------------------------

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap"d training update for performance.
    dropout_rngs = jax.random.split(rng, jax.local_device_count())
    del rng

    logging.info('Starting training loop.')
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=FLAGS.num_train_steps, writer=writer)
    if jax.process_index() == 0:
        hooks += [
            report_progress,
            periodic_actions.Profile(logdir=FLAGS.model_dir,
                                     num_profile_steps=5)
        ]
    train_metrics = []
    total_steps = start_step + FLAGS.num_train_steps
    if FLAGS.eval_only:
        total_steps = start_step + 1
    best_eval_loss = 1000
    curr_eval_loss = 1000
    eval_loss_history = []
    last_eval_step = 0
    do_resample_data = False
    gradual_selection_size = FLAGS.data_selection_size
    dynamic_eval_freq = FLAGS.eval_frequency
    with metric_writers.ensure_flushes(writer):
        for step in range(start_step, total_steps):
            is_last_step = step == total_steps - 1

            # Resample training data for gradual FT
            if do_resample_data:
                # resample data
                do_resample_data = False
                gradual_selection_size *= .7
                dynamic_eval_freq = int(gradual_selection_size / 1000 / 4)

                train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
                    dataset_name=FLAGS.dataset_name,
                    eval_dataset_name=FLAGS.eval_dataset_name,
                    shard_idx=jax.process_index(),
                    shard_count=jax.process_count(),
                    data_dir=FLAGS.data_dir,
                    vocab_path=vocab_path,
                    target_vocab_size=FLAGS.vocab_size,
                    batch_size=FLAGS.batch_size,
                    max_length=FLAGS.max_target_length,
                    max_eval_length=FLAGS.max_eval_target_length,
                    paracrawl_size=FLAGS.paracrawl_size,
                    is_scores_path=FLAGS.is_scores_path,
                    num_to_keep=int(gradual_selection_size),
                    pseudo_path=FLAGS.pseudo_path,
                    repeat_count=FLAGS.repeat_count,
                    newscommentary_size=FLAGS.newscommentary_size,
                    split_tokenizer=FLAGS.split_tokenizer)
                train_iter = iter(train_ds)

            # Shard data to devices and do a training step.
            if not FLAGS.eval_only:
                logging.info('Doing Training.')
                with jax.profiler.StepTraceAnnotation('train', step_num=step):
                    try:
                        batch = common_utils.shard(
                            jax.tree_map(np.asarray, next(train_iter)))
                        optimizer, metrics = p_train_step(
                            optimizer, batch, dropout_rng=dropout_rngs)
                        train_metrics.append(metrics)
                    except StopIteration:
                        is_last_step = True

            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, 'Finished training step %d.', 5,
                                step)
            for h in hooks:
                h(step)

            # Periodic metric handling.
            if (step - start_step) % dynamic_eval_freq == 0 or is_last_step:
                if not FLAGS.eval_only:
                    with report_progress.timed('training_metrics'):
                        logging.info('Gathering training metrics.')
                        train_metrics = common_utils.get_metrics(train_metrics)
                        lr = train_metrics.pop('learning_rate').mean()
                        metrics_sums = jax.tree_map(jnp.sum, train_metrics)
                        denominator = metrics_sums.pop('denominator')
                        summary = jax.tree_map(lambda x: x / denominator,
                                               metrics_sums)  # pylint: disable=cell-var-from-loop
                        summary['learning_rate'] = lr
                        summary = {'train_' + k: v for k, v in summary.items()}
                        writer.write_scalars(step, summary)
                        train_metrics = []

                if FLAGS.eval_only:
                    p_eval_per_pos_step = jax.pmap(functools.partial(
                        train_util.eval_per_pos_step, config=eval_config),
                                                   axis_name='batch')
                    # Get per example loss
                    loss_filename = FLAGS.model_dir + '/test_losses.csv'
                    train_util.write_per_example_losses(
                        p_eval_step=p_eval_per_pos_step,
                        target=optimizer.target,
                        eval_ds=eval_ds,
                        num_eval_steps=FLAGS.num_eval_steps,
                        loss_filename=loss_filename)
                else:
                    with report_progress.timed('eval'):
                        eval_results = train_util.evaluate(
                            p_eval_step=p_eval_step,
                            target=optimizer.target,
                            eval_ds=eval_ds,
                            num_eval_steps=FLAGS.num_eval_steps)
                        curr_eval_loss = eval_results['loss']
                        eval_loss_history.append(curr_eval_loss)
                        if len(eval_loss_history) > 1:
                            improvement_rate = 0.000004
                            orig_loss = eval_loss_history[-2]
                            true_improvement = orig_loss - curr_eval_loss
                            expected_improvement = (
                                step - last_eval_step) * improvement_rate
                            # percent_change = (orig_loss - curr_eval_loss) / orig_loss
                            # percent_change *= 100
                            if true_improvement < expected_improvement:  # percent_change<.1:
                                do_resample_data = True
                        last_eval_step = step
                        writer.write_scalars(
                            step,
                            {'eval_' + k: v
                             for k, v in eval_results.items()})

                if FLAGS.aux_eval_dataset:
                    for aux_i, aux_eval_ds in enumerate(aux_datasets):
                        with report_progress.timed('aux_eval'):
                            eval_results = train_util.evaluate(
                                p_eval_step=p_eval_step,
                                target=optimizer.target,
                                eval_ds=aux_eval_ds,
                                num_eval_steps=FLAGS.num_eval_steps)
                            writer.write_scalars(
                                step, {
                                    'aux' + str(aux_i) + '_eval_' + k: v
                                    for k, v in eval_results.items()
                                })

                if FLAGS.compute_bleu:
                    with report_progress.timed('translate_and_bleu'):
                        decode_file = FLAGS.model_dir + '/decodes.csv'
                        exemplars, bleu_score = train_util.translate_and_calculate_bleu(
                            p_pred_step=p_pred_step,
                            p_init_cache=p_init_cache,
                            target=optimizer.target,
                            predict_ds=predict_ds,
                            decode_tokens=decode_tokens,
                            max_predict_length=FLAGS.max_predict_length,
                            num_eval_steps=FLAGS.num_eval_steps,
                            decode_file=decode_file if FLAGS.eval_only else '')
                        writer.write_scalars(step, {'bleu': bleu_score})
                        writer.write_texts(step, {'samples': exemplars})

            # Save a checkpoint on one host after every checkpoint_freq steps.
            save_checkpoint = ((step - start_step) % FLAGS.checkpoint_freq == 0
                               or is_last_step)
            if FLAGS.save_checkpoints and save_checkpoint and jax.process_index(
            ) == 0:
                if curr_eval_loss < best_eval_loss:  # only save better checkpoints
                    best_eval_loss = curr_eval_loss
                    with report_progress.timed('checkpoint'):
                        checkpoints.save_checkpoint(
                            FLAGS.model_dir,
                            jax_utils.unreplicate(optimizer),
                            step,
                            keep=FLAGS.chkpts_to_keep,
                            overwrite=True)

            if is_last_step:
                break
def train_and_evaluate(config, workdir, strategy):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
    strategy: Distribution strategy to use for distributing the model.
  """
    tf.io.gfile.makedirs(workdir)

    tf_rng, data_rng = tf.random.experimental.stateless_split((config.seed, 0),
                                                              2)
    tf.random.set_seed(tf_rng.numpy()[0])

    # Input pipeline.
    ds_info, train_ds, val_ds, test_ds = input_pipeline.create_datasets(
        config, data_rng, strategy=strategy)
    train_iter = iter(train_ds)  # pytype: disable=wrong-arg-types

    # Learning rate schedule.
    num_train_steps = config.num_train_steps
    if num_train_steps == -1:
        num_train_steps = (ds_info.splits["train"].num_examples //
                           config.global_batch_size * config.num_epochs)
    steps_per_epoch = num_train_steps // config.num_epochs
    logging.info("num_train_steps=%d, steps_per_epoch=%d", num_train_steps,
                 steps_per_epoch)

    # We treat the learning rate in the config as the learning rate for batch size
    # 256 but scale it according to our batch size.
    base_learning_rate = config.learning_rate * config.global_batch_size / 256.0

    # Initialize model.
    num_classes = ds_info.features["label"].num_classes

    if config.distill_teacher:
        do_distill = True
        teacher_file_list = (config.distill_teacher).split(",")
        teacher_models = load_teacher_models(teacher_file_list, num_classes,
                                             config, strategy)
        distill_params = {}
        distill_params["alpha"] = config.distill_alpha
        distill_params["beta"] = config.distill_fd_beta
        distill_params["teacher_model"] = TeacherModel(teacher_models,
                                                       name="teacher")
    else:
        do_distill = False
        distill_params = None

    state = create_state(config, num_classes=num_classes, strategy=strategy)

    ckpt_manager = tf.train.CheckpointManager(checkpoint=state,
                                              directory=workdir,
                                              max_to_keep=5)

    if ckpt_manager.latest_checkpoint:
        state.restore(ckpt_manager.latest_checkpoint)
        logging.info("Restored from %s", ckpt_manager.latest_checkpoint)
    else:
        logging.info("Initializing from scratch.")
    initial_step = state.global_step.numpy().item()

    learning_rate_fn = functools.partial(get_learning_rate,
                                         base_learning_rate=base_learning_rate,
                                         steps_per_epoch=steps_per_epoch,
                                         num_epochs=config.num_epochs,
                                         warmup_epochs=config.warmup_epochs)

    writer = metric_writers.create_default_writer(workdir)
    writer.write_hparams(dict(config))

    logging.info("Starting training loop at step %d.", initial_step)
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=num_train_steps, writer=writer)
    with metric_writers.ensure_flushes(writer):
        for step in range(initial_step, num_train_steps + 1):
            state.model.trainable = True

            # `step` is a Python integer. `global_step` is a TF variable on the
            # GPU/TPU devices.
            is_last_step = step == num_train_steps

            train_step(state, train_iter, config.weight_decay,
                       learning_rate_fn, do_distill, distill_params, strategy)

            state.train_metrics.update_state_lr(
                learning_rate_fn(state.global_step.numpy().item()))

            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, "Finished training step %d.", 5,
                                step)
            report_progress(step)

            if step == initial_step:
                parameter_overview.log_parameter_overview(state.model)

            if step % config.log_loss_every_steps == 0 or is_last_step:
                writer.write_scalars(step, state.train_metrics.result())
                state.train_metrics.reset_states()
                state.train_metrics.reset_lr()

            if step % config.eval_every_steps == 0 or is_last_step:
                state.model.trainable = False
                if config.dataset == "imagenet-lt":
                    evaluate(state, val_ds, state.val_metrics, strategy)
                    writer.write_scalars(step, state.val_metrics.result())
                    logging.info("Num val images %d",
                                 state.val_metrics.accuracy.count.numpy())

                evaluate(state, test_ds, state.test_metrics, strategy)
                writer.write_scalars(step, state.test_metrics.result())

                logging.info("Num test images %d",
                             state.test_metrics.accuracy.count.numpy())

            if step % config.checkpoint_every_steps == 0 or is_last_step:
                checkpoint_path = ckpt_manager.save(step)
                logging.info("Saved checkpoint %s", checkpoint_path)

    logging.info("Finishing training at step %d", step)
    logging.info("Saving the final weights")
    file_path = "%s/final_weights" % workdir
    state.model.save_weights(file_path, save_format="tf")
Exemple #19
0
def monitor_and_sample(config, work_dir):
    """Monitors `work_dir` for new checkpoints and run sampling on them.

  Args:
    config: Hyperparameter configuration for training and evaluation.
    work_dir: Directory where the tensorboard summaries are written to.
  """
    # Init rng key.
    rng = jax.random.PRNGKey(config.seed)
    data_rng, rng = jax.random.split(rng)
    is_first_host = jax.process_index() == 0

    # TODO(agritsenko): We are loading the datasets just to get the metadata.
    #  Can we be smarter about this?
    if config.dataset.name.endswith('speech_commands09'):
        _, ds_metadata = input_pipeline_sc09.get_dataset(data_rng, config)
    else:
        raise ValueError(f'Unknown dataset {config.dataset.name}.')

    # TODO(agritsenko): Can we fix the ugly nested dicts?
    config.data_shape = ds_metadata['train']['shape']['inputs'][2:]
    config.num_classes = ds_metadata['train']['num_classes']
    config.sample_rate = ds_metadata['train']['sample_rate']

    writer = metric_writers.create_default_writer(
        work_dir, just_logging=jax.process_index() > 0)
    rng, init_rng = jax.random.split(rng)

    model, variables = model_setup(init_rng, config)

    # From now on we want different rng across hosts:
    rng = jax.random.fold_in(rng, jax.process_index())
    rng, rng_sample = jax.random.split(rng)

    def tx_fn(lr):
        return optax.adamw(lr,
                           b1=0.9,
                           b2=config.beta2,
                           eps=1e-08,
                           eps_root=0.0,
                           weight_decay=config.weight_decay)

    state = language_train_state.TrainState.create(params=variables['params'],
                                                   tx_fn=tx_fn)

    # Wait for checkpoints in an loop.
    ckpt_path_iterator = checkpoint.checkpoints_iterator(work_dir, target=None)

    with metric_writers.ensure_flushes(writer):
        for _ in ckpt_path_iterator:
            state, step = checkpoint.restore_from_path(work_dir, state)
            is_last_step = step == config.num_train_steps - 1
            logging.info('Loaded checkpoint for step: %d', step)

            # Replicate the state
            state = flax.jax_utils.replicate(state)

            ######################### Run sampling ###############################
            chain = model.sample(jax.random.fold_in(rng_sample, step),
                                 state.ema_params,
                                 config.sample_batch_size,
                                 chain_out_size=config.get(
                                     'chain_out_size', model.num_stages))

            if is_first_host:
                chain = jax.device_get(chain)
                long_sample = np.reshape(chain[-1],
                                         (1, -1, 1)).astype(np.float32)
                long_sample = (2. * long_sample) / config.num_classes - 1.
                long_sample = long_sample.astype(np.float32)
                writer.write_audios(step, {'samples': long_sample},
                                    sample_rate=config.sample_rate)

            if is_last_step:
                break
Exemple #20
0
def predict_and_evaluate(config, workdir, ckpt_path=None):
    """Runs a testing and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
    ckpt_path: The checkpoint to evaluate. If not specified, use the latest
      checkpoint.
  """
    logging.info('Starting testing at %s', workdir)
    tf.io.gfile.makedirs(workdir)

    rng = jax.random.PRNGKey(config.seed)
    # Build input pipeline.
    rng, data_rng = jax.random.split(rng)
    data_rng = jax.random.fold_in(data_rng, jax.process_index())
    test_ds = []
    for split in config.dataset.test_splits:
        ds = input_pipeline.create_val_dataset(
            config.dataset, split, config.dataset.test_per_device_batch_size,
            config.dataset.test_pad_last_batch)
        test_ds.append(ds)

    # Initialize model.
    inputs = train_utils.get_init_inputs(test_ds[0])
    rng, model_rng = jax.random.split(rng)
    predict_config = models.TransformerConfig(**config.model.to_dict())
    predict_config = predict_config.replace(decode=True)
    model = models.Model(predict_config)
    state = train_utils.create_train_state(model,
                                           config,
                                           model_rng,
                                           inputs=inputs)

    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.process_index() > 0)

    # Set up checkpointing of the model and the input pipeline.
    checkpoint_dir = os.path.join(workdir, 'checkpoints')
    ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir, max_to_keep=3)

    logging.info('Testing and evaluating checkpoint %s', ckpt_path)
    try:
        state = ckpt.restore(state, ckpt_path)
    except FileNotFoundError:
        state = ckpt.restore_or_initialize(state)
    step = int(state.step)

    p_pred_step = jax.pmap(functools.partial(predict_step,
                                             config=predict_config),
                           axis_name='batch',
                           static_broadcasted_argnums=(3, ))
    p_init_cache = jax.pmap(functools.partial(init_cache,
                                              config=predict_config),
                            axis_name='batch')

    # Distribute testing.
    state = flax_utils.replicate(state)
    with metric_writers.ensure_flushes(writer):
        test_metrics = {}
        for ds, split in zip(test_ds, config.dataset.test_splits):
            ds_metrics = evaluate_sequence_accuracy(p_pred_step, p_init_cache,
                                                    state, ds, config, split,
                                                    workdir,
                                                    config.num_test_steps)
            ds_metrics = {f'{k}_{split}': v for k, v in ds_metrics.items()}
            test_metrics.update(ds_metrics)
        writer.write_scalars(step, test_metrics)