コード例 #1
0
 def test_should_update_bounds(self, frequency, start_step, current_step,
                               should_update):
     self.assertEqual(
         train_utils.should_update_bounds(
             activation_bound_update_freq=frequency,
             activation_bound_start_step=start_step,
             step=current_step), should_update)
コード例 #2
0
ファイル: train.py プロジェクト: tallamjr/google-research
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  tf.enable_v2_behavior()
  # make sure tf does not allocate gpu memory
  tf.config.experimental.set_visible_devices([], 'GPU')

  if jax.host_id() == 0:
    summary_writer = tensorboard.SummaryWriter(FLAGS.model_dir)

  image_size = 224

  batch_size = FLAGS.batch_size
  if batch_size % jax.device_count() > 0:
    raise ValueError('Batch size must be divisible by the number of devices')
  local_batch_size = batch_size // jax.host_count()
  device_batch_size = batch_size // jax.device_count()

  platform = jax.local_devices()[0].platform

  dynamic_scale = None
  if FLAGS.half_precision:
    if platform == 'tpu':
      model_dtype = jnp.bfloat16
      input_dtype = tf.bfloat16
    else:
      model_dtype = jnp.float16
      input_dtype = tf.float16
      dynamic_scale = optim.DynamicScale()
  else:
    model_dtype = jnp.float32
    input_dtype = tf.float32

  train_iter = imagenet_train_utils.create_input_iter(
      local_batch_size,
      FLAGS.data_dir,
      image_size,
      input_dtype,
      train=True,
      cache=FLAGS.cache)
  eval_iter = imagenet_train_utils.create_input_iter(
      local_batch_size,
      FLAGS.data_dir,
      image_size,
      input_dtype,
      train=False,
      cache=FLAGS.cache)

  # Create the hyperparameter object
  if FLAGS.hparams_config_dict:
    # In this case, there are multiple training configs defined in the config
    # dict, so we pull out the one this training run should use.
    if 'configs' in FLAGS.hparams_config_dict:
      hparams_config_dict = FLAGS.hparams_config_dict.configs[FLAGS.config_idx]
    else:
      hparams_config_dict = FLAGS.hparams_config_dict
    hparams = os_hparams_utils.load_hparams_from_config_dict(
        hparams_config.TrainingHParams, models.ResNet.HParams,
        hparams_config_dict)
  else:
    raise ValueError('Please provide a base config dict.')

  os_hparams_utils.write_hparams_to_file_with_host_id_check(
      hparams, FLAGS.model_dir)

  # get num_epochs from hparam instead of FLAGS
  num_epochs = hparams.lr_scheduler.num_epochs
  steps_per_epoch = input_pipeline.TRAIN_IMAGES // batch_size
  steps_per_eval = input_pipeline.EVAL_IMAGES // batch_size
  steps_per_checkpoint = steps_per_epoch * 10
  num_steps = steps_per_epoch * num_epochs

  # Estimate compute / memory costs
  if jax.host_id() == 0 and FLAGS.estimate_compute_and_memory_cost:
    estimate_compute_and_memory_cost(
        image_size=image_size, model_dir=FLAGS.model_dir, hparams=hparams)
    logging.info('Writing training HLO and estimating compute/memory costs.')

  rng = random.PRNGKey(hparams.seed)
  model, variables = imagenet_train_utils.create_model(
      rng,
      device_batch_size,
      image_size,
      model_dtype,
      hparams=hparams.model_hparams,
      train=True,
      is_teacher=hparams.is_teacher)

  # pylint: disable=g-long-lambda
  if hparams.teacher_model == 'resnet50-8bit':
    teacher_config = w8a8auto_paper_config()
    teacher_hparams = os_hparams_utils.load_hparams_from_config_dict(
        hparams_config.TrainingHParams, models.ResNet.HParams, teacher_config)
    teacher_model, _ = imagenet_train_utils.create_model(
        rng,
        device_batch_size,
        image_size,
        model_dtype,
        hparams=teacher_hparams.model_hparams,
        train=False,
        is_teacher=True)  # teacher model does not need to be trainable
    # Directory where checkpoints are saved
    ckpt_model_dir = FLAGS.resnet508b_ckpt_path
    # will restore to best checkpoint
    state_load = checkpoints.restore_checkpoint(ckpt_model_dir, None)
    teacher_variables = {'params': state_load['optimizer']['target']}
    teacher_variables.update(state_load['model_state'])
    # create a dictionary for better argument passing
    teacher = {
        'model':
            lambda var, img, labels: jax.nn.softmax(
                teacher_model.apply(var, img)),
        'variables':
            teacher_variables,
    }
  elif hparams.teacher_model == 'labels':
    teacher = {
        'model':
            lambda var, img, labels: common_utils.onehot(
                labels, num_classes=1000),
        'variables': {},  # no need of variables in this case
    }
  else:
    raise ValueError('The specified teacher model is not supported.')

  model_state, params = variables.pop('params')
  if hparams.optimizer == 'sgd':
    optimizer = optim.Momentum(
        beta=hparams.momentum, nesterov=True).create(params)
  elif hparams.optimizer == 'adam':
    optimizer = optim.Adam(
        beta1=hparams.adam.beta1, beta2=hparams.adam.beta2).create(params)
  else:
    raise ValueError('Optimizer type is not supported.')
  state = imagenet_train_utils.TrainState(
      step=0,
      optimizer=optimizer,
      model_state=model_state,
      dynamic_scale=dynamic_scale)
  del params, model_state  # do not keep a copy of the initial model

  state = restore_checkpoint(state)
  step_offset = int(state.step)  # step_offset > 0 if restarting from checkpoint
  state = jax_utils.replicate(state)

  base_learning_rate = hparams.base_learning_rate * batch_size / 256.
  learning_rate_fn = create_learning_rate_fn(base_learning_rate,
                                             steps_per_epoch,
                                             hparams.lr_scheduler,
                                             batch_size)

  p_train_step = jax.pmap(
      functools.partial(
          imagenet_train_utils.train_step,
          model,
          learning_rate_fn=learning_rate_fn,
          teacher=teacher),
      axis_name='batch',
      static_broadcasted_argnums=(2, 3, 4))
  p_eval_step = jax.pmap(
      functools.partial(imagenet_train_utils.eval_step, model),
      axis_name='batch',
      static_broadcasted_argnums=(2,))

  epoch_metrics = []
  state_dict_summary_all = []
  state_dict_keys = _get_state_dict_keys_from_flags()
  t_loop_start = time.time()
  last_log_step = 0
  for step, batch in zip(range(step_offset, num_steps), train_iter):
    if hparams.early_stop_steps >= 0 and step > hparams.early_stop_steps * steps_per_epoch:
      break
    update_bounds = train_utils.should_update_bounds(
        hparams.activation_bound_update_freq,
        hparams.activation_bound_start_step, step)
    # and pass the result bool value to p_train_step
    # The function should take hparams.weight_quant_start_step as inputs
    quantize_weights = train_utils.should_quantize_weights(
        hparams.weight_quant_start_step, step // steps_per_epoch)
    state, metrics = p_train_step(state, batch, hparams, update_bounds,
                                  quantize_weights)

    state_dict_summary = summary_utils.get_state_dict_summary(
        state.model_state, state_dict_keys)
    state_dict_summary_all.append(state_dict_summary)

    epoch_metrics.append(metrics)
    def should_log(step):
      epoch_no = step // steps_per_epoch
      step_in_epoch = step - epoch_no * steps_per_epoch
      do_log = False
      do_log = do_log or (step + 1 == num_steps)  # log at the end
      end_of_train = step / num_steps > 0.9
      do_log = do_log or ((step_in_epoch %
                           (steps_per_epoch // 4) == 0) and not end_of_train)
      do_log = do_log or ((step_in_epoch %
                           (steps_per_epoch // 16) == 0) and end_of_train)
      return do_log

    if should_log(step):
      epoch = step // steps_per_epoch
      epoch_metrics = common_utils.get_metrics(epoch_metrics)
      summary = jax.tree_map(lambda x: x.mean(), epoch_metrics)
      logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                   summary['loss'], summary['accuracy'] * 100)
      steps_per_sec = (step - last_log_step) / (time.time() - t_loop_start)
      last_log_step = step
      t_loop_start = time.time()

      # Write to TensorBoard
      state_dict_summary_all = common_utils.get_metrics(state_dict_summary_all)
      if jax.host_id() == 0:
        for key, vals in epoch_metrics.items():
          tag = 'train_%s' % key
          for i, val in enumerate(vals):
            summary_writer.scalar(tag, val, step - len(vals) + i + 1)
        summary_writer.scalar('steps per second', steps_per_sec, step)

        if FLAGS.write_summary:
          summary_utils.write_state_dict_summaries_to_tb(
              state_dict_summary_all, summary_writer,
              FLAGS.state_dict_summary_freq, step)

      state_dict_summary_all = []
      epoch_metrics = []
      eval_metrics = []

      # sync batch statistics across replicas
      state = imagenet_train_utils.sync_batch_stats(state)
      for _ in range(steps_per_eval):
        eval_batch = next(eval_iter)
        metrics = p_eval_step(state, eval_batch, quantize_weights)
        eval_metrics.append(metrics)
      eval_metrics = common_utils.get_metrics(eval_metrics)
      summary = jax.tree_map(lambda x: x.mean(), eval_metrics)
      logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                   summary['loss'], summary['accuracy'] * 100)
      if jax.host_id() == 0:
        for key, val in eval_metrics.items():
          tag = 'eval_%s' % key
          summary_writer.scalar(tag, val.mean(), step)
        summary_writer.flush()
    if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps:
      state = imagenet_train_utils.sync_batch_stats(state)
      save_checkpoint(state)

  # Wait until computations are done before exiting
  jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
コード例 #3
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.enable_v2_behavior()
    # make sure tf does not allocate gpu memory
    tf.config.experimental.set_visible_devices([], 'GPU')

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(FLAGS.model_dir)

    rng = random.PRNGKey(0)

    image_size = 224

    batch_size = FLAGS.batch_size
    if batch_size % jax.device_count() > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')
    local_batch_size = batch_size // jax.host_count()
    device_batch_size = batch_size // jax.device_count()

    platform = jax.local_devices()[0].platform

    dynamic_scale = None
    if FLAGS.half_precision:
        if platform == 'tpu':
            model_dtype = jnp.bfloat16
            input_dtype = tf.bfloat16
        else:
            model_dtype = jnp.float16
            input_dtype = tf.float16
            dynamic_scale = optim.DynamicScale()
    else:
        model_dtype = jnp.float32
        input_dtype = tf.float32

    train_iter = imagenet_train_utils.create_input_iter(local_batch_size,
                                                        FLAGS.data_dir,
                                                        image_size,
                                                        input_dtype,
                                                        train=True,
                                                        cache=FLAGS.cache)
    eval_iter = imagenet_train_utils.create_input_iter(local_batch_size,
                                                       FLAGS.data_dir,
                                                       image_size,
                                                       input_dtype,
                                                       train=False,
                                                       cache=FLAGS.cache)

    # Create the hyperparameter object
    if FLAGS.hparams_config_dict:
        # In this case, there are multiple training configs defined in the config
        # dict, so we pull out the one this training run should use.
        if 'configs' in FLAGS.hparams_config_dict:
            hparams_config_dict = FLAGS.hparams_config_dict.configs[
                FLAGS.config_idx]
        else:
            hparams_config_dict = FLAGS.hparams_config_dict
        hparams = os_hparams_utils.load_hparams_from_config_dict(
            hparams_config.TrainingHParams, models.ResNet.HParams,
            hparams_config_dict)
    else:
        raise ValueError('Please provide a base config dict.')

    os_hparams_utils.write_hparams_to_file_with_host_id_check(
        hparams, FLAGS.model_dir)

    # get num_epochs from hparam instead of FLAGS
    num_epochs = hparams.lr_scheduler.num_epochs
    steps_per_epoch = input_pipeline.TRAIN_IMAGES // batch_size
    steps_per_eval = input_pipeline.EVAL_IMAGES // batch_size
    steps_per_checkpoint = steps_per_epoch * 10
    num_steps = steps_per_epoch * num_epochs

    # Estimate compute / memory costs
    if jax.host_id() == 0:
        estimate_compute_and_memory_cost(image_size=image_size,
                                         model_dir=FLAGS.model_dir,
                                         hparams=hparams)
        logging.info(
            'Writing training HLO and estimating compute/memory costs.')

    model, variables = imagenet_train_utils.create_model(
        rng,
        device_batch_size,
        image_size,
        model_dtype,
        hparams=hparams.model_hparams,
        train=True)
    model_state, params = variables.pop('params')
    if hparams.optimizer == 'sgd':
        optimizer = optim.Momentum(beta=hparams.momentum,
                                   nesterov=True).create(params)
    elif hparams.optimizer == 'adam':
        optimizer = optim.Adam(beta1=hparams.adam.beta1,
                               beta2=hparams.adam.beta2).create(params)
    else:
        raise ValueError('Optimizer type is not supported.')
    state = imagenet_train_utils.TrainState(step=0,
                                            optimizer=optimizer,
                                            model_state=model_state,
                                            dynamic_scale=dynamic_scale)
    del params, model_state  # do not keep a copy of the initial model

    state = restore_checkpoint(state)
    step_offset = int(
        state.step)  # step_offset > 0 if restarting from checkpoint
    state = jax_utils.replicate(state)

    base_learning_rate = hparams.base_learning_rate * batch_size / 256.
    learning_rate_fn = create_learning_rate_fn(base_learning_rate,
                                               steps_per_epoch,
                                               hparams.lr_scheduler)

    p_train_step = jax.pmap(functools.partial(
        imagenet_train_utils.train_step,
        model,
        learning_rate_fn=learning_rate_fn),
                            axis_name='batch',
                            static_broadcasted_argnums=(2, 3))
    p_eval_step = jax.pmap(functools.partial(imagenet_train_utils.eval_step,
                                             model),
                           axis_name='batch')

    epoch_metrics = []
    state_dict_summary_all = []
    state_dict_keys = _get_state_dict_keys_from_flags()
    t_loop_start = time.time()
    for step, batch in zip(range(step_offset, num_steps), train_iter):
        if hparams.early_stop_steps >= 0 and step > hparams.early_stop_steps:
            break
        update_bounds = train_utils.should_update_bounds(
            hparams.activation_bound_update_freq,
            hparams.activation_bound_start_step, step)
        state, metrics = p_train_step(state, batch, hparams, update_bounds)

        state_dict_summary = summary_utils.get_state_dict_summary(
            state.model_state, state_dict_keys)
        state_dict_summary_all.append(state_dict_summary)

        epoch_metrics.append(metrics)
        if (step + 1) % steps_per_epoch == 0:
            epoch = step // steps_per_epoch
            epoch_metrics = common_utils.get_metrics(epoch_metrics)
            summary = jax.tree_map(lambda x: x.mean(), epoch_metrics)
            logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                         summary['loss'], summary['accuracy'] * 100)
            steps_per_sec = steps_per_epoch / (time.time() - t_loop_start)
            t_loop_start = time.time()

            # Write to TensorBoard
            state_dict_summary_all = common_utils.get_metrics(
                state_dict_summary_all)
            if jax.host_id() == 0:
                for key, vals in epoch_metrics.items():
                    tag = 'train_%s' % key
                    for i, val in enumerate(vals):
                        summary_writer.scalar(tag, val,
                                              step - len(vals) + i + 1)
                summary_writer.scalar('steps per second', steps_per_sec, step)

                summary_utils.write_state_dict_summaries_to_tb(
                    state_dict_summary_all, summary_writer,
                    FLAGS.state_dict_summary_freq, step)

            state_dict_summary_all = []
            epoch_metrics = []
            eval_metrics = []

            # sync batch statistics across replicas
            state = imagenet_train_utils.sync_batch_stats(state)
            for _ in range(steps_per_eval):
                eval_batch = next(eval_iter)
                metrics = p_eval_step(state, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            summary = jax.tree_map(lambda x: x.mean(), eval_metrics)
            logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                         summary['loss'], summary['accuracy'] * 100)
            if jax.host_id() == 0:
                for key, val in eval_metrics.items():
                    tag = 'eval_%s' % key
                    summary_writer.scalar(tag, val.mean(), step)
                summary_writer.flush()
        if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps:
            state = imagenet_train_utils.sync_batch_stats(state)
            save_checkpoint(state)

    # Wait until computations are done before exiting
    jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()