class FlaxOptimizersEquivalenceTest(chex.TestCase):

  def setUp(self):
    super().setUp()
    self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.]))
    self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.]))

  @parameterized.named_parameters(
      ('sgd',
       alias.sgd(LR),
       optim.GradientDescent(LR)),
      ('momentum',
       alias.sgd(LR, momentum=0.9),
       optim.Momentum(LR, beta=0.9)),  # Different names.
      ('nesterov_momentum',
       alias.sgd(LR, momentum=0.9, nesterov=True),
       optim.Momentum(LR, beta=0.9, nesterov=True)),
      ('rmsprop',
       alias.rmsprop(LR),
       optim.RMSProp(LR)),
      ('centered_rmsprop',
       alias.rmsprop(LR, centered=True),
       optim.RMSProp(LR, centered=True)),
      ('adam',
       alias.adam(LR),
       optim.Adam(LR)),
      ('adam_w',
       alias.adamw(LR, weight_decay=1e-4),
       optim.Adam(LR, weight_decay=1e-4)),  # Different name.
      ('adagrad',
       alias.adagrad(LR, initial_accumulator_value=0.),  # Different default!
       optim.Adagrad(LR)),
      ('lamb',
       alias.lamb(LR),
       optim.LAMB(LR)),
  )
  def test_flax_optim_equivalence(self, optax_optimizer, flax_optimizer):

    # flax/optim
    flax_params = self.init_params
    flax_optimizer = flax_optimizer.create(flax_params)
    for _ in range(STEPS):
      flax_optimizer = flax_optimizer.apply_gradient(
          self.per_step_updates)
      flax_params = flax_optimizer.target

    # optax
    optax_params = self.init_params
    state = optax_optimizer.init(optax_params)
    for _ in range(STEPS):
      updates, state = optax_optimizer.update(
          self.per_step_updates, state, optax_params)
      optax_params = update.apply_updates(optax_params, updates)

    # Check equivalence.
    chex.assert_tree_all_close(flax_params, optax_params, rtol=1e-4)
 def test_optimizer_serialization(self):
     rng = random.PRNGKey(0)
     module = nn.Dense.partial(features=1, kernel_init=nn.initializers.ones)
     _, initial_params = module.init_by_shape(rng, [((1, 1), jnp.float32)])
     model = nn.Model(module, initial_params)
     optim_def = optim.Momentum(learning_rate=1.)
     optimizer = optim_def.create(model)
     state = serialization.to_state_dict(optimizer)
     expected_state = {
         'target': {
             'params': {
                 'kernel': onp.ones((1, 1)),
                 'bias': onp.zeros((1, )),
             }
         },
         'state': {
             'step': 0,
             'param_states': {
                 'params': {
                     'kernel': {
                         'momentum': onp.zeros((1, 1))
                     },
                     'bias': {
                         'momentum': onp.zeros((1, ))
                     },
                 }
             }
         },
     }
     self.assertEqual(state, expected_state)
     state = jax.tree_map(lambda x: x + 1, expected_state)
     restored_optimizer = serialization.from_state_dict(optimizer, state)
     optimizer_plus1 = jax.tree_map(lambda x: x + 1, optimizer)
     self.assertEqual(restored_optimizer, optimizer_plus1)
Exemple #3
0
def fit_model_test():
    # We just check we can run the functions and that they return "something"
    N = 3
    D = 5
    C = 10
    model = ModelTest(nhidden=0, nclasses=C)
    rng = jax.random.PRNGKey(0)
    X = np.random.randn(N, D)
    y = np.random.choice(C, size=N, p=(1 / C) * np.ones(C))
    batch = {'X': X, 'y': y}
    params = model.init(rng, X)['params']
    metrics = eval_batch(model, params, batch)
    make_optimizer = optim.Momentum(learning_rate=0.1, beta=0.9)
    optimizer = make_optimizer.create(params)
    optimizer, metrics = train_batch(model, optimizer, batch)
    #print(optimizer)
    num_steps = 2

    def make_iter():
        while True:
            yield batch

    train_iter = make_iter()
    test_iter = make_iter()
    params, history = fit_model(model,
                                train_iter,
                                test_iter,
                                rng,
                                num_steps,
                                make_optimizer,
                                train_batch,
                                eval_batch,
                                print_every=1)
    print('test passed')
Exemple #4
0
def create_optimizer(params, learning_rate, beta, weight_decay):
    """Returns a momentum optimizer."""
    optimizer_def = optim.Momentum(learning_rate=learning_rate,
                                   beta=beta,
                                   weight_decay=weight_decay)
    optimizer = optimizer_def.create(params)
    return optimizer
Exemple #5
0
 def test_empty_optimizer(self):
   params = {}
   optimizer_def = optim.Momentum(learning_rate=0.1)
   optimizer = optimizer_def.create(params)
   new_optimizer = optimizer.apply_gradient({})
   expected_state = optim.OptimizerState(1, {})
   self.assertEqual(new_optimizer.state, expected_state)
Exemple #6
0
def create_optimizer(model: flax.nn.Model,
                     learning_rate: float,
                     beta: float = 0.9) -> flax.optim.Optimizer:
  """Creates an optimizer.

  Learning rate will be ignored when using a learning rate schedule.

  Args:
    model: The FLAX model to optimize.
    learning_rate: Learning rate for the gradient descent.
    beta: Momentum parameter.

  Returns:
    A SGD (or RMSProp) optimizer that targets the model.
  """
  if FLAGS.use_rmsprop:
    # We set beta2 and epsilon to the values used in the efficientnet paper.
    optimizer_def = efficientnet_optim.RMSProp(
        learning_rate=learning_rate, beta=beta, beta2=0.9, eps=0.001)
  else:
    optimizer_def = optim.Momentum(learning_rate=learning_rate,
                                   beta=beta,
                                   nesterov=True)
  optimizer = optimizer_def.create(model)
  return optimizer
Exemple #7
0
def create_optimizer(model, learning_rate, beta):
    optimizer_def = optim.Momentum(learning_rate=learning_rate,
                                   beta=beta,
                                   nesterov=True)
    optimizer = optimizer_def.create(model)
    optimizer = jax_utils.replicate(optimizer)
    return optimizer
Exemple #8
0
 def test_init_state(self):
   params = onp.zeros((1,))
   optimizer_def = optim.Momentum(learning_rate=0.1, beta=0.2)
   state = optimizer_def.init_state(params)
   expected_hyper_params = _MomentumHyperParams(0.1, 0.2, 0, False)
   self.assertEqual(optimizer_def.hyper_params, expected_hyper_params)
   expected_state = optim.OptimizerState(
       0, _MomentumParamState(onp.zeros((1,))))
   self.assertEqual(state, expected_state)
Exemple #9
0
 def test_create(self):
   params = onp.ones((1,))
   optimizer_def = optim.Momentum(learning_rate=0.1, beta=0.2)
   optimizer = optimizer_def.create(params)
   expected_state = optim.OptimizerState(
       0, _MomentumParamState(onp.zeros((1,))))
   self.assertEqual(optimizer.optimizer_def, optimizer_def)
   self.assertEqual(optimizer.state, expected_state)
   self.assertEqual(optimizer.target, params)
Exemple #10
0
 def test_optimizer_serialization_to_bytes(self):
   rng = random.PRNGKey(0)
   module = nn.Dense.partial(features=1, kernel_init=nn.initializers.ones)
   _, initial_params = module.init_by_shape(rng, [((1, 1), jnp.float32)])
   model = nn.Model(module, initial_params)
   optim_def = optim.Momentum(learning_rate=1.)
   optimizer = optim_def.create(model)
   serialized_bytes = serialization.to_bytes(optimizer)
   restored_optimizer = serialization.from_bytes(optimizer, serialized_bytes)
   self.assertEqual(restored_optimizer, optimizer)
Exemple #11
0
 def test_apply_gradient(self):
     optimizer_def = optim.Momentum(learning_rate=0.1, beta=0.2)
     params = np.ones((1, ))
     state = optim.OptimizerState(0, _MomentumParamState(np.array([1.])))
     grads = np.array([3.])
     new_params, new_state = optimizer_def.apply_gradient(
         optimizer_def.hyper_params, params, state, grads)
     expected_new_state = optim.OptimizerState(
         1, _MomentumParamState(np.array([3.2])))
     expected_new_params = np.array([1. - 0.32])
     self.assertEqual(new_params, expected_new_params)
     self.assertEqual(new_state, expected_new_state)
Exemple #12
0
def get_optimizer(hparams):
    """Constructs  the optimizer from the given HParams.

  Args:
    hparams: Hyper parameters.

  Returns:
    A flax optimizer.
  """
    if hparams.optimizer == 'sgd':
        return optimizers.GradientDescent(
            learning_rate=hparams.lr_hparams['initial_learning_rate'])
    if hparams.optimizer == 'nesterov':
        return optimizers.Momentum(
            learning_rate=hparams.lr_hparams['initial_learning_rate'],
            beta=hparams.opt_hparams.get('momentum', 0.9),
            weight_decay=hparams.opt_hparams.get('weight_decay', 0.0),
            nesterov=True)
    if hparams.optimizer == 'momentum':
        return optimizers.Momentum(
            learning_rate=hparams.lr_hparams['initial_learning_rate'],
            beta=hparams.opt_hparams.get('momentum', 0.9),
            weight_decay=hparams.opt_hparams.get('weight_decay', 0.0),
            nesterov=False)
    if hparams.optimizer == 'adam':
        return optimizers.Adam(
            learning_rate=hparams.lr_hparams['initial_learning_rate'],
            beta1=hparams.opt_hparams.get('beta1', 0.9),
            beta2=hparams.opt_hparams.get('beta2', 0.999),
            eps=hparams.opt_hparams.get('epsilon', 1e-8),
            weight_decay=hparams.opt_hparams.get('weight_decay', 0.0),
        )
    if hparams.optimizer == 'rmsprop':
        return optimizers.RMSProp(
            learning_rate=hparams.lr_hparams.get('initial_learning_rate'),
            beta2=hparams.opt_hparams.get('beta2', 0.9),
            eps=hparams.opt_hparams.get('epsilon', 1e-8))
    else:
        raise NotImplementedError('Optimizer {} not implemented'.format(
            hparams.optimizer))
Exemple #13
0
def fit_model(model,
              rng,
              num_steps,
              train_iter,
              test_iter=None,
              train_fn=update_classifier,
              test_fn=eval_classifier,
              make_optimizer=None,
              preprocess_train_batch=None,
              preprocess_test_batch=None,
              print_every=1,
              test_every=None):

    batch = next(train_iter)
    if preprocess_train_batch is not None:
        batch = preprocess_train_batch(batch, rng)
    X = batch['X']
    params = model.init(rng, X)['params']

    if make_optimizer is None:
        make_optimizer = optim.Momentum(learning_rate=0.1, beta=0.9)
    optimizer = make_optimizer.create(params)

    history = {
        'train_loss': [],
        'train_accuracy': [],
        'test_loss': [],
        'test_accuracy': []
    }
    if test_iter is None:
        test_every = 0
    if test_every is None:
        test_every = print_every

    for step in range(num_steps):
        batch = next(train_iter)
        if preprocess_train_batch is not None:
            batch = preprocess_train_batch(batch, rng)
        optimizer, train_metrics = train_fn(model, optimizer, batch)
        if (print_every > 0) & (step % print_every == 0):
            print('train step: {:d}, loss: {:0.4f}, accuracy: {:0.2f}'.format(
                step, train_metrics['loss'], train_metrics['accuracy']))

        if (test_every > 0) & (step % test_every == 0):
            batch = next(test_iter)
            if preprocess_test_batch is not None:
                batch = preprocess_test_batch(batch, rng)
            test_metrics = test_fn(model, optimizer.target, batch)
            history = append_history(history, train_metrics, test_metrics)

    params = optimizer.target
    return params, history
def make_optimizer(optimizer: str,
                   learning_rate: float,
                   weight_decay: float = 5.0e-4):
    if optimizer == 'SGD':
        return optim.Optimizer(lerning_rate=learning_rate)
    elif optimizer == 'Momentum':
        return optim.Momentum(learning_rate=learning_rate,
                              weight_decay=weight_decay)
    elif optimizer == 'Adam':
        return optim.Adam(learning_rate=learning_rate,
                          weight_decay=weight_decay)
    else:
        raise ValueError('Unknown optimizer spec.')
Exemple #15
0
  def test_compute_grad(self):
    params = onp.ones(())
    optimizer_def = optim.Momentum(learning_rate=0.1, beta=0.2)
    optimizer = optimizer_def.create(params)
    def loss_fn(x):
      return 2. * x
    loss, grad = optimizer.compute_gradient(loss_fn)
    self.assertEqual(loss, 2.)
    self.assertEqual(grad, 2.)

    def loss_aux_fn(x):
      return 3. * x, 4.
    loss, aux, grad = optimizer.compute_gradient(loss_aux_fn)
    self.assertEqual(loss, 3.)
    self.assertEqual(grad, 3.)
    self.assertEqual(aux, 4.)
Exemple #16
0
def test():
    # We just check we can run the functions and that they return "something"
    print('testing fit-flax')
    N = 3
    D = 5
    C = 10
    model = ModelTest(nhidden=0, nclasses=C)
    rng = jax.random.PRNGKey(0)
    X = np.random.randn(N, D)
    y = np.random.choice(C, size=N, p=(1 / C) * np.ones(C))
    batch = {'X': X, 'y': y}
    params = model.init(rng, X)['params']

    # test apply
    logprobs = model.apply({'params': params}, batch['X'])
    assert logprobs.shape == (N, C)

    # test loss
    labels = batch['y']
    loss = softmax_cross_entropy(logprobs, labels)
    assert loss.shape == ()

    # test test_fn
    metrics = eval_classifier(model, params, batch)
    assert np.allclose(loss, metrics['loss'])

    # test train_fn
    make_optimizer = optim.Momentum(learning_rate=0.1, beta=0.9)
    optimizer = make_optimizer.create(params)
    optimizer, metrics = update_classifier(model, optimizer, batch)

    # test fit_model
    num_steps = 2
    train_iter = make_iterator_from_batch(batch)
    test_iter = make_iterator_from_batch(batch)
    params_init = params
    params_new, history = fit_model(model, rng, num_steps, train_iter,
                                    test_iter)

    diff = tree_util.tree_multimap(lambda x, y: x - y, params_init, params_new)
    print(diff)
    norm = l2norm_sq(diff)
    assert norm > 0  # check that parameters have changed :)

    print(history)
    print('test passed')
Exemple #17
0
 def test_momentum_with_weight_norm(self):
   params = onp.ones((2, 2)) * 2.
   optimizer_def = optim.WeightNorm(optim.Momentum(0.1))
   state = optimizer_def.init_state(params)
   self.assertEqual(jax.tree_map(onp.shape, state), optim.OptimizerState(
       step=(),
       param_states=_WeightNormParamState(
           direction_state=_MomentumParamState(momentum=(2, 2)),
           scale_state=_MomentumParamState(momentum=(1, 2)),
           mult=(1, 2)
       )
   ))
   grads = onp.ones((2, 2))
   new_params, new_state = optimizer_def.apply_gradient(
       optimizer_def.hyper_params, params, state, grads)
   onp.testing.assert_allclose(new_params, onp.full_like(params, 1.9))
   onp.testing.assert_allclose(new_state.param_states.mult, 1.9 * 2 ** 0.5)
Exemple #18
0
def create_train_state(rng, config: ml_collections.ConfigDict,
                       model, image_size):
  """Create initial training state."""
  dynamic_scale = None
  platform = jax.local_devices()[0].platform
  if config.half_precision and platform == 'gpu':
    dynamic_scale = optim.DynamicScale()
  else:
    dynamic_scale = None

  params, model_state = initialized(rng, image_size, model)
  optimizer = optim.Momentum(
      beta=config.momentum, nesterov=True).create(params)
  state = TrainState(
      step=0, optimizer=optimizer, model_state=model_state,
      dynamic_scale=dynamic_scale)
  return state
def create_optimizer(model,
                     learning_rate,
                     beta = 0.9):
  """Creates an SGD (Nesterov momentum) optimizer.

  Learning rate will be ignored when using a learning rate schedule.

  Args:
    model: The FLAX model to optimize.
    learning_rate: Learning rate for the gradient descent.
    beta: Momentum parameter.

  Returns:
    A SGD optimizer that targets the model.
  """
  optimizer_def = optim.Momentum(learning_rate=learning_rate,
                                 beta=beta,
                                 nesterov=True)
  optimizer = optimizer_def.create(model)
  return optimizer
def create_optimizer(config, model, learning_rate, train_size, sampler_rng):
  """Create optimizer definition based on config flags."""
  if config.optimizer == 'adam':
    optimizer_def = optim.Adam(
        learning_rate=learning_rate, beta1=config.momentum)
  elif config.optimizer == 'momentum':
    optimizer_def = optim.Momentum(
        learning_rate=learning_rate, beta=config.momentum)
  elif config.optimizer == 'sym_euler':
    optimizer_def = sym_euler_sgmcmc.SymEulerSGMCMC(
        train_size,
        sampler_rng,
        learning_rate=learning_rate,
        beta=config.momentum,
        temperature=config.base_temp,
        step_size_factor=1.)
  else:
    raise ValueError('Invalid value %s for config.optimizer.' %
                     config.optimizer)

  if config.weight_norm == 'none':
    pass
  elif config.weight_norm == 'learned':
    optimizer_def = optim.WeightNorm(optimizer_def)
  elif config.weight_norm in ['fixed', 'ws_sqrt', 'learned_b', 'ws']:
    # Applied in layers directly.
    pass
  else:
    raise ValueError('Invalid value %s for config.weight_norm.' %
                     config.weight_norm)

  optimizer = optimizer_def.create(model)

  if not config.debug_run:
    optimizer = optimizer.replicate()
  return optimizer
Exemple #21
0
def test():
  # We just check we can run the functions and that they return "something"
  print('testing fit-flax')
  N = 3; D = 5; C = 10;
  model = ModelTest(nhidden = 0, nclasses = C)
  rng = jax.random.PRNGKey(0)
  X = np.random.randn(N,D)
  y = np.random.choice(C, size=N, p=(1/C)*np.ones(C));
  batch = {'X': X, 'y': y}
  params = model.init(rng, X)['params']

  logprobs = model.apply({'params': params}, batch['X'])
  assert logprobs.shape==(N,C)
  labels = batch['y']
  loss = softmax_cross_entropy(logprobs, labels)
  assert loss.shape==()

  metrics = eval_batch(model, params, batch)
  assert np.allclose(loss, metrics['loss'])

  make_optimizer = optim.Momentum(learning_rate=0.1, beta=0.9)
  optimizer = make_optimizer.create(params)
  optimizer, metrics = train_batch(model, optimizer, batch)
  num_steps = 2
  train_iter = make_iterator_from_batch(batch);
  test_iter = make_iterator_from_batch(batch);
  params_init = params

  params_new, history =  fit_model(model, train_iter, test_iter,  rng,
      num_steps, make_optimizer, train_batch, eval_batch,
      print_every=1)
  diff = tree_util.tree_multimap(lambda x,y: x-y, params_init, params_new)
  diff_max = tree_util.tree_map(lambda x: jnp.max(x), diff)
  assert jnp.abs(diff_max['Dense_0']['kernel']) > 0 # has changed 

  print('test passed')
Exemple #22
0
def create_optimizer(params, learning_rate, beta):
    optimizer_def = optim.Momentum(learning_rate=learning_rate, beta=beta)
    optimizer = optimizer_def.create(params)
    return optimizer
def experiment(model_dir='.',  # pylint: disable=dangerous-default-value
               imagenet_subset_dir=None,
               dataset='cifar10',
               batch_size=256,
               eval_batch_size=1024,
               num_epochs=200,
               learning_rate=0.1,
               aug_imagenet_apply_colour_jitter=False,
               aug_imagenet_greyscale_prob=0.0,
               sgd_momentum=0.9,
               sgd_nesterov=True,
               lr_schedule='stepped',
               lr_sched_steps=[[60, 0.2], [120, 0.04], [160, 0.008]],
               lr_sched_halfcoslength=400.0,
               lr_sched_warmup=5.0,
               l2_reg=0.0005,
               weight_decay=0.0,
               architecture='wrn22_10',
               n_val=5000,
               n_sup=1000,
               teacher_alpha=0.999,
               anneal_teacher_alpha=False,
               unsupervised_regularizer='none',
               cons_weight=1.0,
               conf_thresh=0.97,
               conf_avg=False,
               cut_backg_noise=1.0,
               cut_prob=1.0,
               box_reg_scale_mode='fixed',
               box_reg_scale=0.25,
               box_reg_random_aspect_ratio=False,
               cow_sigma_range=(4.0, 8.0),
               cow_prop_range=(0.25, 1.0),
               mix_regularizer='none',
               mix_aug_separately=False,
               mix_logits=True,
               mix_weight=1.0,
               mix_conf_thresh=0.97,
               mix_conf_avg=True,
               mix_conf_mode='mix_prob',
               ict_alpha=0.1,
               mix_box_reg_scale_mode='fixed',
               mix_box_reg_scale=0.25,
               mix_box_reg_random_aspect_ratio=False,
               mix_cow_sigma_range=(4.0, 8.0),
               mix_cow_prop_range=(0.0, 1.0),
               subset_seed=12345,
               val_seed=131,
               run_seed=None,
               log_fn=print,
               checkpoints='on',
               on_epoch_finished_fn=None,
               debug=False):
  """Run experiment."""
  if checkpoints not in {'none', 'on', 'retain'}:
    raise ValueError('checkpoints should be one of (none|on|retain)')

  if checkpoints != 'none':
    checkpoint_path = os.path.join(model_dir, 'checkpoint.pkl')
    checkpoint_new_path = os.path.join(model_dir, 'checkpoint.pkl.new')
  else:
    checkpoint_path = None
    checkpoint_new_path = None

  if dataset not in {'svhn', 'cifar10', 'cifar100', 'imagenet'}:
    raise ValueError('Unknown dataset \'{}\''.format(dataset))

  if architecture not in {'wrn20_10', 'wrn26_10', 'wrn26_2',
                          'wrn20_6_shakeshake', 'wrn26_6_shakeshake',
                          'wrn26_2_shakeshake', 'pyramid',
                          'resnet50', 'resnet101', 'resnet152',
                          'resnet50x2', 'resnet101x2', 'resnet152x2',
                          'resnet50x4', 'resnet101x4', 'resnet152x4',
                          'resnext50_32x4d', 'resnext101_32x8d',
                          'resnext152_32x4d'}:
    raise ValueError('Unknown architecture \'{}\''.format(architecture))

  if lr_schedule not in {'constant', 'stepped', 'cosine'}:
    raise ValueError('Unknown LR schedule \'{}\''.format(lr_schedule))

  if mix_conf_mode not in {'mix_prob', 'mix_conf'}:
    raise ValueError('Unknown mix_conf_mode \'{}\''.format(mix_conf_mode))

  if jax.host_id() == 0:
    summary_writer = tensorboard.SummaryWriter(model_dir)
  else:
    summary_writer = None

  unsup_reg, augment_twice = build_pert_reg(
      unsupervised_regularizer, cut_backg_noise=cut_backg_noise,
      cut_prob=cut_prob, box_reg_scale_mode=box_reg_scale_mode,
      box_reg_scale=box_reg_scale,
      box_reg_random_aspect_ratio=box_reg_random_aspect_ratio,
      cow_sigma_range=cow_sigma_range, cow_prop_range=cow_prop_range)

  mix_reg = build_mix_reg(
      mix_regularizer, ict_alpha=ict_alpha,
      box_reg_scale_mode=mix_box_reg_scale_mode,
      box_reg_scale=mix_box_reg_scale,
      box_reg_random_aspect_ratio=mix_box_reg_random_aspect_ratio,
      cow_sigma_range=mix_cow_sigma_range, cow_prop_range=mix_cow_prop_range)

  if run_seed is None:
    run_seed = subset_seed << 32 | n_val
  train_rng = jax.random.PRNGKey(run_seed)
  init_rng, train_rng = jax.random.split(train_rng)

  if batch_size % jax.device_count() > 0:
    raise ValueError('Train batch size must be divisible by the number of '
                     'devices')
  if eval_batch_size % jax.device_count() > 0:
    raise ValueError('Eval batch size must be divisible by the number of '
                     'devices')
  local_batch_size = batch_size // jax.host_count()
  local_eval_batch_size = eval_batch_size // jax.host_count()
  device_batch_size = batch_size // jax.device_count()

  if dataset == 'svhn':
    image_size = 32
    top5_err_required = False
    data_source = small_image_data_source.SVHNDataSource(
        n_val=n_val, n_sup=n_sup, train_batch_size=local_batch_size,
        eval_batch_size=local_eval_batch_size,
        augment_twice=augment_twice, subset_seed=subset_seed,
        val_seed=val_seed)
  elif dataset == 'cifar10':
    image_size = 32
    top5_err_required = False
    data_source = small_image_data_source.CIFAR10DataSource(
        n_val=n_val, n_sup=n_sup, train_batch_size=local_batch_size,
        eval_batch_size=local_eval_batch_size, augment_twice=augment_twice,
        subset_seed=subset_seed, val_seed=val_seed)
  elif dataset == 'cifar100':
    image_size = 32
    top5_err_required = False
    data_source = small_image_data_source.CIFAR100DataSource(
        n_val=n_val, n_sup=n_sup, train_batch_size=local_batch_size,
        eval_batch_size=local_eval_batch_size, augment_twice=augment_twice,
        subset_seed=subset_seed, val_seed=val_seed)
  elif dataset == 'imagenet':
    image_size = 224
    top5_err_required = True
    if imagenet_subset_dir is None:
      raise ValueError('Please provide a directory to the imagenet_subset_dir '
                       'command line arg to specify where the ImageNet '
                       'subsets are stored')
    data_source = imagenet_data_source.ImageNetDataSource(
        imagenet_subset_dir, n_val, n_sup, local_batch_size,
        local_eval_batch_size, augment_twice,
        apply_colour_jitter=aug_imagenet_apply_colour_jitter,
        greyscale_prob=aug_imagenet_greyscale_prob,
        load_test_set=(n_val == 0), image_size=image_size,
        subset_seed=subset_seed, val_seed=val_seed)
  else:
    raise RuntimeError

  n_train = data_source.n_train
  train_ds = data_source.train_semisup_ds

  if n_val == 0:
    eval_ds = data_source.test_ds
    n_eval = data_source.n_test
  else:
    eval_ds = data_source.val_ds
    n_eval = data_source.n_val

  log_fn('DATA: |train|={}, |sup|={}, |eval|={}, (|val|={}, |test|={})'.format(
      data_source.n_train, data_source.n_sup, n_eval, data_source.n_val,
      data_source.n_test))

  log_fn('Loaded dataset')

  steps_per_epoch = n_train // batch_size
  steps_per_eval = n_eval // eval_batch_size
  if n_eval % eval_batch_size > 0:
    steps_per_eval += 1
  num_steps = steps_per_epoch * num_epochs

  # Create model
  model_stu, state_stu = create_model(
      init_rng, architecture, device_batch_size, image_size,
      data_source.n_classes)
  state_stu = jax_utils.replicate(state_stu)
  log_fn('Built model')

  # Create optimizer
  optimizer_def = optim.Momentum(learning_rate=learning_rate,
                                 beta=sgd_momentum,
                                 nesterov=sgd_nesterov)

  optimizer_stu = optimizer_def.create(model_stu)
  optimizer_stu = optimizer_stu.replicate()
  del model_stu  # don't keep a copy of the initial model

  # Create learning rate function
  base_learning_rate = learning_rate * batch_size / 256.
  if lr_schedule == 'constant':
    learning_rate_fn = create_constant_learning_rate_fn(base_learning_rate)
  elif lr_schedule == 'stepped':
    learning_rate_fn = create_stepped_learning_rate_fn(
        base_learning_rate, steps_per_epoch, lr_sched_steps=lr_sched_steps,
        warmup_length=lr_sched_warmup)
  elif lr_schedule == 'cosine':
    learning_rate_fn = create_cosine_learning_rate_fn(
        base_learning_rate, steps_per_epoch,
        halfcoslength_epochs=lr_sched_halfcoslength,
        warmup_length=lr_sched_warmup)
  else:
    raise RuntimeError

  if anneal_teacher_alpha:
    if lr_schedule == 'constant':
      one_minus_alpha_fn = create_constant_learning_rate_fn(1.0 - teacher_alpha)
    elif lr_schedule == 'stepped':
      one_minus_alpha_fn = create_stepped_learning_rate_fn(
          1.0 - teacher_alpha, steps_per_epoch, lr_sched_steps=lr_sched_steps)
    elif lr_schedule == 'cosine':
      one_minus_alpha_fn = create_cosine_learning_rate_fn(
          1.0 - teacher_alpha, steps_per_epoch,
          halfcoslength_epochs=lr_sched_halfcoslength)
    else:
      raise RuntimeError
    teacher_alpha_fn = lambda step: 1.0 - one_minus_alpha_fn(step)
  else:
    teacher_alpha_fn = lambda step: teacher_alpha

  log_fn('Built optimizer')

  # Teacher model is just the student as we duplicate it when we modify it
  model_tea = optimizer_stu.target
  # Replicate batch stats
  state_tea = jax.tree_map(lambda x: x, state_stu)

  # Set up epoch and step counter
  # Load existing checkpoint if available
  epoch = 1
  step = 0

  if checkpoints != 'none':
    if tf.io.gfile.exists(checkpoint_path):
      with tf.io.gfile.GFile(checkpoint_path, 'rb') as f_in:
        check = pickle.load(f_in)

        # Student optimizer and batch stats
        optimizer_stu = util.restore_state_list(
            optimizer_stu, check['optimizer_stu'])

        state_stu = util.restore_state_list(
            state_stu, check['state_stu'])

        # Teacher model and batch stats
        model_tea = util.restore_state_list(
            model_tea, check['model_tea'])

        state_tea = util.restore_state_list(
            state_tea, check['state_tea'])

        epoch = check['epoch']
        step = check['step']

        log_fn('Loaded checkpoint from {}'.format(checkpoint_path))

  #
  # Training and evaluation step functions
  #
  p_train_step = jax.pmap(
      functools.partial(train_step, learning_rate_fn=learning_rate_fn,
                        l2_reg=l2_reg, weight_decay=weight_decay,
                        teacher_alpha_fn=teacher_alpha_fn,
                        unsup_reg=unsup_reg, cons_weight=cons_weight,
                        conf_thresh=conf_thresh,
                        conf_avg=conf_avg,
                        mix_reg=mix_reg, mix_aug_separately=mix_aug_separately,
                        mix_logits=mix_logits, mix_weight=mix_weight,
                        mix_conf_thresh=mix_conf_thresh,
                        mix_conf_avg=mix_conf_avg,
                        mix_conf_mode=mix_conf_mode),
      axis_name='batch')
  p_eval_step = jax.pmap(
      functools.partial(eval_step, eval_top_5=top5_err_required),
      axis_name='batch')

  # Create dataset batch iterators
  train_iter = iter(train_ds)
  eval_iter = iter(eval_ds)

  #
  # Training loop
  #

  log_fn('Training...')
  epoch_metrics_stu = []
  t1 = time.time()
  while step < num_steps:
    train_rng, iter_rng = jax.random.split(train_rng)
    batch = next(train_iter)
    batch = jax.tree_map(lambda x: x._numpy(), batch)  # pylint: disable=protected-access
    batch = shard(batch, iter_rng)

    optimizer_stu, state_stu, metrics_stu, model_tea, state_tea = p_train_step(
        optimizer_stu, state_stu, model_tea, state_tea, batch)

    if debug:
      log_fn('Step {} time {}'.format(step, time.time()-t1))

    epoch_metrics_stu.append(metrics_stu)
    if (step + 1) % steps_per_epoch == 0:
      epoch_metrics_stu = util.get_metrics(epoch_metrics_stu)
      train_epoch_metrics = jax.tree_map(lambda x: x.mean(), epoch_metrics_stu)
      if summary_writer is not None:
        for key, vals in epoch_metrics_stu.items():
          tag = 'train_%s' % key
          for i, val in enumerate(vals):
            summary_writer.scalar(tag, val, step - len(vals) + i + 1)

      epoch_metrics_stu = []
      eval_stu_metrics = []
      eval_tea_metrics = []
      for _ in range(steps_per_eval):
        eval_batch = next(eval_iter)
        # TF to NumPy
        eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
        # Pad short batches
        eval_batch = util.pad_classification_batch(
            eval_batch, local_eval_batch_size)
        # Shard across local devices
        eval_batch = shard(eval_batch)
        metrics_stu = p_eval_step(optimizer_stu.target, state_stu, eval_batch)
        metrics_tea = p_eval_step(model_tea, state_tea, eval_batch)
        eval_stu_metrics.append(metrics_stu)
        eval_tea_metrics.append(metrics_tea)
      eval_stu_metrics = util.get_metrics(eval_stu_metrics)
      eval_tea_metrics = util.get_metrics(eval_tea_metrics)
      eval_stu_epoch_metrics = jax.tree_map(lambda x: x.sum(), eval_stu_metrics)
      eval_tea_epoch_metrics = jax.tree_map(lambda x: x.sum(), eval_tea_metrics)
      eval_stu_epoch_metrics = avg_eval_metrics(eval_stu_epoch_metrics)
      eval_tea_epoch_metrics = avg_eval_metrics(eval_tea_epoch_metrics)

      t2 = time.time()

      if top5_err_required:
        log_fn('EPOCH {} (took {:.3f}s): Train loss={:.6f}, err={:.3%}, '
               'cons loss={:.6f}, conf rate={:.3%}, mix loss={:.6f}, '
               'mix conf rate={:.3%}; STU Eval loss={:.6f}, err={:.3%}, '
               'top-5-err={:.3%}, TEA Eval loss={:.6f}, err={:.3%}, '
               'top-5-err={:.3%}'.format(
                   epoch, t2 - t1, train_epoch_metrics['loss'],
                   train_epoch_metrics['error_rate'],
                   train_epoch_metrics['cons_loss'],
                   train_epoch_metrics['conf_rate'],
                   train_epoch_metrics['mix_loss'],
                   train_epoch_metrics['mix_conf_rate'],
                   eval_stu_epoch_metrics['loss'],
                   eval_stu_epoch_metrics['error_rate'],
                   eval_stu_epoch_metrics['top5_error_rate'],
                   eval_tea_epoch_metrics['loss'],
                   eval_tea_epoch_metrics['error_rate'],
                   eval_tea_epoch_metrics['top5_error_rate'],))
      else:
        log_fn('EPOCH {} (took {:.3f}s): Train loss={:.6f}, err={:.3%}, '
               'cons loss={:.6f}, conf rate={:.3%}, mix loss={:.6f}, '
               'mix conf rate={:.3%}; STU Eval loss={:.6f}, err={:.3%}, '
               'TEA Eval loss={:.6f}, err={:.3%}'.format(
                   epoch, t2 - t1, train_epoch_metrics['loss'],
                   train_epoch_metrics['error_rate'],
                   train_epoch_metrics['cons_loss'],
                   train_epoch_metrics['conf_rate'],
                   train_epoch_metrics['mix_loss'],
                   train_epoch_metrics['mix_conf_rate'],
                   eval_stu_epoch_metrics['loss'],
                   eval_stu_epoch_metrics['error_rate'],
                   eval_tea_epoch_metrics['loss'],
                   eval_tea_epoch_metrics['error_rate'],))

      if on_epoch_finished_fn is not None:
        if top5_err_required:
          on_epoch_finished_fn(
              epoch,
              eval_stu_err=eval_stu_epoch_metrics['error_rate'],
              eval_tea_err=eval_tea_epoch_metrics['error_rate'],
              eval_stu_top5_err=eval_stu_epoch_metrics['top5_error_rate'],
              eval_tea_top5_err=eval_tea_epoch_metrics['top5_error_rate'],
          )
        else:
          on_epoch_finished_fn(
              epoch,
              eval_stu_err=eval_stu_epoch_metrics['error_rate'],
              eval_tea_err=eval_tea_epoch_metrics['error_rate'],
          )

      t1 = t2

      if summary_writer is not None:
        summary_writer.scalar(
            'eval_stu_loss', eval_stu_epoch_metrics['loss'], epoch)
        summary_writer.scalar(
            'eval_stu_error_rate', eval_stu_epoch_metrics['error_rate'], epoch)
        summary_writer.scalar(
            'eval_tea_loss', eval_tea_epoch_metrics['loss'], epoch)
        summary_writer.scalar(
            'eval_tea_error_rate', eval_tea_epoch_metrics['error_rate'], epoch)
        if top5_err_required:
          summary_writer.scalar(
              'eval_stu_top5_error_rate',
              eval_stu_epoch_metrics['top5_error_rate'], epoch)
          summary_writer.scalar(
              'eval_tea_top5_error_rate',
              eval_tea_epoch_metrics['top5_error_rate'], epoch)
        summary_writer.flush()

        epoch += 1

        if checkpoints != 'none':
          if jax.host_id() == 0:
            # Write to new checkpoint file so that we don't immediately
            # overwrite the old one
            with tf.io.gfile.GFile(checkpoint_new_path, 'wb') as f_out:
              check = dict(
                  optimizer_stu=util.to_state_list(optimizer_stu),
                  state_stu=util.to_state_list(state_stu),
                  model_tea=util.to_state_list(model_tea),
                  state_tea=util.to_state_list(state_tea),
                  epoch=epoch,
                  step=step + 1,
              )
              pickle.dump(check, f_out)
              del check
            # Remove old checkpoint and rename
            if tf.io.gfile.exists(checkpoint_path):
              tf.io.gfile.remove(checkpoint_path)
            tf.io.gfile.rename(checkpoint_new_path, checkpoint_path)

    step += 1

  if checkpoints == 'on':
    if jax.host_id() == 0:
      if tf.io.gfile.exists(checkpoint_path):
        tf.io.gfile.remove(checkpoint_path)
Exemple #24
0
def train(module,
          model_dir,
          batch_size,
          eval_batch_size,
          num_moco_epochs,
          num_clf_epochs,
          moco_learning_rate,
          clf_learning_rate,
          sgd_momentum=0.9,
          sgd_nesterov=True,
          make_moco_lr_fun=None,
          make_clf_lr_fun=None,
          moco_l2_reg=0.0001,
          clf_l2_reg=0.0,
          feature_size=64 * 8 * 4,
          moco_momentum=0.999,
          emb_size=128,
          moco_temperature=0.07,
          dictionary_size=65536,
          run_seed=0,
          steps_per_epoch=None,
          steps_per_eval=None):
    """Train MoCo model."""
    if make_moco_lr_fun is None:

        def make_moco_lr_fun(base_lr, steps_per_epoch):  # pylint: disable=function-redefined
            return lr_schedule.create_stepped_learning_rate_schedule(
                base_lr, steps_per_epoch, [[120, 0.1], [160, 0.01]])

    if make_clf_lr_fun is None:

        def make_clf_lr_fun(base_lr, steps_per_epoch):  # pylint: disable=function-redefined
            return lr_schedule.create_stepped_learning_rate_schedule(
                base_lr, steps_per_epoch, [[60, 0.2], [75, 0.04], [90, 0.008]])

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(model_dir)
    else:
        summary_writer = None

    #
    #
    # If using more than 1 host, warn the user
    #
    #

    if jax.host_count() > 1:
        logging.info(
            'WARNING: the all_to_all collective used by this program is '
            'not yet supported in multi-host environments')

    train_rng = jax.random.PRNGKey(run_seed)
    (init_moco_rng, init_clf_rng, init_dictionary_rng,
     train_rng) = jax.random.split(train_rng, num=4)

    if batch_size % jax.device_count() > 0:
        raise ValueError('Train batch size must be divisible by the number '
                         'of devices')
    if eval_batch_size % jax.device_count() > 0:
        raise ValueError('Eval batch size must be divisible by the number '
                         'of devices')
    local_batch_size = batch_size // jax.host_count()
    local_eval_batch_size = eval_batch_size // jax.host_count()
    n_devices = jax.device_count()
    n_local_devices = jax.local_device_count()

    device_batch_size = batch_size // n_devices

    image_size = 224
    data_source = imagenet_data_source.load_imagenet(
        train_batch_size=local_batch_size,
        eval_batch_size=local_eval_batch_size,
        greyscale_prob=0.1)

    n_train = data_source.n_train
    train_moco_ds = data_source.train_moco_ds
    train_clf_ds = data_source.train_clf_ds
    eval_ds = data_source.test_ds
    n_eval = data_source.n_test

    logging.info('DATA: |train|=%d, |eval|=%d', data_source.n_train, n_eval)

    if steps_per_epoch is None:
        steps_per_epoch = n_train // batch_size
    if steps_per_eval is None:
        steps_per_eval = n_eval // eval_batch_size
    num_moco_steps = steps_per_epoch * num_moco_epochs
    num_clf_steps = steps_per_epoch * num_clf_epochs

    logging.info('Loaded dataset')

    #
    # Create query model
    #
    model_query, state_query = create_model(init_moco_rng, device_batch_size,
                                            image_size, module)
    state_query = jax_utils.replicate(state_query)

    # Create linear classifier
    feat_model_clf = create_linear_classifier(init_clf_rng, device_batch_size,
                                              feature_size,
                                              data_source.n_classes)

    # Randomly initialise dictionary
    moco_dictionary = jax.random.normal(init_dictionary_rng,
                                        (dictionary_size, emb_size),
                                        dtype=jnp.float32)
    moco_dictionary = normalize_embeddings(moco_dictionary)
    logging.info('Built model')

    #
    # Create optimizer
    #

    optimizer_def = optim.Momentum(learning_rate=moco_learning_rate,
                                   beta=sgd_momentum,
                                   nesterov=sgd_nesterov)
    optimizer_query = optimizer_def.create(model_query)
    optimizer_query = optimizer_query.replicate()
    del model_query  # don't keep a copy of the initial model

    feat_clf_optimizer_def = optim.Momentum(learning_rate=clf_learning_rate,
                                            beta=sgd_momentum,
                                            nesterov=sgd_nesterov)
    feat_clf_optimizer = feat_clf_optimizer_def.create(feat_model_clf)
    feat_clf_optimizer = feat_clf_optimizer.replicate()
    logging.info('Built optimizer')

    #
    # Learning rate schedule
    #

    base_moco_learning_rate = moco_learning_rate * batch_size / 256.
    base_clf_learning_rate = clf_learning_rate * batch_size / 256.
    moco_learning_rate_fn = make_moco_lr_fun(base_moco_learning_rate,
                                             steps_per_epoch)
    clf_learning_rate_fn = make_clf_lr_fun(base_clf_learning_rate,
                                           steps_per_epoch)

    # The key model is a replica of the query model. Since Flax models are
    # immutable, we can start with the query model
    model_key = optimizer_query.target
    # Replicate batch stats
    state_key = jax.tree_map(lambda x: x, state_query)

    # Set up epoch and step counter
    # Load existing checkpoint if available
    moco_epoch = 1
    clf_epoch = 1
    moco_step = 0
    clf_step = 0

    #
    # Training and eval functions
    #
    p_moco_key_step = jax.pmap(functools.partial(moco_key_step),
                               axis_name='batch')
    p_moco_train_step = jax.pmap(functools.partial(
        moco_train_step,
        n_devices=n_devices,
        moco_temperature=moco_temperature,
        learning_rate_fn=moco_learning_rate_fn,
        l2_reg=moco_l2_reg,
        moco_momentum=moco_momentum),
                                 axis_name='batch')
    p_classifier_train_step = jax.pmap(functools.partial(
        classifier_train_step,
        learning_rate_fn=clf_learning_rate_fn,
        l2_reg=clf_l2_reg),
                                       axis_name='batch')
    p_eval_step = jax.pmap(functools.partial(eval_step), axis_name='batch')

    # Create MoCo dataset batch iterator
    train_moco_it = iter(train_moco_ds)

    #
    # Training loop
    #

    logging.info('Training MoCo...')

    epoch_metrics_moco = []
    t1 = time.time()
    while moco_step < num_moco_steps:
        (train_rng, shuffle_rng) = jax.random.split(train_rng, num=2)

        batch = next(train_moco_it)
        # TF to NumPy
        batch = jax.tree_map(lambda x: x._numpy(), batch)  # pylint: disable=protected-access

        # Compute key embeddings
        # We have to shuffle the batch to prevent the network from cheating using
        # batch stats
        shuffle_forward = jax.random.shuffle(shuffle_rng,
                                             jnp.arange(local_batch_size))
        shuffle_backward = jnp.zeros((local_batch_size, ), dtype=int)
        shuffle_backward = jax.ops.index_update(shuffle_backward,
                                                shuffle_forward,
                                                jnp.arange(local_batch_size))

        key_batch = dict(x_key=batch['key_image'][shuffle_forward, Ellipsis])
        key_batch_sharded = common_utils.shard(key_batch)
        emb_key, state_key = p_moco_key_step(model_key, state_key,
                                             key_batch_sharded)
        emb_key = emb_key.reshape((-1, emb_size))
        emb_key = emb_key[shuffle_backward, Ellipsis]

        #
        # Main MoCo training step
        #
        moco_batch = batch.copy()
        moco_batch['emb_key'] = emb_key
        sharded_moco_batch = common_utils.shard(moco_batch)

        # Repeat the MoCo dictionary across shards
        sharded_dict = jnp.repeat(moco_dictionary[None, Ellipsis],
                                  n_local_devices,
                                  axis=0)

        # The main train step function is applied slightly differently in
        # multi-host environments
        optimizer_query, state_query, metrics_moco, model_key, code_batch = \
            p_moco_train_step(optimizer_query, state_query, model_key,
                              sharded_moco_batch, sharded_dict)
        code_batch = code_batch[0].reshape((-1, emb_size))

        moco_dictionary = jnp.append(code_batch, moco_dictionary,
                                     axis=0)[:dictionary_size]

        epoch_metrics_moco.append(metrics_moco)
        if (moco_step + 1) % steps_per_epoch == 0:
            epoch_metrics_moco = common_utils.get_metrics(epoch_metrics_moco)
            train_epoch_metrics = jax.tree_map(lambda x: x.mean(),
                                               epoch_metrics_moco)
            if summary_writer is not None:
                for key, vals in epoch_metrics_moco.items():
                    tag = 'train_%s' % key
                    for i, val in enumerate(vals):
                        summary_writer.scalar(tag, val,
                                              moco_step - len(vals) + i + 1)

            epoch_metrics_moco = []

            t2 = time.time()

            logging.info('MoCo EPOCH %d: (took %.3fs): MoCo loss=%.6f',
                         moco_epoch, t2 - t1, train_epoch_metrics['moco_loss'])

            t1 = t2

            if summary_writer is not None:
                summary_writer.flush()

            moco_epoch += 1

        moco_step += 1

    del train_moco_it

    #
    #
    # Unsupervised MoCo training complete
    # Train classifier
    #
    #

    logging.info('Training Linear Classifier...')

    train_clf_it = iter(train_clf_ds)
    eval_iter = iter(eval_ds)

    epoch_feat_metrics = []
    t1 = time.time()
    while clf_step < num_clf_steps:
        batch = next(train_clf_it)
        # TF to NumPy
        batch = jax.tree_map(lambda x: x._numpy(), batch)  # pylint: disable=protected-access
        batch = common_utils.shard(batch)

        feat_clf_optimizer, feat_metrics = p_classifier_train_step(
            feat_clf_optimizer, model_key, state_key, batch)

        epoch_feat_metrics.append(feat_metrics)
        if (clf_step + 1) % steps_per_epoch == 0:
            epoch_feat_metrics = common_utils.get_metrics(epoch_feat_metrics)
            train_epoch_feat_metrics = jax.tree_map(lambda x: x.mean(),
                                                    epoch_feat_metrics)
            if summary_writer is not None:
                for key, vals in epoch_feat_metrics.items():
                    tag = 'train_feat_%s' % key
                    for i, val in enumerate(vals):
                        summary_writer.scalar(tag, val,
                                              clf_step - len(vals) + i + 1)

            epoch_feat_metrics = []
            eval_feat_metrics = []
            for _ in range(steps_per_eval):
                eval_batch = next(eval_iter)
                # TF to NumPy
                eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
                # Shard across local devices
                eval_batch = common_utils.shard(eval_batch)
                feat_metrics = p_eval_step(model_key, state_key,
                                           feat_clf_optimizer.target,
                                           eval_batch)
                eval_feat_metrics.append(feat_metrics)
            eval_feat_metrics = common_utils.get_metrics(eval_feat_metrics)
            eval_epoch_feat_metrics = jax.tree_map(lambda x: x.mean(),
                                                   eval_feat_metrics)

            t2 = time.time()

            logging.info(
                'Linear classifier EPOCH %d: (took %.3fs): TRAIN FEAT loss=%.6f, '
                'err=%.3f; EVAL FEAT loss=%.6f, err=%.3f',
                clf_epoch,
                t2 - t1,
                train_epoch_feat_metrics['loss'],
                train_epoch_feat_metrics['error_rate'] * 100.0,
                eval_epoch_feat_metrics['loss'],
                eval_epoch_feat_metrics['error_rate'] * 100.0,
            )

            t1 = t2

            if summary_writer is not None:
                summary_writer.scalar('eval_feat_loss',
                                      eval_epoch_feat_metrics['loss'],
                                      clf_epoch)
                summary_writer.scalar('eval_feat_error_rate',
                                      eval_epoch_feat_metrics['error_rate'],
                                      clf_epoch)
                summary_writer.flush()

            clf_epoch += 1

        clf_step += 1

    return eval_epoch_feat_metrics
Exemple #25
0
def create_optimizer(model, learning_rate, beta):
    optimizer_def = optim.Momentum(learning_rate=learning_rate, beta=beta)
    optimizer = optimizer_def.create(model)
    return optimizer
Exemple #26
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  tf.enable_v2_behavior()
  # make sure tf does not allocate gpu memory
  tf.config.experimental.set_visible_devices([], 'GPU')

  if jax.host_id() == 0:
    summary_writer = tensorboard.SummaryWriter(FLAGS.model_dir)

  image_size = 224

  batch_size = FLAGS.batch_size
  if batch_size % jax.device_count() > 0:
    raise ValueError('Batch size must be divisible by the number of devices')
  local_batch_size = batch_size // jax.host_count()
  device_batch_size = batch_size // jax.device_count()

  platform = jax.local_devices()[0].platform

  dynamic_scale = None
  if FLAGS.half_precision:
    if platform == 'tpu':
      model_dtype = jnp.bfloat16
      input_dtype = tf.bfloat16
    else:
      model_dtype = jnp.float16
      input_dtype = tf.float16
      dynamic_scale = optim.DynamicScale()
  else:
    model_dtype = jnp.float32
    input_dtype = tf.float32

  train_iter = imagenet_train_utils.create_input_iter(
      local_batch_size,
      FLAGS.data_dir,
      image_size,
      input_dtype,
      train=True,
      cache=FLAGS.cache)
  eval_iter = imagenet_train_utils.create_input_iter(
      local_batch_size,
      FLAGS.data_dir,
      image_size,
      input_dtype,
      train=False,
      cache=FLAGS.cache)

  # Create the hyperparameter object
  if FLAGS.hparams_config_dict:
    # In this case, there are multiple training configs defined in the config
    # dict, so we pull out the one this training run should use.
    if 'configs' in FLAGS.hparams_config_dict:
      hparams_config_dict = FLAGS.hparams_config_dict.configs[FLAGS.config_idx]
    else:
      hparams_config_dict = FLAGS.hparams_config_dict
    hparams = os_hparams_utils.load_hparams_from_config_dict(
        hparams_config.TrainingHParams, models.ResNet.HParams,
        hparams_config_dict)
  else:
    raise ValueError('Please provide a base config dict.')

  os_hparams_utils.write_hparams_to_file_with_host_id_check(
      hparams, FLAGS.model_dir)

  # get num_epochs from hparam instead of FLAGS
  num_epochs = hparams.lr_scheduler.num_epochs
  steps_per_epoch = input_pipeline.TRAIN_IMAGES // batch_size
  steps_per_eval = input_pipeline.EVAL_IMAGES // batch_size
  steps_per_checkpoint = steps_per_epoch * 10
  num_steps = steps_per_epoch * num_epochs

  # Estimate compute / memory costs
  if jax.host_id() == 0 and FLAGS.estimate_compute_and_memory_cost:
    estimate_compute_and_memory_cost(
        image_size=image_size, model_dir=FLAGS.model_dir, hparams=hparams)
    logging.info('Writing training HLO and estimating compute/memory costs.')

  rng = random.PRNGKey(hparams.seed)
  model, variables = imagenet_train_utils.create_model(
      rng,
      device_batch_size,
      image_size,
      model_dtype,
      hparams=hparams.model_hparams,
      train=True,
      is_teacher=hparams.is_teacher)

  # pylint: disable=g-long-lambda
  if hparams.teacher_model == 'resnet50-8bit':
    teacher_config = w8a8auto_paper_config()
    teacher_hparams = os_hparams_utils.load_hparams_from_config_dict(
        hparams_config.TrainingHParams, models.ResNet.HParams, teacher_config)
    teacher_model, _ = imagenet_train_utils.create_model(
        rng,
        device_batch_size,
        image_size,
        model_dtype,
        hparams=teacher_hparams.model_hparams,
        train=False,
        is_teacher=True)  # teacher model does not need to be trainable
    # Directory where checkpoints are saved
    ckpt_model_dir = FLAGS.resnet508b_ckpt_path
    # will restore to best checkpoint
    state_load = checkpoints.restore_checkpoint(ckpt_model_dir, None)
    teacher_variables = {'params': state_load['optimizer']['target']}
    teacher_variables.update(state_load['model_state'])
    # create a dictionary for better argument passing
    teacher = {
        'model':
            lambda var, img, labels: jax.nn.softmax(
                teacher_model.apply(var, img)),
        'variables':
            teacher_variables,
    }
  elif hparams.teacher_model == 'labels':
    teacher = {
        'model':
            lambda var, img, labels: common_utils.onehot(
                labels, num_classes=1000),
        'variables': {},  # no need of variables in this case
    }
  else:
    raise ValueError('The specified teacher model is not supported.')

  model_state, params = variables.pop('params')
  if hparams.optimizer == 'sgd':
    optimizer = optim.Momentum(
        beta=hparams.momentum, nesterov=True).create(params)
  elif hparams.optimizer == 'adam':
    optimizer = optim.Adam(
        beta1=hparams.adam.beta1, beta2=hparams.adam.beta2).create(params)
  else:
    raise ValueError('Optimizer type is not supported.')
  state = imagenet_train_utils.TrainState(
      step=0,
      optimizer=optimizer,
      model_state=model_state,
      dynamic_scale=dynamic_scale)
  del params, model_state  # do not keep a copy of the initial model

  state = restore_checkpoint(state)
  step_offset = int(state.step)  # step_offset > 0 if restarting from checkpoint
  state = jax_utils.replicate(state)

  base_learning_rate = hparams.base_learning_rate * batch_size / 256.
  learning_rate_fn = create_learning_rate_fn(base_learning_rate,
                                             steps_per_epoch,
                                             hparams.lr_scheduler,
                                             batch_size)

  p_train_step = jax.pmap(
      functools.partial(
          imagenet_train_utils.train_step,
          model,
          learning_rate_fn=learning_rate_fn,
          teacher=teacher),
      axis_name='batch',
      static_broadcasted_argnums=(2, 3, 4))
  p_eval_step = jax.pmap(
      functools.partial(imagenet_train_utils.eval_step, model),
      axis_name='batch',
      static_broadcasted_argnums=(2,))

  epoch_metrics = []
  state_dict_summary_all = []
  state_dict_keys = _get_state_dict_keys_from_flags()
  t_loop_start = time.time()
  last_log_step = 0
  for step, batch in zip(range(step_offset, num_steps), train_iter):
    if hparams.early_stop_steps >= 0 and step > hparams.early_stop_steps * steps_per_epoch:
      break
    update_bounds = train_utils.should_update_bounds(
        hparams.activation_bound_update_freq,
        hparams.activation_bound_start_step, step)
    # and pass the result bool value to p_train_step
    # The function should take hparams.weight_quant_start_step as inputs
    quantize_weights = train_utils.should_quantize_weights(
        hparams.weight_quant_start_step, step // steps_per_epoch)
    state, metrics = p_train_step(state, batch, hparams, update_bounds,
                                  quantize_weights)

    state_dict_summary = summary_utils.get_state_dict_summary(
        state.model_state, state_dict_keys)
    state_dict_summary_all.append(state_dict_summary)

    epoch_metrics.append(metrics)
    def should_log(step):
      epoch_no = step // steps_per_epoch
      step_in_epoch = step - epoch_no * steps_per_epoch
      do_log = False
      do_log = do_log or (step + 1 == num_steps)  # log at the end
      end_of_train = step / num_steps > 0.9
      do_log = do_log or ((step_in_epoch %
                           (steps_per_epoch // 4) == 0) and not end_of_train)
      do_log = do_log or ((step_in_epoch %
                           (steps_per_epoch // 16) == 0) and end_of_train)
      return do_log

    if should_log(step):
      epoch = step // steps_per_epoch
      epoch_metrics = common_utils.get_metrics(epoch_metrics)
      summary = jax.tree_map(lambda x: x.mean(), epoch_metrics)
      logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                   summary['loss'], summary['accuracy'] * 100)
      steps_per_sec = (step - last_log_step) / (time.time() - t_loop_start)
      last_log_step = step
      t_loop_start = time.time()

      # Write to TensorBoard
      state_dict_summary_all = common_utils.get_metrics(state_dict_summary_all)
      if jax.host_id() == 0:
        for key, vals in epoch_metrics.items():
          tag = 'train_%s' % key
          for i, val in enumerate(vals):
            summary_writer.scalar(tag, val, step - len(vals) + i + 1)
        summary_writer.scalar('steps per second', steps_per_sec, step)

        if FLAGS.write_summary:
          summary_utils.write_state_dict_summaries_to_tb(
              state_dict_summary_all, summary_writer,
              FLAGS.state_dict_summary_freq, step)

      state_dict_summary_all = []
      epoch_metrics = []
      eval_metrics = []

      # sync batch statistics across replicas
      state = imagenet_train_utils.sync_batch_stats(state)
      for _ in range(steps_per_eval):
        eval_batch = next(eval_iter)
        metrics = p_eval_step(state, eval_batch, quantize_weights)
        eval_metrics.append(metrics)
      eval_metrics = common_utils.get_metrics(eval_metrics)
      summary = jax.tree_map(lambda x: x.mean(), eval_metrics)
      logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                   summary['loss'], summary['accuracy'] * 100)
      if jax.host_id() == 0:
        for key, val in eval_metrics.items():
          tag = 'eval_%s' % key
          summary_writer.scalar(tag, val.mean(), step)
        summary_writer.flush()
    if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps:
      state = imagenet_train_utils.sync_batch_stats(state)
      save_checkpoint(state)

  # Wait until computations are done before exiting
  jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
Exemple #27
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.enable_v2_behavior()
    # make sure tf does not allocate gpu memory
    tf.config.experimental.set_visible_devices([], 'GPU')

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(FLAGS.model_dir)

    rng = random.PRNGKey(0)

    image_size = 224

    batch_size = FLAGS.batch_size
    if batch_size % jax.device_count() > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')
    local_batch_size = batch_size // jax.host_count()
    device_batch_size = batch_size // jax.device_count()

    platform = jax.local_devices()[0].platform

    dynamic_scale = None
    if FLAGS.half_precision:
        if platform == 'tpu':
            model_dtype = jnp.bfloat16
            input_dtype = tf.bfloat16
        else:
            model_dtype = jnp.float16
            input_dtype = tf.float16
            dynamic_scale = optim.DynamicScale()
    else:
        model_dtype = jnp.float32
        input_dtype = tf.float32

    train_iter = imagenet_train_utils.create_input_iter(local_batch_size,
                                                        FLAGS.data_dir,
                                                        image_size,
                                                        input_dtype,
                                                        train=True,
                                                        cache=FLAGS.cache)
    eval_iter = imagenet_train_utils.create_input_iter(local_batch_size,
                                                       FLAGS.data_dir,
                                                       image_size,
                                                       input_dtype,
                                                       train=False,
                                                       cache=FLAGS.cache)

    # Create the hyperparameter object
    if FLAGS.hparams_config_dict:
        # In this case, there are multiple training configs defined in the config
        # dict, so we pull out the one this training run should use.
        if 'configs' in FLAGS.hparams_config_dict:
            hparams_config_dict = FLAGS.hparams_config_dict.configs[
                FLAGS.config_idx]
        else:
            hparams_config_dict = FLAGS.hparams_config_dict
        hparams = os_hparams_utils.load_hparams_from_config_dict(
            hparams_config.TrainingHParams, models.ResNet.HParams,
            hparams_config_dict)
    else:
        raise ValueError('Please provide a base config dict.')

    os_hparams_utils.write_hparams_to_file_with_host_id_check(
        hparams, FLAGS.model_dir)

    # get num_epochs from hparam instead of FLAGS
    num_epochs = hparams.lr_scheduler.num_epochs
    steps_per_epoch = input_pipeline.TRAIN_IMAGES // batch_size
    steps_per_eval = input_pipeline.EVAL_IMAGES // batch_size
    steps_per_checkpoint = steps_per_epoch * 10
    num_steps = steps_per_epoch * num_epochs

    # Estimate compute / memory costs
    if jax.host_id() == 0:
        estimate_compute_and_memory_cost(image_size=image_size,
                                         model_dir=FLAGS.model_dir,
                                         hparams=hparams)
        logging.info(
            'Writing training HLO and estimating compute/memory costs.')

    model, variables = imagenet_train_utils.create_model(
        rng,
        device_batch_size,
        image_size,
        model_dtype,
        hparams=hparams.model_hparams,
        train=True)
    model_state, params = variables.pop('params')
    if hparams.optimizer == 'sgd':
        optimizer = optim.Momentum(beta=hparams.momentum,
                                   nesterov=True).create(params)
    elif hparams.optimizer == 'adam':
        optimizer = optim.Adam(beta1=hparams.adam.beta1,
                               beta2=hparams.adam.beta2).create(params)
    else:
        raise ValueError('Optimizer type is not supported.')
    state = imagenet_train_utils.TrainState(step=0,
                                            optimizer=optimizer,
                                            model_state=model_state,
                                            dynamic_scale=dynamic_scale)
    del params, model_state  # do not keep a copy of the initial model

    state = restore_checkpoint(state)
    step_offset = int(
        state.step)  # step_offset > 0 if restarting from checkpoint
    state = jax_utils.replicate(state)

    base_learning_rate = hparams.base_learning_rate * batch_size / 256.
    learning_rate_fn = create_learning_rate_fn(base_learning_rate,
                                               steps_per_epoch,
                                               hparams.lr_scheduler)

    p_train_step = jax.pmap(functools.partial(
        imagenet_train_utils.train_step,
        model,
        learning_rate_fn=learning_rate_fn),
                            axis_name='batch',
                            static_broadcasted_argnums=(2, 3))
    p_eval_step = jax.pmap(functools.partial(imagenet_train_utils.eval_step,
                                             model),
                           axis_name='batch')

    epoch_metrics = []
    state_dict_summary_all = []
    state_dict_keys = _get_state_dict_keys_from_flags()
    t_loop_start = time.time()
    for step, batch in zip(range(step_offset, num_steps), train_iter):
        if hparams.early_stop_steps >= 0 and step > hparams.early_stop_steps:
            break
        update_bounds = train_utils.should_update_bounds(
            hparams.activation_bound_update_freq,
            hparams.activation_bound_start_step, step)
        state, metrics = p_train_step(state, batch, hparams, update_bounds)

        state_dict_summary = summary_utils.get_state_dict_summary(
            state.model_state, state_dict_keys)
        state_dict_summary_all.append(state_dict_summary)

        epoch_metrics.append(metrics)
        if (step + 1) % steps_per_epoch == 0:
            epoch = step // steps_per_epoch
            epoch_metrics = common_utils.get_metrics(epoch_metrics)
            summary = jax.tree_map(lambda x: x.mean(), epoch_metrics)
            logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                         summary['loss'], summary['accuracy'] * 100)
            steps_per_sec = steps_per_epoch / (time.time() - t_loop_start)
            t_loop_start = time.time()

            # Write to TensorBoard
            state_dict_summary_all = common_utils.get_metrics(
                state_dict_summary_all)
            if jax.host_id() == 0:
                for key, vals in epoch_metrics.items():
                    tag = 'train_%s' % key
                    for i, val in enumerate(vals):
                        summary_writer.scalar(tag, val,
                                              step - len(vals) + i + 1)
                summary_writer.scalar('steps per second', steps_per_sec, step)

                summary_utils.write_state_dict_summaries_to_tb(
                    state_dict_summary_all, summary_writer,
                    FLAGS.state_dict_summary_freq, step)

            state_dict_summary_all = []
            epoch_metrics = []
            eval_metrics = []

            # sync batch statistics across replicas
            state = imagenet_train_utils.sync_batch_stats(state)
            for _ in range(steps_per_eval):
                eval_batch = next(eval_iter)
                metrics = p_eval_step(state, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            summary = jax.tree_map(lambda x: x.mean(), eval_metrics)
            logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                         summary['loss'], summary['accuracy'] * 100)
            if jax.host_id() == 0:
                for key, val in eval_metrics.items():
                    tag = 'eval_%s' % key
                    summary_writer.scalar(tag, val.mean(), step)
                summary_writer.flush()
        if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps:
            state = imagenet_train_utils.sync_batch_stats(state)
            save_checkpoint(state)

    # Wait until computations are done before exiting
    jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
def main(argv):
    del argv
    # BEGIN GOOGLE-INTERNAL
    xm.setup_work_unit()
    # END GOOGLE-INTERNAL

    tf.enable_v2_behavior()

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(FLAGS.output_dir)
        # Write summaries in background thread to avoid blocking on device sync
        summary_thread = thread.ThreadPoolExecutor(1, 'summary')
    if FLAGS.infeed:
        # Infeed is currently synchronous, so do it in a background thread too
        infeed_pool = thread.ThreadPoolExecutor(jax.local_device_count(),
                                                'infeed')

    rng = random.PRNGKey(0)

    image_size = 224

    batch_size = FLAGS.batch_size
    if batch_size is None:
        batch_size = min(128 * jax.device_count(), 32768)
    eval_batch_size = 128 * jax.device_count()
    local_batch_size = batch_size // jax.host_count()
    local_eval_batch_size = eval_batch_size // jax.host_count()
    device_batch_size = batch_size // jax.device_count()
    device_eval_batch_size = eval_batch_size // jax.device_count()
    device_last_eval_batch_size = (input_pipeline.EVAL_IMAGES %
                                   eval_batch_size) // jax.device_count()

    model_dtype = jnp.bfloat16 if FLAGS.bfloat16 else jnp.float32
    input_dtype = tf.bfloat16 if FLAGS.bfloat16 else tf.float32
    if FLAGS.transpose_images:
        train_input_shape = (224, 224, 3, device_batch_size)
        eval_input_shapes = [(224, 224, 3, bs)
                             for bs in (device_eval_batch_size,
                                        device_last_eval_batch_size)]
    else:
        train_input_shape = (device_batch_size, 224, 224, 3)
        eval_input_shapes = [(bs, 224, 224, 3)
                             for bs in (device_eval_batch_size,
                                        device_last_eval_batch_size)]

    num_epochs = FLAGS.num_epochs
    steps_per_epoch = input_pipeline.TRAIN_IMAGES / batch_size
    logging.info('steps_per_epoch: %f', steps_per_epoch)
    steps_per_eval = int(np.ceil(input_pipeline.EVAL_IMAGES / eval_batch_size))
    logging.info('steps_per_eval: %d', steps_per_eval)

    base_learning_rate = FLAGS.learning_rate * batch_size / 256.
    beta = FLAGS.momentum
    weight_decay = FLAGS.weight_decay

    logging.info('creating and initializing model and optimizer')
    model, state = create_model(rng, device_batch_size, image_size,
                                model_dtype)
    state = jax_utils.replicate(state)
    if FLAGS.lars:
        weight_opt_def = optim.LARS(base_learning_rate,
                                    beta,
                                    weight_decay=weight_decay)
        other_opt_def = optim.Momentum(base_learning_rate,
                                       beta,
                                       weight_decay=0,
                                       nesterov=False)
        learning_rate_fn = polynomial_learning_rate_fn(batch_size,
                                                       steps_per_epoch,
                                                       num_epochs)
    else:
        weight_opt_def = optim.Momentum(base_learning_rate,
                                        beta,
                                        weight_decay=weight_decay,
                                        nesterov=True)
        other_opt_def = optim.Momentum(base_learning_rate,
                                       beta,
                                       weight_decay=0,
                                       nesterov=True)
        learning_rate_fn = piecewise_learning_rate_fn(base_learning_rate,
                                                      steps_per_epoch,
                                                      num_epochs)

    def filter_weights(key, _):
        return 'bias' not in key and 'scale' not in key

    def filter_other(key, _):
        return 'bias' in key or 'scale' in key

    weight_traversal = optim.ModelParamTraversal(filter_weights)
    other_traversal = optim.ModelParamTraversal(filter_other)
    optimizer_def = optim.MultiOptimizer((weight_traversal, weight_opt_def),
                                         (other_traversal, other_opt_def))
    optimizer = optimizer_def.create(model)
    optimizer = optimizer.replicate()
    del model  # do not keep a copy of the initial model

    p_train_step = jax.pmap(partial(train_step,
                                    learning_rate_fn=learning_rate_fn),
                            axis_name='batch')
    p_eval_step = jax.pmap(eval_step, axis_name='batch')

    def device_train_loop_cond(args):
        _, _, _, _, step, epoch = args
        return step // steps_per_epoch == epoch

    def device_train_loop_body(args):
        optimizer, state, metrics, token, step, epoch = args
        (images, labels), token = lax.infeed(
            token,
            shape=(jax.ShapedArray(train_input_shape, model_dtype),
                   jax.ShapedArray((device_batch_size, ), jnp.int32)))
        batch = {'image': images, 'label': labels}
        optimizer, state, metrics = train_step(optimizer, state, batch,
                                               metrics, learning_rate_fn)
        step += 1
        return optimizer, state, metrics, token, step, epoch

    def device_train_loop(optimizer, state, metrics, step, epoch):
        token = lax.create_token(step)
        optimizer, state, metrics, _, step, _ = lax.while_loop(
            device_train_loop_cond, device_train_loop_body,
            (optimizer, state, metrics, token, step, epoch))
        return optimizer, state, metrics, step

    p_train_epoch = jax.pmap(device_train_loop, axis_name='batch')

    if FLAGS.precompile:
        logging.info('precompiling step/epoch functions')
        if FLAGS.infeed:
            # the device training loop condition will immediately be false
            p_train_epoch(optimizer, state, empty_metrics(),
                          jax_utils.replicate(0), jax_utils.replicate(1))
        else:
            batch = {
                'image':
                jnp.zeros((jax.local_device_count(), ) + train_input_shape,
                          model_dtype),
                'label':
                jnp.zeros((jax.local_device_count(), ) + (device_batch_size, ),
                          jnp.int32)
            }
            p_train_step(optimizer, state, batch, empty_metrics())
        for dbs, eis in zip(
            [device_eval_batch_size, device_last_eval_batch_size],
                eval_input_shapes):
            batch = {
                'image':
                jnp.zeros((jax.local_device_count(), ) + eis, model_dtype),
                'label':
                jnp.zeros((jax.local_device_count(), ) + (dbs, ), jnp.int32)
            }
            p_eval_step(optimizer.target, state, batch, empty_metrics())
        allreduce_metrics(empty_metrics())
        pmean = functools.partial(jax.lax.pmean, axis_name='batch')
        jax.pmap(pmean, axis_name='batch')(state)

    logging.info('constructing datasets')
    # pylint: disable=g-complex-comprehension
    train_ds, eval_ds = [
        input_pipeline.load_split(
            local_batch_size if train else local_eval_batch_size,
            image_size=image_size,
            dtype=input_dtype,
            train=train,
            transpose_images=FLAGS.transpose_images) for train in (True, False)
    ]
    # pylint: enable=g-complex-comprehension
    logging.info('constructing dataset iterators')
    train_iter = iter(train_ds)
    eval_iter = iter(eval_ds)

    logging.info('beginning training')
    host_step, device_step = 0, jax_utils.replicate(0)
    for epoch in range(num_epochs):
        device_epoch = jax_utils.replicate(epoch)
        metrics = empty_metrics()
        if FLAGS.infeed:
            optimizer, state, metrics, device_step = p_train_epoch(
                optimizer, state, metrics, device_step, device_epoch)
        while int(host_step // steps_per_epoch) == epoch:
            batch = jax.tree_map(lambda x: x._numpy(), next(train_iter))  # pylint: disable=protected-access
            if FLAGS.infeed:
                for i, device in enumerate(jax.local_devices()):
                    images, labels = batch['image'][i], batch['label'][i]
                    assert images.shape == train_input_shape and labels.dtype == jnp.int32
                    infeed_pool.submit(
                        partial(device.transfer_to_infeed, (images, labels)))
            else:
                optimizer, state, metrics = p_train_step(
                    optimizer, state, batch, metrics)
            host_step += 1
        if FLAGS.train_metrics:
            metrics = allreduce_metrics(metrics)
            if jax.host_id() == 0:
                summary_thread.submit(
                    partial(write_summary, summary_writer, metrics, 'train',
                            epoch + 1))
        if not FLAGS.distributed_batchnorm:  # otherwise it's already synced
            pmean = functools.partial(jax.lax.pmean, axis_name='batch')
            state = jax.pmap(pmean, axis_name='batch')(state)
        metrics = empty_metrics()
        for _ in range(steps_per_eval):
            batch = jax.tree_map(lambda x: x._numpy(), next(eval_iter))  # pylint: disable=protected-access
            metrics = p_eval_step(optimizer.target, state, batch, metrics)
        metrics = allreduce_metrics(metrics)
        if jax.host_id() == 0:
            summary_thread.submit(
                partial(write_summary, summary_writer, metrics, 'eval',
                        epoch + 1))
        # TODO(deveci): do something like this from the summary thread:
        # if summary['accuracy'] > TARGET_ACCURACY:
        #   break
    if jax.host_id() == 0:
        summary_thread.shutdown()
    # Wait until computations are done before exiting
    jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
Exemple #29
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.enable_v2_behavior()
    # make sure tf does not allocate gpu memory
    tf.config.experimental.set_visible_devices([], 'GPU')

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(FLAGS.model_dir)

    rng = random.PRNGKey(0)

    image_size = 224

    batch_size = FLAGS.batch_size
    if batch_size % jax.device_count() > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')
    local_batch_size = batch_size // jax.host_count()
    device_batch_size = batch_size // jax.device_count()

    platform = jax.local_devices()[0].platform

    if FLAGS.half_precision:
        if platform == 'tpu':
            model_dtype = jnp.bfloat16
            input_dtype = tf.bfloat16
        else:
            model_dtype = jnp.float16
            input_dtype = tf.float16
    else:
        model_dtype = jnp.float32
        input_dtype = tf.float32

    train_iter = create_input_iter(local_batch_size,
                                   image_size,
                                   input_dtype,
                                   train=True,
                                   cache=FLAGS.cache)
    eval_iter = create_input_iter(local_batch_size,
                                  image_size,
                                  input_dtype,
                                  train=False,
                                  cache=FLAGS.cache)

    num_epochs = FLAGS.num_epochs
    steps_per_epoch = input_pipeline.TRAIN_IMAGES // batch_size
    steps_per_eval = input_pipeline.EVAL_IMAGES // batch_size
    steps_per_checkpoint = steps_per_epoch * 10
    num_steps = steps_per_epoch * num_epochs

    base_learning_rate = FLAGS.learning_rate * batch_size / 256.
    base_learning_rate = base_learning_rate / FLAGS.loss_scaling

    model, model_state = create_model(rng, device_batch_size, image_size,
                                      model_dtype)
    optimizer = optim.Momentum(beta=FLAGS.momentum,
                               nesterov=True).create(model)
    state = TrainState(step=0, optimizer=optimizer, model_state=model_state)
    del model, model_state  # do not keep a copy of the initial model

    state = restore_checkpoint(state)
    step_offset = int(
        state.step)  # step_offset > 0 if restarting from checkpoint
    state = jax_utils.replicate(state)

    learning_rate_fn = create_learning_rate_fn(base_learning_rate,
                                               steps_per_epoch, num_epochs)

    p_train_step = jax.pmap(functools.partial(
        train_step, learning_rate_fn=learning_rate_fn),
                            axis_name='batch')
    p_eval_step = jax.pmap(eval_step, axis_name='batch')

    epoch_metrics = []
    t_loop_start = time.time()
    for step, batch in zip(range(step_offset, num_steps), train_iter):
        state, metrics = p_train_step(state, batch)
        epoch_metrics.append(metrics)
        if (step + 1) % steps_per_epoch == 0:
            epoch = step // steps_per_epoch
            epoch_metrics = common_utils.get_metrics(epoch_metrics)
            summary = jax.tree_map(lambda x: x.mean(), epoch_metrics)
            logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                         summary['loss'], summary['accuracy'] * 100)
            steps_per_sec = steps_per_epoch / (time.time() - t_loop_start)
            t_loop_start = time.time()
            if jax.host_id() == 0:
                for key, vals in epoch_metrics.items():
                    tag = 'train_%s' % key
                    for i, val in enumerate(vals):
                        summary_writer.scalar(tag, val,
                                              step - len(vals) + i + 1)
                summary_writer.scalar('steps per second', steps_per_sec, step)

            epoch_metrics = []
            eval_metrics = []

            # sync batch statistics across replicas
            state = sync_batch_stats(state)
            for _ in range(steps_per_eval):
                eval_batch = next(eval_iter)
                metrics = p_eval_step(state, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            summary = jax.tree_map(lambda x: x.mean(), eval_metrics)
            logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                         summary['loss'], summary['accuracy'] * 100)
            if jax.host_id() == 0:
                for key, val in eval_metrics.items():
                    tag = 'eval_%s' % key
                    summary_writer.scalar(tag, val.mean(), step)
                summary_writer.flush()
        if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps:
            state = sync_batch_stats(state)
            save_checkpoint(state)

    # Wait until computations are done before exiting
    jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
Exemple #30
0
class FlaxOptimizersEquivalenceTest(chex.TestCase):
    def setUp(self):
        super().setUp()
        self.init_params = (jnp.array([1., 0.1, 1., 2.]), jnp.array([3., 4.]))
        self.per_step_updates = (jnp.array([0., 0.3, 500.,
                                            5.]), jnp.array([300., 3.]))

    @parameterized.named_parameters(
        ('sgd', alias.sgd(LR), optim.GradientDescent(LR)),
        ('momentum', alias.sgd(LR, momentum=0.9), optim.Momentum(
            LR, beta=0.9)),  # Different names.
        ('nesterov_momentum', alias.sgd(LR, momentum=0.9, nesterov=True),
         optim.Momentum(LR, beta=0.9, nesterov=True)),
        ('rmsprop', alias.rmsprop(LR), optim.RMSProp(LR)),
        ('centered_rmsprop', alias.rmsprop(
            LR, centered=True), optim.RMSProp(LR, centered=True)),
        ('adam', alias.adam(LR), optim.Adam(LR)),
        ('adam_w', alias.adamw(LR, weight_decay=1e-4),
         optim.Adam(LR, weight_decay=1e-4)),  # Different name.
        (
            'adagrad',
            alias.adagrad(LR,
                          initial_accumulator_value=0.),  # Different default!
            optim.Adagrad(LR)),
        ('lamb', alias.lamb(LR), optim.LAMB(LR)),
        ('lars',
         alias.lars(LR,
                    weight_decay=.5,
                    trust_coefficient=0.003,
                    momentum=0.9,
                    eps=1e-3),
         optim.LARS(
             LR, weight_decay=.5, trust_coefficient=0.003, beta=0.9,
             eps=1e-3)),
        ('adafactor',
         alias.adafactor(learning_rate=LR / 10.,
                         factored=True,
                         multiply_by_parameter_scale=True,
                         clipping_threshold=1.0,
                         decay_rate=0.8,
                         min_dim_size_to_factor=2),
         optim.Adafactor(learning_rate=LR / 10.,
                         factored=True,
                         multiply_by_parameter_scale=True,
                         clipping_threshold=1.0,
                         decay_rate=0.8,
                         min_dim_size_to_factor=2)),
    )
    def test_flax_optim_equivalence(self, optax_optimizer, flax_optimizer):

        # flax/optim
        flax_params = self.init_params
        flax_optimizer = flax_optimizer.create(flax_params)
        for _ in range(STEPS):
            flax_optimizer = flax_optimizer.apply_gradient(
                self.per_step_updates)
            flax_params = flax_optimizer.target

        # optax
        optax_params = self.init_params
        state = optax_optimizer.init(optax_params)
        for _ in range(STEPS):
            updates, state = optax_optimizer.update(self.per_step_updates,
                                                    state, optax_params)
            optax_params = update.apply_updates(optax_params, updates)

        # Check equivalence.
        chex.assert_tree_all_close(flax_params, optax_params, rtol=2e-4)