Ejemplo n.º 1
0
    def loss_fn(model, state, batch, prng_key):
        """Loss function used for training."""
        with flax.deprecated.nn.stateful(state) as new_state:
            with flax.deprecated.nn.stochastic(prng_key):
                logits, model_penalty, summaries = model(batch['image'])
        ce_loss = train_functions.cross_entropy_loss(logits, batch['label'])

        # l2_reg not used when using prior
        no_prior = (config.kernel_prior == 'none'
                    and config.bias_prior == 'none'
                    and config.kernel_prior == 'none')
        assert l2_reg == 0 or no_prior, 'Either set priors or l2_reg > 0, not both.'
        weight_penalty_params = jax.tree_leaves(model.params)
        weight_l2 = sum(
            [jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1])
        weight_penalty = l2_reg * 0.5 * weight_l2

        kernels = optim.ModelParamTraversal(lambda n, _: n.endswith('kernel'))
        kernel_prior_penalty = sum(map(kernel_prior(), kernels.iterate(model)))

        scales = optim.ModelParamTraversal(
            lambda n, _: n.endswith('scale') or n.endswith('gamma'))
        scale_prior_penalty = sum(map(scale_prior(), scales.iterate(model)))

        biases = optim.ModelParamTraversal(
            lambda n, _: any(map(n.endswith, ['bias', 'rebias', 'beta'])))
        bias_prior_penalty = sum(map(bias_prior(), biases.iterate(model)))
        prior_penalty = kernel_prior_penalty + bias_prior_penalty + scale_prior_penalty
        prior_penalty *= config.prior_reg

        loss = (ce_loss + weight_penalty +
                config.std_penalty_mult * model_penalty + prior_penalty)

        return loss, (new_state, logits, prior_penalty, model_penalty,
                      weight_penalty, summaries)
Ejemplo n.º 2
0
def create_optimizer(config, params):
    if config.optimizer == "adam":
        optimizer_cls = optim.Adam
    elif config.optimizer == "lamb":
        optimizer_cls = optim.LAMB
    else:
        raise ValueError("Unsupported value for optimizer: {config.optimizer}")
    common_kwargs = dict(
        learning_rate=config.learning_rate,
        beta1=config.adam_beta1,
        beta2=config.adam_beta2,
        eps=config.adam_epsilon,
    )
    optimizer_decay_def = optimizer_cls(weight_decay=config.weight_decay,
                                        **common_kwargs)
    optimizer_no_decay_def = optimizer_cls(weight_decay=0.0, **common_kwargs)

    def exclude_from_decay(path, _):
        return "bias" in path or "layer_norm" in path or "layernorm" in path

    decay = optim.ModelParamTraversal(
        lambda *args: not exclude_from_decay(*args))
    no_decay = optim.ModelParamTraversal(exclude_from_decay)
    optimizer_def = optim.MultiOptimizer((decay, optimizer_decay_def),
                                         (no_decay, optimizer_no_decay_def))
    optimizer = optimizer_def.create(params)
    return optimizer
Ejemplo n.º 3
0
def create_optimizer(model, learning_rate=1e-4):
  """Create optimizer used for training model.

  MultiOpt is used to apply Adam Optimizer with weight decay to all parameters
  except layer_norm and bias and Adam Optimizer without weight decay for
  layer_norm and bias params.

  Args:
    model: JAX model to add optimizer to
    learning_rate: base learning rate used for initializing optimizer

  Returns:
    optimizer: model with Adam Optimizer to be used for training
  """
  weight_decay_def = optim.Adam(
      learning_rate=learning_rate, eps=1e-6, weight_decay=0.01)
  no_decay_def = optim.Adam(
      learning_rate=learning_rate, eps=1e-6, weight_decay=0.0)

  def filter_weight_decay(key, _):
    return 'layer_norm' not in key and 'bias' not in key
  def filter_other(key, _):
    return 'layer_norm' in key or 'bias' in key

  weight_decay_traversal = optim.ModelParamTraversal(filter_weight_decay)
  no_decay_traversal = optim.ModelParamTraversal(filter_other)
  optimizer_def = optim.MultiOptimizer(
      (weight_decay_traversal, weight_decay_def),
      (no_decay_traversal, no_decay_def))

  optimizer = optimizer_def.create(model)
  optimizer = optimizer.replicate()
  del model
  return optimizer
Ejemplo n.º 4
0
 def test_multi_optimizer_multiple_matches(self):
     params = {'a': {'x': 0., 'y': 0.}, 'b': {'y': 0, 'z': 0.}}
     opt_a = optim.GradientDescent(learning_rate=1.)
     opt_b = optim.GradientDescent(learning_rate=10.)
     t_a = optim.ModelParamTraversal(
         lambda path, _: path.endswith('/x') or path.endswith('/y'))
     t_b = optim.ModelParamTraversal(lambda path, value: value.dtype == jnp.
                                     int32 or path.endswith('/z'))
     optimizer_def = optim.MultiOptimizer((t_a, opt_a), (t_b, opt_b))
     with self.assertRaisesRegex(
             ValueError, r"Multiple optimizers match.*'y': \[0, 1\]"):
         jax.jit(optimizer_def.init_state)(params)
Ejemplo n.º 5
0
def create_optimizer(model, model_kwargs, learning_rate=1e-4):
  """Create optimizer used for training model.

  MultiOpt is used to apply Adam/LAMB Optimizer with weight decay to all
  parameters except layer_norm and bias and Adam/LAMB Optimizer without weight
  decay for layer_norm and bias params.

  Args:
    model: JAX model to add optimizer to
    model_kwargs: Bert model config parameter dictionary.
    learning_rate: base learning rate used for initializing optimizer

  Returns:
    optimizer: model with Adam/LAMB Optimizer to be used for training
  """
  if FLAGS.use_lamb:
    weight_decay_def = bert_lamb.BertLAMB(
        learning_rate=learning_rate,
        beta1=FLAGS.lamb_beta_1, beta2=FLAGS.lamb_beta_2,
        eps=10**FLAGS.log_epsilon,
        weight_decay=FLAGS.lamb_weight_decay,
        num_layers=model_kwargs['num_layers'])
    no_decay_def = bert_lamb.BertLAMB(
        learning_rate=learning_rate,
        beta1=FLAGS.lamb_beta_1, beta2=FLAGS.lamb_beta_2,
        eps=10**FLAGS.log_epsilon, weight_decay=0.0,
        num_layers=model_kwargs['num_layers'])
  else:
    weight_decay_def = optim.Adam(
        learning_rate=learning_rate, eps=1e-6, weight_decay=FLAGS.lamb_weight_decay)
    no_decay_def = optim.Adam(
        learning_rate=learning_rate, eps=1e-6, weight_decay=0.0)

  def filter_weight_decay(key, _):
    return 'layer_norm' not in key and 'bias' not in key and 'layernorm' not in key

  def filter_other(key, _):
    return 'layer_norm' in key or 'bias' in key or 'layernorm' in key

  weight_decay_traversal = optim.ModelParamTraversal(filter_weight_decay)
  no_decay_traversal = optim.ModelParamTraversal(filter_other)
  optimizer_def = optim.MultiOptimizer(
      (weight_decay_traversal, weight_decay_def),
      (no_decay_traversal, no_decay_def))

  optimizer = optimizer_def.create(model)
  optimizer = jax_utils.replicate(optimizer)
  del model
  return optimizer
Ejemplo n.º 6
0
def create_optimizer(config, model):
    common_kwargs = dict(
        learning_rate=config.learning_rate,
        beta1=0.9,
        beta2=0.999,
        eps=1e-6,
    )
    optimizer_decay_def = optim.Adam(weight_decay=0.01, **common_kwargs)
    optimizer_no_decay_def = optim.Adam(weight_decay=0.0, **common_kwargs)
    decay = optim.ModelParamTraversal(lambda path, _: 'bias' not in path)
    no_decay = optim.ModelParamTraversal(lambda path, _: 'bias' in path)
    optimizer_def = optim.MultiOptimizer((decay, optimizer_decay_def),
                                         (no_decay, optimizer_no_decay_def))
    optimizer = optimizer_def.create(model)
    return optimizer
Ejemplo n.º 7
0
def rescale_layers(flax_module, layer_rescale_factors):
    """Rescales the model variables by given multiplicative factors.

  Args:
    flax_module: A flax module where params is a nested dictionary.
    layer_rescale_factors: A dictionary mapping flat keys to a multiplicative
      rescale factor. The corresponding params in the module pytree will be
      changed from x -> a * x for rescale factor a. The keys of the dictionary
      must be of the form described in the flatten_keys documentation.

  Returns:
    A new flax module with the corresponding params rescaled.
  """
    all_keys = flatten_dict(flax_module.params).keys()
    logging.info('All keys:')
    for key in all_keys:
        logging.info(key)

    for key in layer_rescale_factors:
        if key not in all_keys:
            raise ValueError('Module does not have key: {}'.format(key))
        logging.info('Rescaling %s by factor %f', key,
                     layer_rescale_factors[key])
        traversal = optim.ModelParamTraversal(lambda path, _: path == key)
        flax_module = traversal.update(
            lambda x: x * layer_rescale_factors[key], flax_module)
    return flax_module
Ejemplo n.º 8
0
 def test_param_selection(self):
   params = {
       'x': {
           'kernel': 1,
           'bias': 2,
           'y': {
               'kernel': 3,
               'bias': 4,
           },
       },
   }
   names = []
   def filter_fn(name, _):
     names.append(name)  # track names passed to filter_fn for testing
     return 'kernel' in name
   model = nn.Model(None, params)
   traversal = optim.ModelParamTraversal(filter_fn)
   values = list(traversal.iterate(model))
   self.assertEqual(values, [1, 3])
   self.assertEqual(set(names), set([
       '/x/kernel', '/x/bias', '/x/y/kernel', '/x/y/bias']))
   new_model = traversal.update(lambda x: x + x, model)
   expected_params = {
       'x': {
           'kernel': 2,
           'bias': 2,
           'y': {
               'kernel': 6,
               'bias': 4,
           },
       },
   }
   expected_model = nn.Model(None, expected_params)
   self.assertEqual(new_model, expected_model)
def create_optimizer(model, learning_rate, weight_decay, layers=None):
    """Instantiates Adam multi-optimizer."""

    if layers is None:
        assert (
            type(learning_rate) == type(weight_decay) == float
        ), 'Specify float values for moded learning rate and weight decay!'
        optimizer_def = optim.Adam(learning_rate=learning_rate,
                                   weight_decay=weight_decay)
        optimizer = optimizer_def.create(model)

    else:
        assert (
            len(learning_rate) == len(weight_decay) == len(layers)
        ), 'Number of specified learning rates, weight decays, and layers must be equal!'
        optimizers = []
        for lr, wd, layer in zip(learning_rate, weight_decay, layers):
            if lr > 0:
                opt = optim.Adam(learning_rate=lr, weight_decay=wd)
                filter_fn = functools.partial(path_inclusion_filter_fn,
                                              layer=layer)
                traversal = optim.ModelParamTraversal(filter_fn)
                traversal_opt = (traversal, opt)
                optimizers.append(traversal_opt)
        optimizer_def = optim.MultiOptimizer(*optimizers)
        optimizer = optimizer_def.create(model)

    return optimizer
Ejemplo n.º 10
0
def create_optimizer(config, model, initial_params):
    """Create a model, starting with a pre-trained checkpoint."""
    common_kwargs = dict(
        learning_rate=config.learning_rate,
        beta1=0.9,
        beta2=0.999,
        eps=1e-6,
    )
    optimizer_decay_def = optim.Adam(weight_decay=0.01, **common_kwargs)
    optimizer_no_decay_def = optim.Adam(weight_decay=0.0, **common_kwargs)
    decay = optim.ModelParamTraversal(lambda path, _: 'bias' not in path)
    no_decay = optim.ModelParamTraversal(lambda path, _: 'bias' in path)
    optimizer_def = optim.MultiOptimizer((decay, optimizer_decay_def),
                                         (no_decay, optimizer_no_decay_def))
    # TODO(marcvanzee): MultiOptimizer triggers double XLA compilation on TPU so
    # we use Adam here, but we should investigate why this happens.
    optimizer_def = optim.Adam(learning_rate=config.learning_rate)
    optimizer = optimizer_def.create(model)
    optimizer = optimizer.replicate()
    del model  # don't keep a copy of the initial model
    return optimizer
Ejemplo n.º 11
0
def create_optimizer(config, model):
    if config.optimizer == 'adam':
        optimizer_cls = optim.Adam
    elif config.optimizer == 'lamb':
        optimizer_cls = optim.LAMB
    else:
        raise ValueError('Unsupported value for optimizer: {config.optimizer}')
    common_kwargs = dict(
        learning_rate=config.learning_rate,
        beta1=0.9,
        beta2=0.999,
        eps=1e-6,
    )
    optimizer_decay_def = optimizer_cls(weight_decay=0.01, **common_kwargs)
    optimizer_no_decay_def = optimizer_cls(weight_decay=0.0, **common_kwargs)
    decay = optim.ModelParamTraversal(lambda path, _: 'bias' not in path)
    no_decay = optim.ModelParamTraversal(lambda path, _: 'bias' in path)
    optimizer_def = optim.MultiOptimizer((decay, optimizer_decay_def),
                                         (no_decay, optimizer_no_decay_def))
    optimizer = optimizer_def.create(model)
    return optimizer
Ejemplo n.º 12
0
def rescale_layers(params, layer_rescale_factors):
  """Rescales the model variables by given multiplicative factors.

  Args:
    params: a dict of trainable model parameters.
    layer_rescale_factors: A dictionary mapping flat keys to a multiplicative
      rescale factor. The corresponding params in the module pytree will be
      changed from x -> a * x for rescale factor a. The keys of the dictionary
      must be of the form described in the flatten_keys documentation.

  Returns:
    A new flax module with the corresponding params rescaled.
  """
  all_keys = flatten_dict(params).keys()
  logging.info('All keys:')
  for key in all_keys:
    logging.info(key)

  for key in layer_rescale_factors:
    logging.info('Rescaling %s by factor %f', key, layer_rescale_factors[key])
    traversal = optim.ModelParamTraversal(lambda path, _: path == key)
    params = traversal.update(lambda x: x * layer_rescale_factors[key], params)
  return params
Ejemplo n.º 13
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  # Make sure tf does not allocate gpu memory.
  tf.config.experimental.set_visible_devices([], 'GPU')

  if FLAGS.jax_backend_target:
    jax.config.FLAGS.jax_xla_backend = 'tpu_driver'
    jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target

  # Number of local devices for this host.
  n_devices = jax.local_device_count()

  if jax.host_id() == 0:
    tf.io.gfile.makedirs(FLAGS.model_dir)

  if FLAGS.batch_size % n_devices:
    raise ValueError('Batch size must be divisible by the number of devices')

  vocab_path = FLAGS.vocab_path
  if vocab_path is None:
    vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model')
  tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

  # Load Dataset
  # ---------------------------------------------------------------------------
  logging.info('Initializing dataset.')
  if FLAGS.dynamic:
    train_ds_mgr, eval_ds, predict_ds, encoder = input_pipeline.get_dynamic_datasets(
        dataset_name=FLAGS.dataset_name,
        eval_dataset_name=FLAGS.eval_dataset_name,
        shard_idx=jax.host_id(),
        shard_count=jax.host_count(),
        data_dir=FLAGS.data_dir,
        vocab_path=FLAGS.vocab_path,
        target_vocab_size=FLAGS.vocab_size,
        batch_size=FLAGS.batch_size,
        max_length=FLAGS.max_target_length,
        max_eval_length=FLAGS.max_eval_target_length,
        paracrawl_size=FLAGS.paracrawl_size,
        is_scores_path=FLAGS.is_scores_path,
        num_buckets=FLAGS.num_data_buckets)
    if FLAGS.static:
      weights = np.array([float(w) for w in FLAGS.static.split(',')])
      assert len(weights) == FLAGS.num_data_buckets
      train_ds = train_ds_mgr.sampled_dataset(weights)
      FLAGS.dynamic = False
    else:
      init_dist = np.zeros(FLAGS.num_data_buckets)
      if FLAGS.data_selection_size < FLAGS.num_data_buckets:
        init_dist[range(FLAGS.data_selection_size)] = 1.0
        train_ds = train_ds_mgr.sampled_dataset(init_dist)
      else:
        train_ds = build_split(train_ds_mgr, 1.0)

  else:
    train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
        dataset_name=FLAGS.dataset_name,
        eval_dataset_name=FLAGS.eval_dataset_name,
        shard_idx=jax.host_id(),
        shard_count=jax.host_count(),
        data_dir=FLAGS.data_dir,
        vocab_path=vocab_path,
        target_vocab_size=FLAGS.vocab_size,
        batch_size=FLAGS.batch_size,
        max_length=FLAGS.max_target_length,
        max_eval_length=FLAGS.max_eval_target_length,
        paracrawl_size=FLAGS.paracrawl_size,
        is_scores_path=FLAGS.is_scores_path,
        num_to_keep=FLAGS.data_selection_size,
        pseudo_path=FLAGS.pseudo_path,
        repeat_count=FLAGS.repeat_count,
        newscommentary_size=FLAGS.newscommentary_size)

  if FLAGS.aux_eval_dataset:
    aux_datasets = []
    aux_names = FLAGS.aux_eval_dataset.split(',')
    for name in aux_names:
      _, aux_eval_ds, _, _ = input_pipeline.get_wmt_datasets(
          dataset_name=name,
          eval_dataset_name=None,
          shard_idx=jax.host_id(),
          shard_count=jax.host_count(),
          data_dir=FLAGS.data_dir,
          vocab_path=vocab_path,
          target_vocab_size=FLAGS.vocab_size,
          batch_size=FLAGS.batch_size,
          max_length=FLAGS.max_target_length,
          max_eval_length=FLAGS.max_eval_target_length,
          paracrawl_size=FLAGS.paracrawl_size,
          is_scores_path=FLAGS.is_scores_path,
          num_to_keep=FLAGS.data_selection_size,
          pseudo_path=FLAGS.pseudo_path,
          repeat_count=FLAGS.repeat_count,
          newscommentary_size=FLAGS.newscommentary_size)
      aux_datasets.append(aux_eval_ds)

  train_iter = iter(train_ds)
  vocab_size = int(encoder.vocab_size())
  eos_id = decode.EOS_ID  # Default Sentencepiece EOS token.

  def decode_tokens(toks):
    valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
    return encoder.detokenize(valid_toks).numpy().decode('utf-8')

  logging.info('Initializing model, optimizer, and step functions.')

  # Build Model and Optimizer
  # ---------------------------------------------------------------------------
  train_config = models.TransformerConfig(
      vocab_size=vocab_size,
      output_vocab_size=vocab_size,
      share_embeddings=FLAGS.share_embeddings,
      logits_via_embedding=FLAGS.logits_via_embedding,
      dtype=jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32,
      emb_dim=FLAGS.emb_dim,
      num_heads=FLAGS.num_heads,
      num_layers=FLAGS.num_layers,
      qkv_dim=FLAGS.qkv_dim,
      mlp_dim=FLAGS.mlp_dim,
      max_len=max(FLAGS.max_target_length, FLAGS.max_eval_target_length),
      dropout_rate=FLAGS.dropout_rate,
      attention_dropout_rate=FLAGS.attention_dropout_rate,
      deterministic=False,
      decode=False,
      kernel_init=nn.initializers.xavier_uniform(),
      bias_init=nn.initializers.normal(stddev=1e-6))
  eval_config = train_config.replace(deterministic=True)
  predict_config = train_config.replace(deterministic=True, decode=True)

  start_step = 0
  rng = jax.random.PRNGKey(FLAGS.random_seed)
  rng, init_rng = jax.random.split(rng)
  # It's possible that is supposed to be per device batch size
  input_shape = (FLAGS.batch_size, FLAGS.max_target_length)
  target_shape = (FLAGS.batch_size, FLAGS.max_target_length)

  m = models.Transformer(eval_config)
  initial_variables = jax.jit(m.init)(init_rng,
                                      jnp.ones(input_shape, jnp.float32),
                                      jnp.ones(target_shape, jnp.float32))

  # apply an optimizer to this tree
  optimizer_def = optim.Adam(
      FLAGS.learning_rate,
      beta1=0.9,
      beta2=0.98,
      eps=1e-9,
      weight_decay=FLAGS.weight_decay)
  optimizer = optimizer_def.create(initial_variables['params'])

  # We access model params only from optimizer below via optimizer.target.
  del initial_variables

  if FLAGS.restore_checkpoints:
    logging.info('Restoring checkpoint.')
    # If we have a pretrained model, use that. Else, just continue where leftoff
    model_path = FLAGS.pretrained_model_dir if FLAGS.pretrained_model_dir else FLAGS.model_dir
    optimizer = checkpoints.restore_checkpoint(model_path, optimizer)
    # Grab last step.
    start_step = int(optimizer.state.step)

  if FLAGS.adapter != NONE:
    adapter = optim.ModelParamTraversal(lambda path, _: FLAGS.adapter in path)
    optimizer = optimizer_def.create(optimizer.target, focus=adapter)

  writer = metric_writers.create_default_writer(
      FLAGS.model_dir, just_logging=jax.process_index() > 0)

  flag_key = [k for k in FLAGS.flags_by_module_dict().keys() if 'wmt.par' in k
             ]
  if flag_key:
    flag_key = flag_key[0]
    local_flags = {
        f.name: f.value for f in FLAGS.flags_by_module_dict()[flag_key]
    }
    writer.write_hparams(local_flags)

  # Replicate optimizer.
  optimizer = jax_utils.replicate(optimizer)

  if FLAGS.adapter != NONE:
    learning_rate_fn = common.create_learning_rate_scheduler(
        factors='constant',
        base_learning_rate=FLAGS.learning_rate,
        warmup_steps=FLAGS.warmup_steps)
  else:
    learning_rate_fn = common.create_learning_rate_scheduler(
        base_learning_rate=FLAGS.learning_rate, warmup_steps=FLAGS.warmup_steps,
        steps_per_cycle=FLAGS.steps_per_cycle, init_step=start_step,
        finetune_lr=FLAGS.finetune_lr)

  # compile multidevice versions of train/eval/predict step and cache init fn.
  p_train_step = jax.pmap(
      functools.partial(
          train_step,
          config=train_config,
          learning_rate_fn=learning_rate_fn,
          label_smoothing=FLAGS.label_smoothing),
      axis_name='batch',
      donate_argnums=(0,))  # pytype: disable=wrong-arg-types
  p_eval_step = jax.pmap(
      functools.partial(eval_step, config=eval_config), axis_name='batch')
  p_init_cache = jax.pmap(
      functools.partial(
          initialize_cache,
          max_decode_len=FLAGS.max_predict_length,
          config=predict_config),
      axis_name='batch')
  p_pred_step = jax.pmap(
      functools.partial(
          predict_step, config=predict_config, beam_size=FLAGS.beam_size),
      axis_name='batch',
      static_broadcasted_argnums=(3, 4))  # eos token, max_length are constant

  p_get_diag_grads = jax.pmap(
      functools.partial(
          get_diag_grads,
          config=eval_config),
      axis_name='batch')

  p_get_bucket_score = jax.pmap(
      functools.partial(
          get_diag_score,
          strategy=FLAGS.strategy),
      axis_name='batch')

  # Main Train Loop
  # ---------------------------------------------------------------------------

  # We init the first set of dropout PRNG keys, but update it afterwards inside
  # the main pmap"d training update for performance.
  dropout_rngs = jax.random.split(rng, jax.local_device_count())
  del rng

  logging.info('Starting training loop.')
  hooks = []
  report_progress = periodic_actions.ReportProgress(
      num_train_steps=FLAGS.num_train_steps, writer=writer)
  if jax.process_index() == 0:
    hooks += [
        report_progress,
        periodic_actions.Profile(logdir=FLAGS.model_dir, num_profile_steps=5)
    ]
  train_metrics = []
  total_steps = start_step + FLAGS.num_train_steps
  best_eval_loss = 1000
  curr_eval_loss = 1000
  with metric_writers.ensure_flushes(writer):
    for step in range(start_step, total_steps):
      is_last_step = step == total_steps - 1

      if FLAGS.dynamic and ((step - start_step) % FLAGS.resample_freq == 0):
        # Dynamic macro: use gradient alignment to score different ratios
        # of top k vs bottom N-k bins
        if FLAGS.macro:
          train_iter = get_macro_distribution(p_get_diag_grads,
                                              p_get_bucket_score, aux_eval_ds,
                                              train_ds_mgr, optimizer, eval_ds)
        else:
          # Use gradient alignment to score bins
          # take the top k bins and sample uniformly from them.
          raw_distribution = get_new_distribution(p_get_diag_grads,
                                                  p_get_bucket_score,
                                                  aux_eval_ds, train_ds_mgr,
                                                  optimizer,
                                                  eval_ds)
          logging.info(raw_distribution)
          selected = np.argsort(
              raw_distribution)[::-1][:FLAGS.data_selection_size]
          new_distribution = np.zeros(100)
          new_distribution[selected] = 1.0
          logging.info(new_distribution)
          train_ds = train_ds_mgr.sampled_dataset(new_distribution)
          train_iter = iter(train_ds)

      # Shard data to devices and do a training step.
      with jax.profiler.StepTraceAnnotation('train', step_num=step):
        try:
          batch = common_utils.shard(jax.tree_map(np.asarray, next(train_iter)))
          optimizer, metrics = p_train_step(
              optimizer, batch, dropout_rng=dropout_rngs)
          train_metrics.append(metrics)
        except StopIteration:
          is_last_step = True

      # Quick indication that training is happening.
      logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step)
      for h in hooks:
        h(step)

      # Periodic metric handling.
      if (step - start_step) % FLAGS.eval_frequency == 0 or is_last_step:
        with report_progress.timed('training_metrics'):
          logging.info('Gathering training metrics.')
          train_metrics = common_utils.get_metrics(train_metrics)
          lr = train_metrics.pop('learning_rate').mean()
          metrics_sums = jax.tree_map(jnp.sum, train_metrics)
          denominator = metrics_sums.pop('denominator')
          summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
          summary['learning_rate'] = lr
          summary = {'train_' + k: v for k, v in summary.items()}
          writer.write_scalars(step, summary)
          train_metrics = []

        with report_progress.timed('eval'):
          eval_results = evaluate(
              p_eval_step=p_eval_step,
              target=optimizer.target,
              eval_ds=eval_ds,
              num_eval_steps=FLAGS.num_eval_steps)
          curr_eval_loss = eval_results['loss']
          writer.write_scalars(
              step, {'eval_' + k: v for k, v in eval_results.items()})

        if FLAGS.aux_eval_dataset:
          for aux_i, aux_eval_ds in enumerate(aux_datasets):
            with report_progress.timed('aux_eval'):
              eval_results = evaluate(
                  p_eval_step=p_eval_step,
                  target=optimizer.target,
                  eval_ds=aux_eval_ds,
                  num_eval_steps=FLAGS.num_eval_steps)
              writer.write_scalars(
                  step, {
                      'aux' + str(aux_i) + '_eval_' + k: v
                      for k, v in eval_results.items()
                  })

        if FLAGS.compute_bleu:
          with report_progress.timed('translate_and_bleu'):
            exemplars, bleu_score = translate_and_calculate_bleu(
                p_pred_step=p_pred_step,
                p_init_cache=p_init_cache,
                target=optimizer.target,
                predict_ds=predict_ds,
                decode_tokens=decode_tokens,
                max_predict_length=FLAGS.max_predict_length)
            writer.write_scalars(step, {'bleu': bleu_score})
            writer.write_texts(step, {'samples': exemplars})

      # Save a checkpoint on one host after every checkpoint_freq steps.
      save_checkpoint = ((step - start_step) % FLAGS.checkpoint_freq == 0 or
                         is_last_step)
      if FLAGS.save_checkpoints and save_checkpoint and jax.host_id() == 0:
        if curr_eval_loss < best_eval_loss:  # only save better checkpoints
          best_eval_loss = curr_eval_loss
          with report_progress.timed('checkpoint'):
            checkpoints.save_checkpoint(
                FLAGS.model_dir, jax_utils.unreplicate(optimizer),
                step, keep=FLAGS.chkpts_to_keep, overwrite=True)

      if is_last_step:
        break
Ejemplo n.º 14
0
 def test_only_works_on_models(self):
   traversal = optim.ModelParamTraversal(lambda *_: True)
   with self.assertRaises(ValueError):
     list(traversal.iterate({}))
Ejemplo n.º 15
0
def main(argv):
    del argv
    # BEGIN GOOGLE-INTERNAL
    xm.setup_work_unit()
    # END GOOGLE-INTERNAL

    tf.enable_v2_behavior()
    init_mllogger()

    mllogger.event('cache_clear')
    mllogger.start('init_start')
    mllogger.event('submission_org', 'Google')
    mllogger.event('submission_platform',
                   'TPUv3-{}'.format(jax.device_count()))
    mllogger.event('submission_division', 'closed')
    mllogger.event('submission_status', 'research')
    mllogger.event('submission_benchmark', 'resnet')
    mllogger.event('train_samples', input_pipeline.TRAIN_IMAGES)
    mllogger.event('eval_samples', input_pipeline.EVAL_IMAGES)

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(FLAGS.output_dir)
        # Write summaries in background thread to avoid blocking on device sync
        summary_thread = thread.ThreadPoolExecutor(1, 'summary')
    # Infeed is currently synchronous, so do it in a background thread too
    infeed_pool = thread.ThreadPoolExecutor(jax.local_device_count(), 'infeed')

    if FLAGS.seed is not None:
        seed = FLAGS.seed
    else:
        seed = np.uint32(time.time() if jax.host_id() == 0 else 0)
        seed = per_host_sum_pmap(seed)

    mllogger.event('seed', int(seed))
    key = random.PRNGKey(seed)

    batch_size = FLAGS.batch_size
    if batch_size == -1:
        if jax.device_count() > 4096:
            batch_size = 65536
        else:
            batch_size = min(128 * jax.device_count(), 32768)
    mllogger.event('global_batch_size', batch_size)
    eval_batch_size = min(input_pipeline.EVAL_IMAGES, 256 * jax.device_count())
    device_batch_size = batch_size // jax.device_count()
    device_eval_batch_size = int(
        math.ceil(eval_batch_size / jax.device_count()))

    model_dtype = jnp.bfloat16 if FLAGS.bfloat16 else jnp.float32
    input_dtype = tf.bfloat16 if FLAGS.bfloat16 else tf.float32

    num_epochs = FLAGS.num_epochs
    if num_epochs is None:
        if batch_size < 32768:
            num_epochs = 56
        elif batch_size < 65536:
            num_epochs = 64
        else:
            num_epochs = 92

    steps_per_epoch = input_pipeline.TRAIN_IMAGES / batch_size
    # match TF submission behavior (round steps per loop up)
    steps_per_loop = int(math.ceil(steps_per_epoch * FLAGS.epochs_per_loop))
    # also apply rounding loop up to next step to "epochs" in LR schedule
    steps_per_epoch *= steps_per_loop / (steps_per_epoch *
                                         FLAGS.epochs_per_loop)

    steps_per_eval = int(
        math.ceil(input_pipeline.EVAL_IMAGES / eval_batch_size))

    base_learning_rate = FLAGS.learning_rate * batch_size / 256.
    beta = FLAGS.momentum
    if beta is None:
        if batch_size < 32768:
            beta = 0.9
        elif batch_size < 65536:
            beta = 0.929
        else:
            beta = 0.9537213777059405
    weight_decay = FLAGS.weight_decay
    if weight_decay is None:
        weight_decay = 2e-4 if batch_size < 32768 else 1e-4

    space_to_depth = FLAGS.space_to_depth
    if space_to_depth is None:
        space_to_depth = device_batch_size <= 8

    image_format = FLAGS.image_format
    if image_format is None:
        if space_to_depth and device_batch_size <= 8:
            image_format = 'HWNC'
        else:
            image_format = 'HWCN'

    image_size = input_pipeline.IMAGE_SIZE
    if space_to_depth:
        train_input_shape = (device_batch_size, image_size // 2,
                             image_size // 2, 12)
        eval_input_shape = (device_eval_batch_size, image_size // 2,
                            image_size // 2, 12)
    else:
        train_input_shape = (device_batch_size, image_size, image_size, 3)
        eval_input_shape = (device_eval_batch_size, image_size, image_size, 3)
    if image_format == 'HWCN':
        train_input_shape = tuple(train_input_shape[i] for i in [1, 2, 3, 0])
        eval_input_shape = tuple(eval_input_shape[i] for i in [1, 2, 3, 0])
    elif image_format == 'HWNC':
        train_input_shape = tuple(train_input_shape[i] for i in [1, 2, 0, 3])
        eval_input_shape = tuple(eval_input_shape[i] for i in [1, 2, 0, 3])

    model, state = create_model(key, device_batch_size, image_size,
                                model_dtype, space_to_depth)

    if FLAGS.lars:
        mllogger.event('opt_name', 'lars')
        mllogger.event('lars_opt_weight_decay', weight_decay)
        mllogger.event('lars_opt_momentum', beta)
        mllogger.event('lars_epsilon', 0)
        weight_opt_def = optim.LARS(base_learning_rate,
                                    beta,
                                    weight_decay=weight_decay)
        other_opt_def = optim.Momentum(base_learning_rate,
                                       beta,
                                       weight_decay=0,
                                       nesterov=False)
        learning_rate_fn = polynomial_learning_rate_fn(batch_size,
                                                       steps_per_epoch,
                                                       num_epochs)
    else:
        mllogger.event('opt_name', 'sgd')
        mllogger.event('sgd_opt_momentum', beta)
        weight_opt_def = optim.Momentum(base_learning_rate,
                                        beta,
                                        weight_decay=weight_decay,
                                        nesterov=True)
        other_opt_def = optim.Momentum(base_learning_rate,
                                       beta,
                                       weight_decay=0,
                                       nesterov=True)
        learning_rate_fn = piecewise_learning_rate_fn(base_learning_rate,
                                                      steps_per_epoch,
                                                      num_epochs)

    def filter_weights(key, _):
        return 'bias' not in key and 'scale' not in key

    def filter_other(key, _):
        return 'bias' in key or 'scale' in key

    weight_traversal = optim.ModelParamTraversal(filter_weights)
    other_traversal = optim.ModelParamTraversal(filter_other)
    optimizer_def = optim.MultiOptimizer((weight_traversal, weight_opt_def),
                                         (other_traversal, other_opt_def))
    optimizer = optimizer_def.create(model)
    del model  # do not keep a copy of the initial model

    optimizer = broadcast(optimizer)
    state = broadcast(state)
    empty_metrics = broadcast({'samples': 0, 'loss': 0., 'accuracy': 0})

    p_allreduce_metrics = jax.pmap(allreduce_metrics, axis_name='batch')

    p_sync_batchnorm_stats = jax.pmap(sync_batchnorm_stats, axis_name='batch')

    def host_loop_train_step(optimizer, state, metrics):
        token = lax.create_token(optimizer.state[0].step)
        batch, token = lax.infeed(token,
                                  shape=(jax.ShapedArray(
                                      train_input_shape, model_dtype),
                                         jax.ShapedArray((device_batch_size, ),
                                                         jnp.int32)))
        optimizer, state, metrics = train_step(optimizer, state, batch,
                                               metrics, learning_rate_fn,
                                               image_format, space_to_depth)
        return optimizer, state, metrics

    p_host_loop_train_step = jax.pmap(host_loop_train_step,
                                      axis_name='batch',
                                      in_axes=(None, 0, 0))

    def host_loop_eval_step(model, state, metrics):
        token = lax.create_token(metrics['samples'])
        batch, token = lax.infeed(
            token,
            shape=(jax.ShapedArray(eval_input_shape, model_dtype),
                   jax.ShapedArray((device_eval_batch_size, ), jnp.int32)))
        metrics = eval_step(model, state, batch, metrics, image_format,
                            space_to_depth)
        return metrics

    p_host_loop_eval_step = jax.pmap(host_loop_eval_step,
                                     axis_name='batch',
                                     in_axes=(None, None, 0))

    def device_train_loop_cond(args):
        _, _, _, _, step, loop = args
        return step // steps_per_loop == loop

    def device_train_loop_body(args):
        optimizer, state, metrics, token, step, loop = args
        batch, token = lax.infeed(token,
                                  shape=(jax.ShapedArray(
                                      train_input_shape, model_dtype),
                                         jax.ShapedArray((device_batch_size, ),
                                                         jnp.int32)))
        optimizer, state, metrics = train_step(optimizer, state, batch,
                                               metrics, learning_rate_fn,
                                               image_format, space_to_depth)
        step += 1
        return optimizer, state, metrics, token, step, loop

    def device_train_loop(optimizer, state, metrics, step, loop):
        token = lax.create_token(step)
        optimizer, state, metrics, _, step, _ = lax.while_loop(
            device_train_loop_cond, device_train_loop_body,
            (optimizer, state, metrics, token, step, loop))
        state = sync_batchnorm_stats(state)
        metrics = allreduce_metrics(metrics)
        return optimizer, state, metrics, step

    p_train_loop = jax.pmap(device_train_loop,
                            axis_name='batch',
                            in_axes=(None, None, 0, None, None))

    # BEGIN GOOGLE-INTERNAL
    def maybe_start_xprof(seconds):
        if jax.host_id() == 0 and FLAGS.xprof:
            xprof = xprof_session.XprofSession()
            xprof.start_session('REDACTED', True, 2)

            def sleep_and_end_xprof():
                time.sleep(seconds)
                logging.info(
                    'Xprof URL: %s',
                    xprof.end_session_and_get_url(
                        tag='flax resnet, {} devices, batch {} per device'.
                        format(jax.device_count(), device_batch_size)))

            thread.ThreadPoolExecutor(1, 'xprof').submit(sleep_and_end_xprof)

    # END GOOGLE-INTERNAL

    if FLAGS.precompile:
        logging.info('precompiling step/loop functions')
        if FLAGS.device_loop:
            # the device training loop condition will immediately be false
            p_train_loop(unbroadcast(optimizer), unbroadcast(state),
                         empty_metrics, jnp.array(0, dtype=jnp.int32), 1)
        else:
            for device in jax.local_devices():
                images = np.zeros(train_input_shape, model_dtype)
                labels = np.zeros((device_batch_size, ), np.int32)
                infeed_pool.submit(
                    partial(device.transfer_to_infeed, (images, labels)))
            p_host_loop_train_step(unbroadcast(optimizer), state,
                                   empty_metrics)
            p_sync_batchnorm_stats(state)
        for device in jax.local_devices():
            images = np.zeros(eval_input_shape, model_dtype)
            labels = np.zeros((device_eval_batch_size, ), np.int32)
            infeed_pool.submit(
                partial(device.transfer_to_infeed, (images, labels)))
        p_host_loop_eval_step(unbroadcast(optimizer.target),
                              unbroadcast(state), empty_metrics)
        p_allreduce_metrics(empty_metrics)['accuracy'].block_until_ready()
        logging.info('finished precompiling')

    # BEGIN GOOGLE-INTERNAL
    maybe_start_xprof(20)
    # END GOOGLE-INTERNAL
    if not FLAGS.fake_data:
        logging.info('constructing datasets')
        # pylint: disable=g-complex-comprehension
        train_ds, eval_ds = [
            input_pipeline.load_split(
                device_batch_size if train else device_eval_batch_size,
                dtype=input_dtype,
                train=train,
                image_format=image_format,
                space_to_depth=space_to_depth,
                cache_uncompressed=jax.device_count() > 64)
            for train in (True, False)
        ]
        logging.info('constructing dataset iterators')
        train_iter = iter(train_ds)
        eval_iter = iter(eval_ds)

    local_devices = jax.local_devices()
    host_step, device_step = 0, broadcast(0)
    mllogger.end('init_stop')
    mllogger.start('run_start')
    mllogger.start('block_start',
                   metadata={
                       'first_epoch_num': 1,
                       'epoch_count': FLAGS.epochs_per_loop
                   })
    for loop in range(int(math.ceil(num_epochs / FLAGS.epochs_per_loop)) + 2):
        # BEGIN GOOGLE-INTERNAL
        if loop == 10: maybe_start_xprof(1)
        # END GOOGLE-INTERNAL
        metrics = empty_metrics
        if FLAGS.device_loop:
            optimizer, state, metrics, device_step = p_train_loop(
                unbroadcast(optimizer), unbroadcast(state), metrics,
                unbroadcast(device_step), loop)
        while int(host_step // steps_per_loop) == loop:
            if not FLAGS.device_loop:
                optimizer, state, metrics = p_host_loop_train_step(
                    unbroadcast(optimizer), state, metrics)
            # pylint: disable=protected-access
            while infeed_pool._work_queue.qsize() > 100:
                time.sleep(0.01)
            for device in local_devices:
                if FLAGS.fake_data:
                    images = np.zeros(train_input_shape, model_dtype)
                    labels = np.zeros((device_batch_size, ), np.int32)
                else:
                    # pylint: disable=protected-access
                    images, labels = jax.tree_map(lambda x: x._numpy(),
                                                  next(train_iter))
                assert images.shape == train_input_shape and labels.dtype == jnp.int32
                infeed_pool.submit(
                    partial(device.transfer_to_infeed, (images, labels)))
            host_step += 1
        epoch = (loop + 1) * FLAGS.epochs_per_loop
        if FLAGS.train_metrics:
            if not FLAGS.device_loop:
                metrics = p_allreduce_metrics(metrics)
            if jax.host_id() == 0:
                summary_thread.submit(
                    partial(write_summary, summary_writer, metrics, 'train',
                            epoch))
        if not FLAGS.device_loop:
            state = p_sync_batchnorm_stats(state)
        metrics = empty_metrics
        for _ in range(steps_per_eval):
            metrics = p_host_loop_eval_step(unbroadcast(optimizer.target),
                                            unbroadcast(state), metrics)
            for device in local_devices:
                if FLAGS.fake_data:
                    images = np.zeros(eval_input_shape, model_dtype)
                    labels = np.zeros((device_eval_batch_size, ), np.int32)
                else:
                    # pylint: disable=protected-access
                    images, labels = jax.tree_map(lambda x: x._numpy(),
                                                  next(eval_iter))
                assert images.shape == eval_input_shape and labels.dtype == jnp.int32, \
                    'images.shape={}'.format(images.shape)
                infeed_pool.submit(
                    partial(device.transfer_to_infeed, (images, labels)))
        metrics = p_allreduce_metrics(metrics)
        if jax.host_id() == 0:
            summary_thread.submit(
                partial(write_summary, summary_writer, metrics, 'eval', epoch))
    # Wait until computations are done before exiting
    p_allreduce_metrics(metrics)['accuracy'].block_until_ready()
    if jax.host_id() == 0:
        summary_thread.shutdown()
        if not DONE:
            mllogger.end('run_stop', metadata={'status': 'aborted'})
Ejemplo n.º 16
0
def compute_is_scores(filename):
    """Compute IS scores for training data."""

    # Make sure tf does not allocate gpu memory.
    tf.config.experimental.set_visible_devices([], 'GPU')

    if FLAGS.jax_backend_target:
        jax.config.FLAGS.jax_xla_backend = 'tpu_driver'
        jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target

    # Number of local devices for this host.
    n_devices = jax.local_device_count()

    if jax.host_id() == 0:
        tf.io.gfile.makedirs(FLAGS.model_dir)

    if FLAGS.batch_size % n_devices:
        raise ValueError(
            'Batch size must be divisible by the number of devices')

    vocab_path = FLAGS.vocab_path
    if vocab_path is None:
        vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model')
    tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

    # Load Dataset
    print('Loading data')
    logging.info('Initializing dataset.')
    train_ds, encoder = input_pipeline.get_wmt_is_datasets(
        n_devices=n_devices,
        dataset_name=FLAGS.dataset_name,
        shard_idx=jax.host_id(),
        shard_count=jax.host_count(),
        data_dir=FLAGS.data_dir,
        vocab_path=vocab_path,
        target_vocab_size=FLAGS.vocab_size,
        batch_size=FLAGS.batch_size,
        max_length=FLAGS.max_target_length,
        paracrawl_size=FLAGS.paracrawl_size)
    print('Datasets created')

    train_iter = iter(train_ds)
    vocab_size = int(encoder.vocab_size())
    eos_id = decode.EOS_ID  # Default Sentencepiece EOS token.
    print('data iterators created')

    logging.info('Initializing model, optimizer, and step functions.')
    # Build Model and Optimizer
    # ---------------------------------------------------------------------------
    eval_config = models.TransformerConfig(
        vocab_size=vocab_size,
        output_vocab_size=vocab_size,
        share_embeddings=FLAGS.share_embeddings,
        logits_via_embedding=FLAGS.logits_via_embedding,
        dtype=jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32,
        emb_dim=FLAGS.emb_dim,
        num_heads=FLAGS.num_heads,
        num_layers=FLAGS.num_layers,
        qkv_dim=FLAGS.qkv_dim,
        mlp_dim=FLAGS.mlp_dim,
        max_len=max(FLAGS.max_target_length, FLAGS.max_eval_target_length),
        dropout_rate=FLAGS.dropout_rate,
        attention_dropout_rate=FLAGS.attention_dropout_rate,
        deterministic=True,
        decode=False,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6))

    start_step = 0
    rng = jax.random.PRNGKey(FLAGS.random_seed)
    rng, init_rng = jax.random.split(rng)
    # It's possible that is supposed to be per device batch size
    input_shape = (FLAGS.batch_size, FLAGS.max_target_length)
    target_shape = (FLAGS.batch_size, FLAGS.max_target_length)

    m = models.Transformer(eval_config)
    initial_variables = jax.jit(m.init)(init_rng,
                                        jnp.ones(input_shape, jnp.float32),
                                        jnp.ones(target_shape, jnp.float32))

    # apply an optimizer to this tree
    optimizer_def = optim.Adam(FLAGS.learning_rate,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=FLAGS.weight_decay)
    optimizer = optimizer_def.create(initial_variables['params'])

    # We access model params only from optimizer below via optimizer.target.
    del initial_variables

    if FLAGS.restore_checkpoints:
        logging.info('Restoring checkpoint.')
        # If we have a pretrained model, use that. Else, just continue where leftoff
        model_path = FLAGS.pretrained_model_dir if FLAGS.pretrained_model_dir else FLAGS.model_dir
        # When loading a checkpoint trained with adapters (ie. frozen weights)
        # restoring from the base optimizer fails. We catch this error and create
        # the optimizer with frozen weights.
        try:
            optimizer = checkpoints.restore_checkpoint(model_path, optimizer)
            # Grab last step.
            start_step = int(optimizer.state.step)
        except ValueError:
            adapter = optim.ModelParamTraversal(
                lambda path, _: FLAGS.adapter in path)
            optimizer = optimizer_def.create(optimizer.target, focus=adapter)
            optimizer = checkpoints.restore_checkpoint(model_path, optimizer)
            start_step = optimizer.state[0].step

    else:
        raise RuntimeError('Must restore checkpoint for IS')

    if FLAGS.adapter != NONE and not isinstance(optimizer,
                                                optim.MultiOptimizer):
        adapter = optim.ModelParamTraversal(
            lambda path, _: FLAGS.adapter in path)
        optimizer = optimizer_def.create(optimizer.target, focus=adapter)
    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    p_eval_step = jax.pmap(functools.partial(eval_for_is_step,
                                             config=eval_config),
                           axis_name='batch')

    logging.info('Start scoring loop.')
    metrics_all = []
    t_loop_start = time.time()

    # Eval Metrics
    logging.info('Gathering evaluation metrics.')
    t_eval_start = time.time()
    save_file = FLAGS.is_save_path + '/' + filename + '-lengths.txt'
    length_fp = tf.io.gfile.GFile(save_file, 'w')
    lengths_writer = csv.writer(length_fp)

    save_file = FLAGS.is_save_path + '/' + filename + '.txt'
    with tf.io.gfile.GFile(save_file, 'w') as fp:
        writer = csv.writer(fp)

        for batch_idx, eval_batch in enumerate(train_iter):
            eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
            cur_pred_batch_size = eval_batch['inputs'].shape[0]
            if cur_pred_batch_size % n_devices:
                padded_size = int(
                    np.ceil(cur_pred_batch_size / n_devices) * n_devices)
                eval_batch = jax.tree_map(
                    lambda x: common.pad_examples(x, padded_size), eval_batch)  # pylint: disable=cell-var-from-loop
            eval_batch = common_utils.shard(eval_batch)
            losses, lengths = p_eval_step(optimizer.target, eval_batch)
            if jax.host_id() == 0:
                losses = common.tohost(losses)
                lengths = common.tohost(lengths)
                if cur_pred_batch_size % n_devices:
                    writer.writerow(losses[:cur_pred_batch_size])
                    lengths_writer.writerow(lengths[:cur_pred_batch_size])
                else:
                    writer.writerow(losses)
                    lengths_writer.writerow(lengths)

            if batch_idx % 500 == 0:
                print('Batch', batch_idx)
                print(time.time() - t_loop_start)
    length_fp.close()
Ejemplo n.º 17
0
def main(_):
  tf.enable_v2_behavior()

  tf.random.set_seed(FLAGS.seed)
  np.random.seed(FLAGS.seed)
  random.seed(FLAGS.seed)

  if not gfile.isdir(FLAGS.save_dir):
    gfile.mkdir(FLAGS.save_dir)

  hparam_str_dict = dict(seed=FLAGS.seed, lr=FLAGS.lr)
  # Get hyperparmaters
  if FLAGS.xm_parameters:
    for key, value in json.loads(FLAGS.xm_parameters).items():
      if key not in hparam_str_dict:
        hparam_str_dict[key] = value

  hparam_str = ','.join(['%s=%s' % (shorten(k), str(hparam_str_dict[k]))
                         for k in sorted(hparam_str_dict.keys())])

  # Number of local devices for this host.
  n_devices = jax.local_device_count()

  if jax.host_id() == 0:
    summary_writer = tensorboard.SummaryWriter(
        os.path.join(FLAGS.save_dir, 'tb', hparam_str))

  batch_size = FLAGS.per_device_batch_size * n_devices
  io_shape = (FLAGS.per_device_batch_size,
              FLAGS.num_strings_per_task,
              FLAGS.max_characters)
  program_shape = (FLAGS.per_device_batch_size,
                   FLAGS.num_partial_programs,
                   FLAGS.max_program_length)
  split_io_shape = (FLAGS.per_device_batch_size,
                    FLAGS.num_strings_per_task,
                    FLAGS.num_partial_programs,
                    FLAGS.max_characters)

  # Setup DSL
  # ---------------------------------------------------------------------------

  # Build token tables.
  id_char_table = {i+1: char for (i, char) in enumerate(dsl.CHARACTER)}
  char_id_table = {char: id for id, char in id_char_table.items()}
  id_token_table, token_id_table = dsl_tokens.build_token_tables()
  io_vocab_size = len(char_id_table) + 1  # For padding.
  program_vocab_size = len(token_id_table) + 1

  bos_token = token_id_table[dsl.BOS]
  eos_token = token_id_table[dsl.EOS]

  # Parse io and program token sequences (for eval).
  def decode_io(inputs, outputs):
    """Decode io examples tokens."""
    def decode_str(s):
      """Decode string tokens."""
      return ''.join([id_char_table[c_id] for c_id in s if c_id > 0])

    inps, outs = [], []
    for inp, out in zip(inputs, outputs):
      inps.append(decode_str(inp))
      outs.append(decode_str(out))
    return inps, outs

  def decode_program(program):
    """Decode program tokens."""
    # Concatenate all partial programs.
    full_program = []
    for p in program:
      full_program.extend(p[:np.argmax(p == eos_token)].astype(np.int32))
    full_program = np.concatenate([full_program, [eos_token]], axis=0)

    try:
      return dsl.decode_program(full_program, id_token_table)
    except:  # pylint: disable=bare-except
      return None  # Program does not compile.

  # Load Dataset
  # ---------------------------------------------------------------------------
  logging.info('Initializing dataset.')
  if not FLAGS.dataset_filepattern:
    raise ValueError('Must specify filepattern to dataset.')

  # Training dataset.
  dataset = input_pipeline.create_dataset_from_tf_record(
      FLAGS.dataset_filepattern,
      token_id_table,
      char_id_table,
      num_partial_programs=FLAGS.num_partial_programs)
  dataset = dataset.padded_batch(
      batch_size,
      padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:],
                     split_io_shape[1:]),
      drop_remainder=True)
  # Split evaluation and training.
  eval_ds = dataset.take(FLAGS.num_eval_steps)
  # Decrease batch of predict dataset to handle beam search.
  predict_ds = eval_ds.unbatch().padded_batch(
      int(np.ceil(batch_size / 10)),
      padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:],
                     split_io_shape[1:]))
  train_ds = dataset.skip(FLAGS.num_eval_steps).repeat().prefetch(5)
  train_iter = train_ds.as_numpy_iterator()

  # Build Model and Optimizer
  # ---------------------------------------------------------------------------
  train_config = base_models.TransformerConfig(
      vocab_size=io_vocab_size,
      output_vocab_size=program_vocab_size,
      shift=True,
      emb_dim=FLAGS.embedding_dim,
      num_heads=FLAGS.num_heads,
      num_layers=FLAGS.num_layers,
      qkv_dim=FLAGS.embedding_dim,
      mlp_dim=FLAGS.hidden_dim,
      max_len=max(FLAGS.max_characters, FLAGS.max_program_length),
      deterministic=False,
      decode=False,
      bos_token=bos_token)
  eval_config = train_config.replace(deterministic=True)
  predict_config = train_config.replace(
      shift=False, deterministic=True, decode=not FLAGS.slow_decode)

  rng = jax.random.PRNGKey(FLAGS.seed)
  rng = jax.random.fold_in(rng, jax.host_id())
  rng, init_rng = jax.random.split(rng)

  m = models.DecomposeExpandingLayerTransformer(
      config=eval_config, num_partial_programs=FLAGS.num_partial_programs,
      use_expanding_layer=FLAGS.use_expanding_layer)
  initial_variables = jax.jit(m.init)(
      init_rng,
      jnp.ones(io_shape, jnp.float32),
      jnp.ones(io_shape, jnp.float32),
      jnp.ones(program_shape, jnp.float32))

  adam_opt_def = optim.Adam(
      FLAGS.lr,
      beta1=0.9,
      beta2=0.98,
      eps=1e-9,
      weight_decay=FLAGS.weight_decay)
  optimizer = adam_opt_def.create(initial_variables['params'])

  del initial_variables  # Don't keep a copy of the initial model.

  start_step = 0
  if FLAGS.restore_checkpoints:
    # Restore unreplicated optimizer + model state from last checkpoint.
    optimizer = checkpoints.restore_checkpoint(
        os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer)
    # Grab last step.
    start_step = int(optimizer.state.step)
    logging.info('Found model checkpointed at step %d.', start_step)
    if start_step > 0:
      start_step += 1

  # Build Pretraining Model and Optimizer (if specified)
  # ---------------------------------------------------------------------------
  pretrain_optimizer = None  # Optimizer used for pretrainined
  split_target = None  # Split pretrained model on partial programs.
  if start_step < FLAGS.num_pretrain_steps:
    # Load in pretraining optimizer.
    def filter_fn(path, value):
      del value
      if FLAGS.freeze_encoder and path.startswith('/encoder'):
        return False
      if FLAGS.freeze_decoder and path.startswith('/decoder'):
        return False
      return True
    trainable_weights = optim.ModelParamTraversal(filter_fn)
    pretrain_opt_def = optim.MultiOptimizer((trainable_weights, adam_opt_def))
    pretrain_optimizer = pretrain_opt_def.create(optimizer.target)

    if FLAGS.pretrain_checkpoint_format:
      pretrain_exprs = FLAGS.max_expressions // FLAGS.num_partial_programs
      checkpoint_dir = FLAGS.pretrain_checkpoint_format.format(pretrain_exprs)

      if gfile.isdir(checkpoint_dir):
        # Use the pretrained parameters if no training has occurred yet.
        if start_step == 0:
          restore_paths = []
          if FLAGS.restore_encoder:
            restore_paths.append('target/encoder')
          if FLAGS.restore_decoder:
            restore_paths.append('target/decoder')

          pretrain_optimizer = restore_selected_paths(
              pretrain_optimizer,
              checkpoint_dir=checkpoint_dir,
              restore_paths=restore_paths)
          logging.info('Found model pretrained at %s.', checkpoint_dir)

        if FLAGS.match_split_encoding:
          split_model = models.DecomposeExpandingLayerTransformer(
              config=eval_config, num_partial_programs=1,
              use_expanding_layer=False)
          split_program_shape = (FLAGS.per_device_batch_size,
                                 1,
                                 FLAGS.max_program_length)
          split_initial_variables = jax.jit(split_model.init)(
              init_rng,
              jnp.ones(io_shape, jnp.float32),
              jnp.ones(io_shape, jnp.float32),
              jnp.ones(split_program_shape, jnp.float32))
          split_optimizer = adam_opt_def.create(
              split_initial_variables['params'])
          split_optimizer = checkpoints.restore_checkpoint(
              checkpoint_dir, split_optimizer)
          split_target = split_optimizer.target
      else:
        logging.warn('Could not find model at %s.', checkpoint_dir)

    if FLAGS.match_split_encoding and (split_target is None):
      raise RuntimeError('We could not load the pretrained checkpoint, '
                         'which is needed to match split embeddings.')

  learning_rate_fn = create_learning_rate_scheduler(base_learning_rate=FLAGS.lr)
  p_pretrain_step = jax.pmap(
      functools.partial(
          pretrain_step,
          num_partial_programs=FLAGS.num_partial_programs,
          learning_rate_fn=learning_rate_fn,
          config=train_config,
          use_expanding_layer=FLAGS.use_expanding_layer,
          split_params=split_target),
      axis_name='batch')
  p_train_step = jax.pmap(
      functools.partial(
          train_step,
          num_partial_programs=FLAGS.num_partial_programs,
          learning_rate_fn=learning_rate_fn,
          config=train_config,
          use_expanding_layer=FLAGS.use_expanding_layer),
      axis_name='batch')
  p_eval_step = jax.pmap(
      functools.partial(
          eval_step,
          num_partial_programs=FLAGS.num_partial_programs,
          eos_token=eos_token,
          config=eval_config,
          use_expanding_layer=FLAGS.use_expanding_layer),
      axis_name='batch')
  p_init_cache = jax.pmap(
      functools.partial(
          initialize_cache,
          num_partial_programs=FLAGS.num_partial_programs,
          max_decode_len=FLAGS.max_program_length,
          config=predict_config,
          use_expanding_layer=FLAGS.use_expanding_layer),
      axis_name='batch')
  p_pred_step = jax.pmap(
      functools.partial(
          predict_step,
          num_partial_programs=FLAGS.num_partial_programs,
          max_decode_len=FLAGS.max_program_length,
          eos_token=eos_token,
          config=predict_config,
          slow_decode=FLAGS.slow_decode,
          use_expanding_layer=FLAGS.use_expanding_layer),
      axis_name='batch',
      static_broadcasted_argnums=(4,))
  p_split_pred_step = jax.pmap(
      functools.partial(
          predict_step,
          num_partial_programs=FLAGS.num_partial_programs,
          max_decode_len=FLAGS.max_program_length,
          eos_token=eos_token,
          config=predict_config,
          slow_decode=FLAGS.slow_decode,
          use_expanding_layer=False,
          use_split_encoding=True,
          split_params=split_target),
      axis_name='batch',
      static_broadcasted_argnums=(4,))

  # Main Train Loop
  # ---------------------------------------------------------------------------
  train_rngs = jax.random.split(rng, jax.local_device_count())
  del rng

  # Replicate optimizer.
  if pretrain_optimizer:
    pretrain_optimizer = jax_utils.replicate(pretrain_optimizer)

  optimizer = jax_utils.replicate(optimizer)

  metrics_all = []
  tick = time.time()
  for step in range(start_step, FLAGS.num_train_steps):
    inputs, outputs, programs, split_outputs = (
        common_utils.shard(next(train_iter)))

    if step < FLAGS.num_pretrain_steps:
      pretrain_optimizer, metrics, train_rngs = p_pretrain_step(
          pretrain_optimizer, inputs, outputs, programs,
          split_outputs=split_outputs,
          pretrain_rng=train_rngs)
    else:
      optimizer, metrics, train_rngs = p_train_step(
          optimizer, inputs, outputs, programs,
          train_rng=train_rngs)

    metrics_all.append(metrics)
    is_last_pretrain_step = step == FLAGS.num_pretrain_steps - 1
    is_last_step = step == FLAGS.num_train_steps - 1

    if is_last_pretrain_step:
      optimizer = maybe_copy_model_from_pretraining(
          optimizer, pretrain_optimizer, step, adam_opt_def)

    # Save a Checkpoint
    if (step % FLAGS.checkpoint_freq == 0 and step > 0) or is_last_step:
      optimizer = maybe_copy_model_from_pretraining(
          optimizer, pretrain_optimizer, step, adam_opt_def)
      if jax.host_id() == 0:
        # Save unreplicated optimizer + model state.
        checkpoints.save_checkpoint(
            os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str),
            jax_utils.unreplicate(optimizer),
            step)

    # Periodic metric handling.
    if not step or (step % FLAGS.log_freq != 0 and not is_last_step and
                    not is_last_pretrain_step):
      continue

    optimizer = maybe_copy_model_from_pretraining(
        optimizer, pretrain_optimizer, step, adam_opt_def)

    logging.info('Gathering training metrics.')
    # Training Metrics
    metrics_all = common_utils.get_metrics(metrics_all)
    lr = metrics_all.pop('learning_rate').mean()
    metrics_sums = jax.tree_map(jnp.sum, metrics_all)
    denominator = metrics_sums.pop('denominator')
    summary = jax.tree_map(
        lambda x: x / denominator,  # pylint: disable=cell-var-from-loop
        metrics_sums)
    summary['learning_rate'] = lr
    # Calculate (clipped) perplexity after averaging log-perplexities:
    summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4)

    if jax.host_id() == 0:
      logging.info('Train in step: %d, loss: %.4f', step, summary['loss'])
      tock = time.time()
      steps_per_sec = FLAGS.log_freq / (tock - tick)
      tick = tock
      summary_writer.scalar('train/steps per second', steps_per_sec, step)
      for key, val in summary.items():
        summary_writer.scalar('train/' + key, val, step)
      summary_writer.flush()
    # Reset metric accumulation for next evaluation cycle.
    metrics_all = []

    # Evaluation Metrics
    logging.info('Gathering evaluation metrics.')
    t_evaluation_start = time.time()

    eval_summary = evaluate(
        p_eval_step=p_eval_step,
        target=optimizer.target,
        eval_ds=eval_ds)
    if jax.host_id() == 0:
      logging.info('Evaluation time: %.4f s step %d, loss: %.4f.',
                   time.time()-t_evaluation_start, step, eval_summary['loss'])
      for key, val in eval_summary.items():
        summary_writer.scalar('eval/' + key, val, step)
      summary_writer.flush()

    # Beam search metrics.
    logging.info('Gathering beam search metrics.')
    for beam_size in [1, 10, 12, 24, 48, 96]:
      t_inference_start = time.time()

      pred_acc, message = predict_and_compute_score(
          p_pred_step=p_pred_step,
          p_init_cache=p_init_cache,
          target=optimizer.target,
          predict_ds=predict_ds,
          decode_io=decode_io,
          decode_program=decode_program,
          beam_size=beam_size,
          num_partial_programs=FLAGS.num_partial_programs,
          use_best_first_search=FLAGS.best_first_search,
          slow_decode=FLAGS.slow_decode)

      # Write to tensorboard.
      if jax.host_id() == 0:
        slow_or_fast = 'slow' if FLAGS.slow_decode else 'fast'
        logging.info(
            'Prediction time, %s (beam %d): %.4f s, step %d, score %.4f',
            slow_or_fast, beam_size, time.time() - t_inference_start, step,
            pred_acc)
        beam_search_or_bfs = 'bfs' if FLAGS.best_first_search else 'beam-search'
        summary_writer.scalar(
            'predict-{}/score-{}-{}'.format(slow_or_fast,
                                            beam_search_or_bfs,
                                            beam_size),
            pred_acc, step)
        summary_writer.text('samples-{}'.format(beam_size),
                            '\n------\n'.join(message), step)
        summary_writer.flush()

      if step < FLAGS.num_pretrain_steps and FLAGS.match_split_encoding:
        pred_acc, message = predict_and_compute_score(
            p_pred_step=p_split_pred_step,
            p_init_cache=p_init_cache,
            target=optimizer.target,
            predict_ds=predict_ds,
            decode_io=decode_io,
            decode_program=decode_program,
            beam_size=beam_size,
            num_partial_programs=FLAGS.num_partial_programs,
            use_best_first_search=FLAGS.best_first_search,
            slow_decode=FLAGS.slow_decode)

        # Write to tensorboard.
        if jax.host_id() == 0:
          slow_or_fast = 'slow' if FLAGS.slow_decode else 'fast'
          beam_search_or_bfs = ('bfs' if FLAGS.best_first_search
                                else 'beam-search')
          summary_writer.scalar(
              'predict-split-{}/score-{}-{}'.format(slow_or_fast,
                                                    beam_search_or_bfs,
                                                    beam_size),
              pred_acc, step)
          summary_writer.text('samples-split-{}'.format(beam_size),
                              '\n------\n'.join(message), step)
          summary_writer.flush()
Ejemplo n.º 18
0
def main(argv):
    del argv
    # BEGIN GOOGLE-INTERNAL
    xm.setup_work_unit()
    # END GOOGLE-INTERNAL

    tf.enable_v2_behavior()

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(FLAGS.output_dir)
        # Write summaries in background thread to avoid blocking on device sync
        summary_thread = thread.ThreadPoolExecutor(1, 'summary')
    if FLAGS.infeed:
        # Infeed is currently synchronous, so do it in a background thread too
        infeed_pool = thread.ThreadPoolExecutor(jax.local_device_count(),
                                                'infeed')

    rng = random.PRNGKey(0)

    image_size = 224

    batch_size = FLAGS.batch_size
    if batch_size is None:
        batch_size = min(128 * jax.device_count(), 32768)
    eval_batch_size = 128 * jax.device_count()
    local_batch_size = batch_size // jax.host_count()
    local_eval_batch_size = eval_batch_size // jax.host_count()
    device_batch_size = batch_size // jax.device_count()
    device_eval_batch_size = eval_batch_size // jax.device_count()
    device_last_eval_batch_size = (input_pipeline.EVAL_IMAGES %
                                   eval_batch_size) // jax.device_count()

    model_dtype = jnp.bfloat16 if FLAGS.bfloat16 else jnp.float32
    input_dtype = tf.bfloat16 if FLAGS.bfloat16 else tf.float32
    if FLAGS.transpose_images:
        train_input_shape = (224, 224, 3, device_batch_size)
        eval_input_shapes = [(224, 224, 3, bs)
                             for bs in (device_eval_batch_size,
                                        device_last_eval_batch_size)]
    else:
        train_input_shape = (device_batch_size, 224, 224, 3)
        eval_input_shapes = [(bs, 224, 224, 3)
                             for bs in (device_eval_batch_size,
                                        device_last_eval_batch_size)]

    num_epochs = FLAGS.num_epochs
    steps_per_epoch = input_pipeline.TRAIN_IMAGES / batch_size
    logging.info('steps_per_epoch: %f', steps_per_epoch)
    steps_per_eval = int(np.ceil(input_pipeline.EVAL_IMAGES / eval_batch_size))
    logging.info('steps_per_eval: %d', steps_per_eval)

    base_learning_rate = FLAGS.learning_rate * batch_size / 256.
    beta = FLAGS.momentum
    weight_decay = FLAGS.weight_decay

    logging.info('creating and initializing model and optimizer')
    model, state = create_model(rng, device_batch_size, image_size,
                                model_dtype)
    state = jax_utils.replicate(state)
    if FLAGS.lars:
        weight_opt_def = optim.LARS(base_learning_rate,
                                    beta,
                                    weight_decay=weight_decay)
        other_opt_def = optim.Momentum(base_learning_rate,
                                       beta,
                                       weight_decay=0,
                                       nesterov=False)
        learning_rate_fn = polynomial_learning_rate_fn(batch_size,
                                                       steps_per_epoch,
                                                       num_epochs)
    else:
        weight_opt_def = optim.Momentum(base_learning_rate,
                                        beta,
                                        weight_decay=weight_decay,
                                        nesterov=True)
        other_opt_def = optim.Momentum(base_learning_rate,
                                       beta,
                                       weight_decay=0,
                                       nesterov=True)
        learning_rate_fn = piecewise_learning_rate_fn(base_learning_rate,
                                                      steps_per_epoch,
                                                      num_epochs)

    def filter_weights(key, _):
        return 'bias' not in key and 'scale' not in key

    def filter_other(key, _):
        return 'bias' in key or 'scale' in key

    weight_traversal = optim.ModelParamTraversal(filter_weights)
    other_traversal = optim.ModelParamTraversal(filter_other)
    optimizer_def = optim.MultiOptimizer((weight_traversal, weight_opt_def),
                                         (other_traversal, other_opt_def))
    optimizer = optimizer_def.create(model)
    optimizer = optimizer.replicate()
    del model  # do not keep a copy of the initial model

    p_train_step = jax.pmap(partial(train_step,
                                    learning_rate_fn=learning_rate_fn),
                            axis_name='batch')
    p_eval_step = jax.pmap(eval_step, axis_name='batch')

    def device_train_loop_cond(args):
        _, _, _, _, step, epoch = args
        return step // steps_per_epoch == epoch

    def device_train_loop_body(args):
        optimizer, state, metrics, token, step, epoch = args
        (images, labels), token = lax.infeed(
            token,
            shape=(jax.ShapedArray(train_input_shape, model_dtype),
                   jax.ShapedArray((device_batch_size, ), jnp.int32)))
        batch = {'image': images, 'label': labels}
        optimizer, state, metrics = train_step(optimizer, state, batch,
                                               metrics, learning_rate_fn)
        step += 1
        return optimizer, state, metrics, token, step, epoch

    def device_train_loop(optimizer, state, metrics, step, epoch):
        token = lax.create_token(step)
        optimizer, state, metrics, _, step, _ = lax.while_loop(
            device_train_loop_cond, device_train_loop_body,
            (optimizer, state, metrics, token, step, epoch))
        return optimizer, state, metrics, step

    p_train_epoch = jax.pmap(device_train_loop, axis_name='batch')

    if FLAGS.precompile:
        logging.info('precompiling step/epoch functions')
        if FLAGS.infeed:
            # the device training loop condition will immediately be false
            p_train_epoch(optimizer, state, empty_metrics(),
                          jax_utils.replicate(0), jax_utils.replicate(1))
        else:
            batch = {
                'image':
                jnp.zeros((jax.local_device_count(), ) + train_input_shape,
                          model_dtype),
                'label':
                jnp.zeros((jax.local_device_count(), ) + (device_batch_size, ),
                          jnp.int32)
            }
            p_train_step(optimizer, state, batch, empty_metrics())
        for dbs, eis in zip(
            [device_eval_batch_size, device_last_eval_batch_size],
                eval_input_shapes):
            batch = {
                'image':
                jnp.zeros((jax.local_device_count(), ) + eis, model_dtype),
                'label':
                jnp.zeros((jax.local_device_count(), ) + (dbs, ), jnp.int32)
            }
            p_eval_step(optimizer.target, state, batch, empty_metrics())
        allreduce_metrics(empty_metrics())
        pmean = functools.partial(jax.lax.pmean, axis_name='batch')
        jax.pmap(pmean, axis_name='batch')(state)

    logging.info('constructing datasets')
    # pylint: disable=g-complex-comprehension
    train_ds, eval_ds = [
        input_pipeline.load_split(
            local_batch_size if train else local_eval_batch_size,
            image_size=image_size,
            dtype=input_dtype,
            train=train,
            transpose_images=FLAGS.transpose_images) for train in (True, False)
    ]
    # pylint: enable=g-complex-comprehension
    logging.info('constructing dataset iterators')
    train_iter = iter(train_ds)
    eval_iter = iter(eval_ds)

    logging.info('beginning training')
    host_step, device_step = 0, jax_utils.replicate(0)
    for epoch in range(num_epochs):
        device_epoch = jax_utils.replicate(epoch)
        metrics = empty_metrics()
        if FLAGS.infeed:
            optimizer, state, metrics, device_step = p_train_epoch(
                optimizer, state, metrics, device_step, device_epoch)
        while int(host_step // steps_per_epoch) == epoch:
            batch = jax.tree_map(lambda x: x._numpy(), next(train_iter))  # pylint: disable=protected-access
            if FLAGS.infeed:
                for i, device in enumerate(jax.local_devices()):
                    images, labels = batch['image'][i], batch['label'][i]
                    assert images.shape == train_input_shape and labels.dtype == jnp.int32
                    infeed_pool.submit(
                        partial(device.transfer_to_infeed, (images, labels)))
            else:
                optimizer, state, metrics = p_train_step(
                    optimizer, state, batch, metrics)
            host_step += 1
        if FLAGS.train_metrics:
            metrics = allreduce_metrics(metrics)
            if jax.host_id() == 0:
                summary_thread.submit(
                    partial(write_summary, summary_writer, metrics, 'train',
                            epoch + 1))
        if not FLAGS.distributed_batchnorm:  # otherwise it's already synced
            pmean = functools.partial(jax.lax.pmean, axis_name='batch')
            state = jax.pmap(pmean, axis_name='batch')(state)
        metrics = empty_metrics()
        for _ in range(steps_per_eval):
            batch = jax.tree_map(lambda x: x._numpy(), next(eval_iter))  # pylint: disable=protected-access
            metrics = p_eval_step(optimizer.target, state, batch, metrics)
        metrics = allreduce_metrics(metrics)
        if jax.host_id() == 0:
            summary_thread.submit(
                partial(write_summary, summary_writer, metrics, 'eval',
                        epoch + 1))
        # TODO(deveci): do something like this from the summary thread:
        # if summary['accuracy'] > TARGET_ACCURACY:
        #   break
    if jax.host_id() == 0:
        summary_thread.shutdown()
    # Wait until computations are done before exiting
    jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
Ejemplo n.º 19
0
def meta_optimize_scales(loss_fn,
                         fprop,
                         normalized_params,
                         norms,
                         hps,
                         input_shape,
                         output_shape,
                         rng_key,
                         metrics_logger=None,
                         log_every=10):
  """Implements MetaInit initializer.

  Args:
    loss_fn: Loss function.
    fprop: Forward pass of the model with API fprop(params, inputs) -> outputs.
    normalized_params: Pytree of model parameters. We assume that all non-bias
      terms have norm 1, and all bias terms are all 0's.
    norms: The initial guess of the learned norms, this is the starting point of
      meta_init.
    hps: HParam object. Required hparams are meta_learning_rate,
      meta_batch_size, meta_steps, and epsilon.
    input_shape: Must agree with batch[0].shape[1:].
    output_shape: Must agree with batch[1].shape[1:].
    rng_key: jax.PRNGKey, used to seed all randomness.
    metrics_logger: Supply a utils.MetricsLogger object.
    log_every: Log the meta loss every k steps.

  Returns:
    scales: The model scales after optimizing the meta_init loss.
    final_loss: The final meta objective value achieved.
  """
  num_outputs = output_shape[-1]
  if hps.meta_batch_size % jax.device_count() != 0:
    raise ValueError('meta_bs: {}, n_devices: {}'.format(
        hps.meta_batch_size, jax.device_count()))

  def get_batch(rng_key):
    """Return a fake batch of data."""
    meta_input_shape = (
        jax.local_device_count(),
        hps.meta_batch_size // jax.device_count(),
    ) + input_shape
    input_key, target_key = jax.random.split(rng_key)

    inputs = jax.random.normal(input_key, meta_input_shape)
    targets = jax.random.randint(target_key, (
        jax.local_device_count(),
        hps.meta_batch_size // jax.device_count(),
    ), 0, num_outputs)
    targets = jnp.eye(num_outputs)[targets]
    return (inputs, targets)

  # We will only optimize the scalars for model parameters with rank >=2.
  non_bias_and_scalar_keys = _get_non_bias_params(normalized_params)
  if jax.process_index() == 0:
    logging.info('MetaInit will optimize the following parameters:')
    for key in non_bias_and_scalar_keys:
      logging.info(key)
  traversal = optimizers.ModelParamTraversal(
      lambda path, _: path in non_bias_and_scalar_keys)

  # Non-bias, non-scalar norms.
  meta_params = traversal.update(lambda x: x, norms)
  meta_opt_init_fn, meta_opt_update_fn = optax.sgd(
      learning_rate=hps.meta_learning_rate,
      momentum=hps.meta_momentum)
  meta_optimizer_state = meta_opt_init_fn(meta_params)
  meta_optimizer_state = jax_utils.replicate(meta_optimizer_state)
  meta_params = jax_utils.replicate(meta_params)

  # Make a closure over the static variables (model, normalized_params, hps).
  @functools.partial(jax.pmap, axis_name='batch')
  def update(meta_params, optimizer_state, inputs, targets):
    """Update step."""
    def params_to_loss(params):
      return loss_fn(fprop({'params': params}, inputs, train=True), targets)

    def _meta_loss(params):
      return meta_loss(params_to_loss, params, normalized_params, hps.epsilon)

    grad_fn = jax.value_and_grad(_meta_loss, has_aux=False)
    loss, grads = grad_fn(meta_params)
    grads = model_utils.cross_device_avg(grads)
    grads = jax.tree_map(jnp.sign, grads)
    meta_updates, new_meta_optimizer_state = meta_opt_update_fn(
        grads, optimizer_state, params=meta_params)
    new_meta_params = optax.apply_updates(meta_params, meta_updates)
    return new_meta_params, new_meta_optimizer_state, loss

  training_curve = []
  start = time.perf_counter()
  for i in range(hps.meta_steps):
    batch_rng = jax.random.fold_in(rng_key, i)
    inputs, targets = get_batch(batch_rng)

    meta_params, meta_optimizer_state, loss_value = update(
        meta_params, meta_optimizer_state, inputs, targets)
    training_curve.append(loss_value)
    if (jax.process_index() == 0 and
        (i % log_every == 0 or (i + 1) == hps.meta_steps)):
      end = time.perf_counter()
      logging.info('Cumulative time (seconds): %d', end-start)
      logging.info('meta_init step %d, loss: %f', i, float(loss_value[0]))
      if metrics_logger is not None:
        metrics_logger.append_scalar_metrics({
            'global_step': i,
            'meta_loss': float(loss_value[0])
        })

  # Create a new model with the learned init.
  learned_norms = jax_utils.unreplicate(meta_params)
  return learned_norms, training_curve
Ejemplo n.º 20
0
def get_optimizer(hps):
    """Constructs the optimizer from the given HParams."""
    if 'weight_decay' in hps.opt_hparams:
        weight_decay = hps.opt_hparams['weight_decay']
    else:
        weight_decay = 0

    if hps.optimizer == 'sgd':
        return optimizers.GradientDescent(learning_rate=None)
    elif hps.optimizer == 'nesterov':
        return optimizers.Momentum(learning_rate=None,
                                   beta=hps.opt_hparams['momentum'],
                                   nesterov=True,
                                   weight_decay=weight_decay)
    elif hps.optimizer == 'momentum':
        return optimizers.Momentum(learning_rate=None,
                                   beta=hps.opt_hparams['momentum'],
                                   nesterov=False,
                                   weight_decay=weight_decay)
    elif hps.optimizer == 'lamb':
        assert hps.l2_decay_factor is None or weight_decay == 0.0
        return optimizers.LAMB(learning_rate=None,
                               beta1=hps.opt_hparams['beta1'],
                               beta2=hps.opt_hparams['beta2'],
                               eps=hps.opt_hparams['epsilon'],
                               weight_decay=weight_decay)
    elif hps.optimizer == 'adam':
        assert hps.l2_decay_factor is None or weight_decay == 0.0
        return optimizers.Adam(learning_rate=None,
                               beta1=hps.opt_hparams['beta1'],
                               beta2=hps.opt_hparams['beta2'],
                               eps=hps.opt_hparams['epsilon'],
                               weight_decay=weight_decay)
    elif hps.optimizer == 'lars':
        assert hps.l2_decay_factor is None or weight_decay == 0.0
        return optimizers.LARS(learning_rate=None,
                               beta=hps.opt_hparams['beta'],
                               weight_decay=weight_decay)
    elif hps.optimizer == 'mlperf_lars_resnet':
        assert hps.l2_decay_factor is None or weight_decay == 0.0
        weight_opt_def = optimizers.LARS(learning_rate=None,
                                         beta=hps.opt_hparams['beta'],
                                         weight_decay=weight_decay)
        other_opt_def = optimizers.Momentum(learning_rate=None,
                                            beta=hps.opt_hparams['beta'],
                                            weight_decay=0,
                                            nesterov=False)

        def filter_weights(key, _):
            return 'bias' not in key and 'scale' not in key

        def filter_other(key, _):
            return 'bias' in key or 'scale' in key

        weight_traversal = optimizers.ModelParamTraversal(filter_weights)
        other_traversal = optimizers.ModelParamTraversal(filter_other)
        return optimizers.MultiOptimizer((weight_traversal, weight_opt_def),
                                         (other_traversal, other_opt_def))
    elif hps.optimizer == 'mlperf_lamb':
        assert hps.l2_decay_factor is None or weight_decay == 0.0
        weight_opt_def = optimizers.LAMB(
            learning_rate=None,
            beta1=hps.opt_hparams['beta1'],
            beta2=hps.opt_hparams['beta2'],
            eps=hps.opt_hparams['epsilon'],
            weight_decay=hps.opt_hparams['lamb_weight_decay'])
        other_opt_def = optimizers.Adam(
            learning_rate=None,
            beta1=hps.opt_hparams['beta1'],
            beta2=hps.opt_hparams['beta2'],
            eps=hps.opt_hparams['epsilon'],
            weight_decay=hps.opt_hparams['adam_weight_decay'])

        def filter_weights(key, _):
            return 'bias' not in key and 'scale' not in key

        def filter_other(key, _):
            return 'bias' in key or 'scale' in key

        weight_traversal = optimizers.ModelParamTraversal(filter_weights)
        other_traversal = optimizers.ModelParamTraversal(filter_other)
        return optimizers.MultiOptimizer((weight_traversal, weight_opt_def),
                                         (other_traversal, other_opt_def))
    else:
        raise NotImplementedError('Optimizer {} not implemented'.format(
            hps.optimizer))