Exemple #1
0
def create_train_state(rng, config: ml_collections.ConfigDict,
                       model, image_size):
  """Create initial training state."""
  dynamic_scale = None
  platform = jax.local_devices()[0].platform
  if config.half_precision and platform == 'gpu':
    dynamic_scale = optim.DynamicScale()
  else:
    dynamic_scale = None

  params, model_state = initialized(rng, image_size, model)
  optimizer = optim.Momentum(
      beta=config.momentum, nesterov=True).create(params)
  state = TrainState(
      step=0, optimizer=optimizer, model_state=model_state,
      dynamic_scale=dynamic_scale)
  return state
Exemple #2
0
  def test_dynamic_scale(self):
    def loss_fn(p):
      return jnp.asarray(p, jnp.float16) ** 2
    p = jnp.array(1., jnp.float32)

    dyn_scale = optim.DynamicScale(growth_interval=2)
    step = jax.jit(lambda ds, p: ds.value_and_grad(loss_fn)(p))
    inf = float('inf')
    nan = float('nan')
    expected_values = [
        (False, nan, 32768.0, inf),
        (False, 1.0, 16384.0, inf),
        (True, 1.0, 16384.0, 2.0),
        (True, 1.0, 16384.0, 2.0),
        (True, 1.0, 32768.0, 2.0),
        (False, 1.0, 16384.0, inf),
    ]

    for expected in expected_values:
      dyn_scale, is_fin, loss, grad = step(dyn_scale, p)
      values = onp.array((is_fin, loss, dyn_scale.scale, grad))
      onp.testing.assert_allclose(values, expected)
Exemple #3
0
def create_train_state(rng, config: ml_collections.ConfigDict, model,
                       image_size, learning_rate_fn):
    """Create initial training state."""
    dynamic_scale = None
    platform = jax.local_devices()[0].platform
    if config.half_precision and platform == 'gpu':
        dynamic_scale = optim.DynamicScale()
    else:
        dynamic_scale = None

    params, batch_stats = initialized(rng, image_size, model)
    tx = optax.sgd(
        learning_rate=learning_rate_fn,
        momentum=config.momentum,
        nesterov=True,
    )
    state = TrainState.create(apply_fn=model.apply,
                              params=params,
                              tx=tx,
                              batch_stats=batch_stats,
                              dynamic_scale=dynamic_scale)
    return state
Exemple #4
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.enable_v2_behavior()
    # make sure tf does not allocate gpu memory
    tf.config.experimental.set_visible_devices([], 'GPU')

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(FLAGS.model_dir)

    rng = random.PRNGKey(0)

    image_size = 224

    batch_size = FLAGS.batch_size
    if batch_size % jax.device_count() > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')
    local_batch_size = batch_size // jax.host_count()
    device_batch_size = batch_size // jax.device_count()

    platform = jax.local_devices()[0].platform

    dynamic_scale = None
    if FLAGS.half_precision:
        if platform == 'tpu':
            model_dtype = jnp.bfloat16
            input_dtype = tf.bfloat16
        else:
            model_dtype = jnp.float16
            input_dtype = tf.float16
            dynamic_scale = optim.DynamicScale()
    else:
        model_dtype = jnp.float32
        input_dtype = tf.float32

    train_iter = create_input_iter(local_batch_size,
                                   image_size,
                                   input_dtype,
                                   train=True,
                                   cache=FLAGS.cache)
    eval_iter = create_input_iter(local_batch_size,
                                  image_size,
                                  input_dtype,
                                  train=False,
                                  cache=FLAGS.cache)

    num_epochs = FLAGS.num_epochs
    steps_per_epoch = input_pipeline.TRAIN_IMAGES // batch_size
    steps_per_eval = input_pipeline.EVAL_IMAGES // batch_size
    steps_per_checkpoint = steps_per_epoch * 10
    num_steps = steps_per_epoch * num_epochs

    base_learning_rate = FLAGS.learning_rate * batch_size / 256.

    model, model_state = create_model(rng, device_batch_size, image_size,
                                      model_dtype)
    optimizer = optim.Momentum(beta=FLAGS.momentum,
                               nesterov=True).create(model)
    state = TrainState(step=0,
                       optimizer=optimizer,
                       model_state=model_state,
                       dynamic_scale=dynamic_scale)
    del model, model_state  # do not keep a copy of the initial model

    state = restore_checkpoint(state)
    step_offset = int(
        state.step)  # step_offset > 0 if restarting from checkpoint
    state = jax_utils.replicate(state)

    learning_rate_fn = create_learning_rate_fn(base_learning_rate,
                                               steps_per_epoch, num_epochs)

    p_train_step = jax.pmap(functools.partial(
        train_step, learning_rate_fn=learning_rate_fn),
                            axis_name='batch')
    p_eval_step = jax.pmap(eval_step, axis_name='batch')

    epoch_metrics = []
    t_loop_start = time.time()
    for step, batch in zip(range(step_offset, num_steps), train_iter):
        state, metrics = p_train_step(state, batch)
        epoch_metrics.append(metrics)
        if (step + 1) % steps_per_epoch == 0:
            epoch = step // steps_per_epoch
            epoch_metrics = common_utils.get_metrics(epoch_metrics)
            summary = jax.tree_map(lambda x: x.mean(), epoch_metrics)
            logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                         summary['loss'], summary['accuracy'] * 100)
            steps_per_sec = steps_per_epoch / (time.time() - t_loop_start)
            t_loop_start = time.time()
            if jax.host_id() == 0:
                for key, vals in epoch_metrics.items():
                    tag = 'train_%s' % key
                    for i, val in enumerate(vals):
                        summary_writer.scalar(tag, val,
                                              step - len(vals) + i + 1)
                summary_writer.scalar('steps per second', steps_per_sec, step)

            epoch_metrics = []
            eval_metrics = []

            # sync batch statistics across replicas
            state = sync_batch_stats(state)
            for _ in range(steps_per_eval):
                eval_batch = next(eval_iter)
                metrics = p_eval_step(state, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            summary = jax.tree_map(lambda x: x.mean(), eval_metrics)
            logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                         summary['loss'], summary['accuracy'] * 100)
            if jax.host_id() == 0:
                for key, val in eval_metrics.items():
                    tag = 'eval_%s' % key
                    summary_writer.scalar(tag, val.mean(), step)
                summary_writer.flush()
        if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps:
            state = sync_batch_stats(state)
            save_checkpoint(state)

    # Wait until computations are done before exiting
    jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
Exemple #5
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  tf.enable_v2_behavior()
  # make sure tf does not allocate gpu memory
  tf.config.experimental.set_visible_devices([], 'GPU')

  if jax.host_id() == 0:
    summary_writer = tensorboard.SummaryWriter(FLAGS.model_dir)

  image_size = 224

  batch_size = FLAGS.batch_size
  if batch_size % jax.device_count() > 0:
    raise ValueError('Batch size must be divisible by the number of devices')
  local_batch_size = batch_size // jax.host_count()
  device_batch_size = batch_size // jax.device_count()

  platform = jax.local_devices()[0].platform

  dynamic_scale = None
  if FLAGS.half_precision:
    if platform == 'tpu':
      model_dtype = jnp.bfloat16
      input_dtype = tf.bfloat16
    else:
      model_dtype = jnp.float16
      input_dtype = tf.float16
      dynamic_scale = optim.DynamicScale()
  else:
    model_dtype = jnp.float32
    input_dtype = tf.float32

  train_iter = imagenet_train_utils.create_input_iter(
      local_batch_size,
      FLAGS.data_dir,
      image_size,
      input_dtype,
      train=True,
      cache=FLAGS.cache)
  eval_iter = imagenet_train_utils.create_input_iter(
      local_batch_size,
      FLAGS.data_dir,
      image_size,
      input_dtype,
      train=False,
      cache=FLAGS.cache)

  # Create the hyperparameter object
  if FLAGS.hparams_config_dict:
    # In this case, there are multiple training configs defined in the config
    # dict, so we pull out the one this training run should use.
    if 'configs' in FLAGS.hparams_config_dict:
      hparams_config_dict = FLAGS.hparams_config_dict.configs[FLAGS.config_idx]
    else:
      hparams_config_dict = FLAGS.hparams_config_dict
    hparams = os_hparams_utils.load_hparams_from_config_dict(
        hparams_config.TrainingHParams, models.ResNet.HParams,
        hparams_config_dict)
  else:
    raise ValueError('Please provide a base config dict.')

  os_hparams_utils.write_hparams_to_file_with_host_id_check(
      hparams, FLAGS.model_dir)

  # get num_epochs from hparam instead of FLAGS
  num_epochs = hparams.lr_scheduler.num_epochs
  steps_per_epoch = input_pipeline.TRAIN_IMAGES // batch_size
  steps_per_eval = input_pipeline.EVAL_IMAGES // batch_size
  steps_per_checkpoint = steps_per_epoch * 10
  num_steps = steps_per_epoch * num_epochs

  # Estimate compute / memory costs
  if jax.host_id() == 0 and FLAGS.estimate_compute_and_memory_cost:
    estimate_compute_and_memory_cost(
        image_size=image_size, model_dir=FLAGS.model_dir, hparams=hparams)
    logging.info('Writing training HLO and estimating compute/memory costs.')

  rng = random.PRNGKey(hparams.seed)
  model, variables = imagenet_train_utils.create_model(
      rng,
      device_batch_size,
      image_size,
      model_dtype,
      hparams=hparams.model_hparams,
      train=True,
      is_teacher=hparams.is_teacher)

  # pylint: disable=g-long-lambda
  if hparams.teacher_model == 'resnet50-8bit':
    teacher_config = w8a8auto_paper_config()
    teacher_hparams = os_hparams_utils.load_hparams_from_config_dict(
        hparams_config.TrainingHParams, models.ResNet.HParams, teacher_config)
    teacher_model, _ = imagenet_train_utils.create_model(
        rng,
        device_batch_size,
        image_size,
        model_dtype,
        hparams=teacher_hparams.model_hparams,
        train=False,
        is_teacher=True)  # teacher model does not need to be trainable
    # Directory where checkpoints are saved
    ckpt_model_dir = FLAGS.resnet508b_ckpt_path
    # will restore to best checkpoint
    state_load = checkpoints.restore_checkpoint(ckpt_model_dir, None)
    teacher_variables = {'params': state_load['optimizer']['target']}
    teacher_variables.update(state_load['model_state'])
    # create a dictionary for better argument passing
    teacher = {
        'model':
            lambda var, img, labels: jax.nn.softmax(
                teacher_model.apply(var, img)),
        'variables':
            teacher_variables,
    }
  elif hparams.teacher_model == 'labels':
    teacher = {
        'model':
            lambda var, img, labels: common_utils.onehot(
                labels, num_classes=1000),
        'variables': {},  # no need of variables in this case
    }
  else:
    raise ValueError('The specified teacher model is not supported.')

  model_state, params = variables.pop('params')
  if hparams.optimizer == 'sgd':
    optimizer = optim.Momentum(
        beta=hparams.momentum, nesterov=True).create(params)
  elif hparams.optimizer == 'adam':
    optimizer = optim.Adam(
        beta1=hparams.adam.beta1, beta2=hparams.adam.beta2).create(params)
  else:
    raise ValueError('Optimizer type is not supported.')
  state = imagenet_train_utils.TrainState(
      step=0,
      optimizer=optimizer,
      model_state=model_state,
      dynamic_scale=dynamic_scale)
  del params, model_state  # do not keep a copy of the initial model

  state = restore_checkpoint(state)
  step_offset = int(state.step)  # step_offset > 0 if restarting from checkpoint
  state = jax_utils.replicate(state)

  base_learning_rate = hparams.base_learning_rate * batch_size / 256.
  learning_rate_fn = create_learning_rate_fn(base_learning_rate,
                                             steps_per_epoch,
                                             hparams.lr_scheduler,
                                             batch_size)

  p_train_step = jax.pmap(
      functools.partial(
          imagenet_train_utils.train_step,
          model,
          learning_rate_fn=learning_rate_fn,
          teacher=teacher),
      axis_name='batch',
      static_broadcasted_argnums=(2, 3, 4))
  p_eval_step = jax.pmap(
      functools.partial(imagenet_train_utils.eval_step, model),
      axis_name='batch',
      static_broadcasted_argnums=(2,))

  epoch_metrics = []
  state_dict_summary_all = []
  state_dict_keys = _get_state_dict_keys_from_flags()
  t_loop_start = time.time()
  last_log_step = 0
  for step, batch in zip(range(step_offset, num_steps), train_iter):
    if hparams.early_stop_steps >= 0 and step > hparams.early_stop_steps * steps_per_epoch:
      break
    update_bounds = train_utils.should_update_bounds(
        hparams.activation_bound_update_freq,
        hparams.activation_bound_start_step, step)
    # and pass the result bool value to p_train_step
    # The function should take hparams.weight_quant_start_step as inputs
    quantize_weights = train_utils.should_quantize_weights(
        hparams.weight_quant_start_step, step // steps_per_epoch)
    state, metrics = p_train_step(state, batch, hparams, update_bounds,
                                  quantize_weights)

    state_dict_summary = summary_utils.get_state_dict_summary(
        state.model_state, state_dict_keys)
    state_dict_summary_all.append(state_dict_summary)

    epoch_metrics.append(metrics)
    def should_log(step):
      epoch_no = step // steps_per_epoch
      step_in_epoch = step - epoch_no * steps_per_epoch
      do_log = False
      do_log = do_log or (step + 1 == num_steps)  # log at the end
      end_of_train = step / num_steps > 0.9
      do_log = do_log or ((step_in_epoch %
                           (steps_per_epoch // 4) == 0) and not end_of_train)
      do_log = do_log or ((step_in_epoch %
                           (steps_per_epoch // 16) == 0) and end_of_train)
      return do_log

    if should_log(step):
      epoch = step // steps_per_epoch
      epoch_metrics = common_utils.get_metrics(epoch_metrics)
      summary = jax.tree_map(lambda x: x.mean(), epoch_metrics)
      logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                   summary['loss'], summary['accuracy'] * 100)
      steps_per_sec = (step - last_log_step) / (time.time() - t_loop_start)
      last_log_step = step
      t_loop_start = time.time()

      # Write to TensorBoard
      state_dict_summary_all = common_utils.get_metrics(state_dict_summary_all)
      if jax.host_id() == 0:
        for key, vals in epoch_metrics.items():
          tag = 'train_%s' % key
          for i, val in enumerate(vals):
            summary_writer.scalar(tag, val, step - len(vals) + i + 1)
        summary_writer.scalar('steps per second', steps_per_sec, step)

        if FLAGS.write_summary:
          summary_utils.write_state_dict_summaries_to_tb(
              state_dict_summary_all, summary_writer,
              FLAGS.state_dict_summary_freq, step)

      state_dict_summary_all = []
      epoch_metrics = []
      eval_metrics = []

      # sync batch statistics across replicas
      state = imagenet_train_utils.sync_batch_stats(state)
      for _ in range(steps_per_eval):
        eval_batch = next(eval_iter)
        metrics = p_eval_step(state, eval_batch, quantize_weights)
        eval_metrics.append(metrics)
      eval_metrics = common_utils.get_metrics(eval_metrics)
      summary = jax.tree_map(lambda x: x.mean(), eval_metrics)
      logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                   summary['loss'], summary['accuracy'] * 100)
      if jax.host_id() == 0:
        for key, val in eval_metrics.items():
          tag = 'eval_%s' % key
          summary_writer.scalar(tag, val.mean(), step)
        summary_writer.flush()
    if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps:
      state = imagenet_train_utils.sync_batch_stats(state)
      save_checkpoint(state)

  # Wait until computations are done before exiting
  jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
Exemple #6
0
def main(unused_argv):
    rng = random.PRNGKey(20200823)
    # Shift the numpy random seed by host_id() to shuffle data loaded by different
    # hosts.
    np.random.seed(20201473 + jax.host_id())

    if FLAGS.config is not None:
        utils.update_flags(FLAGS)
    if FLAGS.batch_size % jax.device_count() != 0:
        raise ValueError(
            "Batch size must be divisible by the number of devices.")
    if FLAGS.train_dir is None:
        raise ValueError("train_dir must be set. None set now.")
    if FLAGS.data_dir is None:
        raise ValueError("data_dir must be set. None set now.")
    dataset = datasets.get_dataset("train", FLAGS)
    test_dataset = datasets.get_dataset("test", FLAGS)

    rng, key = random.split(rng)
    model, variables = models.get_model(key, dataset.peek(), FLAGS)
    dynamic_scale = optim.DynamicScale()
    optimizer = flax.optim.Adam(FLAGS.lr_init).create(variables)
    state = utils.TrainState(optimizer=optimizer, dynamic_scale=dynamic_scale)
    del optimizer, variables

    learning_rate_fn = functools.partial(utils.learning_rate_decay,
                                         lr_init=FLAGS.lr_init,
                                         lr_final=FLAGS.lr_final,
                                         max_steps=FLAGS.max_steps,
                                         lr_delay_steps=FLAGS.lr_delay_steps,
                                         lr_delay_mult=FLAGS.lr_delay_mult)

    train_pstep = jax.pmap(functools.partial(train_step, model),
                           axis_name="batch",
                           in_axes=(0, 0, 0, None),
                           donate_argnums=(2, ))

    def render_fn(variables, key_0, key_1, rays):
        return jax.lax.all_gather(model.apply(variables, key_0, key_1, rays,
                                              FLAGS.randomized),
                                  axis_name="batch")

    render_pfn = jax.pmap(
        render_fn,
        in_axes=(None, None, None, 0),  # Only distribute the data input.
        donate_argnums=(3, ),
        axis_name="batch",
    )

    # Compiling to the CPU because it's faster and more accurate.
    ssim_fn = jax.jit(functools.partial(utils.compute_ssim, max_val=1.),
                      backend="cpu")

    if not utils.isdir(FLAGS.train_dir):
        utils.makedirs(FLAGS.train_dir)
    state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
    # Resume training a the step of the last checkpoint.
    init_step = state.optimizer.state.step + 1
    state = flax.jax_utils.replicate(state)

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(FLAGS.train_dir)

    # Prefetch_buffer_size = 3 x batch_size
    pdataset = flax.jax_utils.prefetch_to_device(dataset, 3)
    n_local_devices = jax.local_device_count()
    rng = rng + jax.host_id()  # Make random seed separate across hosts.
    keys = random.split(rng, n_local_devices)  # For pmapping RNG keys.
    gc.disable()  # Disable automatic garbage collection for efficiency.
    stats_trace = []
    reset_timer = True
    for step, batch in zip(range(init_step, FLAGS.max_steps + 1), pdataset):
        if reset_timer:
            t_loop_start = time.time()
            reset_timer = False
        lr = learning_rate_fn(step)
        state, stats, keys = train_pstep(keys, state, batch, lr)
        if jax.host_id() == 0:
            stats_trace.append(stats)
        if step % FLAGS.gc_every == 0:
            gc.collect()

        # Log training summaries. This is put behind a host_id check because in
        # multi-host evaluation, all hosts need to run inference even though we
        # only use host 0 to record results.
        if jax.host_id() == 0:
            if step % FLAGS.print_every == 0:
                summary_writer.scalar("train_loss", stats.loss[0], step)
                summary_writer.scalar("train_psnr", stats.psnr[0], step)
                summary_writer.scalar("train_loss_coarse", stats.loss_c[0],
                                      step)
                summary_writer.scalar("train_psnr_coarse", stats.psnr_c[0],
                                      step)
                summary_writer.scalar("weight_l2", stats.weight_l2[0], step)
                avg_loss = np.mean(
                    np.concatenate([s.loss for s in stats_trace]))
                avg_psnr = np.mean(
                    np.concatenate([s.psnr for s in stats_trace]))
                stats_trace = []
                summary_writer.scalar("train_avg_loss", avg_loss, step)
                summary_writer.scalar("train_avg_psnr", avg_psnr, step)
                summary_writer.scalar("learning_rate", lr, step)
                steps_per_sec = FLAGS.print_every / (time.time() -
                                                     t_loop_start)
                reset_timer = True
                rays_per_sec = FLAGS.batch_size * steps_per_sec
                summary_writer.scalar("train_steps_per_sec", steps_per_sec,
                                      step)
                summary_writer.scalar("train_rays_per_sec", rays_per_sec, step)
                precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1
                print(("{:" + "{:d}".format(precision) + "d}").format(step) +
                      f"/{FLAGS.max_steps:d}: " +
                      f"i_loss={stats.loss[0]:0.4f}, " +
                      f"avg_loss={avg_loss:0.4f}, " +
                      f"weight_l2={stats.weight_l2[0]:0.2e}, " +
                      f"lr={lr:0.2e}, " + f"{rays_per_sec:0.0f} rays/sec")
            if step % FLAGS.save_every == 0:
                state_to_save = jax.device_get(
                    jax.tree_map(lambda x: x[0], state))
                checkpoints.save_checkpoint(FLAGS.train_dir,
                                            state_to_save,
                                            int(step),
                                            keep=100)

        # Test-set evaluation.
        if FLAGS.render_every > 0 and step % FLAGS.render_every == 0:
            # We reuse the same random number generator from the optimization step
            # here on purpose so that the visualization matches what happened in
            # training.
            t_eval_start = time.time()
            eval_variables = jax.device_get(jax.tree_map(
                lambda x: x[0], state)).optimizer.target
            test_case = next(test_dataset)
            pred_color, pred_disp, pred_acc = utils.render_image(
                functools.partial(render_pfn, eval_variables),
                test_case["rays"],
                keys[0],
                FLAGS.dataset == "llff",
                chunk=FLAGS.chunk)

            # Log eval summaries on host 0.
            if jax.host_id() == 0:
                psnr = utils.compute_psnr(
                    ((pred_color - test_case["pixels"])**2).mean())
                ssim = ssim_fn(pred_color, test_case["pixels"])
                eval_time = time.time() - t_eval_start
                num_rays = jnp.prod(
                    jnp.array(test_case["rays"].directions.shape[:-1]))
                rays_per_sec = num_rays / eval_time
                summary_writer.scalar("test_rays_per_sec", rays_per_sec, step)
                print(
                    f"Eval {step}: {eval_time:0.3f}s., {rays_per_sec:0.0f} rays/sec"
                )
                summary_writer.scalar("test_psnr", psnr, step)
                summary_writer.scalar("test_ssim", ssim, step)
                summary_writer.image("test_pred_color", pred_color, step)
                summary_writer.image("test_pred_disp", pred_disp, step)
                summary_writer.image("test_pred_acc", pred_acc, step)
                summary_writer.image("test_target", test_case["pixels"], step)

    if FLAGS.max_steps % FLAGS.save_every != 0:
        state = jax.device_get(jax.tree_map(lambda x: x[0], state))
        checkpoints.save_checkpoint(FLAGS.train_dir,
                                    state,
                                    int(FLAGS.max_steps),
                                    keep=100)
Exemple #7
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.enable_v2_behavior()
    # make sure tf does not allocate gpu memory
    tf.config.experimental.set_visible_devices([], 'GPU')

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(FLAGS.model_dir)

    rng = random.PRNGKey(0)

    image_size = 224

    batch_size = FLAGS.batch_size
    if batch_size % jax.device_count() > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')
    local_batch_size = batch_size // jax.host_count()
    device_batch_size = batch_size // jax.device_count()

    platform = jax.local_devices()[0].platform

    dynamic_scale = None
    if FLAGS.half_precision:
        if platform == 'tpu':
            model_dtype = jnp.bfloat16
            input_dtype = tf.bfloat16
        else:
            model_dtype = jnp.float16
            input_dtype = tf.float16
            dynamic_scale = optim.DynamicScale()
    else:
        model_dtype = jnp.float32
        input_dtype = tf.float32

    train_iter = imagenet_train_utils.create_input_iter(local_batch_size,
                                                        FLAGS.data_dir,
                                                        image_size,
                                                        input_dtype,
                                                        train=True,
                                                        cache=FLAGS.cache)
    eval_iter = imagenet_train_utils.create_input_iter(local_batch_size,
                                                       FLAGS.data_dir,
                                                       image_size,
                                                       input_dtype,
                                                       train=False,
                                                       cache=FLAGS.cache)

    # Create the hyperparameter object
    if FLAGS.hparams_config_dict:
        # In this case, there are multiple training configs defined in the config
        # dict, so we pull out the one this training run should use.
        if 'configs' in FLAGS.hparams_config_dict:
            hparams_config_dict = FLAGS.hparams_config_dict.configs[
                FLAGS.config_idx]
        else:
            hparams_config_dict = FLAGS.hparams_config_dict
        hparams = os_hparams_utils.load_hparams_from_config_dict(
            hparams_config.TrainingHParams, models.ResNet.HParams,
            hparams_config_dict)
    else:
        raise ValueError('Please provide a base config dict.')

    os_hparams_utils.write_hparams_to_file_with_host_id_check(
        hparams, FLAGS.model_dir)

    # get num_epochs from hparam instead of FLAGS
    num_epochs = hparams.lr_scheduler.num_epochs
    steps_per_epoch = input_pipeline.TRAIN_IMAGES // batch_size
    steps_per_eval = input_pipeline.EVAL_IMAGES // batch_size
    steps_per_checkpoint = steps_per_epoch * 10
    num_steps = steps_per_epoch * num_epochs

    # Estimate compute / memory costs
    if jax.host_id() == 0:
        estimate_compute_and_memory_cost(image_size=image_size,
                                         model_dir=FLAGS.model_dir,
                                         hparams=hparams)
        logging.info(
            'Writing training HLO and estimating compute/memory costs.')

    model, variables = imagenet_train_utils.create_model(
        rng,
        device_batch_size,
        image_size,
        model_dtype,
        hparams=hparams.model_hparams,
        train=True)
    model_state, params = variables.pop('params')
    if hparams.optimizer == 'sgd':
        optimizer = optim.Momentum(beta=hparams.momentum,
                                   nesterov=True).create(params)
    elif hparams.optimizer == 'adam':
        optimizer = optim.Adam(beta1=hparams.adam.beta1,
                               beta2=hparams.adam.beta2).create(params)
    else:
        raise ValueError('Optimizer type is not supported.')
    state = imagenet_train_utils.TrainState(step=0,
                                            optimizer=optimizer,
                                            model_state=model_state,
                                            dynamic_scale=dynamic_scale)
    del params, model_state  # do not keep a copy of the initial model

    state = restore_checkpoint(state)
    step_offset = int(
        state.step)  # step_offset > 0 if restarting from checkpoint
    state = jax_utils.replicate(state)

    base_learning_rate = hparams.base_learning_rate * batch_size / 256.
    learning_rate_fn = create_learning_rate_fn(base_learning_rate,
                                               steps_per_epoch,
                                               hparams.lr_scheduler)

    p_train_step = jax.pmap(functools.partial(
        imagenet_train_utils.train_step,
        model,
        learning_rate_fn=learning_rate_fn),
                            axis_name='batch',
                            static_broadcasted_argnums=(2, 3))
    p_eval_step = jax.pmap(functools.partial(imagenet_train_utils.eval_step,
                                             model),
                           axis_name='batch')

    epoch_metrics = []
    state_dict_summary_all = []
    state_dict_keys = _get_state_dict_keys_from_flags()
    t_loop_start = time.time()
    for step, batch in zip(range(step_offset, num_steps), train_iter):
        if hparams.early_stop_steps >= 0 and step > hparams.early_stop_steps:
            break
        update_bounds = train_utils.should_update_bounds(
            hparams.activation_bound_update_freq,
            hparams.activation_bound_start_step, step)
        state, metrics = p_train_step(state, batch, hparams, update_bounds)

        state_dict_summary = summary_utils.get_state_dict_summary(
            state.model_state, state_dict_keys)
        state_dict_summary_all.append(state_dict_summary)

        epoch_metrics.append(metrics)
        if (step + 1) % steps_per_epoch == 0:
            epoch = step // steps_per_epoch
            epoch_metrics = common_utils.get_metrics(epoch_metrics)
            summary = jax.tree_map(lambda x: x.mean(), epoch_metrics)
            logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                         summary['loss'], summary['accuracy'] * 100)
            steps_per_sec = steps_per_epoch / (time.time() - t_loop_start)
            t_loop_start = time.time()

            # Write to TensorBoard
            state_dict_summary_all = common_utils.get_metrics(
                state_dict_summary_all)
            if jax.host_id() == 0:
                for key, vals in epoch_metrics.items():
                    tag = 'train_%s' % key
                    for i, val in enumerate(vals):
                        summary_writer.scalar(tag, val,
                                              step - len(vals) + i + 1)
                summary_writer.scalar('steps per second', steps_per_sec, step)

                summary_utils.write_state_dict_summaries_to_tb(
                    state_dict_summary_all, summary_writer,
                    FLAGS.state_dict_summary_freq, step)

            state_dict_summary_all = []
            epoch_metrics = []
            eval_metrics = []

            # sync batch statistics across replicas
            state = imagenet_train_utils.sync_batch_stats(state)
            for _ in range(steps_per_eval):
                eval_batch = next(eval_iter)
                metrics = p_eval_step(state, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            summary = jax.tree_map(lambda x: x.mean(), eval_metrics)
            logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                         summary['loss'], summary['accuracy'] * 100)
            if jax.host_id() == 0:
                for key, val in eval_metrics.items():
                    tag = 'eval_%s' % key
                    summary_writer.scalar(tag, val.mean(), step)
                summary_writer.flush()
        if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps:
            state = imagenet_train_utils.sync_batch_stats(state)
            save_checkpoint(state)

    # Wait until computations are done before exiting
    jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
Exemple #8
0
def train_and_evaluate(model_dir: str,
                       batch_size: int,
                       num_epochs: int,
                       learning_rate: float,
                       momentum: float,
                       cache: bool,
                       half_precision: bool,
                       num_train_steps: int = -1,
                       num_eval_steps: int = -1):
    """Runs model training and evaluation loop.

  Args:
    model_dir: Directory where the checkpoints and tensorboard summaries
      should be written to.
    batch_size: Batch size of the input.
    num_epochs: Number of epochs to cycle through before stopping.
    learning_rate: The learning rate in case you have batch size 256.
      The effective learning rate is scaled linearly to the batch size.
    momentum: Momentum value for the momentum optimizer.
    cache: Determines whether the dataset should be cached.
    half_precision: Determines whether bfloat16/float16 should be used
      instead of float32.
    num_train_steps: Number of trainings steps to be executed in a
      single epoch. Default = -1 signifies using the entire TRAIN split.
    num_eval_steps: Number of evaluation steps to be executed in a
      single epoch. Default = -1 signifies using the entire VALIDATION split.
  """
    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(model_dir)

    rng = random.PRNGKey(0)

    image_size = 224

    if batch_size % jax.device_count() > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')
    local_batch_size = batch_size // jax.host_count()
    device_batch_size = batch_size // jax.device_count()

    platform = jax.local_devices()[0].platform

    dynamic_scale = None
    if half_precision:
        if platform == 'tpu':
            model_dtype = jnp.bfloat16
            input_dtype = tf.bfloat16
        else:
            model_dtype = jnp.float16
            input_dtype = tf.float16
            dynamic_scale = optim.DynamicScale()
    else:
        model_dtype = jnp.float32
        input_dtype = tf.float32

    dataset_builder = tfds.builder('imagenet2012:5.*.*')
    train_iter = create_input_iter(dataset_builder,
                                   local_batch_size,
                                   image_size,
                                   input_dtype,
                                   train=True,
                                   cache=cache)
    eval_iter = create_input_iter(dataset_builder,
                                  local_batch_size,
                                  image_size,
                                  input_dtype,
                                  train=False,
                                  cache=cache)

    if num_train_steps == -1:
        steps_per_epoch = \
            dataset_builder.info.splits['train'].num_examples // batch_size
    else:
        steps_per_epoch = num_train_steps

    if num_eval_steps == -1:
        steps_per_eval = \
            dataset_builder.info.splits['validation'].num_examples // batch_size
    else:
        steps_per_eval = num_eval_steps

    steps_per_checkpoint = steps_per_epoch * 10
    num_steps = steps_per_epoch * num_epochs

    base_learning_rate = learning_rate * batch_size / 256.

    model, model_state = create_model(rng, device_batch_size, image_size,
                                      model_dtype)
    optimizer = optim.Momentum(beta=momentum, nesterov=True).create(model)
    state = TrainState(step=0,
                       optimizer=optimizer,
                       model_state=model_state,
                       dynamic_scale=dynamic_scale)
    del model, model_state  # do not keep a copy of the initial model

    state = restore_checkpoint(model_dir, state)
    step_offset = int(
        state.step)  # step_offset > 0 if restarting from checkpoint
    state = jax_utils.replicate(state)

    learning_rate_fn = create_learning_rate_fn(base_learning_rate,
                                               steps_per_epoch, num_epochs)

    p_train_step = jax.pmap(functools.partial(
        train_step, learning_rate_fn=learning_rate_fn),
                            axis_name='batch')
    p_eval_step = jax.pmap(eval_step, axis_name='batch')

    epoch_metrics = []
    t_loop_start = time.time()
    for step, batch in zip(range(step_offset, num_steps), train_iter):
        state, metrics = p_train_step(state, batch)
        epoch_metrics.append(metrics)
        if (step + 1) % steps_per_epoch == 0:
            epoch = step // steps_per_epoch
            epoch_metrics = common_utils.get_metrics(epoch_metrics)
            summary = jax.tree_map(lambda x: x.mean(), epoch_metrics)
            logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                         summary['loss'], summary['accuracy'] * 100)
            steps_per_sec = steps_per_epoch / (time.time() - t_loop_start)
            t_loop_start = time.time()
            if jax.host_id() == 0:
                for key, vals in epoch_metrics.items():
                    tag = 'train_%s' % key
                    for i, val in enumerate(vals):
                        summary_writer.scalar(tag, val,
                                              step - len(vals) + i + 1)
                summary_writer.scalar('steps per second', steps_per_sec, step)

            epoch_metrics = []
            eval_metrics = []

            # sync batch statistics across replicas
            state = sync_batch_stats(state)
            for _ in range(steps_per_eval):
                eval_batch = next(eval_iter)
                metrics = p_eval_step(state, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            summary = jax.tree_map(lambda x: x.mean(), eval_metrics)
            logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                         summary['loss'], summary['accuracy'] * 100)
            if jax.host_id() == 0:
                for key, val in eval_metrics.items():
                    tag = 'eval_%s' % key
                    summary_writer.scalar(tag, val.mean(), step)
                summary_writer.flush()
        if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps:
            state = sync_batch_stats(state)
            save_checkpoint(model_dir, state)

    # Wait until computations are done before exiting
    jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
Exemple #9
0
def train_and_evaluate(config: ml_collections.ConfigDict, model_dir: str):
  """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.
    model_dir: Directory where the tensorboard summaries are written to.
  """

  if jax.host_id() == 0:
    summary_writer = tensorboard.SummaryWriter(model_dir)

  rng = random.PRNGKey(0)

  image_size = 224

  if config.batch_size % jax.device_count() > 0:
    raise ValueError('Batch size must be divisible by the number of devices')
  local_batch_size = config.batch_size // jax.host_count()

  platform = jax.local_devices()[0].platform

  dynamic_scale = None
  if config.half_precision:
    if platform == 'tpu':
      input_dtype = tf.bfloat16
    else:
      input_dtype = tf.float16
      dynamic_scale = optim.DynamicScale()
  else:
    input_dtype = tf.float32

  dataset_builder = tfds.builder('imagenet2012:5.*.*')
  train_iter = create_input_iter(
      dataset_builder, local_batch_size, image_size, input_dtype, train=True,
      cache=config.cache)
  eval_iter = create_input_iter(
      dataset_builder, local_batch_size, image_size, input_dtype, train=False,
      cache=config.cache)

  steps_per_epoch = (
      dataset_builder.info.splits['train'].num_examples // config.batch_size
  )

  if config.num_train_steps == -1:
    num_steps = steps_per_epoch * config.num_epochs
  else:
    num_steps = config.num_train_steps

  if config.steps_per_eval == -1:
    num_validation_examples = dataset_builder.info.splits['train'].num_examples
    steps_per_eval = num_validation_examples // config.batch_size
  else:
    steps_per_eval = config.steps_per_eval

  steps_per_checkpoint = steps_per_epoch * 10

  base_learning_rate = config.learning_rate * config.batch_size / 256.

  variables = initialized(rng, image_size, config.half_precision)
  optimizer = optim.Momentum(
      beta=config.momentum, nesterov=True).create(variables['params'])
  state = TrainState(
      step=0, optimizer=optimizer, batch_stats=variables['batch_stats'],
      dynamic_scale=dynamic_scale)

  state = restore_checkpoint(state, model_dir)
  # step_offset > 0 if restarting from checkpoint
  step_offset = int(state.step)
  state = jax_utils.replicate(state)

  learning_rate_fn = create_learning_rate_fn(
      base_learning_rate, steps_per_epoch, config.num_epochs)

  p_train_step = jax.pmap(
      functools.partial(train_step, half_precision=config.half_precision,
                        learning_rate_fn=learning_rate_fn),
      axis_name='batch')
  p_eval_step = jax.pmap(eval_step, axis_name='batch')

  epoch_metrics = []
  t_loop_start = time.time()
  for step, batch in zip(range(step_offset, num_steps), train_iter):
    state, metrics = p_train_step(state, batch)
    epoch_metrics.append(metrics)
    if (step + 1) % steps_per_epoch == 0:
      epoch = step // steps_per_epoch
      epoch_metrics = common_utils.get_metrics(epoch_metrics)
      summary = jax.tree_map(lambda x: x.mean(), epoch_metrics)
      logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f',
                   epoch, summary['loss'], summary['accuracy'] * 100)
      steps_per_sec = steps_per_epoch / (time.time() - t_loop_start)
      t_loop_start = time.time()
      if jax.host_id() == 0:
        for key, vals in epoch_metrics.items():
          tag = 'train_%s' % key
          for i, val in enumerate(vals):
            summary_writer.scalar(tag, val, step - len(vals) + i + 1)
        summary_writer.scalar('steps per second', steps_per_sec, step)

      epoch_metrics = []
      eval_metrics = []

      # sync batch statistics across replicas
      state = sync_batch_stats(state)
      for _ in range(steps_per_eval):
        eval_batch = next(eval_iter)
        metrics = p_eval_step(state, eval_batch, config.half_precision)
        eval_metrics.append(metrics)
      eval_metrics = common_utils.get_metrics(eval_metrics)
      summary = jax.tree_map(lambda x: x.mean(), eval_metrics)
      logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f',
                   epoch, summary['loss'], summary['accuracy'] * 100)
      if jax.host_id() == 0:
        for key, val in eval_metrics.items():
          tag = 'eval_%s' % key
          summary_writer.scalar(tag, val.mean(), step)
        summary_writer.flush()
    if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps:
      state = sync_batch_stats(state)
      save_checkpoint(state, model_dir)

  # Wait until computations are done before exiting
  jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()