示例#1
0
def train_model():
  
  subdirname = (
    "model_{}_wd_{}_stepsize_{}_trajlen_{}_burnin_{}_{}_mh_{}_temp_{}_"
    "seed_{}".format(
      args.model_name, args.weight_decay, args.step_size, args.trajectory_len,
      args.num_burn_in_iterations, args.burn_in_step_size_factor,
      not args.no_mh, args.temperature, args.seed
    ))
  dirname = os.path.join(args.dir, subdirname)
  os.makedirs(dirname, exist_ok=True)
  tf_writer = tf.summary.create_file_writer(dirname)
  cmd_args_utils.save_cmd(dirname, tf_writer)
  num_devices = len(jax.devices())

  dtype = jnp.float64 if args.use_float64 else jnp.float32
  train_set, test_set, num_classes = data.make_ds_pmap_fullbatch(
    args.dataset_name, dtype)

  net_apply, net_init = models.get_model(args.model_name, num_classes)
  
  checkpoint_dict, status = checkpoint_utils.initialize(
      dirname, args.init_checkpoint)
  
  if status == checkpoint_utils.InitStatus.LOADED_PREEMPTED:
    print("Continuing the run from the last saved checkpoint")
    (start_iteration, params, net_state, key, step_size, _, num_ensembled,
     ensemble_predicted_probs) = (
        checkpoint_utils.parse_hmc_checkpoint_dict(checkpoint_dict))
    
  else:
    key, net_init_key = jax.random.split(jax.random.PRNGKey(args.seed), 2)
    start_iteration = 0
    num_ensembled = 0
    ensemble_predicted_probs = None
    step_size = args.step_size
    
    if status == checkpoint_utils.InitStatus.INIT_CKPT:
      print("Resuming the run from the provided init_checkpoint")
      _, params, net_state, _, _, _, _, _ = (
        checkpoint_utils.parse_hmc_checkpoint_dict(checkpoint_dict))
    elif status == checkpoint_utils.InitStatus.INIT_RANDOM:
      print("Starting from random initialization with provided seed")
      key, net_init_key = jax.random.split(jax.random.PRNGKey(args.seed), 2)
      init_data = jax.tree_map(lambda elem: elem[0][:1], train_set)
      params, net_state = net_init(net_init_key, init_data, True)
      net_state = jax.pmap(lambda _: net_state)(jnp.arange(num_devices))
    else:
      raise ValueError("Unknown initialization status: {}".format(status))

  # manually convert all params to dtype
  params = jax.tree_map(lambda p: p.astype(dtype), params)
  
  param_types = tree_utils._get_types(params)
  assert all([p_type == dtype for p_type in param_types]), (
    "Params data types {} do not match specified data type {}".format(
      param_types, dtype))
    
  trajectory_len = args.trajectory_len

  log_likelihood_fn = nn_loss.make_xent_log_likelihood(
    num_classes, args.temperature)
  log_prior_fn, log_prior_diff_fn = nn_loss.make_gaussian_log_prior(
    args.weight_decay, args.temperature)

  update, get_log_prob_and_grad, evaluate = train_utils.make_hmc_update(
    net_apply, log_likelihood_fn, log_prior_fn, log_prior_diff_fn,
    args.max_num_leapfrog_steps, args.target_accept_rate,
    args.step_size_adaptation_speed)

  log_prob, state_grad, log_likelihood, net_state = (
      get_log_prob_and_grad(train_set, params, net_state))

  assert log_prob.dtype == dtype, (
    "log_prob data type {} does not match specified data type {}".format(
        log_prob.dtype, dtype))

  grad_types = tree_utils._get_types(state_grad)
  assert all([g_type == dtype for g_type in grad_types]), (
    "Gradient data types {} do not match specified data type {}".format(
      grad_types, dtype))

  ensemble_acc = 0
  
  for iteration in range(start_iteration, args.num_iterations):
    
    # do a linear ramp-down of the step-size in the burn-in phase
    if iteration < args.num_burn_in_iterations:
      alpha = iteration / (args.num_burn_in_iterations - 1)
      initial_step_size = args.step_size
      final_step_size = args.burn_in_step_size_factor * args.step_size
      step_size = final_step_size * alpha + initial_step_size * (1 - alpha)
    in_burnin = (iteration < args.num_burn_in_iterations)
    do_mh_correction = (not args.no_mh) and (not in_burnin)

    start_time = time.time()
    (params, net_state, log_likelihood, state_grad, step_size, key,
     accept_prob, accepted) = (
        update(train_set, params, net_state, log_likelihood, state_grad,
               key, step_size, trajectory_len, do_mh_correction))
    iteration_time = time.time() - start_time

    checkpoint_name = checkpoint_utils.make_checkpoint_name(iteration)
    checkpoint_path = os.path.join(dirname, checkpoint_name)
    checkpoint_dict = checkpoint_utils.make_hmc_checkpoint_dict(
        iteration, params, net_state, key, step_size, accepted, num_ensembled,
        ensemble_predicted_probs)
    checkpoint_utils.save_checkpoint(checkpoint_path, checkpoint_dict)

    if ((not in_burnin) and accepted) or args.no_mh:
      ensemble_predicted_probs, ensemble_acc, num_ensembled = (
          train_utils.update_ensemble(
              net_apply, params, net_state, test_set, num_ensembled,
              ensemble_predicted_probs))
      
    test_log_prob, test_acc, test_ce, _ = evaluate(params, net_state, test_set)
    train_log_prob, train_acc, train_ce, prior = (
        evaluate(params, net_state, train_set))
      
    tabulate_dict = OrderedDict()
    tabulate_dict["iteration"] = iteration
    tabulate_dict["step_size"] = step_size
    tabulate_dict["train_logprob"] = log_prob
    tabulate_dict["train_acc"] = train_acc
    tabulate_dict["test_acc"] = test_acc
    tabulate_dict["test_ce"] = test_ce
    tabulate_dict["accept_prob"] = accept_prob
    tabulate_dict["accepted"] = accepted
    tabulate_dict["ensemble_acc"] = ensemble_acc
    tabulate_dict["n_ens"] = num_ensembled
    tabulate_dict["time"] = iteration_time

    with tf_writer.as_default():
      tf.summary.scalar("train/log_prob", train_log_prob, step=iteration)
      tf.summary.scalar("test/log_prob", test_log_prob, step=iteration)
      tf.summary.scalar("train/log_likelihood", train_ce, step=iteration)
      tf.summary.scalar("test/log_likelihood", test_ce, step=iteration)
      tf.summary.scalar("train/accuracy", train_acc, step=iteration)
      tf.summary.scalar("test/accuracy", test_acc, step=iteration)
      tf.summary.scalar("test/ens_accuracy", ensemble_acc, step=iteration)
      tf.summary.scalar("test/ens_accuracy", ensemble_acc, step=iteration)
      tf.summary.scalar("test/ens_accuracy", ensemble_acc, step=iteration)
      
      if num_ensembled > 0:
        test_labels = onp.asarray(test_set[1])
        ensemble_nll = metrics.nll(ensemble_predicted_probs, test_labels)
        ensemble_calibration = metrics.calibration_curve(
            ensemble_predicted_probs, test_labels)
        tf.summary.scalar(
            "test/ens_ece", ensemble_calibration["ece"], step=iteration)
        tf.summary.scalar("test/ens_nll", ensemble_nll, step=iteration)

      tf.summary.scalar("telemetry/log_prior", prior, step=iteration)
      tf.summary.scalar("telemetry/accept_prob", accept_prob, step=iteration)
      tf.summary.scalar("telemetry/accepted", accepted, step=iteration)
      tf.summary.scalar("telemetry/n_ens", num_ensembled, step=iteration)
      tf.summary.scalar("telemetry/iteration_time", iteration_time,
                        step=iteration)
      
      tf.summary.scalar("hypers/step_size", step_size, step=iteration)
      tf.summary.scalar("hypers/trajectory_len", trajectory_len,
                        step=iteration)
      tf.summary.scalar("hypers/weight_decay", args.weight_decay,
                        step=iteration)
      tf.summary.scalar("hypers/temperature", args.temperature,
                        step=iteration)

      tf.summary.scalar("debug/do_mh_correction", float(do_mh_correction),
                        step=iteration)
      tf.summary.scalar("debug/in_burnin", float(in_burnin),
                        step=iteration)
      
      
      

    table = tabulate_utils.make_table(
      tabulate_dict, iteration - start_iteration, args.tabulate_freq)
    print(table)
示例#2
0
def train_and_evaluate(config, workdir, vocab_filepath):
    """Runs a training and evaluation loop.

  Args:
    config: Model and training configuration.
    workdir: Working directory for checkpoints and TensorBoard summaries. If
      this contains a checkpoint, training will be resumed from the latest
      checkpoint.
    vocab_filepath: Absolute path to SentencePiece vocab model.

  Raises:
    ValueError: If training or eval batch sizes won't fit number of hosts and
      devices, or config is underspecified.
  """
    # Update config before config validation.
    with config.unlocked():
        # Numeric floating point type to use for model computations.
        config.dtype = jnp.float32

    train_utils.validate_config(config)

    per_host_train_batch_size = config.train_batch_size // jax.process_count()
    per_host_eval_batch_size = config.eval_batch_size // jax.process_count()

    if jax.process_index() == 0:
        train_summary_writer = tensorboard.SummaryWriter(
            os.path.join(workdir, "train"))
        eval_summary_writer = tensorboard.SummaryWriter(
            os.path.join(workdir, "eval"))
    else:
        train_summary_writer = None
        eval_summary_writer = None

    tokenizer = spm.SentencePieceProcessor()
    tokenizer.Load(vocab_filepath)

    ds_info = tfds.builder(config.dataset_name).info
    num_train_examples = ds_info.splits[tfds.Split.TRAIN].num_examples

    num_train_steps = int(num_train_examples * config.num_train_epochs //
                          config.train_batch_size)
    num_warmup_steps = int(config.warmup_proportion * num_train_steps)
    # Round up evaluation frequency to power of 10.
    eval_frequency = int(
        math.ceil(config.eval_proportion * num_train_steps / 10)) * 10

    # STSB is a regression task. COPA and ReCoRD are treated as scalar/regression
    # tasks during training.
    is_regression_task = (config.dataset_name == "glue/stsb"
                          or config.dataset_name == "super_glue/copa"
                          or config.dataset_name == "super_glue/record")
    if is_regression_task:
        num_classes = 1
    else:
        num_classes = ds_info.features["label"].num_classes

    with config.unlocked():
        config.vocab_size = tokenizer.GetPieceSize()
        config.pad_id = tokenizer.pad_id()

    config = ml_collections.FrozenConfigDict(config)
    model = models.SequenceClassificationModel(config, num_classes)
    rng = random.PRNGKey(config.seed)
    rng, init_rng = random.split(rng)
    params = _init_params(model, init_rng, config)

    learning_rate_fn = train_utils.create_learning_rate_scheduler(
        factors="constant * linear_warmup * linear_decay",
        base_learning_rate=config.learning_rate,
        warmup_steps=num_warmup_steps,
        decay_steps=num_train_steps - num_warmup_steps,
    )

    tx = optax.adamw(learning_rate_fn,
                     b1=0.9,
                     b2=0.999,
                     eps=1e-6,
                     weight_decay=0.01)
    if config.clipped_grad_norm:
        tx = optax.chain(optax.clip_by_global_norm(config.clipped_grad_norm),
                         tx)

    # jit state creation to ensure arrays are created on same device as input
    # (i.e. CPU).
    state_cpu = jax.jit(
        functools.partial(FlaxTrainState.create,
                          apply_fn=model.apply,
                          params=params,
                          tx=tx))()

    # We access model params only via state.params
    del params

    if config.num_experts > 1:
        sharded_match_fn = core_utils.match_fn(r".*expert.*")
        not_sharded_match_fn = lambda name: not sharded_match_fn(name)
    else:
        sharded_match_fn = None
        not_sharded_match_fn = lambda name: True

    state, start_step = _restore_state_from_checkpoint(workdir, state_cpu,
                                                       sharded_match_fn,
                                                       not_sharded_match_fn,
                                                       config)

    if is_regression_task:
        scoring_fn = lambda y: y[Ellipsis, 0]
    else:
        scoring_fn = lambda y: y.argmax(-1)
    compute_stats = functools.partial(_compute_stats,
                                      model=model,
                                      scoring_fn=scoring_fn)

    classification_inputs = functools.partial(
        input_pipeline.classification_inputs,
        dataset_name=config.dataset_name,
        max_seq_length=config.max_seq_length,
        tokenizer=tokenizer)
    train_ds = classification_inputs(split=tfds.Split.TRAIN,
                                     batch_size=per_host_train_batch_size,
                                     training=True)
    train_iter = iter(train_ds)

    if config.dataset_name == "glue/mnli":
        # MNLI contains two validation and test datasets.
        split_suffixes = ["_matched", "_mismatched"]
    else:
        split_suffixes = [""]

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    rngs = random.split(rng, jax.local_device_count())

    loss_and_metrics_fn = functools.partial(
        _compute_loss_and_metrics,
        model=model,
        is_experts_model=config.num_experts > 1,
        auxiliary_loss_factor=config.auxiliary_loss_factor,
        router_z_loss_factor=config.router_z_loss_factor)
    train_step = functools.partial(
        train_utils.pmap_train_step,
        loss_and_metrics_fn=loss_and_metrics_fn,
        axis_name="batch",
        sharded_match_fn=sharded_match_fn,
        gradient_accum_steps=config.gradient_accum_steps)
    p_train_step = jax.pmap(train_step, axis_name="batch")
    p_eval_step = jax.pmap(compute_stats, axis_name="batch")
    eval_metrics_fn = _create_eval_metrics_fn(config.dataset_name)

    train_stats = []
    logging.info("Starting training loop.")
    logging.info("====================")

    for step in range(start_step, num_train_steps):
        with jax.profiler.StepTraceContext("train", step_num=step):
            train_batch = next(train_iter)
            train_batch = common_utils.shard(train_batch)

            state, train_step_stats, rngs = p_train_step(state,
                                                         train_batch,
                                                         rng=rngs)

            train_stats.append(train_step_stats)

        if ((step > 0 and config.save_checkpoints_steps
             and step % config.save_checkpoints_steps == 0)
                or step == num_train_steps - 1):
            # We allow all hosts to potentially save checkpoints because some model
            # parameters are sharded across devices. Parameters replicated across
            # devices (i.e. not sharded) will only be checkpointed by host 0.
            unreplicated_train_state = jax.tree_map(
                np.array,
                core_utils.tree_unreplicate_by_name(state,
                                                    not_sharded_match_fn))
            checkpoints.save_checkpoint(workdir,
                                        unreplicated_train_state,
                                        sharded_match_fn,
                                        step,
                                        keep=config.checkpoints_to_keep)
            del unreplicated_train_state  # Only used for checkpointing.

        # Periodic metric handling.
        if step % eval_frequency != 0 and step < num_train_steps - 1:
            continue

        logging.info("Gathering training metrics at step: %d", step)
        train_metrics = train_utils.collect_metrics(train_stats)
        train_summary = train_utils.compute_classification_metrics(
            train_metrics, is_regression_task)
        train_summary["learning_rate"] = learning_rate_fn(step)

        if jax.process_index() == 0:
            assert train_summary_writer
            for key, val in train_summary.items():
                train_summary_writer.scalar(key, val, step)
            train_summary_writer.flush()
        # Reset metric accumulation for next training evaluation cycle.
        train_stats = []

        logging.info("Gathering validation metrics at step: %d", step)

        for split_suffix in split_suffixes:
            eval_ds = classification_inputs(
                split=tfds.Split.VALIDATION + split_suffix,
                batch_size=per_host_eval_batch_size,
                training=False)

            eval_stats = []
            for _, eval_batch in zip(range(config.max_num_eval_steps),
                                     eval_ds):
                eval_stats.append(
                    _evaluate(p_eval_step, state.params, eval_batch))
            eval_metrics = {}
            for k in eval_stats[
                    0]:  # All batches of output stats are the same size
                eval_metrics[k] = np.concatenate(
                    [stat[k] for stat in eval_stats], axis=0)
            eval_summary = eval_metrics_fn(eval_metrics)

            if jax.process_index() == 0:
                assert eval_summary_writer
                for key, val in eval_summary.items():
                    eval_summary_writer.scalar(f"{key}{split_suffix}", val,
                                               step)
                eval_summary_writer.flush()
示例#3
0
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
  """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
  tf.io.gfile.makedirs(workdir)

  # Number of local devices for this host.
  n_devices = jax.local_device_count()

  if jax.host_id() == 0:
    tf.io.gfile.makedirs(workdir)
    train_summary_writer = tensorboard.SummaryWriter(
        os.path.join(workdir, "train"))
    eval_summary_writer = tensorboard.SummaryWriter(
        os.path.join(workdir, "eval"))

  if config.batch_size % n_devices:
    raise ValueError("Batch size must be divisible by the number of devices")

  vocab_path = config.vocab_path
  if vocab_path is None:
    vocab_path = os.path.join(workdir, "sentencepiece_model")
  tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

  # Load Dataset
  # ---------------------------------------------------------------------------
  logging.info("Initializing dataset.")
  train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
      n_devices=n_devices,
      dataset_name=config.dataset_name,
      eval_dataset_name=config.eval_dataset_name,
      shard_idx=jax.host_id(),
      shard_count=jax.host_count(),
      vocab_path=vocab_path,
      target_vocab_size=config.vocab_size,
      batch_size=config.batch_size,
      max_corpus_chars=config.max_corpus_chars,
      max_length=config.max_target_length,
      max_eval_length=config.max_eval_target_length)
  train_iter = iter(train_ds)
  vocab_size = int(encoder.vocab_size())
  eos_id = decode.EOS_ID  # Default Sentencepiece EOS token.
  def decode_tokens(toks):
    valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
    return encoder.detokenize(valid_toks).numpy().decode("utf-8")

  if config.num_predict_steps > 0:
    predict_ds = predict_ds.take(config.num_predict_steps)

  logging.info("Initializing model, optimizer, and step functions.")

  # Build Model and Optimizer
  # ---------------------------------------------------------------------------
  train_config = models.TransformerConfig(
      vocab_size=vocab_size,
      output_vocab_size=vocab_size,
      share_embeddings=config.share_embeddings,
      logits_via_embedding=config.logits_via_embedding,
      dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32,
      emb_dim=config.emb_dim,
      num_heads=config.num_heads,
      num_layers=config.num_layers,
      qkv_dim=config.qkv_dim,
      mlp_dim=config.mlp_dim,
      max_len=max(config.max_target_length, config.max_eval_target_length),
      dropout_rate=config.dropout_rate,
      attention_dropout_rate=config.attention_dropout_rate,
      deterministic=False,
      decode=False,
      kernel_init=nn.initializers.xavier_uniform(),
      bias_init=nn.initializers.normal(stddev=1e-6))
  eval_config = train_config.replace(deterministic=True)
  predict_config = train_config.replace(deterministic=True, decode=True)

  start_step = 0
  rng = random.PRNGKey(config.seed)
  rng, init_rng = random.split(rng)
  input_shape = (config.batch_size, config.max_target_length)
  target_shape = (config.batch_size, config.max_target_length)

  m = models.Transformer(eval_config)
  initial_variables = jax.jit(m.init)(init_rng, jnp.ones(input_shape, jnp.float32),
                  jnp.ones(target_shape, jnp.float32))

  # apply an optimizer to this tree
  optimizer_def = optim.Adam(
      config.learning_rate,
      beta1=0.9,
      beta2=0.98,
      eps=1e-9,
      weight_decay=config.weight_decay)
  optimizer = optimizer_def.create(initial_variables["params"])

  # We access model params only from optimizer below via optimizer.target.
  del initial_variables

  if config.restore_checkpoints:
    # Restore unreplicated optimizer + model state from last checkpoint.
    optimizer = checkpoints.restore_checkpoint(workdir, optimizer)
    # Grab last step.
    start_step = int(optimizer.state.step)

  # Replicate optimizer.
  optimizer = jax_utils.replicate(optimizer)

  learning_rate_fn = create_learning_rate_scheduler(
      base_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps)

  # compile multidevice versions of train/eval/predict step and cache init fn.
  p_train_step = jax.pmap(
      functools.partial(
          train_step,
          config=train_config,
          learning_rate_fn=learning_rate_fn,
          label_smoothing=config.label_smoothing),
      axis_name="batch",
      donate_argnums=(0,))  # pytype: disable=wrong-arg-types
  p_eval_step = jax.pmap(
      functools.partial(
          eval_step, config=eval_config,
          label_smoothing=config.label_smoothing),
      axis_name="batch")
  p_init_cache = jax.pmap(
      functools.partial(
          initialize_cache,
          max_decode_len=config.max_predict_length,
          config=predict_config),
      axis_name="batch")
  p_pred_step = jax.pmap(
      functools.partial(
          predict_step, config=predict_config, beam_size=config.beam_size),
      axis_name="batch",
      static_broadcasted_argnums=(3, 4))  # eos token, max_length are constant

  # Main Train Loop
  # ---------------------------------------------------------------------------

  # We init the first set of dropout PRNG keys, but update it afterwards inside
  # the main pmap"d training update for performance.
  dropout_rngs = random.split(rng, n_devices)

  logging.info("Starting training loop.")
  metrics_all = []
  t_loop_start = time.time()
  for step, batch in zip(range(start_step, config.num_train_steps), train_iter):
    # Shard data to devices and do a training step.
    batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access
    optimizer, metrics, dropout_rngs = p_train_step(
        optimizer, batch, dropout_rng=dropout_rngs)
    metrics_all.append(metrics)

    # Quick indication that training is happening.
    logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step)

    # Save a checkpoint on one host after every checkpoint_freq steps.
    if (config.save_checkpoints and step % config.checkpoint_freq == 0 and
        step > 0 and jax.host_id() == 0):
      checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(optimizer),
                                  step)

    # Periodic metric handling.
    if step % config.eval_frequency != 0 and step > 0:
      continue

    # Training Metrics
    logging.info("Gathering training metrics.")
    metrics_all = common_utils.get_metrics(metrics_all)
    lr = metrics_all.pop("learning_rate").mean()
    metrics_sums = jax.tree_map(jnp.sum, metrics_all)
    denominator = metrics_sums.pop("denominator")
    summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
    summary["learning_rate"] = lr
    steps_per_eval = config.eval_frequency if step != 0 else 1
    steps_per_sec = steps_per_eval / (time.time() - t_loop_start)
    t_loop_start = time.time()
    if jax.host_id() == 0:
      train_summary_writer.scalar("steps per second", steps_per_sec, step)
      for key, val in summary.items():
        train_summary_writer.scalar(key, val, step)
      train_summary_writer.flush()
    metrics_all = []
    logging.info("train in step: %d, loss: %.4f", step, summary["loss"])

    # Eval Metrics
    logging.info("Gathering evaluation metrics.")
    t_eval_start = time.time()
    eval_metrics = []
    eval_iter = iter(eval_ds)
    for _, eval_batch in zip(range(config.num_eval_steps), eval_iter):
      eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
      eval_batch = common_utils.shard(eval_batch)
      metrics = p_eval_step(optimizer.target, eval_batch)
      eval_metrics.append(metrics)
    eval_metrics = common_utils.get_metrics(eval_metrics)
    eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
    eval_denominator = eval_metrics_sums.pop("denominator")
    eval_summary = jax.tree_map(
        lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
        eval_metrics_sums)
    if jax.host_id() == 0:
      for key, val in eval_summary.items():
        eval_summary_writer.scalar(key, val, step)
      eval_summary_writer.flush()
    logging.info("eval in step: %d, loss: %.4f", step, eval_summary["loss"])
    logging.info("eval time: %.4f s step %d", time.time() - t_eval_start, step)

    # Translation and BLEU Score.
    logging.info("Translating evaluation dataset.")
    t_inference_start = time.time()
    sources, references, predictions = [], [], []
    for pred_batch in predict_ds:
      pred_batch = jax.tree_map(lambda x: x._numpy(), pred_batch)  # pylint: disable=protected-access
      # Handle final odd-sized batch by padding instead of dropping it.
      cur_pred_batch_size = pred_batch["inputs"].shape[0]
      if cur_pred_batch_size % n_devices:
        padded_size = int(
            np.ceil(cur_pred_batch_size / n_devices) * n_devices)
        pred_batch = jax.tree_map(
            lambda x: pad_examples(x, padded_size), pred_batch)  # pylint: disable=cell-var-from-loop
      pred_batch = common_utils.shard(pred_batch)
      cache = p_init_cache(pred_batch["inputs"])
      predicted = p_pred_step(pred_batch["inputs"], optimizer.target, cache,
                              eos_id, config.max_predict_length)
      predicted = tohost(predicted)
      inputs = tohost(pred_batch["inputs"])
      targets = tohost(pred_batch["targets"])
      # Iterate through non-padding examples of batch.
      for i, s in enumerate(predicted[:cur_pred_batch_size]):
        sources.append(decode_tokens(inputs[i]))
        references.append(decode_tokens(targets[i]))
        predictions.append(decode_tokens(s))
    logging.info("Translation: %d predictions %d references %d sources.",
                 len(predictions), len(references), len(sources))
    logging.info("Translation time: %.4f s step %d.",
                 time.time() - t_inference_start, step)

    # Calculate BLEU score for translated eval corpus against reference.
    bleu_matches = bleu.bleu_partial(references, predictions)
    all_bleu_matches = per_host_sum_pmap(bleu_matches)
    bleu_score = bleu.complete_bleu(*all_bleu_matches)
    # Save translation samples for tensorboard.
    exemplars = ""
    for n in np.random.choice(np.arange(len(predictions)), 8):
      exemplars += f"{sources[n]}\n\n{references[n]}\n\n{predictions[n]}\n\n"
    if jax.host_id() == 0:
      eval_summary_writer.scalar("bleu", bleu_score, step)
      eval_summary_writer.text("samples", exemplars, step)
      eval_summary_writer.flush()
    logging.info("Translation BLEU Score %.4f", bleu_score)
示例#4
0
def train_and_evaluate(config: ml_collections.ConfigDict,
                       workdir: str) -> TrainState:
    """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.
    workdir: Directory where the tensorboard summaries are written to.
  """

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(workdir)
        summary_writer.hparams(dict(config))

    rng = random.PRNGKey(0)

    image_size = 224

    if config.batch_size % jax.device_count() > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')
    local_batch_size = config.batch_size // jax.host_count()

    platform = jax.local_devices()[0].platform

    if config.half_precision:
        if platform == 'tpu':
            input_dtype = tf.bfloat16
        else:
            input_dtype = tf.float16
    else:
        input_dtype = tf.float32

    dataset_builder = tfds.builder(config.dataset)
    train_iter = create_input_iter(dataset_builder,
                                   local_batch_size,
                                   image_size,
                                   input_dtype,
                                   train=True,
                                   cache=config.cache)
    eval_iter = create_input_iter(dataset_builder,
                                  local_batch_size,
                                  image_size,
                                  input_dtype,
                                  train=False,
                                  cache=config.cache)

    steps_per_epoch = (dataset_builder.info.splits['train'].num_examples //
                       config.batch_size)

    if config.num_train_steps == -1:
        num_steps = int(steps_per_epoch * config.num_epochs)
    else:
        num_steps = config.num_train_steps

    if config.steps_per_eval == -1:
        num_validation_examples = dataset_builder.info.splits[
            'validation'].num_examples
        steps_per_eval = num_validation_examples // config.batch_size
    else:
        steps_per_eval = config.steps_per_eval

    steps_per_checkpoint = steps_per_epoch * 10

    base_learning_rate = config.learning_rate * config.batch_size / 256.

    model_cls = getattr(models, config.model)
    model = create_model(model_cls=model_cls,
                         half_precision=config.half_precision)

    state = create_train_state(rng, config, model, image_size)
    state = restore_checkpoint(state, workdir)
    # step_offset > 0 if restarting from checkpoint
    step_offset = int(state.step)
    state = jax_utils.replicate(state)

    learning_rate_fn = create_learning_rate_fn(config, steps_per_epoch)

    p_train_step = jax.pmap(functools.partial(
        train_step, model.apply, learning_rate_fn=learning_rate_fn),
                            axis_name='batch')
    p_eval_step = jax.pmap(functools.partial(eval_step, model.apply),
                           axis_name='batch')

    epoch_metrics = []
    hooks = []
    if jax.host_id() == 0:
        hooks += [
            periodic_actions.Profile(num_profile_steps=5, logdir=workdir)
        ]
    t_loop_start = time.time()
    logging.info('Initial compilation, this might take some minutes...')
    for step, batch in zip(range(step_offset, num_steps), train_iter):
        state, metrics = p_train_step(state, batch)
        for h in hooks:
            h(step)
        if step == step_offset:
            logging.info('Initial compilation completed.')
        epoch_metrics.append(metrics)
        if (step + 1) % steps_per_epoch == 0:
            epoch = step // steps_per_epoch
            epoch_metrics = common_utils.get_metrics(epoch_metrics)
            summary = jax.tree_map(lambda x: x.mean(), epoch_metrics)
            logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                         summary['loss'], summary['accuracy'] * 100)
            steps_per_sec = steps_per_epoch / (time.time() - t_loop_start)
            t_loop_start = time.time()
            if jax.host_id() == 0:
                for key, vals in epoch_metrics.items():
                    tag = 'train_%s' % key
                    for i, val in enumerate(vals):
                        summary_writer.scalar(tag, val,
                                              step - len(vals) + i + 1)
                summary_writer.scalar('steps per second', steps_per_sec, step)

            epoch_metrics = []
            eval_metrics = []

            # sync batch statistics across replicas
            state = sync_batch_stats(state)
            for _ in range(steps_per_eval):
                eval_batch = next(eval_iter)
                metrics = p_eval_step(state, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            summary = jax.tree_map(lambda x: x.mean(), eval_metrics)
            logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                         summary['loss'], summary['accuracy'] * 100)
            if jax.host_id() == 0:
                for key, val in eval_metrics.items():
                    tag = 'eval_%s' % key
                    summary_writer.scalar(tag, val.mean(), step)
                summary_writer.flush()
        if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps:
            state = sync_batch_stats(state)
            save_checkpoint(state, workdir)

    # Wait until computations are done before exiting
    jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()

    return state
示例#5
0
def main(unused_argv):
  rng = random.PRNGKey(20200823)

  if FLAGS.config is not None:
    utils.update_flags(FLAGS)
  if FLAGS.train_dir is None:
    raise ValueError("train_dir must be set. None set now.")
  if FLAGS.data_dir is None:
    raise ValueError("data_dir must be set. None set now.")

  dataset = datasets.get_dataset("test", FLAGS)
  rng, key = random.split(rng)
  model, init_variables = models.get_model(key, dataset.peek(), FLAGS)
  optimizer = flax.optim.Adam(FLAGS.lr_init).create(init_variables)
  state = utils.TrainState(optimizer=optimizer)
  del optimizer, init_variables

  # Rendering is forced to be deterministic even if training was randomized, as
  # this eliminates "speckle" artifacts.
  def render_fn(variables, key_0, key_1, rays):
    return jax.lax.all_gather(
        model.apply(variables, key_0, key_1, rays, False), axis_name="batch")

  # pmap over only the data input.
  render_pfn = jax.pmap(
      render_fn,
      in_axes=(None, None, None, 0),
      donate_argnums=3,
      axis_name="batch",
  )

  # Compiling to the CPU because it's faster and more accurate.
  ssim_fn = jax.jit(
      functools.partial(utils.compute_ssim, max_val=1.), backend="cpu")

  last_step = 0
  out_dir = path.join(FLAGS.train_dir,
                      "path_renders" if FLAGS.render_path else "test_preds")
  if not FLAGS.eval_once:
    summary_writer = tensorboard.SummaryWriter(
        path.join(FLAGS.train_dir, "eval"))
  while True:
    state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
    step = int(state.optimizer.state.step)
    if step <= last_step:
      continue
    if FLAGS.save_output and (not utils.isdir(out_dir)):
      utils.makedirs(out_dir)
    psnrs = []
    ssims = []
    if not FLAGS.eval_once:
      showcase_index = np.random.randint(0, dataset.size)
    for idx in range(dataset.size):
      print(f"Evaluating {idx+1}/{dataset.size}")
      batch = next(dataset)
      pred_color, pred_disp, pred_acc = utils.render_image(
          functools.partial(render_pfn, state.optimizer.target),
          batch["rays"],
          rng,
          FLAGS.dataset == "llff",
          chunk=FLAGS.chunk)
      if jax.host_id() != 0:  # Only record via host 0.
        continue
      if not FLAGS.eval_once and idx == showcase_index:
        showcase_color = pred_color
        showcase_disp = pred_disp
        showcase_acc = pred_acc
        if not FLAGS.render_path:
          showcase_gt = batch["pixels"]
      if not FLAGS.render_path:
        psnr = utils.compute_psnr(((pred_color - batch["pixels"])**2).mean())
        ssim = ssim_fn(pred_color, batch["pixels"])
        print(f"PSNR = {psnr:.4f}, SSIM = {ssim:.4f}")
        psnrs.append(float(psnr))
        ssims.append(float(ssim))
      if FLAGS.save_output:
        utils.save_img(pred_color, path.join(out_dir, "{:03d}.png".format(idx)))
        utils.save_img(pred_disp[Ellipsis, 0],
                       path.join(out_dir, "disp_{:03d}.png".format(idx)))
    if (not FLAGS.eval_once) and (jax.host_id() == 0):
      summary_writer.image("pred_color", showcase_color, step)
      summary_writer.image("pred_disp", showcase_disp, step)
      summary_writer.image("pred_acc", showcase_acc, step)
      if not FLAGS.render_path:
        summary_writer.scalar("psnr", np.mean(np.array(psnrs)), step)
        summary_writer.scalar("ssim", np.mean(np.array(ssims)), step)
        summary_writer.image("target", showcase_gt, step)
    if FLAGS.save_output and (not FLAGS.render_path) and (jax.host_id() == 0):
      with utils.open_file(path.join(out_dir, f"psnrs_{step}.txt"), "w") as f:
        f.write(" ".join([str(v) for v in psnrs]))
      with utils.open_file(path.join(out_dir, f"ssims_{step}.txt"), "w") as f:
        f.write(" ".join([str(v) for v in ssims]))
      with utils.open_file(path.join(out_dir, "psnr.txt"), "w") as f:
        f.write("{}".format(np.mean(np.array(psnrs))))
      with utils.open_file(path.join(out_dir, "ssim.txt"), "w") as f:
        f.write("{}".format(np.mean(np.array(ssims))))
    if FLAGS.eval_once:
      break
    if int(step) >= FLAGS.max_steps:
      break
    last_step = step
示例#6
0
    def compute_preconditioners_from_statistics(self, states, hps, step):
        """Compute preconditioners for statistics."""
        statistics = []
        num_statistics_per_state = []
        original_shapes = []
        exponents = []
        max_size = 0
        prev_preconditioners = []
        for state in states:
            num_statistics = len(state.statistics)
            num_statistics_per_state.append(num_statistics)
            original_shapes_for_state = []
            if num_statistics > 0:
                for statistic in state.statistics:
                    exponents.append(2 *
                                     num_statistics if hps.exponent_override ==
                                     0 else hps.exponent_override)
                    original_shapes_for_state.append(statistic.shape)
                    max_size = max(max_size, statistic.shape[0])
                statistics.extend(state.statistics)
                prev_preconditioners.extend(state.preconditioners)
                original_shapes.extend(original_shapes_for_state)
        num_statistics = len(statistics)

        def pack(mat, max_size):
            """Pack a matrix to a max_size for inverse on TPUs with static shapes.

      Args:
        mat: Matrix for computing inverse pth root.
        max_size: Matrix size to pack to.

      Returns:
        Given M returns [[M, 0], [0, I]]
      """
            size = mat.shape[0]
            assert size <= max_size
            if size == max_size:
                return mat
            pad_size = max_size - size
            zs1 = jnp.zeros([size, pad_size], dtype=mat.dtype)
            zs2 = jnp.zeros([pad_size, size], dtype=mat.dtype)
            eye = jnp.eye(pad_size, dtype=mat.dtype)
            mat = jnp.concatenate([mat, zs1], 1)
            mat = jnp.concatenate([mat, jnp.concatenate([zs2, eye], 1)], 0)
            return mat

        if not hps.batch_axis_name:
            num_devices = jax.local_device_count()
        else:
            num_devices = lax.psum(1, hps.batch_axis_name)

        # Pad statistics and exponents to next multiple of num_devices.
        packed_statistics = [pack(stat, max_size) for stat in statistics]
        to_pad = -num_statistics % num_devices
        packed_statistics.extend([
            jnp.eye(max_size, dtype=packed_statistics[0].dtype)
            for _ in range(to_pad)
        ])
        exponents.extend([1 for _ in range(to_pad)])

        # Batch statistics and exponents so that so that leading axis is
        # num_devices.
        def _batch(statistics, exponents, num_devices):
            assert len(statistics) == len(exponents)
            n = len(statistics)
            b = int(n / num_devices)
            batched_statistics = [
                jnp.stack(statistics[idx:idx + b]) for idx in range(0, n, b)
            ]
            batched_exponents = [
                jnp.stack(exponents[idx:idx + b]) for idx in range(0, n, b)
            ]
            return jnp.stack(batched_statistics), jnp.stack(batched_exponents)

        # Unbatch values across leading axis and return a list of elements.
        def _unbatch(batched_values):
            b1, b2 = batched_values.shape[0], batched_values.shape[1]
            results = []
            for v_array in jnp.split(batched_values, b1, 0):
                for v in jnp.split(jnp.squeeze(v_array), b2, 0):
                    results.append(jnp.squeeze(v))
            return results

        all_statistics, all_exponents = _batch(packed_statistics, exponents,
                                               num_devices)

        def _matrix_inverse_pth_root(xs, ps):
            mi_pth_root = lambda x, y: matrix_inverse_pth_root(  # pylint: disable=g-long-lambda
                x,
                y,
                ridge_epsilon=hps.matrix_eps)
            preconditioners, errors = jax.vmap(mi_pth_root)(xs, ps)
            return preconditioners, errors

        if not hps.batch_axis_name:
            preconditioners, errors = jax.pmap(_matrix_inverse_pth_root)(
                all_statistics, all_exponents)
            preconditioners_flat = _unbatch(preconditioners)
            errors_flat = _unbatch(errors)
        else:

            def _internal_inverse_pth_root_all():
                preconditioners = jnp.array(all_statistics)
                current_replica = lax.axis_index(hps.batch_axis_name)
                preconditioners, errors = _matrix_inverse_pth_root(
                    all_statistics[current_replica],
                    all_exponents[current_replica])
                preconditioners = jax.lax.all_gather(preconditioners,
                                                     hps.batch_axis_name)
                errors = jax.lax.all_gather(errors, hps.batch_axis_name)
                preconditioners_flat = _unbatch(preconditioners)
                errors_flat = _unbatch(errors)
                return preconditioners_flat, errors_flat

            if hps.preconditioning_compute_steps == 1:
                preconditioners_flat, errors_flat = _internal_inverse_pth_root_all(
                )
            else:
                # Passing statistics instead of preconditioners as they are similarly
                # shaped tensors, as error we are passing is the threshold these will
                # be ignored.
                preconditioners_init = packed_statistics
                errors_init = ([_INVERSE_PTH_ROOT_FAILURE_THRESHOLD] *
                               len(packed_statistics))
                init_state = [preconditioners_init, errors_init]
                perform_step = step % hps.preconditioning_compute_steps == 0
                preconditioners_flat, errors_flat = self.fast_cond(
                    perform_step, _internal_inverse_pth_root_all, init_state)

        def _skip(error):
            return jnp.logical_or(
                jnp.isnan(error),
                error >= _INVERSE_PTH_ROOT_FAILURE_THRESHOLD).astype(
                    error.dtype)

        def _select_preconditioner(error, new_p, old_p):
            return lax.cond(_skip(error),
                            lambda _: old_p,
                            lambda _: new_p,
                            operand=None)

        new_preconditioners_flat = []
        for p, shape, prev_p, error in zip(preconditioners_flat,
                                           original_shapes,
                                           prev_preconditioners, errors_flat):
            new_preconditioners_flat.append(
                _select_preconditioner(error, p[:shape[0], :shape[1]], prev_p))

        assert len(states) == len(num_statistics_per_state)
        assert len(new_preconditioners_flat) == num_statistics

        # Add back empty preconditioners so we that we can set the optimizer state.
        preconditioners_for_states = []
        idx = 0
        for num_statistics, state in zip(num_statistics_per_state, states):
            if num_statistics == 0:
                preconditioners_for_states.append([])
            else:
                preconditioners_for_state = new_preconditioners_flat[
                    idx:idx + num_statistics]
                assert len(state.statistics) == len(preconditioners_for_state)
                preconditioners_for_states.append(preconditioners_for_state)
                idx += num_statistics
        new_states = []
        for state, new_preconditioners in zip(states,
                                              preconditioners_for_states):
            new_states.append(
                _ShampooDefaultParamState(state.diagonal_statistics,
                                          state.statistics,
                                          new_preconditioners,
                                          state.diagonal_momentum,
                                          state.momentum))

        return new_states
示例#7
0
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(
            json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses(
        )

    if (os.path.exists(training_args.output_dir)
            and os.listdir(training_args.output_dir) and training_args.do_train
            and not training_args.overwrite_output_dir):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty."
            "Use --overwrite_output_dir to overcome.")

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    # Setup logging, we only want one process per machine to log things on the screen.
    logger.setLevel(logging.INFO if jax.process_index() ==
                    0 else logging.ERROR)
    if jax.process_index() == 0:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    # Set the verbosity to info of the Transformers logger (on main process only):
    logger.info(f"Training/evaluation parameters {training_args}")

    # Set seed before initializing model.
    set_seed(training_args.seed)

    # Handle the repository creation
    if training_args.push_to_hub:
        if training_args.hub_model_id is None:
            repo_name = get_full_repo_name(Path(
                training_args.output_dir).absolute().name,
                                           token=training_args.hub_token)
        else:
            repo_name = training_args.hub_model_id
        repo = Repository(training_args.output_dir, clone_from=repo_name)

    #  Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub).
    #
    # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
    # 'text' is found. You can easily tweak this behavior (see below).
    #
    # In distributed training, the load_dataset function guarantees that only one local process can concurrently
    # download the dataset.
    if data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        dataset = load_dataset(
            data_args.dataset_name,
            data_args.dataset_config_name,
            cache_dir=model_args.cache_dir,
            keep_in_memory=False,
            use_auth_token=True if model_args.use_auth_token else None,
        )

        if "validation" not in dataset.keys():
            dataset["validation"] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[:{data_args.validation_split_percentage}%]",
                cache_dir=model_args.cache_dir,
                use_auth_token=True if model_args.use_auth_token else None,
            )
            dataset["train"] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[{data_args.validation_split_percentage}%:]",
                cache_dir=model_args.cache_dir,
                use_auth_token=True if model_args.use_auth_token else None,
            )
    else:
        data_files = {}
        dataset_args = {}
        if data_args.train_file is not None:
            data_files["train"] = data_args.train_file
        if data_args.validation_file is not None:
            data_files["validation"] = data_args.validation_file
        extension = data_args.train_file.split(".")[-1]
        if extension == "txt":
            extension = "text"
            dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
        dataset = load_dataset(
            extension,
            data_files=data_files,
            cache_dir=model_args.cache_dir,
            **dataset_args,
            use_auth_token=True if model_args.use_auth_token else None,
        )

        if "validation" not in dataset.keys():
            dataset["validation"] = load_dataset(
                extension,
                data_files=data_files,
                split=f"train[:{data_args.validation_split_percentage}%]",
                cache_dir=model_args.cache_dir,
                **dataset_args,
                use_auth_token=True if model_args.use_auth_token else None,
            )
            dataset["train"] = load_dataset(
                extension,
                data_files=data_files,
                split=f"train[{data_args.validation_split_percentage}%:]",
                cache_dir=model_args.cache_dir,
                **dataset_args,
                use_auth_token=True if model_args.use_auth_token else None,
            )
    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Load pretrained model and tokenizer

    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.
    if model_args.config_name:
        config = AutoConfig.from_pretrained(
            model_args.config_name,
            cache_dir=model_args.cache_dir,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    elif model_args.model_name_or_path:
        config = AutoConfig.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    else:
        config = CONFIG_MAPPING[model_args.model_type]()
        logger.warning(
            "You are instantiating a new config instance from scratch.")

    if model_args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.tokenizer_name,
            cache_dir=model_args.cache_dir,
            use_fast=model_args.use_fast_tokenizer,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    elif model_args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
            use_fast=model_args.use_fast_tokenizer,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported by this script."
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
        )

    if model_args.model_name_or_path:
        model = FlaxAutoModelForCausalLM.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            seed=training_args.seed,
            dtype=getattr(jnp, model_args.dtype),
            use_auth_token=True if model_args.use_auth_token else None,
        )
    else:
        model = FlaxAutoModelForCausalLM.from_config(
            config,
            seed=training_args.seed,
            dtype=getattr(jnp, model_args.dtype),
            use_auth_token=True if model_args.use_auth_token else None,
        )

    # Preprocessing the datasets.
    # First we tokenize all the texts.
    if training_args.do_train:
        column_names = dataset["train"].column_names
    else:
        column_names = dataset["validation"].column_names
    text_column_name = "text" if "text" in column_names else column_names[0]

    # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
    tok_logger = transformers.utils.logging.get_logger(
        "transformers.tokenization_utils_base")

    def tokenize_function(examples):
        with CaptureLogger(tok_logger) as cl:
            output = tokenizer(examples[text_column_name])
        # clm input could be much much longer than block_size
        if "Token indices sequence length is longer than the" in cl.out:
            tok_logger.warning(
                "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model."
            )
        return output

    tokenized_datasets = dataset.map(
        tokenize_function,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        remove_columns=column_names,
        load_from_cache_file=not data_args.overwrite_cache,
    )

    if data_args.block_size is None:
        block_size = tokenizer.model_max_length
        if block_size > config.max_position_embeddings:
            logger.warning(
                f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
                "Picking 1024 instead. You can change that default value by passing --block_size xxx."
            )
            block_size = 1024
    else:
        if data_args.block_size > tokenizer.model_max_length:
            logger.warning(
                f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
                f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
            )
        block_size = min(data_args.block_size, tokenizer.model_max_length)

    # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
    def group_texts(examples):
        # Concatenate all texts.
        concatenated_examples = {
            k: list(chain(*examples[k]))
            for k in examples.keys()
        }
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
        if total_length >= block_size:
            total_length = (total_length // block_size) * block_size
        # Split by chunks of max_len.
        result = {
            k:
            [t[i:i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
        result["labels"] = result["input_ids"].copy()
        return result

    # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
    # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
    # to preprocess.
    #
    # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
    # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map

    lm_datasets = tokenized_datasets.map(
        group_texts,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        load_from_cache_file=not data_args.overwrite_cache,
    )

    if training_args.do_train:
        if "train" not in tokenized_datasets:
            raise ValueError("--do_train requires a train dataset")
        train_dataset = lm_datasets["train"]
        if data_args.max_train_samples is not None:
            max_train_samples = min(len(train_dataset),
                                    data_args.max_train_samples)
            train_dataset = train_dataset.select(range(max_train_samples))

    if training_args.do_eval:
        if "validation" not in tokenized_datasets:
            raise ValueError("--do_eval requires a validation dataset")
        eval_dataset = lm_datasets["validation"]
        if data_args.max_eval_samples is not None:
            max_eval_samples = min(len(eval_dataset),
                                   data_args.max_eval_samples)
            eval_dataset = eval_dataset.select(range(max_eval_samples))

    # Enable tensorboard only on the master node
    has_tensorboard = is_tensorboard_available()
    if has_tensorboard and jax.process_index() == 0:
        try:
            from flax.metrics.tensorboard import SummaryWriter

            summary_writer = SummaryWriter(
                log_dir=Path(training_args.output_dir))
        except ImportError as ie:
            has_tensorboard = False
            logger.warning(
                f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
            )
    else:
        logger.warning(
            "Unable to display metrics through TensorBoard because the package is not installed: "
            "Please run pip install tensorboard to enable.")

    # Initialize our training
    rng = jax.random.PRNGKey(training_args.seed)
    rng, dropout_rng = jax.random.split(rng)

    # Store some constant
    num_epochs = int(training_args.num_train_epochs)
    train_batch_size = int(
        training_args.per_device_train_batch_size) * jax.device_count()
    eval_batch_size = int(
        training_args.per_device_eval_batch_size) * jax.device_count()
    steps_per_epoch = len(train_dataset) // train_batch_size
    total_train_steps = steps_per_epoch * num_epochs

    # Create learning rate schedule
    linear_decay_lr_schedule_fn = create_learning_rate_fn(
        len(train_dataset),
        train_batch_size,
        training_args.num_train_epochs,
        training_args.warmup_steps,
        training_args.learning_rate,
    )

    # We use Optax's "masking" functionality to not apply weight decay
    # to bias and LayerNorm scale parameters. decay_mask_fn returns a
    # mask boolean with the same structure as the parameters.
    # The mask is True for parameters that should be decayed.
    # Note that this mask is specifically adapted for FlaxGPT2.
    # For other models, one should correct the layer norm parameter naming
    # accordingly.
    def decay_mask_fn(params):
        flat_params = traverse_util.flatten_dict(params)
        flat_mask = {
            path: (path[-1] != "bias" and path[-2:] not in [("ln_1", "scale"),
                                                            ("ln_2", "scale"),
                                                            ("ln_f", "scale")])
            for path in flat_params
        }
        return traverse_util.unflatten_dict(flat_mask)

    # create adam optimizer
    if training_args.adafactor:
        # We use the default parameters here to initialize adafactor,
        # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
        optimizer = optax.adafactor(
            learning_rate=linear_decay_lr_schedule_fn, )
    else:
        optimizer = optax.adamw(
            learning_rate=linear_decay_lr_schedule_fn,
            b1=training_args.adam_beta1,
            b2=training_args.adam_beta2,
            eps=training_args.adam_epsilon,
            weight_decay=training_args.weight_decay,
            mask=decay_mask_fn,
        )

    # Setup train state
    state = TrainState.create(apply_fn=model.__call__,
                              params=model.params,
                              tx=optimizer,
                              dropout_rng=dropout_rng)

    def loss_fn(logits, labels):
        shift_logits = logits[..., :-1, :]
        shift_labels = labels[..., 1:]
        loss = optax.softmax_cross_entropy(
            shift_logits, onehot(shift_labels, shift_logits.shape[-1]))
        return loss.mean()

    # Define gradient update step fn
    def train_step(state, batch):
        dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)

        def compute_loss(params):
            labels = batch.pop("labels")
            logits = state.apply_fn(**batch,
                                    params=params,
                                    dropout_rng=dropout_rng,
                                    train=True)[0]
            loss = loss_fn(logits, labels)
            return loss

        grad_fn = jax.value_and_grad(compute_loss)
        loss, grad = grad_fn(state.params)
        grad = jax.lax.pmean(grad, "batch")

        new_state = state.apply_gradients(grads=grad,
                                          dropout_rng=new_dropout_rng)

        metrics = {
            "loss": loss,
            "learning_rate": linear_decay_lr_schedule_fn(state.step)
        }
        metrics = jax.lax.pmean(metrics, axis_name="batch")

        return new_state, metrics

    # Define eval fn
    def eval_step(params, batch):
        labels = batch.pop("labels")
        logits = model(**batch, params=params, train=False)[0]
        loss = loss_fn(logits, labels)

        # summarize metrics
        metrics = {"loss": loss}
        metrics = jax.lax.pmean(metrics, axis_name="batch")
        return metrics

    # Create parallel version of the train and eval step
    p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, ))
    p_eval_step = jax.pmap(eval_step, "batch")

    # Replicate the train state on each device
    state = state.replicate()

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num Epochs = {num_epochs}")
    logger.info(
        f"  Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
    )
    logger.info(
        f"  Total train batch size (w. parallel & distributed) = {train_batch_size}"
    )
    logger.info(f"  Total optimization steps = {total_train_steps}")

    train_time = 0
    train_metrics = []
    epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0)
    for epoch in epochs:
        # ======================== Training ================================
        train_start = time.time()

        # Create sampling rng
        rng, input_rng = jax.random.split(rng)

        # Generate an epoch by shuffling sampling indices from the train dataset
        train_loader = data_loader(input_rng,
                                   train_dataset,
                                   train_batch_size,
                                   shuffle=True)
        steps_per_epoch = len(train_dataset) // train_batch_size
        # train
        for step in tqdm(range(steps_per_epoch),
                         desc="Training...",
                         position=1,
                         leave=False):
            batch = next(train_loader)
            batch = shard(batch)
            state, train_metric = p_train_step(state, batch)
            train_metrics.append(train_metric)

            cur_step = epoch * (len(train_dataset) // train_batch_size) + step

            if cur_step % training_args.logging_steps == 0 and cur_step > 0:
                # Save metrics
                train_metric = unreplicate(train_metric)
                train_time += time.time() - train_start
                if has_tensorboard and jax.process_index() == 0:
                    write_train_metric(summary_writer, train_metrics,
                                       train_time, cur_step)

                epochs.write(
                    f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
                )

                train_metrics = []

            if cur_step % training_args.eval_steps == 0 and cur_step > 0:
                # ======================== Evaluating ==============================
                eval_metrics = []
                eval_loader = data_loader(input_rng, eval_dataset,
                                          eval_batch_size)
                eval_steps = len(eval_dataset) // eval_batch_size
                for _ in tqdm(range(eval_steps),
                              desc="Evaluating...",
                              position=2,
                              leave=False):
                    # Model forward
                    batch = next(eval_loader)
                    batch = shard(batch)
                    metrics = p_eval_step(state.params, batch)
                    eval_metrics.append(metrics)

                # normalize eval metrics
                eval_metrics = get_metrics(eval_metrics)
                eval_metrics = jax.tree_map(jnp.mean, eval_metrics)

                try:
                    eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
                except OverflowError:
                    eval_metrics["perplexity"] = float("inf")

                # Print metrics and update progress bar
                desc = f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})"
                epochs.write(desc)
                epochs.desc = desc

                # Save metrics
                if has_tensorboard and jax.process_index() == 0:
                    write_eval_metric(summary_writer, eval_metrics, cur_step)

            if cur_step % training_args.save_steps == 0 and cur_step > 0:
                # save checkpoint after each epoch and push checkpoint to the hub
                if jax.process_index() == 0:
                    params = jax.device_get(unreplicate(state.params))
                    model.save_pretrained(training_args.output_dir,
                                          params=params)
                    tokenizer.save_pretrained(training_args.output_dir)
                    if training_args.push_to_hub:
                        repo.push_to_hub(
                            commit_message=
                            f"Saving weights and logs of step {cur_step}",
                            blocking=False)

    # Eval after training
    if training_args.do_eval:
        eval_metrics = []
        eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
        eval_steps = len(eval_dataset) // eval_batch_size
        for _ in tqdm(range(eval_steps),
                      desc="Evaluating...",
                      position=2,
                      leave=False):
            # Model forward
            batch = shard(next(eval_loader))
            metrics = p_eval_step(state.params, batch)
            eval_metrics.append(metrics)

        # normalize eval metrics
        eval_metrics = get_metrics(eval_metrics)
        eval_metrics = jax.tree_map(lambda x: jnp.mean(x).item(), eval_metrics)

        try:
            eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
        except OverflowError:
            eval_metrics["perplexity"] = float("inf")

        if jax.process_index() == 0:
            eval_metrics = {
                f"eval_{metric_name}": value
                for metric_name, value in eval_metrics.items()
            }
            path = os.path.join(training_args.output_dir, "eval_results.json")
            with open(path, "w") as f:
                json.dump(eval_metrics, f, indent=4, sort_keys=True)
示例#8
0
 def f(x, y):
     z = jax.pmap(np.exp)(x)
     return x + 2., z + y
示例#9
0
def main():
    image_size = 384

    # jax model
    jax_model = models.KNOWN_MODELS['ViT-B_16'].partial(
        num_classes=1000, representation_size=None)
    _, params = jax_model.init_by_shape(
        jax.random.PRNGKey(0),
        # Discard the "num_local_devices" dimension of the batch for initialization.
        [((4, image_size, image_size, 3), 'float32')])
    params = checkpoint.load_pretrained(
        pretrained_path=
        '/home/hchen/Projects/vision_transformer/weights/jax/imagenet21k+imagenet2012_ViT-B_16.npz',
        init_params=params,
        model_config=models.CONFIGS['ViT-B_16'],
        logger=logger)
    params_repl = flax.jax_utils.replicate(params)
    # Then map the call to our model's forward pass onto all available devices.
    vit_apply_repl = jax.pmap(jax_model.call)

    # torch_model
    keys, values = load_jax(
        '/home/hchen/Projects/vision_transformer/weights/jax/imagenet21k+imagenet2012_ViT-B_16.npz'
    )
    state_dict = convert_jax_pytorch(keys, values)

    torch_model = VisionTransformer(image_size=(image_size, image_size),
                                    patch_size=(16, 16),
                                    emb_dim=768,
                                    mlp_dim=3072,
                                    num_heads=12,
                                    num_layers=12,
                                    num_classes=1000,
                                    attn_dropout_rate=0.0,
                                    dropout_rate=0.1)
    torch_model.load_state_dict(state_dict)
    torch_model.eval()

    data_loader = ImageNetDataLoader(
        data_dir='/home/hchen/Projects/vat_contrast/data/ImageNet',
        split='val',
        image_size=image_size,
        batch_size=16,
        num_workers=0)

    for batch_idx, (data, target) in enumerate(data_loader):

        # jax prediction
        target_numpy = target.cpu().numpy()
        data_numpy = data.cpu().numpy().transpose(0, 2, 3, 1).reshape(
            1, -1, image_size, image_size, 3)
        jax_predicted_logits = vit_apply_repl(params_repl,
                                              data_numpy)._value[0]
        jax_predicted = onp.argmax(jax_predicted_logits, axis=-1)

        # torch prediction
        with torch.no_grad():
            torch_predicted = torch_model(data)
        torch_predicted_logits = torch_predicted.cpu().numpy()
        torch_predicted = onp.argmax(torch_predicted_logits, axis=-1)

        # check difference
        # diff = onp.abs(jax_predicted_logits - torch_predicted_logits)
        # assert onp.allclose(jax_predicted_logits, torch_predicted_logits, rtol=1e-1, atol=1e-1), "diff {}, max {}, sum {}".format(diff, onp.max(diff), onp.sum(diff))

        diff = onp.abs(jax_predicted - torch_predicted)
        print(diff)
示例#10
0
文件: learning.py 项目: deepmind/acme
    def __init__(self,
                 network: networks_lib.FeedForwardNetwork,
                 random_key: networks_lib.PRNGKey,
                 loss_fn: losses.Loss,
                 optimizer: optax.GradientTransformation,
                 prefetching_iterator: Iterator[types.Transition],
                 num_sgd_steps_per_step: int,
                 loss_has_aux: bool = False,
                 logger: Optional[loggers.Logger] = None,
                 counter: Optional[counting.Counter] = None):
        """Behavior Cloning Learner.

    Args:
      network: Networks with signature for apply: (params, obs, is_training,
        key) -> jnp.ndarray and for init: (rng, is_training) -> params
      random_key: RNG key.
      loss_fn: BC loss to use.
      optimizer: Optax optimizer.
      prefetching_iterator: A sharded prefetching iterator as outputted from
        `acme.jax.utils.sharded_prefetch`. Please see the documentation for
        `sharded_prefetch` for more details.
      num_sgd_steps_per_step: Number of gradient updates per step.
      loss_has_aux: Whether the loss function returns auxiliary metrics as a
        second argument.
      logger: Logger.
      counter: Counter.
    """
        def sgd_step(
            state: TrainingState,
            transitions: types.Transition,
        ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:

            loss_and_grad = jax.value_and_grad(loss_fn,
                                               argnums=1,
                                               has_aux=loss_has_aux)

            # Compute losses and their gradients.
            key, key_input = jax.random.split(state.key)
            loss_result, gradients = loss_and_grad(network.apply,
                                                   state.policy_params,
                                                   key_input, transitions)

            # Combine the gradient across all devices (by taking their mean).
            gradients = jax.lax.pmean(gradients, axis_name=_PMAP_AXIS_NAME)

            # Compute and combine metrics across all devices.
            metrics = _create_loss_metrics(loss_has_aux, loss_result,
                                           gradients)
            metrics = jax.lax.pmean(metrics, axis_name=_PMAP_AXIS_NAME)

            policy_update, optimizer_state = optimizer.update(
                gradients, state.optimizer_state, state.policy_params)
            policy_params = optax.apply_updates(state.policy_params,
                                                policy_update)

            new_state = TrainingState(
                optimizer_state=optimizer_state,
                policy_params=policy_params,
                key=key,
                steps=state.steps + 1,
            )

            return new_state, metrics

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter(prefix='learner')
        self._logger = logger or loggers.make_default_logger(
            'learner',
            asynchronous=True,
            serialize_fn=utils.fetch_devicearray,
            steps_key=self._counter.get_steps_key())

        # Split the input batch to `num_sgd_steps_per_step` minibatches in order
        # to achieve better performance on accelerators.
        sgd_step = utils.process_multiple_batches(sgd_step,
                                                  num_sgd_steps_per_step)
        self._sgd_step = jax.pmap(sgd_step, axis_name=_PMAP_AXIS_NAME)

        random_key, init_key = jax.random.split(random_key)
        policy_params = network.init(init_key)
        optimizer_state = optimizer.init(policy_params)

        # Create initial state.
        state = TrainingState(
            optimizer_state=optimizer_state,
            policy_params=policy_params,
            key=random_key,
            steps=0,
        )
        self._state = utils.replicate_in_all_devices(state)

        self._timestamp = None

        self._prefetching_iterator = prefetching_iterator
示例#11
0
 def f(x):
     return jax.pmap(lambda x: np.exp(x) + 2.)(x)
示例#12
0
    ("reserve_rng_keys", lambda: base.reserve_rng_keys(2)),
    ("with_rng", with_rng_example),
)

# JAX transforms and control flow that need to be aware of Haiku internal
# state to operate unsurprisingly.
# pylint: disable=g-long-lambda
JAX_PURE_EXPECTING_FNS = (
    # Just-in-time compilation.
    ("jit", jax.jit),
    ("make_jaxpr", jax.make_jaxpr),
    ("eval_shape", lambda f: (lambda x: jax.eval_shape(f, x))),

    # Parallelization.
    # TODO(tomhennigan): Add missing features (e.g. pjit,xmap).
    ("pmap", lambda f: jax.pmap(f, "i")),

    # Vectorization.
    ("vmap", jax.vmap),

    # Control flow.
    # TODO(tomhennigan): Enable for associative_scan.
    # ("associative_scan", lambda f:
    #  (lambda x: jax.lax.associative_scan(
    #      lambda a, b: [f(a + b), a + b][-1], jnp.stack([x, x, x, x])))),
    ("cond", lambda f: (lambda x: jax.lax.cond(True, f, f, x))),
    ("fori_loop", lambda f:
     (lambda x: jax.lax.fori_loop(0, 1, ignore_index(f), x))),
    ("map", lambda f: (lambda x: jax.lax.map(f, x))),
    ("scan", lambda f: (lambda x: jax.lax.scan(identity_carry(f), None, x))),
    ("switch", lambda f: (lambda x: jax.lax.switch(0, [f, f], x))),
示例#13
0
    def test_model_shape(
        self,
        separate_memory_values=False,
        num_intermediate_layers=None,
    ):
        """Test loss function runs and produces expected values."""
        config = copy.deepcopy(self.config)
        config['model_config']['encoder_config'][
            'separate_memory_values'] = separate_memory_values
        config['model_config']['encoder_config'][
            'num_intermediate_layers'] = num_intermediate_layers
        config = ml_collections.FrozenConfigDict(config)

        model_config = config.model_config
        encoder_config = model_config.encoder_config

        rows = encoder_config.rows
        preprocess_fn = mention_memory_task.MentionMemoryTask.make_preprocess_fn(config)  # pylint: disable=line-too-long
        collater_fn = mention_memory_task.MentionMemoryTask.make_collater_fn(
            config)

        test_utils.force_multi_devices(self.n_devices)
        devices = jax.local_devices()

        model = mention_memory_encoder.MentionMemoryEncoder(**encoder_config)
        dummy_input = mention_memory_task.MentionMemoryTask.dummy_input(config)
        dummy_input = jax.device_put_replicated(dummy_input, devices)
        init_rng = jax.random.PRNGKey(0)
        split_rng = jax.random.split(init_rng, self.n_devices)

        memory_table = np.random.rand(rows, self.table_size // rows,
                                      encoder_config.memory_key_dim)
        memory_keys = jax.device_put_replicated(memory_table, devices)
        memory_values = memory_table.reshape(-1, encoder_config.memory_key_dim)
        memory_values = jax.device_put_replicated(memory_values, devices)
        memory_identifiers = np.arange(self.table_size)
        memory_identifiers = jax.device_put_replicated(memory_identifiers,
                                                       devices)
        memory_entity_ids = memory_identifiers
        memory_text_entities = np.zeros(
            (self.table_size, encoder_config.n_memory_text_entities),
            dtype=np.int32)
        memory_text_entities = jax.device_put_replicated(
            memory_text_entities, devices)

        def model_init(*args, **kwargs):
            return model.init(*args, method=model.forward, **kwargs)

        initial_variables = jax.pmap(model_init,
                                     'batch',
                                     static_broadcasted_argnums=2)(
                                         split_rng,
                                         dummy_input,
                                         True,
                                     )
        initial_variables = {'params': initial_variables['params']}
        initial_variables['constants'] = {
            'memory_keys': memory_keys,
            'memory_values': memory_values,
            'memory_identifiers': memory_identifiers,
            'memory_entity_ids': memory_entity_ids,
            'memory_text_entities': memory_text_entities,
        }

        raw_example = test_utils.gen_mention_pretraining_sample(
            self.text_length,
            self.n_mentions,
            self.n_linked_mentions,
            max_length=encoder_config.max_length)
        processed_example = preprocess_fn(raw_example)
        batch = {
            key: np.tile(value, (config.per_device_batch_size, 1))
            for key, value in processed_example.items()
        }
        batch = collater_fn(batch)
        batch = {
            key: test_utils.tensor_to_numpy(value)
            for key, value in batch.items()
        }
        batch = {
            key: jax.device_put_replicated(value, devices)
            for key, value in batch.items()
        }

        def model_apply(*args, **kwargs):
            return model.apply(*args, method=model.forward, **kwargs)

        papply = jax.pmap(model_apply, 'batch', static_broadcasted_argnums=(2))
        encoded_output, loss_helpers, _ = papply(
            {
                'params': initial_variables['params'],
                'constants': initial_variables['constants'],
            },
            batch,
            True,
        )

        self.assertEqual(
            encoded_output.shape,
            (self.n_devices, config.per_device_batch_size,
             encoder_config.max_length, encoder_config.hidden_size))

        memory_value_dim = encoder_config.memory_value_dim
        memory_key_dim = encoder_config.memory_key_dim
        memory_size = memory_value_dim if memory_value_dim else memory_key_dim
        self.assertEqual(loss_helpers['target_mention_encodings'].shape,
                         (self.n_devices, config.max_mention_targets *
                          config.per_device_batch_size, memory_size))
示例#14
0
    def test_load_weights(self,
                          separate_memory_values=False,
                          memory_only=False):
        """Test saving and loading model recovers original parameters."""

        config = copy.deepcopy(self.config)
        config['model_config']['encoder_config'][
            'separate_memory_values'] = separate_memory_values
        config = ml_collections.ConfigDict(config)

        model_config = config.model_config
        encoder_config = model_config.encoder_config
        rows = encoder_config.rows
        test_utils.force_multi_devices(self.n_devices)
        devices = jax.local_devices()
        model = mention_memory_encoder.MentionMemoryEncoder(**encoder_config)
        dummy_input = mention_memory_task.MentionMemoryTask.dummy_input(config)
        dummy_input = jax.device_put_replicated(dummy_input, devices)
        init_rng = jax.random.PRNGKey(0)
        split_rng = jax.random.split(init_rng, self.n_devices)

        memory_table = np.random.rand(rows, self.table_size // rows,
                                      encoder_config.memory_key_dim)
        memory_keys = jax.device_put_replicated(memory_table, devices)
        memory_values = memory_table.reshape(-1, encoder_config.memory_key_dim)
        memory_values = jax.device_put_replicated(memory_values, devices)
        memory_identifiers = np.arange(self.table_size)
        memory_identifiers = jax.device_put_replicated(memory_identifiers,
                                                       devices)
        memory_entity_ids = memory_identifiers
        memory_text_entities = np.zeros(
            (self.table_size, encoder_config.n_memory_text_entities),
            dtype=np.int32)
        memory_text_entities = jax.device_put_replicated(
            memory_text_entities, devices)

        def model_init(*args, **kwargs):
            return model.init(*args, method=model.forward, **kwargs)

        initial_variables = jax.pmap(model_init,
                                     'batch',
                                     static_broadcasted_argnums=2)(
                                         split_rng,
                                         dummy_input,
                                         True,
                                     )
        initial_variables = {'params': initial_variables['params']}
        initial_variables['constants'] = {
            'memory_keys': memory_keys,
            'memory_values': memory_values,
            'memory_identifiers': memory_identifiers,
            'memory_entity_ids': memory_entity_ids,
            'memory_text_entities': memory_text_entities,
        }
        n_shards = 4

        tempdir_obj = self.create_tempdir()
        tempdir = tempdir_obj.full_path

        memory_key_base = os.path.join(tempdir, 'memory_keys')
        memory_value_base = os.path.join(tempdir, 'memory_values')
        memory_id_base = os.path.join(tempdir, 'memory_id')
        memory_entity_id_base = os.path.join(tempdir, 'memory_entity_id')
        memory_text_entities_base = os.path.join(tempdir,
                                                 'memory_text_entities')

        unreplicated_variables = jax_utils.unreplicate(initial_variables)
        unreplicated_variables['params'] = unreplicated_variables[
            'params'].unfreeze()

        if memory_only:
            load_weights = 'memory_only'
        else:
            load_weights = os.path.join(tempdir, 'weights')
            checkpoint_utils.save_weights(load_weights,
                                          unreplicated_variables['params'])

        memory_keys = initial_variables['constants']['memory_keys']
        memory_keys = memory_keys.reshape(n_shards, -1,
                                          encoder_config.memory_key_dim)
        memory_values = initial_variables['constants']['memory_values']
        memory_values = memory_values.reshape(n_shards, -1,
                                              encoder_config.memory_key_dim)
        memory_ids = initial_variables['constants'][
            'memory_identifiers'].reshape(n_shards, -1)
        memory_entity_ids = initial_variables['constants'][
            'memory_entity_ids'].reshape(n_shards, -1)
        memory_text_entities = initial_variables['constants'][
            'memory_text_entities'].reshape(
                n_shards, -1, encoder_config.n_memory_text_entities)

        for shard in range(n_shards):
            np.save(memory_key_base + str(shard), memory_keys[shard])
            np.save(memory_value_base + str(shard), memory_values[shard])
            np.save(memory_id_base + str(shard), memory_ids[shard])
            np.save(memory_entity_id_base + str(shard),
                    memory_entity_ids[shard])
            np.save(memory_entity_id_base + str(shard),
                    memory_entity_ids[shard])
            np.save(memory_text_entities_base + str(shard),
                    memory_text_entities[shard])

        config.memory_key_pattern = memory_key_base + '*'
        config.memory_value_pattern = memory_value_base + '*'
        config.memory_id_pattern = memory_id_base + '*'
        config.memory_entity_id_pattern = memory_entity_id_base + '*'
        config.memory_text_entities_pattern = memory_text_entities_base + '*'
        config.load_weights = load_weights

        loaded_variables = mention_memory_encoder.MentionMemoryEncoder.load_weights(
            config)

        arrayeq = lambda x, y: jnp.all(x == y)
        constants = {
            key: value
            for key, value in initial_variables['constants'].items()
            if not (key == 'memory_values' and not separate_memory_values)
        }
        comparison_variables = {'constants': constants}
        if not memory_only:
            comparison_variables['params'] = initial_variables[
                'params'].unfreeze()

        self.assertTrue(
            jax.tree_map(arrayeq, loaded_variables, comparison_variables))
示例#15
0
文件: train.py 项目: wdevazelhes/flax
def sync_batch_stats(state):
    """Sync the batch statistics across replicas."""
    avg = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x')
    return state.replace(model_state=avg(state.model_state))
from timeit import default_timer as timer

print(jax.local_device_count())  # 4


def random_walk(key, steps=1000):
    position = 0
    for _ in range(steps):
        key, subkey = random.split(key)
        position += random.normal(subkey)
    return position


jit_random_walk = jit(random_walk)

p_random_walk = pmap(jit_random_walk)

start = timer()
jit_random_walk(random.PRNGKey(0))
end = timer()
print("compile time serial:", end - start)

start = timer()
for i in range(4):
    jit_random_walk(random.PRNGKey(i))
end = timer()
print("time elapsed serial:", end - start)

keys = np.array([random.PRNGKey(i) for i in range(4)])

start = timer()
示例#17
0
文件: train.py 项目: wdevazelhes/flax
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.enable_v2_behavior()
    # make sure tf does not allocate gpu memory
    tf.config.experimental.set_visible_devices([], 'GPU')

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(FLAGS.model_dir)

    rng = random.PRNGKey(0)

    image_size = 224

    batch_size = FLAGS.batch_size
    if batch_size % jax.device_count() > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')
    local_batch_size = batch_size // jax.host_count()
    device_batch_size = batch_size // jax.device_count()

    platform = jax.local_devices()[0].platform

    dynamic_scale = None
    if FLAGS.half_precision:
        if platform == 'tpu':
            model_dtype = jnp.bfloat16
            input_dtype = tf.bfloat16
        else:
            model_dtype = jnp.float16
            input_dtype = tf.float16
            dynamic_scale = optim.DynamicScale()
    else:
        model_dtype = jnp.float32
        input_dtype = tf.float32

    train_iter = create_input_iter(local_batch_size,
                                   image_size,
                                   input_dtype,
                                   train=True,
                                   cache=FLAGS.cache)
    eval_iter = create_input_iter(local_batch_size,
                                  image_size,
                                  input_dtype,
                                  train=False,
                                  cache=FLAGS.cache)

    num_epochs = FLAGS.num_epochs
    steps_per_epoch = input_pipeline.TRAIN_IMAGES // batch_size
    steps_per_eval = input_pipeline.EVAL_IMAGES // batch_size
    steps_per_checkpoint = steps_per_epoch * 10
    num_steps = steps_per_epoch * num_epochs

    base_learning_rate = FLAGS.learning_rate * batch_size / 256.

    model, model_state = create_model(rng, device_batch_size, image_size,
                                      model_dtype)
    optimizer = optim.Momentum(beta=FLAGS.momentum,
                               nesterov=True).create(model)
    state = TrainState(step=0,
                       optimizer=optimizer,
                       model_state=model_state,
                       dynamic_scale=dynamic_scale)
    del model, model_state  # do not keep a copy of the initial model

    state = restore_checkpoint(state)
    step_offset = int(
        state.step)  # step_offset > 0 if restarting from checkpoint
    state = jax_utils.replicate(state)

    learning_rate_fn = create_learning_rate_fn(base_learning_rate,
                                               steps_per_epoch, num_epochs)

    p_train_step = jax.pmap(functools.partial(
        train_step, learning_rate_fn=learning_rate_fn),
                            axis_name='batch')
    p_eval_step = jax.pmap(eval_step, axis_name='batch')

    epoch_metrics = []
    t_loop_start = time.time()
    for step, batch in zip(range(step_offset, num_steps), train_iter):
        state, metrics = p_train_step(state, batch)
        epoch_metrics.append(metrics)
        if (step + 1) % steps_per_epoch == 0:
            epoch = step // steps_per_epoch
            epoch_metrics = common_utils.get_metrics(epoch_metrics)
            summary = jax.tree_map(lambda x: x.mean(), epoch_metrics)
            logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                         summary['loss'], summary['accuracy'] * 100)
            steps_per_sec = steps_per_epoch / (time.time() - t_loop_start)
            t_loop_start = time.time()
            if jax.host_id() == 0:
                for key, vals in epoch_metrics.items():
                    tag = 'train_%s' % key
                    for i, val in enumerate(vals):
                        summary_writer.scalar(tag, val,
                                              step - len(vals) + i + 1)
                summary_writer.scalar('steps per second', steps_per_sec, step)

            epoch_metrics = []
            eval_metrics = []

            # sync batch statistics across replicas
            state = sync_batch_stats(state)
            for _ in range(steps_per_eval):
                eval_batch = next(eval_iter)
                metrics = p_eval_step(state, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            summary = jax.tree_map(lambda x: x.mean(), eval_metrics)
            logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                         summary['loss'], summary['accuracy'] * 100)
            if jax.host_id() == 0:
                for key, val in eval_metrics.items():
                    tag = 'eval_%s' % key
                    summary_writer.scalar(tag, val.mean(), step)
                summary_writer.flush()
        if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps:
            state = sync_batch_stats(state)
            save_checkpoint(state)

    # Wait until computations are done before exiting
    jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
示例#18
0
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(
            json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses(
        )

    if (os.path.exists(training_args.output_dir)
            and os.listdir(training_args.output_dir) and training_args.do_train
            and not training_args.overwrite_output_dir):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty."
            "Use --overwrite_output_dir to overcome.")

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        level=logging.INFO,
        datefmt="[%X]",
    )

    # Log on each process the small summary:
    logger = logging.getLogger(__name__)

    # Set the verbosity to info of the Transformers logger (on main process only):
    logger.info(f"Training/evaluation parameters {training_args}")

    # Set seed before initializing model.
    set_seed(training_args.seed)

    # Handle the repository creation
    if training_args.push_to_hub:
        if training_args.hub_model_id is None:
            repo_name = get_full_repo_name(Path(
                training_args.output_dir).absolute().name,
                                           token=training_args.hub_token)
        else:
            repo_name = training_args.hub_model_id
        repo = Repository(training_args.output_dir, clone_from=repo_name)

    # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub).
    #
    # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
    # 'text' is found. You can easily tweak this behavior (see below).
    if data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        datasets = load_dataset(data_args.dataset_name,
                                data_args.dataset_config_name,
                                cache_dir=model_args.cache_dir)

        if "validation" not in datasets.keys():
            datasets["validation"] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[:{data_args.validation_split_percentage}%]",
                cache_dir=model_args.cache_dir,
            )
            datasets["train"] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[{data_args.validation_split_percentage}%:]",
                cache_dir=model_args.cache_dir,
            )
    else:
        data_files = {}
        if data_args.train_file is not None:
            data_files["train"] = data_args.train_file
        if data_args.validation_file is not None:
            data_files["validation"] = data_args.validation_file
        extension = data_args.train_file.split(".")[-1]
        if extension == "txt":
            extension = "text"
        datasets = load_dataset(extension,
                                data_files=data_files,
                                cache_dir=model_args.cache_dir)

        if "validation" not in datasets.keys():
            datasets["validation"] = load_dataset(
                extension,
                data_files=data_files,
                split=f"train[:{data_args.validation_split_percentage}%]",
                cache_dir=model_args.cache_dir,
            )
            datasets["train"] = load_dataset(
                extension,
                data_files=data_files,
                split=f"train[{data_args.validation_split_percentage}%:]",
                cache_dir=model_args.cache_dir,
            )
    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Load pretrained model and tokenizer

    if model_args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.tokenizer_name,
            cache_dir=model_args.cache_dir,
            use_fast=model_args.use_fast_tokenizer)
    elif model_args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
            use_fast=model_args.use_fast_tokenizer)
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported by this script."
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
        )

    if model_args.config_name:
        config = T5Config.from_pretrained(model_args.config_name,
                                          cache_dir=model_args.cache_dir,
                                          vocab_size=len(tokenizer))
    elif model_args.model_name_or_path:
        config = T5Config.from_pretrained(model_args.model_name_or_path,
                                          cache_dir=model_args.cache_dir)
    else:
        config = CONFIG_MAPPING[model_args.model_type]()
        logger.warning(
            "You are instantiating a new config instance from scratch.")

    # Preprocessing the datasets.
    # First we tokenize all the texts.
    if training_args.do_train:
        column_names = datasets["train"].column_names
    else:
        column_names = datasets["validation"].column_names
    text_column_name = "text" if "text" in column_names else column_names[0]

    max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)

    # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
    # Since we make sure that all sequences are of the same length, no attention_mask is needed.
    def tokenize_function(examples):
        return tokenizer(examples[text_column_name],
                         return_attention_mask=False)

    tokenized_datasets = datasets.map(
        tokenize_function,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        remove_columns=column_names,
        load_from_cache_file=not data_args.overwrite_cache,
    )

    # T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token.
    # To ensure that the input length is `max_seq_length`, we need to increase the maximum length
    # according to `mlm_probability` and `mean_noise_span_length`. We can also define the label length accordingly.
    expanded_inputs_length, targets_length = compute_input_and_target_lengths(
        inputs_length=max_seq_length,
        noise_density=data_args.mlm_probability,
        mean_noise_span_length=data_args.mean_noise_span_length,
    )

    # Main data processing function that will concatenate all texts from our dataset and generate chunks of expanded_inputs_length.
    def group_texts(examples):
        # Concatenate all texts.
        concatenated_examples = {
            k: list(chain(*examples[k]))
            for k in examples.keys()
        }
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
        if total_length >= expanded_inputs_length:
            total_length = (total_length //
                            expanded_inputs_length) * expanded_inputs_length
        # Split by chunks of max_len.
        result = {
            k: [
                t[i:i + expanded_inputs_length]
                for i in range(0, total_length, expanded_inputs_length)
            ]
            for k, t in concatenated_examples.items()
        }
        return result

    # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
    # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
    # might be slower to preprocess.
    #
    # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
    # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
    tokenized_datasets = tokenized_datasets.map(
        group_texts,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        load_from_cache_file=not data_args.overwrite_cache,
    )

    # Enable tensorboard only on the master node
    has_tensorboard = is_tensorboard_available()
    if has_tensorboard and jax.process_index() == 0:
        try:
            from flax.metrics.tensorboard import SummaryWriter

            summary_writer = SummaryWriter(
                log_dir=Path(training_args.output_dir))
        except ImportError as ie:
            has_tensorboard = False
            logger.warning(
                f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
            )
    else:
        logger.warning(
            "Unable to display metrics through TensorBoard because the package is not installed: "
            "Please run pip install tensorboard to enable.")

    # Initialize our training
    rng = jax.random.PRNGKey(training_args.seed)
    dropout_rngs = jax.random.split(rng, jax.local_device_count())

    if model_args.model_name_or_path:
        model = FlaxT5ForConditionalGeneration.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            seed=training_args.seed,
            dtype=getattr(jnp, model_args.dtype))
    else:
        config.vocab_size = len(tokenizer)
        model = FlaxT5ForConditionalGeneration(config,
                                               seed=training_args.seed,
                                               dtype=getattr(
                                                   jnp, model_args.dtype))

    # Data collator
    # This one will take care of randomly masking the tokens.
    data_collator = FlaxDataCollatorForT5MLM(
        tokenizer=tokenizer,
        noise_density=data_args.mlm_probability,
        mean_noise_span_length=data_args.mean_noise_span_length,
        input_length=max_seq_length,
        target_length=targets_length,
        pad_token_id=model.config.pad_token_id,
        decoder_start_token_id=model.config.decoder_start_token_id,
    )

    # Store some constant
    num_epochs = int(training_args.num_train_epochs)
    train_batch_size = int(
        training_args.per_device_train_batch_size) * jax.device_count()
    eval_batch_size = int(
        training_args.per_device_eval_batch_size) * jax.device_count()

    num_train_steps = len(
        tokenized_datasets["train"]) // train_batch_size * num_epochs

    # Create learning rate schedule
    warmup_fn = optax.linear_schedule(
        init_value=0.0,
        end_value=training_args.learning_rate,
        transition_steps=training_args.warmup_steps)
    decay_fn = optax.linear_schedule(
        init_value=training_args.learning_rate,
        end_value=0,
        transition_steps=num_train_steps - training_args.warmup_steps,
    )
    linear_decay_lr_schedule_fn = optax.join_schedules(
        schedules=[warmup_fn, decay_fn],
        boundaries=[training_args.warmup_steps])

    # We use Optax's "masking" functionality to not apply weight decay
    # to bias and LayerNorm scale parameters. decay_mask_fn returns a
    # mask boolean with the same structure as the parameters.
    # The mask is True for parameters that should be decayed.
    def decay_mask_fn(params):
        flat_params = traverse_util.flatten_dict(params)
        flat_mask = {
            path: (path[-1] != "bias"
                   and path[-2:] not in [("layer_norm", "scale"),
                                         ("final_layer_norm", "scale")])
            for path in flat_params
        }
        return traverse_util.unflatten_dict(flat_mask)

    # create adam optimizer
    if training_args.adafactor:
        # We use the default parameters here to initialize adafactor,
        # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
        optimizer = optax.adafactor(
            learning_rate=linear_decay_lr_schedule_fn, )
    else:
        optimizer = optax.adamw(
            learning_rate=linear_decay_lr_schedule_fn,
            b1=training_args.adam_beta1,
            b2=training_args.adam_beta2,
            weight_decay=training_args.weight_decay,
            mask=decay_mask_fn,
        )

    # Setup train state
    state = train_state.TrainState.create(apply_fn=model.__call__,
                                          params=model.params,
                                          tx=optimizer)

    # Define gradient update step fn
    def train_step(state, batch, dropout_rng):
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)

        def loss_fn(params):
            labels = batch.pop("labels")

            logits = state.apply_fn(**batch,
                                    params=params,
                                    dropout_rng=dropout_rng,
                                    train=True)[0]

            # compute loss
            loss = optax.softmax_cross_entropy(
                logits, onehot(labels, logits.shape[-1])).mean()

            return loss

        grad_fn = jax.value_and_grad(loss_fn)
        loss, grad = grad_fn(state.params)
        grad = jax.lax.pmean(grad, "batch")
        new_state = state.apply_gradients(grads=grad)

        metrics = jax.lax.pmean(
            {
                "loss": loss,
                "learning_rate": linear_decay_lr_schedule_fn(state.step)
            },
            axis_name="batch")

        return new_state, metrics, new_dropout_rng

    # Create parallel version of the train step
    p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, ))

    # Define eval fn
    def eval_step(params, batch):
        labels = batch.pop("labels")

        logits = model(**batch, params=params, train=False)[0]

        # compute loss
        loss = optax.softmax_cross_entropy(logits,
                                           onehot(labels, logits.shape[-1]))

        # compute accuracy
        accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels)

        # summarize metrics
        metrics = {"loss": loss.mean(), "accuracy": accuracy.mean()}
        metrics = jax.lax.pmean(metrics, axis_name="batch")

        return metrics

    p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0, ))

    # Replicate the train state on each device
    state = jax_utils.replicate(state)

    train_time = 0
    epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0)
    for epoch in epochs:
        # ======================== Training ================================
        train_start = time.time()
        train_metrics = []

        # Create sampling rng
        rng, input_rng = jax.random.split(rng)

        # Generate an epoch by shuffling sampling indices from the train dataset
        num_train_samples = len(tokenized_datasets["train"])
        train_samples_idx = np.random.permutation(np.arange(num_train_samples))
        train_batch_idx = generate_batch_splits(train_samples_idx,
                                                train_batch_size)

        # Gather the indexes for creating the batch and do a training step
        for step, batch_idx in enumerate(
                tqdm(train_batch_idx, desc="Training...", position=1)):
            samples = [
                tokenized_datasets["train"][int(idx)] for idx in batch_idx
            ]
            model_inputs = data_collator(samples)

            # Model forward
            model_inputs = shard(model_inputs.data)
            state, train_metric, dropout_rngs = p_train_step(
                state, model_inputs, dropout_rngs)
            train_metrics.append(train_metric)

            cur_step = epoch * (num_train_samples // train_batch_size) + step

            if cur_step % training_args.logging_steps == 0 and cur_step > 0:
                # Save metrics
                train_metric = jax_utils.unreplicate(train_metric)
                train_time += time.time() - train_start
                if has_tensorboard and jax.process_index() == 0:
                    write_train_metric(summary_writer, train_metrics,
                                       train_time, cur_step)

                epochs.write(
                    f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
                )

                train_metrics = []

            if cur_step % training_args.eval_steps == 0 and cur_step > 0:
                # ======================== Evaluating ==============================
                num_eval_samples = len(tokenized_datasets["validation"])
                eval_samples_idx = jnp.arange(num_eval_samples)
                eval_batch_idx = generate_batch_splits(eval_samples_idx,
                                                       eval_batch_size)

                eval_metrics = []
                for i, batch_idx in enumerate(
                        tqdm(eval_batch_idx, desc="Evaluating ...",
                             position=2)):
                    samples = [
                        tokenized_datasets["validation"][int(idx)]
                        for idx in batch_idx
                    ]
                    model_inputs = data_collator(samples)

                    # Model forward
                    model_inputs = shard(model_inputs.data)
                    metrics = p_eval_step(state.params, model_inputs)
                    eval_metrics.append(metrics)

                # get eval metrics
                eval_metrics = get_metrics(eval_metrics)
                eval_metrics = jax.tree_map(jnp.mean, eval_metrics)

                # Update progress bar
                epochs.write(
                    f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
                )

                # Save metrics
                if has_tensorboard and jax.process_index() == 0:
                    write_eval_metric(summary_writer, eval_metrics, cur_step)

            if cur_step % training_args.save_steps == 0 and cur_step > 0:
                # save checkpoint after each epoch and push checkpoint to the hub
                if jax.process_index() == 0:
                    params = jax.device_get(
                        jax.tree_map(lambda x: x[0], state.params))
                    model.save_pretrained(training_args.output_dir,
                                          params=params)
                    tokenizer.save_pretrained(training_args.output_dir)
                    if training_args.push_to_hub:
                        repo.push_to_hub(
                            commit_message=
                            f"Saving weights and logs of step {cur_step}",
                            blocking=False)

    # Eval after training
    if training_args.do_eval:
        num_eval_samples = len(tokenized_datasets["validation"])
        eval_samples_idx = jnp.arange(num_eval_samples)
        eval_batch_idx = generate_batch_splits(eval_samples_idx,
                                               eval_batch_size)

        eval_metrics = []
        for i, batch_idx in enumerate(
                tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
            samples = [
                tokenized_datasets["validation"][int(idx)] for idx in batch_idx
            ]
            model_inputs = data_collator(samples)

            # Model forward
            model_inputs = shard(model_inputs.data)
            metrics = p_eval_step(state.params, model_inputs)
            eval_metrics.append(metrics)

        # get eval metrics
        eval_metrics = get_metrics(eval_metrics)
        eval_metrics = jax.tree_map(lambda metric: jnp.mean(metric).item(),
                                    eval_metrics)

        if jax.process_index() == 0:
            eval_metrics = {
                f"eval_{metric_name}": value
                for metric_name, value in eval_metrics.items()
            }
            path = os.path.join(training_args.output_dir, "eval_results.json")
            with open(path, "w") as f:
                json.dump(eval_metrics, f, indent=4, sort_keys=True)
示例#19
0
def main():
    args = parse_args()

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    # Setup logging, we only want one process per machine to log things on the screen.
    logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
    if jax.process_index() == 0:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
    # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).

    # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the
    # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named
    # label if at least two columns are provided.

    # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this
    # single column. You can easily tweak this behavior (see below)

    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
    # download the dataset.
    if args.task_name is not None:
        # Downloading and loading a dataset from the hub.
        raw_datasets = load_dataset("glue", args.task_name)
    else:
        # Loading the dataset from local csv or json file.
        data_files = {}
        if args.train_file is not None:
            data_files["train"] = args.train_file
        if args.validation_file is not None:
            data_files["validation"] = args.validation_file
        extension = (args.train_file if args.train_file is not None else args.valid_file).split(".")[-1]
        raw_datasets = load_dataset(extension, data_files=data_files)
    # See more about loading any type of standard or custom dataset at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Labels
    if args.task_name is not None:
        is_regression = args.task_name == "stsb"
        if not is_regression:
            label_list = raw_datasets["train"].features["label"].names
            num_labels = len(label_list)
        else:
            num_labels = 1
    else:
        # Trying to have good defaults here, don't hesitate to tweak to your needs.
        is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"]
        if is_regression:
            num_labels = 1
        else:
            # A useful fast method:
            # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique
            label_list = raw_datasets["train"].unique("label")
            label_list.sort()  # Let's sort it for determinism
            num_labels = len(label_list)

    # Load pretrained model and tokenizer
    config = AutoConfig.from_pretrained(args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name)
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
    model = FlaxAutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, config=config)

    # Preprocessing the datasets
    if args.task_name is not None:
        sentence1_key, sentence2_key = task_to_keys[args.task_name]
    else:
        # Again, we try to have some nice defaults but don't hesitate to tweak to your use case.
        non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"]
        if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names:
            sentence1_key, sentence2_key = "sentence1", "sentence2"
        else:
            if len(non_label_column_names) >= 2:
                sentence1_key, sentence2_key = non_label_column_names[:2]
            else:
                sentence1_key, sentence2_key = non_label_column_names[0], None

    # Some models have set the order of the labels to use, so let's make sure we do use it.
    label_to_id = None
    if (
        model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
        and args.task_name is not None
        and not is_regression
    ):
        # Some have all caps in their config, some don't.
        label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
        if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
            logger.info(
                f"The configuration of the model provided the following label correspondence: {label_name_to_id}. "
                "Using it!"
            )
            label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)}
        else:
            logger.warning(
                "Your model seems to have been trained with labels, but they don't match the dataset: ",
                f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
                "\nIgnoring the model labels as a result.",
            )
    elif args.task_name is None:
        label_to_id = {v: i for i, v in enumerate(label_list)}

    def preprocess_function(examples):
        # Tokenize the texts
        texts = (
            (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
        )
        result = tokenizer(*texts, padding="max_length", max_length=args.max_length, truncation=True)

        if "label" in examples:
            if label_to_id is not None:
                # Map labels to IDs (not necessary for GLUE tasks)
                result["labels"] = [label_to_id[l] for l in examples["label"]]
            else:
                # In all cases, rename the column to labels because the model will expect that.
                result["labels"] = examples["label"]
        return result

    processed_datasets = raw_datasets.map(
        preprocess_function, batched=True, remove_columns=raw_datasets["train"].column_names
    )

    train_dataset = processed_datasets["train"]
    eval_dataset = processed_datasets["validation_matched" if args.task_name == "mnli" else "validation"]

    # Log a few random samples from the training set:
    for index in random.sample(range(len(train_dataset)), 3):
        logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")

    # Define a summary writer
    summary_writer = tensorboard.SummaryWriter(args.output_dir)
    summary_writer.hparams(vars(args))

    def write_metric(train_metrics, eval_metrics, train_time, step):
        summary_writer.scalar("train_time", train_time, step)

        train_metrics = get_metrics(train_metrics)
        for key, vals in train_metrics.items():
            tag = f"train_{key}"
            for i, val in enumerate(vals):
                summary_writer.scalar(tag, val, step - len(vals) + i + 1)

        for metric_name, value in eval_metrics.items():
            summary_writer.scalar(f"eval_{metric_name}", value, step)

    num_epochs = int(args.num_train_epochs)
    rng = jax.random.PRNGKey(args.seed)

    train_batch_size = args.per_device_train_batch_size * jax.local_device_count()
    eval_batch_size = args.per_device_eval_batch_size * jax.local_device_count()

    learning_rate_fn = create_learning_rate_fn(
        len(train_dataset), train_batch_size, args.num_train_epochs, args.num_warmup_steps, args.learning_rate
    )

    state = create_train_state(model, learning_rate_fn, is_regression, num_labels=num_labels)

    # define step functions
    def train_step(
        state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey
    ) -> Tuple[train_state.TrainState, float]:
        """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`."""
        targets = batch.pop("labels")

        def loss_fn(params):
            logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
            loss = state.loss_fn(logits, targets)
            return loss, logits

        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (loss, logits), grad = grad_fn(state.params)
        grad = jax.lax.pmean(grad, "batch")
        new_state = state.apply_gradients(grads=grad)
        metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch")
        return new_state, metrics

    p_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))

    def eval_step(state, batch):
        logits = state.apply_fn(**batch, params=state.params, train=False)[0]
        return state.logits_fn(logits)

    p_eval_step = jax.pmap(eval_step, axis_name="batch")

    if args.task_name is not None:
        metric = load_metric("glue", args.task_name)
    else:
        metric = load_metric("accuracy")

    logger.info(f"===== Starting training ({num_epochs} epochs) =====")
    train_time = 0

    for epoch in range(1, num_epochs + 1):
        logger.info(f"Epoch {epoch}")
        logger.info("  Training...")

        # make sure weights are replicated on each device
        state = replicate(state)

        train_start = time.time()
        train_metrics = []
        rng, input_rng, dropout_rng = jax.random.split(rng, 3)

        # train
        for batch in glue_train_data_collator(input_rng, train_dataset, train_batch_size):
            dropout_rngs = shard_prng_key(dropout_rng)
            state, metrics = p_train_step(state, batch, dropout_rngs)
            train_metrics.append(metrics)
        train_time += time.time() - train_start
        logger.info(f"    Done! Training metrics: {unreplicate(metrics)}")

        logger.info("  Evaluating...")
        rng, input_rng = jax.random.split(rng)

        # evaluate
        for batch in glue_eval_data_collator(eval_dataset, eval_batch_size):
            labels = batch.pop("labels")
            predictions = p_eval_step(state, batch)
            metric.add_batch(predictions=chain(*predictions), references=chain(*labels))

        # evaluate also on leftover examples (not divisible by batch_size)
        num_leftover_samples = len(eval_dataset) % eval_batch_size

        # make sure leftover batch is evaluated on one device
        if num_leftover_samples > 0 and jax.process_index() == 0:
            # put weights on single device
            state = unreplicate(state)

            # take leftover samples
            batch = eval_dataset[-num_leftover_samples:]
            batch = {k: jnp.array(v) for k, v in batch.items()}

            labels = batch.pop("labels")
            predictions = eval_step(state, batch)
            metric.add_batch(predictions=predictions, references=labels)

        eval_metric = metric.compute()
        logger.info(f"    Done! Eval metrics: {eval_metric}")

        cur_step = epoch * (len(train_dataset) // train_batch_size)
        write_metric(train_metrics, eval_metric, train_time, cur_step)

    # save last checkpoint
    if jax.process_index() == 0:
        params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
        model.save_pretrained(args.output_dir, params=params)
示例#20
0
 def f(x):
     return jax.pmap(lambda x: variable(x, name='x'))(x)
示例#21
0
文件: train.py 项目: tokusumi/flax
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
    tf.io.gfile.makedirs(workdir)

    vocab_path = config.vocab_path
    if vocab_path is None:
        vocab_path = os.path.join(workdir, "sentencepiece_model")
    tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

    # Load Dataset
    # ---------------------------------------------------------------------------
    logging.info("Initializing dataset.")
    train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
        n_devices=jax.local_device_count(),
        config=config,
        shard_idx=jax.host_id(),
        shard_count=jax.host_count(),
        vocab_path=vocab_path)

    train_iter = iter(train_ds)
    vocab_size = int(encoder.vocab_size())
    eos_id = decode.EOS_ID  # Default Sentencepiece EOS token.

    def decode_tokens(toks):
        valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
        return encoder.detokenize(valid_toks).numpy().decode("utf-8")

    if config.num_predict_steps > 0:
        predict_ds = predict_ds.take(config.num_predict_steps)

    logging.info("Initializing model, optimizer, and step functions.")

    # Build Model and Optimizer
    # ---------------------------------------------------------------------------
    train_config = models.TransformerConfig(
        vocab_size=vocab_size,
        output_vocab_size=vocab_size,
        share_embeddings=config.share_embeddings,
        logits_via_embedding=config.logits_via_embedding,
        dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32,
        emb_dim=config.emb_dim,
        num_heads=config.num_heads,
        num_layers=config.num_layers,
        qkv_dim=config.qkv_dim,
        mlp_dim=config.mlp_dim,
        max_len=max(config.max_target_length, config.max_eval_target_length),
        dropout_rate=config.dropout_rate,
        attention_dropout_rate=config.attention_dropout_rate,
        deterministic=False,
        decode=False,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6))
    eval_config = train_config.replace(deterministic=True)
    predict_config = train_config.replace(deterministic=True, decode=True)

    start_step = 0
    rng = jax.random.PRNGKey(config.seed)
    rng, init_rng = jax.random.split(rng)
    input_shape = (config.per_device_batch_size, config.max_target_length)
    target_shape = (config.per_device_batch_size, config.max_target_length)

    m = models.Transformer(eval_config)
    initial_variables = jax.jit(m.init)(init_rng,
                                        jnp.ones(input_shape, jnp.float32),
                                        jnp.ones(target_shape, jnp.float32))

    # apply an optimizer to this tree
    optimizer_def = optim.Adam(config.learning_rate,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=config.weight_decay)
    optimizer = optimizer_def.create(initial_variables["params"])

    # We access model params only from optimizer below via optimizer.target.
    del initial_variables

    if config.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(workdir, optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)

    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.host_id() > 0)
    if start_step == 1:
        writer.write_hparams(dict(config))

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = create_learning_rate_scheduler(
        base_learning_rate=config.learning_rate,
        warmup_steps=config.warmup_steps)

    # compile multidevice versions of train/eval/predict step and cache init fn.
    p_train_step = jax.pmap(functools.partial(
        train_step,
        config=train_config,
        learning_rate_fn=learning_rate_fn,
        label_smoothing=config.label_smoothing),
                            axis_name="batch",
                            donate_argnums=(0, ))  # pytype: disable=wrong-arg-types
    p_eval_step = jax.pmap(functools.partial(
        eval_step, config=eval_config, label_smoothing=config.label_smoothing),
                           axis_name="batch")
    p_init_cache = jax.pmap(functools.partial(
        initialize_cache,
        max_decode_len=config.max_predict_length,
        config=predict_config),
                            axis_name="batch")
    p_pred_step = jax.pmap(
        functools.partial(predict_step,
                          config=predict_config,
                          beam_size=config.beam_size),
        axis_name="batch",
        static_broadcasted_argnums=(3,
                                    4))  # eos token, max_length are constant

    # Main Train Loop
    # ---------------------------------------------------------------------------

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap"d training update for performance.
    dropout_rngs = jax.random.split(rng, jax.local_device_count())
    del rng

    logging.info("Starting training loop.")
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_train_steps, writer=writer)
    if jax.host_id() == 0:
        hooks += [
            report_progress,
            periodic_actions.Profile(num_profile_steps=5)
        ]
    train_metrics = []
    with metric_writers.ensure_flushes(writer):
        for step in range(start_step, config.num_train_steps):
            is_last_step = step == config.num_train_steps - 1

            # Shard data to devices and do a training step.
            with jax.profiler.StepTraceContext("train", step_num=step):
                batch = common_utils.shard(
                    jax.tree_map(np.asarray, next(train_iter)))
                optimizer, metrics = p_train_step(optimizer,
                                                  batch,
                                                  dropout_rng=dropout_rngs)
                train_metrics.append(metrics)

            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, "Finished training step %d.", 5,
                                step)
            for h in hooks:
                h(step)

            # Periodic metric handling.
            if step % config.eval_every_steps == 0 or is_last_step:
                with report_progress.timed("training_metrics"):
                    logging.info("Gathering training metrics.")
                    train_metrics = common_utils.get_metrics(train_metrics)
                    lr = train_metrics.pop("learning_rate").mean()
                    metrics_sums = jax.tree_map(jnp.sum, train_metrics)
                    denominator = metrics_sums.pop("denominator")
                    summary = jax.tree_map(lambda x: x / denominator,
                                           metrics_sums)  # pylint: disable=cell-var-from-loop
                    summary["learning_rate"] = lr
                    summary = {"train_" + k: v for k, v in summary.items()}
                    writer.write_scalars(step, summary)
                    train_metrics = []

                with report_progress.timed("eval"):
                    eval_results = evaluate(
                        p_eval_step=p_eval_step,
                        target=optimizer.target,
                        eval_ds=eval_ds,
                        num_eval_steps=config.num_eval_steps)
                    writer.write_scalars(
                        step,
                        {"eval_" + k: v
                         for k, v in eval_results.items()})

                with report_progress.timed("translate_and_bleu"):
                    exemplars, bleu_score = translate_and_calculate_bleu(
                        p_pred_step=p_pred_step,
                        p_init_cache=p_init_cache,
                        target=optimizer.target,
                        predict_ds=predict_ds,
                        decode_tokens=decode_tokens,
                        max_predict_length=config.max_predict_length)
                    writer.write_scalars(step, {"bleu": bleu_score})
                    writer.write_texts(step, {"samples": exemplars})

            # Save a checkpoint on one host after every checkpoint_freq steps.
            save_checkpoint = step % config.checkpoint_every_steps or is_last_step
            if config.save_checkpoints and save_checkpoint and jax.host_id():
                with report_progress.timed("checkpoint"):
                    checkpoints.save_checkpoint(
                        workdir, jax_utils.unreplicate(optimizer), step)
示例#22
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  # Make sure tf does not allocate gpu memory.
  tf.config.experimental.set_visible_devices([], 'GPU')

  if FLAGS.jax_backend_target:
    jax.config.FLAGS.jax_xla_backend = 'tpu_driver'
    jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target

  # Number of local devices for this host.
  n_devices = jax.local_device_count()

  if jax.host_id() == 0:
    tf.io.gfile.makedirs(FLAGS.model_dir)

  if FLAGS.batch_size % n_devices:
    raise ValueError('Batch size must be divisible by the number of devices')

  vocab_path = FLAGS.vocab_path
  if vocab_path is None:
    vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model')
  tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

  # Load Dataset
  # ---------------------------------------------------------------------------
  logging.info('Initializing dataset.')
  if FLAGS.dynamic:
    train_ds_mgr, eval_ds, predict_ds, encoder = input_pipeline.get_dynamic_datasets(
        dataset_name=FLAGS.dataset_name,
        eval_dataset_name=FLAGS.eval_dataset_name,
        shard_idx=jax.host_id(),
        shard_count=jax.host_count(),
        data_dir=FLAGS.data_dir,
        vocab_path=FLAGS.vocab_path,
        target_vocab_size=FLAGS.vocab_size,
        batch_size=FLAGS.batch_size,
        max_length=FLAGS.max_target_length,
        max_eval_length=FLAGS.max_eval_target_length,
        paracrawl_size=FLAGS.paracrawl_size,
        is_scores_path=FLAGS.is_scores_path,
        num_buckets=FLAGS.num_data_buckets)
    if FLAGS.static:
      weights = np.array([float(w) for w in FLAGS.static.split(',')])
      assert len(weights) == FLAGS.num_data_buckets
      train_ds = train_ds_mgr.sampled_dataset(weights)
      FLAGS.dynamic = False
    else:
      init_dist = np.zeros(FLAGS.num_data_buckets)
      if FLAGS.data_selection_size < FLAGS.num_data_buckets:
        init_dist[range(FLAGS.data_selection_size)] = 1.0
        train_ds = train_ds_mgr.sampled_dataset(init_dist)
      else:
        train_ds = build_split(train_ds_mgr, 1.0)

  else:
    train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
        dataset_name=FLAGS.dataset_name,
        eval_dataset_name=FLAGS.eval_dataset_name,
        shard_idx=jax.host_id(),
        shard_count=jax.host_count(),
        data_dir=FLAGS.data_dir,
        vocab_path=vocab_path,
        target_vocab_size=FLAGS.vocab_size,
        batch_size=FLAGS.batch_size,
        max_length=FLAGS.max_target_length,
        max_eval_length=FLAGS.max_eval_target_length,
        paracrawl_size=FLAGS.paracrawl_size,
        is_scores_path=FLAGS.is_scores_path,
        num_to_keep=FLAGS.data_selection_size,
        pseudo_path=FLAGS.pseudo_path,
        repeat_count=FLAGS.repeat_count,
        newscommentary_size=FLAGS.newscommentary_size)

  if FLAGS.aux_eval_dataset:
    aux_datasets = []
    aux_names = FLAGS.aux_eval_dataset.split(',')
    for name in aux_names:
      _, aux_eval_ds, _, _ = input_pipeline.get_wmt_datasets(
          dataset_name=name,
          eval_dataset_name=None,
          shard_idx=jax.host_id(),
          shard_count=jax.host_count(),
          data_dir=FLAGS.data_dir,
          vocab_path=vocab_path,
          target_vocab_size=FLAGS.vocab_size,
          batch_size=FLAGS.batch_size,
          max_length=FLAGS.max_target_length,
          max_eval_length=FLAGS.max_eval_target_length,
          paracrawl_size=FLAGS.paracrawl_size,
          is_scores_path=FLAGS.is_scores_path,
          num_to_keep=FLAGS.data_selection_size,
          pseudo_path=FLAGS.pseudo_path,
          repeat_count=FLAGS.repeat_count,
          newscommentary_size=FLAGS.newscommentary_size)
      aux_datasets.append(aux_eval_ds)

  train_iter = iter(train_ds)
  vocab_size = int(encoder.vocab_size())
  eos_id = decode.EOS_ID  # Default Sentencepiece EOS token.

  def decode_tokens(toks):
    valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
    return encoder.detokenize(valid_toks).numpy().decode('utf-8')

  logging.info('Initializing model, optimizer, and step functions.')

  # Build Model and Optimizer
  # ---------------------------------------------------------------------------
  train_config = models.TransformerConfig(
      vocab_size=vocab_size,
      output_vocab_size=vocab_size,
      share_embeddings=FLAGS.share_embeddings,
      logits_via_embedding=FLAGS.logits_via_embedding,
      dtype=jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32,
      emb_dim=FLAGS.emb_dim,
      num_heads=FLAGS.num_heads,
      num_layers=FLAGS.num_layers,
      qkv_dim=FLAGS.qkv_dim,
      mlp_dim=FLAGS.mlp_dim,
      max_len=max(FLAGS.max_target_length, FLAGS.max_eval_target_length),
      dropout_rate=FLAGS.dropout_rate,
      attention_dropout_rate=FLAGS.attention_dropout_rate,
      deterministic=False,
      decode=False,
      kernel_init=nn.initializers.xavier_uniform(),
      bias_init=nn.initializers.normal(stddev=1e-6))
  eval_config = train_config.replace(deterministic=True)
  predict_config = train_config.replace(deterministic=True, decode=True)

  start_step = 0
  rng = jax.random.PRNGKey(FLAGS.random_seed)
  rng, init_rng = jax.random.split(rng)
  # It's possible that is supposed to be per device batch size
  input_shape = (FLAGS.batch_size, FLAGS.max_target_length)
  target_shape = (FLAGS.batch_size, FLAGS.max_target_length)

  m = models.Transformer(eval_config)
  initial_variables = jax.jit(m.init)(init_rng,
                                      jnp.ones(input_shape, jnp.float32),
                                      jnp.ones(target_shape, jnp.float32))

  # apply an optimizer to this tree
  optimizer_def = optim.Adam(
      FLAGS.learning_rate,
      beta1=0.9,
      beta2=0.98,
      eps=1e-9,
      weight_decay=FLAGS.weight_decay)
  optimizer = optimizer_def.create(initial_variables['params'])

  # We access model params only from optimizer below via optimizer.target.
  del initial_variables

  if FLAGS.restore_checkpoints:
    logging.info('Restoring checkpoint.')
    # If we have a pretrained model, use that. Else, just continue where leftoff
    model_path = FLAGS.pretrained_model_dir if FLAGS.pretrained_model_dir else FLAGS.model_dir
    optimizer = checkpoints.restore_checkpoint(model_path, optimizer)
    # Grab last step.
    start_step = int(optimizer.state.step)

  if FLAGS.adapter != NONE:
    adapter = optim.ModelParamTraversal(lambda path, _: FLAGS.adapter in path)
    optimizer = optimizer_def.create(optimizer.target, focus=adapter)

  writer = metric_writers.create_default_writer(
      FLAGS.model_dir, just_logging=jax.process_index() > 0)

  flag_key = [k for k in FLAGS.flags_by_module_dict().keys() if 'wmt.par' in k
             ]
  if flag_key:
    flag_key = flag_key[0]
    local_flags = {
        f.name: f.value for f in FLAGS.flags_by_module_dict()[flag_key]
    }
    writer.write_hparams(local_flags)

  # Replicate optimizer.
  optimizer = jax_utils.replicate(optimizer)

  if FLAGS.adapter != NONE:
    learning_rate_fn = common.create_learning_rate_scheduler(
        factors='constant',
        base_learning_rate=FLAGS.learning_rate,
        warmup_steps=FLAGS.warmup_steps)
  else:
    learning_rate_fn = common.create_learning_rate_scheduler(
        base_learning_rate=FLAGS.learning_rate, warmup_steps=FLAGS.warmup_steps,
        steps_per_cycle=FLAGS.steps_per_cycle, init_step=start_step,
        finetune_lr=FLAGS.finetune_lr)

  # compile multidevice versions of train/eval/predict step and cache init fn.
  p_train_step = jax.pmap(
      functools.partial(
          train_step,
          config=train_config,
          learning_rate_fn=learning_rate_fn,
          label_smoothing=FLAGS.label_smoothing),
      axis_name='batch',
      donate_argnums=(0,))  # pytype: disable=wrong-arg-types
  p_eval_step = jax.pmap(
      functools.partial(eval_step, config=eval_config), axis_name='batch')
  p_init_cache = jax.pmap(
      functools.partial(
          initialize_cache,
          max_decode_len=FLAGS.max_predict_length,
          config=predict_config),
      axis_name='batch')
  p_pred_step = jax.pmap(
      functools.partial(
          predict_step, config=predict_config, beam_size=FLAGS.beam_size),
      axis_name='batch',
      static_broadcasted_argnums=(3, 4))  # eos token, max_length are constant

  p_get_diag_grads = jax.pmap(
      functools.partial(
          get_diag_grads,
          config=eval_config),
      axis_name='batch')

  p_get_bucket_score = jax.pmap(
      functools.partial(
          get_diag_score,
          strategy=FLAGS.strategy),
      axis_name='batch')

  # Main Train Loop
  # ---------------------------------------------------------------------------

  # We init the first set of dropout PRNG keys, but update it afterwards inside
  # the main pmap"d training update for performance.
  dropout_rngs = jax.random.split(rng, jax.local_device_count())
  del rng

  logging.info('Starting training loop.')
  hooks = []
  report_progress = periodic_actions.ReportProgress(
      num_train_steps=FLAGS.num_train_steps, writer=writer)
  if jax.process_index() == 0:
    hooks += [
        report_progress,
        periodic_actions.Profile(logdir=FLAGS.model_dir, num_profile_steps=5)
    ]
  train_metrics = []
  total_steps = start_step + FLAGS.num_train_steps
  best_eval_loss = 1000
  curr_eval_loss = 1000
  with metric_writers.ensure_flushes(writer):
    for step in range(start_step, total_steps):
      is_last_step = step == total_steps - 1

      if FLAGS.dynamic and ((step - start_step) % FLAGS.resample_freq == 0):
        # Dynamic macro: use gradient alignment to score different ratios
        # of top k vs bottom N-k bins
        if FLAGS.macro:
          train_iter = get_macro_distribution(p_get_diag_grads,
                                              p_get_bucket_score, aux_eval_ds,
                                              train_ds_mgr, optimizer, eval_ds)
        else:
          # Use gradient alignment to score bins
          # take the top k bins and sample uniformly from them.
          raw_distribution = get_new_distribution(p_get_diag_grads,
                                                  p_get_bucket_score,
                                                  aux_eval_ds, train_ds_mgr,
                                                  optimizer,
                                                  eval_ds)
          logging.info(raw_distribution)
          selected = np.argsort(
              raw_distribution)[::-1][:FLAGS.data_selection_size]
          new_distribution = np.zeros(100)
          new_distribution[selected] = 1.0
          logging.info(new_distribution)
          train_ds = train_ds_mgr.sampled_dataset(new_distribution)
          train_iter = iter(train_ds)

      # Shard data to devices and do a training step.
      with jax.profiler.StepTraceAnnotation('train', step_num=step):
        try:
          batch = common_utils.shard(jax.tree_map(np.asarray, next(train_iter)))
          optimizer, metrics = p_train_step(
              optimizer, batch, dropout_rng=dropout_rngs)
          train_metrics.append(metrics)
        except StopIteration:
          is_last_step = True

      # Quick indication that training is happening.
      logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step)
      for h in hooks:
        h(step)

      # Periodic metric handling.
      if (step - start_step) % FLAGS.eval_frequency == 0 or is_last_step:
        with report_progress.timed('training_metrics'):
          logging.info('Gathering training metrics.')
          train_metrics = common_utils.get_metrics(train_metrics)
          lr = train_metrics.pop('learning_rate').mean()
          metrics_sums = jax.tree_map(jnp.sum, train_metrics)
          denominator = metrics_sums.pop('denominator')
          summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
          summary['learning_rate'] = lr
          summary = {'train_' + k: v for k, v in summary.items()}
          writer.write_scalars(step, summary)
          train_metrics = []

        with report_progress.timed('eval'):
          eval_results = evaluate(
              p_eval_step=p_eval_step,
              target=optimizer.target,
              eval_ds=eval_ds,
              num_eval_steps=FLAGS.num_eval_steps)
          curr_eval_loss = eval_results['loss']
          writer.write_scalars(
              step, {'eval_' + k: v for k, v in eval_results.items()})

        if FLAGS.aux_eval_dataset:
          for aux_i, aux_eval_ds in enumerate(aux_datasets):
            with report_progress.timed('aux_eval'):
              eval_results = evaluate(
                  p_eval_step=p_eval_step,
                  target=optimizer.target,
                  eval_ds=aux_eval_ds,
                  num_eval_steps=FLAGS.num_eval_steps)
              writer.write_scalars(
                  step, {
                      'aux' + str(aux_i) + '_eval_' + k: v
                      for k, v in eval_results.items()
                  })

        if FLAGS.compute_bleu:
          with report_progress.timed('translate_and_bleu'):
            exemplars, bleu_score = translate_and_calculate_bleu(
                p_pred_step=p_pred_step,
                p_init_cache=p_init_cache,
                target=optimizer.target,
                predict_ds=predict_ds,
                decode_tokens=decode_tokens,
                max_predict_length=FLAGS.max_predict_length)
            writer.write_scalars(step, {'bleu': bleu_score})
            writer.write_texts(step, {'samples': exemplars})

      # Save a checkpoint on one host after every checkpoint_freq steps.
      save_checkpoint = ((step - start_step) % FLAGS.checkpoint_freq == 0 or
                         is_last_step)
      if FLAGS.save_checkpoints and save_checkpoint and jax.host_id() == 0:
        if curr_eval_loss < best_eval_loss:  # only save better checkpoints
          best_eval_loss = curr_eval_loss
          with report_progress.timed('checkpoint'):
            checkpoints.save_checkpoint(
                FLAGS.model_dir, jax_utils.unreplicate(optimizer),
                step, keep=FLAGS.chkpts_to_keep, overwrite=True)

      if is_last_step:
        break
示例#23
0
    def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs):
        """
        Run the MCMC samplers and collect samples.

        :param random.PRNGKey rng_key: Random number generator key to be used for the sampling.
            For multi-chains, a batch of `num_chains` keys can be supplied. If `rng_key`
            does not have batch_size, it will be split in to a batch of `num_chains` keys.
        :param args: Arguments to be provided to the :meth:`numpyro.infer.mcmc.MCMCKernel.init` method.
            These are typically the arguments needed by the `model`.
        :param extra_fields: Extra fields (aside from `z`, `diverging`) from :data:`numpyro.infer.mcmc.HMCState`
            to collect during the MCMC run.
        :type extra_fields: tuple or list
        :param init_params: Initial parameters to begin sampling. The type must be consistent
            with the input type to `potential_fn`.
        :param kwargs: Keyword arguments to be provided to the :meth:`numpyro.infer.mcmc.MCMCKernel.init`
            method. These are typically the keyword arguments needed by the `model`.

        .. note:: jax allows python code to continue even when the compiled code has not finished yet.
            This can cause troubles when trying to profile the code for speed.
            See https://jax.readthedocs.io/en/latest/async_dispatch.html and
            https://jax.readthedocs.io/en/latest/profiling.html for pointers on profiling jax programs.
        """
        self._args = args
        self._kwargs = kwargs
        init_state = self._get_cached_init_state(rng_key, args, kwargs)
        if self.num_chains > 1 and rng_key.ndim == 1:
            rng_key = random.split(rng_key, self.num_chains)

        if self._warmup_state is not None:
            self._set_collection_params(0, self.num_samples, self.num_samples)
            init_state = self._warmup_state._replace(rng_key=rng_key)

        chain_method = self.chain_method
        if chain_method == 'parallel' and xla_bridge.device_count(
        ) < self.num_chains:
            chain_method = 'sequential'
            warnings.warn(
                'There are not enough devices to run parallel chains: expected {} but got {}.'
                ' Chains will be drawn sequentially. If you are running MCMC in CPU,'
                ' consider to use `numpyro.set_host_device_count({})` at the beginning'
                ' of your program.'.format(self.num_chains,
                                           xla_bridge.device_count(),
                                           self.num_chains))

        if init_params is not None and self.num_chains > 1:
            prototype_init_val = tree_flatten(init_params)[0][0]
            if jnp.shape(prototype_init_val)[0] != self.num_chains:
                raise ValueError(
                    '`init_params` must have the same leading dimension'
                    ' as `num_chains`.')
        assert isinstance(extra_fields, (tuple, list))
        collect_fields = tuple(
            set((self._sample_field, ) + tuple(self._default_fields) +
                tuple(extra_fields)))
        partial_map_fn = partial(self._single_chain_mcmc,
                                 args=args,
                                 kwargs=kwargs,
                                 collect_fields=collect_fields)
        map_args = (rng_key, init_state, init_params)
        if self.num_chains == 1:
            states_flat, last_state = partial_map_fn(map_args)
            states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
        else:
            if chain_method == 'sequential':
                if self.progress_bar:
                    states, last_state = _laxmap(partial_map_fn, map_args)
                else:
                    states, last_state = lax.map(partial_map_fn, map_args)
            elif chain_method == 'parallel':
                states, last_state = pmap(partial_map_fn)(map_args)
                # TODO: remove when https://github.com/google/jax/issues/3597 is resolved
                states = device_put(states)
            else:
                assert chain_method == 'vectorized'
                states, last_state = partial_map_fn(map_args)
                # swap num_samples x num_chains to num_chains x num_samples
                states = tree_map(lambda x: jnp.swapaxes(x, 0, 1), states)
            states_flat = tree_map(
                lambda x: jnp.reshape(x, (-1, ) + x.shape[2:]), states)
        self._last_state = last_state
        self._states = states
        self._states_flat = states_flat
        self._set_collection_params()
示例#24
0
def train(optimizer: flax.optim.Optimizer,
          state: flax.nn.Collection,
          dataset_source: dataset_source_lib.DatasetSource,
          training_dir: str, num_epochs: int):
  """Trains the model.

  Args:
    optimizer: The optimizer targeting the model to train.
    state: Current state associated with the model (contains the batch norm MA).
    dataset_source: Container for the training dataset.
    training_dir: Parent directory where the tensorboard logs and model
      checkpoints should be saved.
   num_epochs: Number of epochs for which we want to train the model.
  """
  checkpoint_dir = os.path.join(training_dir, 'checkpoints')
  summary_writer = tensorboard.SummaryWriter(training_dir)
  if jax.host_id() != 0:  # Don't log if not first host.
    summary_writer.scalar = lambda *args: None
  prng_key = jax.random.PRNGKey(FLAGS.run_seed)

  if FLAGS.ema_decay:
    end_warmup_step = 1560
    moving_averages = efficientnet_optim.ExponentialMovingAverage(
        (optimizer.target, state), FLAGS.ema_decay, end_warmup_step)  # pytype:disable=wrong-arg-count

    def update_ema(optimizer, state, ema):
      step = optimizer.state.step
      return ema.update_moving_average((optimizer.target, state), step)

    pmapped_update_ema = jax.pmap(update_ema, axis_name='batch')
  else:
    pmapped_update_ema = moving_averages = None

  # Log initial results:
  if gfile.exists(checkpoint_dir):
    if FLAGS.ema_decay:
      optimizer, (state,
                  moving_averages), epoch_last_checkpoint = restore_checkpoint(
                      optimizer, (state, moving_averages), checkpoint_dir)
    else:
      optimizer, state, epoch_last_checkpoint = restore_checkpoint(
          optimizer, state, checkpoint_dir)
    # If last checkpoint was saved at the end of epoch n, then the first
    # training epochs to do when we resume training is n+1.
    initial_epoch = epoch_last_checkpoint + 1
    info = 'Resuming training from epoch {}'.format(initial_epoch)
    logging.info(info)
  else:
    initial_epoch = jnp.array(0, dtype=jnp.int32)
    logging.info('Starting training from scratch.')

  optimizer = jax_utils.replicate(optimizer)
  state = jax_utils.replicate(state)
  if FLAGS.ema_decay:
    moving_averages = jax_utils.replicate(moving_averages)

  if FLAGS.use_learning_rate_schedule:
    if FLAGS.lr_schedule == 'cosine':
      learning_rate_fn = get_cosine_schedule(num_epochs, FLAGS.learning_rate,
                                             dataset_source.num_training_obs,
                                             dataset_source.batch_size)
    elif FLAGS.lr_schedule == 'exponential':
      learning_rate_fn = get_exponential_schedule(
          num_epochs, FLAGS.learning_rate, dataset_source.num_training_obs,
          dataset_source.batch_size)
    else:
      raise ValueError('Wrong schedule: ' + FLAGS.lr_schedule)
  else:
    learning_rate_fn = lambda step: FLAGS.learning_rate

  # pmap the training and evaluation functions.
  pmapped_train_step = jax.pmap(
      functools.partial(
          train_step,
          learning_rate_fn=learning_rate_fn,
          l2_reg=FLAGS.weight_decay),
      axis_name='batch',
      donate_argnums=(0, 1))
  pmapped_eval_step = jax.pmap(eval_step, axis_name='batch')

  time_at_last_checkpoint = time.time()
  for epochs_id in range(initial_epoch, num_epochs):
    if epochs_id in FLAGS.additional_checkpoints_at_epochs:
      # To save additional checkpoints that will not be erase by later version,
      # we save them in a new directory.
      c_path = os.path.join(checkpoint_dir, 'additional_ckpt_' + str(epochs_id))
      save_checkpoint(optimizer, state, c_path, epochs_id)
    tick = time.time()

    optimizer, state, moving_averages = train_for_one_epoch(
        dataset_source, optimizer, state, prng_key, pmapped_train_step,
        pmapped_update_ema, moving_averages, summary_writer)

    tock = time.time()
    info = 'Epoch {} finished in {:.2f}s.'.format(epochs_id, tock - tick)
    logging.info(info)

    # Evaluate the model on the test set, and optionally the training set.
    if (epochs_id + 1) % FLAGS.evaluate_every == 0:
      info = 'Evaluating at end of epoch {} (0-indexed)'.format(epochs_id)
      logging.info(info)
      tick = time.time()
      current_step = int(optimizer.state.step[0])
      if FLAGS.also_eval_on_training_set:
        train_ds = dataset_source.get_train(use_augmentations=False)
        train_metrics = eval_on_dataset(
            optimizer.target, state, train_ds, pmapped_eval_step)
        for metric_name, metric_value in train_metrics.items():
          summary_writer.scalar('eval_on_train_' + metric_name,
                                metric_value, current_step)
        summary_writer.flush()

      if FLAGS.ema_decay:
        logging.info('Evaluating with EMA.')
        ema_model, ema_state = moving_averages.param_ema  # pytype:disable=attribute-error
        test_ds = dataset_source.get_test()
        test_metrics = eval_on_dataset(
            ema_model, ema_state, test_ds, pmapped_eval_step)
        for metric_name, metric_value in test_metrics.items():
          summary_writer.scalar('ema_test_' + metric_name,
                                metric_value, current_step)
        summary_writer.flush()

      else:
        test_ds = dataset_source.get_test()
        test_metrics = eval_on_dataset(
            optimizer.target, state, test_ds, pmapped_eval_step)
        for metric_name, metric_value in test_metrics.items():
          summary_writer.scalar('test_' + metric_name,
                                metric_value, current_step)
        summary_writer.flush()

        tock = time.time()
        info = 'Evaluated model in {:.2f}.'.format(tock - tick)
        logging.info(info)

    # Save new checkpoint if the last one was saved more than
    # `save_progress_seconds` seconds ago.
    sec_from_last_ckpt = time.time() - time_at_last_checkpoint
    if sec_from_last_ckpt > FLAGS.save_progress_seconds:
      if FLAGS.ema_decay:
        save_checkpoint(
            optimizer, (state, moving_averages), checkpoint_dir, epochs_id)
      else:
        save_checkpoint(optimizer, state, checkpoint_dir, epochs_id)
      time_at_last_checkpoint = time.time()
      logging.info('Saved checkpoint.')

  # Always save final checkpoint
  if FLAGS.ema_decay:
    save_checkpoint(
        optimizer, (state, moving_averages), checkpoint_dir, epochs_id)
  else:
    save_checkpoint(optimizer, state, checkpoint_dir, epochs_id)
示例#25
0
def pmap(func,
         dyn_vars=None,
         axis_name=None,
         in_axes=0,
         out_axes=0,
         static_broadcasted_argnums=(),
         devices=None,
         backend=None,
         axis_size=None,
         donate_argnums=(),
         global_arg_shapes=None,
         reduce_func=None):
    """Parallel compilation for class objects.

  Parallel compile a function or a module to run on multiple devices in parallel.

  Parameters
  ----------
  func
  axis_name
  in_axes
  out_axes
  static_broadcasted_argnums
  devices
  backend
  axis_size
  donate_argnums
  global_arg_shapes

  Returns
  -------


  Examples
  --------


  """
    from brainpy.building.brainobjects import DynamicalSystem

    if isinstance(func, DynamicalSystem):
        if len(func.steps):  # DynamicalSystem has step functions

            # dynamical variables
            all_vars = (dyn_vars or func.vars().unique())
            dyn_vars = TensorCollector()
            rand_vars = TensorCollector()
            for key, val in all_vars.items():
                if isinstance(val, RandomState):
                    rand_vars[key] = val
                else:
                    dyn_vars[key] = val

            # reduce function
            if reduce_func is None:
                reduce_func = jnp.concatenate

            # static broadcast-ed arguments
            if static_broadcasted_argnums is None:
                static_broadcasted_argnums = ()
            elif isinstance(static_broadcasted_argnums, int):
                static_broadcasted_argnums = (static_broadcasted_argnums + 2, )
            elif isinstance(static_broadcasted_argnums, (tuple, list)):
                static_broadcasted_argnums = tuple(
                    argnum + 2 for argnum in static_broadcasted_argnums)
            assert isinstance(static_broadcasted_argnums, (tuple, list))

            # jit functions
            for key in func.steps.keys():
                step = func.steps[key]
                func.steps[key] = _make_pmap(
                    dyn_vars=dyn_vars,
                    rand_vars=rand_vars,
                    func=step,
                    axis_name=axis_name,
                    in_axes=in_axes,
                    out_axes=out_axes,
                    static_broadcasted_argnums=static_broadcasted_argnums,
                    devices=devices,
                    backend=backend,
                    axis_size=axis_size,
                    donate_argnums=donate_argnums,
                    global_arg_shapes=global_arg_shapes,
                    reduce_func=reduce_func,
                    f_name=key)
            return func

    if callable(func):
        if dyn_vars is not None:
            dyn_vars = dyn_vars
        elif isinstance(func, Base):  # Base has '__call__()' implementation
            dyn_vars = func.vars().unique()
        elif hasattr(func, '__self__'):
            if isinstance(func.__self__, Base):
                dyn_vars = func.__self__.vars().unique()

        if dyn_vars is None:
            return jax.pmap(
                func,
                axis_name=axis_name,
                in_axes=in_axes,
                out_axes=out_axes,
                static_broadcasted_argnums=static_broadcasted_argnums,
                devices=devices,
                backend=backend,
                axis_size=axis_size,
                donate_argnums=donate_argnums,
                global_arg_shapes=global_arg_shapes)
        else:
            # dynamical variables
            dyn_vars = TensorCollector()
            rand_vars = TensorCollector()
            for key, val in dyn_vars.items():
                if isinstance(val, RandomState):
                    rand_vars[key] = val
                else:
                    dyn_vars[key] = val

            # static broadcast-ed arguments
            if static_broadcasted_argnums is None:
                static_broadcasted_argnums = ()
            elif isinstance(static_broadcasted_argnums, int):
                static_broadcasted_argnums = (static_broadcasted_argnums + 2, )
            elif isinstance(static_broadcasted_argnums, (tuple, list)):
                static_broadcasted_argnums = tuple(
                    argnum + 2 for argnum in static_broadcasted_argnums)
            assert isinstance(static_broadcasted_argnums, (tuple, list))

            # reduce function
            if reduce_func is None:
                reduce_func = jnp.concatenate

            # jit function
            func.__call__ = _make_pmap(
                dyn_vars=dyn_vars,
                rand_vars=rand_vars,
                func=func,
                axis_name=axis_name,
                in_axes=in_axes,
                out_axes=out_axes,
                static_broadcasted_argnums=static_broadcasted_argnums,
                devices=devices,
                backend=backend,
                axis_size=axis_size,
                donate_argnums=donate_argnums,
                global_arg_shapes=global_arg_shapes,
                reduce_func=reduce_func)
            return func

    else:
        raise errors.BrainPyError(
            f'Only support instance of {Base.__name__}, or a callable function, '
            f'but we got {type(func)}.')
示例#26
0
def train():
    """Train model."""
    batch_size = FLAGS.batch_size
    n_devices = jax.device_count()
    if jax.host_count() > 1:
        raise ValueError(
            'PixelCNN++ example should not be run on more than 1 host'
            ' (for now)')
    if batch_size % n_devices > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')

    train_summary_writer, eval_summary_writer = get_summary_writers()

    # Load dataset
    data_source = input_pipeline.DataSource(train_batch_size=batch_size,
                                            eval_batch_size=batch_size)
    train_ds = data_source.train_ds
    eval_ds = data_source.eval_ds

    # Create dataset batch iterators
    train_iter = iter(train_ds)
    eval_iter = iter(eval_ds)

    # Compute steps per epoch and nb of eval steps
    steps_per_epoch = data_source.TRAIN_IMAGES // batch_size
    steps_per_eval = data_source.EVAL_IMAGES // batch_size
    steps_per_checkpoint = steps_per_epoch * 10
    num_steps = steps_per_epoch * FLAGS.num_epochs

    # Create the model using data-dependent initialization. Don't shard the init
    # batch.
    assert FLAGS.init_batch_size <= batch_size
    init_batch = next(train_iter)['image']._numpy()[:FLAGS.init_batch_size]

    rng = random.PRNGKey(FLAGS.rng)
    rng, init_rng = random.split(rng)
    rng, dropout_rng = random.split(rng)

    initial_variables = model().init(
        {
            'params': init_rng,
            'dropout': dropout_rng
        }, init_batch)['params']
    optimizer_def = optim.Adam(learning_rate=FLAGS.learning_rate,
                               beta1=0.95,
                               beta2=0.9995)
    optimizer = optimizer_def.create(initial_variables)

    optimizer, ema = restore_checkpoint(optimizer, initial_variables)
    ema = initial_variables
    step_offset = int(optimizer.state.step)

    optimizer, ema = jax_utils.replicate((optimizer, ema))

    # Learning rate schedule
    learning_rate_fn = lambda step: FLAGS.learning_rate * FLAGS.lr_decay**step

    # pmap the train and eval functions
    p_train_step = jax.pmap(partial(train_step,
                                    learning_rate_fn=learning_rate_fn),
                            axis_name='batch')
    p_eval_step = jax.pmap(eval_step, axis_name='batch')

    # Gather metrics
    train_metrics = []

    for step, batch in zip(range(step_offset, num_steps), train_iter):
        # Load and shard the TF batch
        batch = load_and_shard_tf_batch(batch)

        # Generate a PRNG key that will be rolled into the batch.
        rng, step_rng = random.split(rng)
        sharded_rngs = common_utils.shard_prng_key(step_rng)

        # Train step
        optimizer, ema, metrics = p_train_step(optimizer, ema, batch,
                                               sharded_rngs)
        train_metrics.append(metrics)

        if (step + 1) % steps_per_epoch == 0:
            epoch = step // steps_per_epoch
            # We've finished an epoch
            train_metrics = common_utils.get_metrics(train_metrics)
            # Get training epoch summary for logging
            train_summary = jax.tree_map(lambda x: x.mean(), train_metrics)
            # Send stats to Tensorboard
            for key, vals in train_metrics.items():
                for i, val in enumerate(vals):
                    train_summary_writer.scalar(key, val,
                                                step - len(vals) + i + 1)
            # Reset train metrics
            train_metrics = []

            # Evaluation
            eval_metrics = []
            for _ in range(steps_per_eval):
                eval_batch = next(eval_iter)
                # Load and shard the TF batch
                eval_batch = load_and_shard_tf_batch(eval_batch)
                # Step
                metrics = p_eval_step(ema, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            # Get eval epoch summary for logging
            eval_summary = jax.tree_map(lambda x: x.mean(), eval_metrics)

            # Log epoch summary
            logging.info('Epoch %d: TRAIN loss=%.6f, EVAL loss=%.6f', epoch,
                         train_summary['loss'], eval_summary['loss'])

            eval_summary_writer.scalar('loss', eval_summary['loss'], step)
            train_summary_writer.flush()
            eval_summary_writer.flush()

        if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps:
            save_checkpoint(optimizer, ema, step)
示例#27
0
    # Setup optimizer
    optimizer = Adam(
        learning_rate=training_args.learning_rate,
        weight_decay=training_args.weight_decay,
        beta1=training_args.adam_beta1,
        beta2=training_args.adam_beta2,
    ).create(model.params)

    # Create learning rate scheduler
    # warmup_steps = 0 causes the Flax optimizer to return NaNs; warmup_steps = 1 is functionally equivalent.
    lr_scheduler_fn = create_learning_rate_scheduler(
        base_learning_rate=training_args.learning_rate, warmup_steps=min(training_args.warmup_steps, 1)
    )

    # Create parallel version of the training and evaluation steps
    p_training_step = jax.pmap(training_step, "batch", donate_argnums=(0,))
    p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))

    # Replicate the optimizer on each device
    optimizer = jax_utils.replicate(optimizer)

    # Store some constant
    nb_epochs = int(training_args.num_train_epochs)
    batch_size = int(training_args.train_batch_size)
    eval_batch_size = int(training_args.eval_batch_size)

    epochs = tqdm(range(nb_epochs), desc=f"Epoch ... (1/{nb_epochs})", position=0)
    for epoch in epochs:

        # ======================== Training ================================
        # Create sampling rng
def main():
    # region Argument parsing
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(
            json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses(
        )
    # endregion

    # region Logging
    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    # Setup logging, we only want one process per machine to log things on the screen.
    logger.setLevel(logging.INFO if jax.process_index() ==
                    0 else logging.ERROR)
    if jax.process_index() == 0:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()
    # endregion

    # Handle the repository creation
    if training_args.push_to_hub:
        if training_args.hub_model_id is None:
            repo_name = get_full_repo_name(Path(
                training_args.output_dir).absolute().name,
                                           token=training_args.hub_token)
        else:
            repo_name = training_args.hub_model_id
        repo = Repository(training_args.output_dir, clone_from=repo_name)

    # region Load Data
    # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub).
    #
    # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
    # 'text' is found. You can easily tweak this behavior (see below).
    #
    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
    # download the dataset.
    if data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        raw_datasets = load_dataset(data_args.dataset_name,
                                    data_args.dataset_config_name,
                                    cache_dir=model_args.cache_dir)
    else:
        # Loading the dataset from local csv or json file.
        data_files = {}
        if data_args.train_file is not None:
            data_files["train"] = data_args.train_file
            extension = data_args.train_file.split(".")[-1]

        if data_args.validation_file is not None:
            data_files["validation"] = data_args.validation_file
            extension = data_args.validation_file.split(".")[-1]
        if data_args.test_file is not None:
            data_files["test"] = data_args.test_file
            extension = data_args.test_file.split(".")[-1]
        raw_datasets = load_dataset(extension,
                                    data_files=data_files,
                                    field="data",
                                    cache_dir=model_args.cache_dir)
    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
    # https://huggingface.co/docs/datasets/loading_datasets.html.
    # endregion

    # region Load pretrained model and tokenizer
    #
    # Load pretrained model and tokenizer
    config = AutoConfig.from_pretrained(
        model_args.config_name
        if model_args.config_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name
        if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        use_fast=True,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    # endregion

    # region Tokenizer check: this script requires a fast tokenizer.
    if not isinstance(tokenizer, PreTrainedTokenizerFast):
        raise ValueError(
            "This example script only works for models that have a fast tokenizer. Checkout the big table of models "
            "at https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet this "
            "requirement")
    # endregion

    # region Preprocessing the datasets
    # Preprocessing is slightly different for training and evaluation.
    if training_args.do_train:
        column_names = raw_datasets["train"].column_names
    elif training_args.do_eval:
        column_names = raw_datasets["validation"].column_names
    else:
        column_names = raw_datasets["test"].column_names
    question_column_name = "question" if "question" in column_names else column_names[
        0]
    context_column_name = "context" if "context" in column_names else column_names[
        1]
    answer_column_name = "answers" if "answers" in column_names else column_names[
        2]

    # Padding side determines if we do (question|context) or (context|question).
    pad_on_right = tokenizer.padding_side == "right"

    if data_args.max_seq_length > tokenizer.model_max_length:
        logger.warning(
            f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
            f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
        )
    max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)

    # Training preprocessing
    def prepare_train_features(examples):
        # Some of the questions have lots of whitespace on the left, which is not useful and will make the
        # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
        # left whitespace
        examples[question_column_name] = [
            q.lstrip() for q in examples[question_column_name]
        ]

        # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
        # in one example possible giving several features when a context is long, each of those features having a
        # context that overlaps a bit the context of the previous feature.
        tokenized_examples = tokenizer(
            examples[
                question_column_name if pad_on_right else context_column_name],
            examples[
                context_column_name if pad_on_right else question_column_name],
            truncation="only_second" if pad_on_right else "only_first",
            max_length=max_seq_length,
            stride=data_args.doc_stride,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            padding="max_length",
        )

        # Since one example might give us several features if it has a long context, we need a map from a feature to
        # its corresponding example. This key gives us just that.
        sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
        # The offset mappings will give us a map from token to character position in the original context. This will
        # help us compute the start_positions and end_positions.
        offset_mapping = tokenized_examples.pop("offset_mapping")

        # Let's label those examples!
        tokenized_examples["start_positions"] = []
        tokenized_examples["end_positions"] = []

        for i, offsets in enumerate(offset_mapping):
            # We will label impossible answers with the index of the CLS token.
            input_ids = tokenized_examples["input_ids"][i]
            cls_index = input_ids.index(tokenizer.cls_token_id)

            # Grab the sequence corresponding to that example (to know what is the context and what is the question).
            sequence_ids = tokenized_examples.sequence_ids(i)

            # One example can give several spans, this is the index of the example containing this span of text.
            sample_index = sample_mapping[i]
            answers = examples[answer_column_name][sample_index]
            # If no answers are given, set the cls_index as answer.
            if len(answers["answer_start"]) == 0:
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # Start/end character index of the answer in the text.
                start_char = answers["answer_start"][0]
                end_char = start_char + len(answers["text"][0])

                # Start token index of the current span in the text.
                token_start_index = 0
                while sequence_ids[token_start_index] != (1 if pad_on_right
                                                          else 0):
                    token_start_index += 1

                # End token index of the current span in the text.
                token_end_index = len(input_ids) - 1
                while sequence_ids[token_end_index] != (1 if pad_on_right else
                                                        0):
                    token_end_index -= 1

                # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
                if not (offsets[token_start_index][0] <= start_char
                        and offsets[token_end_index][1] >= end_char):
                    tokenized_examples["start_positions"].append(cls_index)
                    tokenized_examples["end_positions"].append(cls_index)
                else:
                    # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                    # Note: we could go after the last offset if the answer is the last word (edge case).
                    while token_start_index < len(offsets) and offsets[
                            token_start_index][0] <= start_char:
                        token_start_index += 1
                    tokenized_examples["start_positions"].append(
                        token_start_index - 1)
                    while offsets[token_end_index][1] >= end_char:
                        token_end_index -= 1
                    tokenized_examples["end_positions"].append(
                        token_end_index + 1)

        return tokenized_examples

    processed_raw_datasets = dict()
    if training_args.do_train:
        if "train" not in raw_datasets:
            raise ValueError("--do_train requires a train dataset")
        train_dataset = raw_datasets["train"]
        if data_args.max_train_samples is not None:
            # We will select sample from whole data if agument is specified
            train_dataset = train_dataset.select(
                range(data_args.max_train_samples))
        # Create train feature from dataset
        train_dataset = train_dataset.map(
            prepare_train_features,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not data_args.overwrite_cache,
        )
        if data_args.max_train_samples is not None:
            # Number of samples might increase during Feature Creation, We select only specified max samples
            train_dataset = train_dataset.select(
                range(data_args.max_train_samples))
        processed_raw_datasets["train"] = train_dataset

    # Validation preprocessing
    def prepare_validation_features(examples):
        # Some of the questions have lots of whitespace on the left, which is not useful and will make the
        # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
        # left whitespace
        examples[question_column_name] = [
            q.lstrip() for q in examples[question_column_name]
        ]

        # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
        # in one example possible giving several features when a context is long, each of those features having a
        # context that overlaps a bit the context of the previous feature.
        tokenized_examples = tokenizer(
            examples[
                question_column_name if pad_on_right else context_column_name],
            examples[
                context_column_name if pad_on_right else question_column_name],
            truncation="only_second" if pad_on_right else "only_first",
            max_length=max_seq_length,
            stride=data_args.doc_stride,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            padding="max_length",
        )

        # Since one example might give us several features if it has a long context, we need a map from a feature to
        # its corresponding example. This key gives us just that.
        sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")

        # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the
        # corresponding example_id and we will store the offset mappings.
        tokenized_examples["example_id"] = []

        for i in range(len(tokenized_examples["input_ids"])):
            # Grab the sequence corresponding to that example (to know what is the context and what is the question).
            sequence_ids = tokenized_examples.sequence_ids(i)
            context_index = 1 if pad_on_right else 0

            # One example can give several spans, this is the index of the example containing this span of text.
            sample_index = sample_mapping[i]
            tokenized_examples["example_id"].append(
                examples["id"][sample_index])

            # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
            # position is part of the context or not.
            tokenized_examples["offset_mapping"][i] = [
                (o if sequence_ids[k] == context_index else None)
                for k, o in enumerate(tokenized_examples["offset_mapping"][i])
            ]

        return tokenized_examples

    if training_args.do_eval:
        if "validation" not in raw_datasets:
            raise ValueError("--do_eval requires a validation dataset")
        eval_examples = raw_datasets["validation"]
        if data_args.max_eval_samples is not None:
            # We will select sample from whole data
            eval_examples = eval_examples.select(
                range(data_args.max_eval_samples))
        # Validation Feature Creation
        eval_dataset = eval_examples.map(
            prepare_validation_features,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not data_args.overwrite_cache,
        )
        if data_args.max_eval_samples is not None:
            # During Feature creation dataset samples might increase, we will select required samples again
            eval_dataset = eval_dataset.select(
                range(data_args.max_eval_samples))
        processed_raw_datasets["validation"] = eval_dataset

    if training_args.do_predict:
        if "test" not in raw_datasets:
            raise ValueError("--do_predict requires a test dataset")
        predict_examples = raw_datasets["test"]
        if data_args.max_predict_samples is not None:
            # We will select sample from whole data
            predict_examples = predict_examples.select(
                range(data_args.max_predict_samples))
        # Predict Feature Creation
        predict_dataset = predict_examples.map(
            prepare_validation_features,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not data_args.overwrite_cache,
        )
        if data_args.max_predict_samples is not None:
            # During Feature creation dataset samples might increase, we will select required samples again
            predict_dataset = predict_dataset.select(
                range(data_args.max_predict_samples))
        processed_raw_datasets["test"] = predict_dataset
    # endregion

    # region Metrics and Post-processing:
    def post_processing_function(examples,
                                 features,
                                 predictions,
                                 stage="eval"):
        # Post-processing: we match the start logits and end logits to answers in the original context.
        predictions = postprocess_qa_predictions(
            examples=examples,
            features=features,
            predictions=predictions,
            version_2_with_negative=data_args.version_2_with_negative,
            n_best_size=data_args.n_best_size,
            max_answer_length=data_args.max_answer_length,
            null_score_diff_threshold=data_args.null_score_diff_threshold,
            output_dir=training_args.output_dir,
            prefix=stage,
        )
        # Format the result to the format the metric expects.
        if data_args.version_2_with_negative:
            formatted_predictions = [{
                "id": k,
                "prediction_text": v,
                "no_answer_probability": 0.0
            } for k, v in predictions.items()]
        else:
            formatted_predictions = [{
                "id": k,
                "prediction_text": v
            } for k, v in predictions.items()]

        references = [{
            "id": ex["id"],
            "answers": ex[answer_column_name]
        } for ex in examples]
        return EvalPrediction(predictions=formatted_predictions,
                              label_ids=references)

    metric = load_metric(
        "squad_v2" if data_args.version_2_with_negative else "squad")

    def compute_metrics(p: EvalPrediction):
        return metric.compute(predictions=p.predictions,
                              references=p.label_ids)

    # Create and fill numpy array of size len_of_validation_data * max_length_of_output_tensor
    def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
        """
        Create and fill numpy array of size len_of_validation_data * max_length_of_output_tensor

        Args:
            start_or_end_logits(:obj:`tensor`):
                This is the output predictions of the model. We can only enter either start or end logits.
            eval_dataset: Evaluation dataset
            max_len(:obj:`int`):
                The maximum length of the output tensor. ( See the model.eval() part for more details )
        """

        step = 0
        # create a numpy array and fill it with -100.
        logits_concat = np.full((len(dataset), max_len),
                                -100,
                                dtype=np.float64)
        # Now since we have create an array now we will populate it with the outputs of the model.
        for i, output_logit in enumerate(
                start_or_end_logits):  # populate columns
            # We have to fill it such that we have to take the whole tensor and replace it on the newly created array
            # And after every iteration we have to change the step

            batch_size = output_logit.shape[0]
            cols = output_logit.shape[1]

            if step + batch_size < len(dataset):
                logits_concat[step:step + batch_size, :cols] = output_logit
            else:
                logits_concat[step:, :cols] = output_logit[:len(dataset) -
                                                           step]

            step += batch_size

        return logits_concat

    # endregion

    # region Training steps and logging init
    train_dataset = processed_raw_datasets["train"]
    eval_dataset = processed_raw_datasets["validation"]

    # Log a few random samples from the training set:
    for index in random.sample(range(len(train_dataset)), 3):
        logger.info(
            f"Sample {index} of the training set: {train_dataset[index]}.")

    # Define a summary writer
    summary_writer = tensorboard.SummaryWriter(training_args.output_dir)
    summary_writer.hparams({
        **training_args.to_dict(),
        **vars(model_args),
        **vars(data_args)
    })

    def write_train_metric(summary_writer, train_metrics, train_time, step):
        summary_writer.scalar("train_time", train_time, step)

        train_metrics = get_metrics(train_metrics)
        for key, vals in train_metrics.items():
            tag = f"train_{key}"
            for i, val in enumerate(vals):
                summary_writer.scalar(tag, val, step - len(vals) + i + 1)

    def write_eval_metric(summary_writer, eval_metrics, step):
        for metric_name, value in eval_metrics.items():
            summary_writer.scalar(f"eval_{metric_name}", value, step)

    num_epochs = int(training_args.num_train_epochs)
    rng = jax.random.PRNGKey(training_args.seed)
    dropout_rngs = jax.random.split(rng, jax.local_device_count())

    train_batch_size = training_args.per_device_train_batch_size * jax.local_device_count(
    )
    eval_batch_size = training_args.per_device_eval_batch_size * jax.local_device_count(
    )
    # endregion

    # region Load model
    model = FlaxAutoModelForQuestionAnswering.from_pretrained(
        model_args.model_name_or_path,
        config=config,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
        seed=training_args.seed,
        dtype=getattr(jnp, model_args.dtype),
    )

    learning_rate_fn = create_learning_rate_fn(
        len(train_dataset),
        train_batch_size,
        training_args.num_train_epochs,
        training_args.warmup_steps,
        training_args.learning_rate,
    )

    state = create_train_state(model,
                               learning_rate_fn,
                               num_labels=max_seq_length,
                               training_args=training_args)

    # endregion

    # region Define train step functions
    def train_step(
            state: train_state.TrainState, batch: Dict[str, Array],
            dropout_rng: PRNGKey) -> Tuple[train_state.TrainState, float]:
        """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`."""
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
        start_positions = batch.pop("start_positions")
        end_positions = batch.pop("end_positions")
        targets = (start_positions, end_positions)

        def loss_fn(params):
            logits = state.apply_fn(**batch,
                                    params=params,
                                    dropout_rng=dropout_rng,
                                    train=True)
            loss = state.loss_fn(logits, targets)
            return loss

        grad_fn = jax.value_and_grad(loss_fn)
        loss, grad = grad_fn(state.params)
        grad = jax.lax.pmean(grad, "batch")
        new_state = state.apply_gradients(grads=grad)
        metrics = jax.lax.pmean(
            {
                "loss": loss,
                "learning_rate": learning_rate_fn(state.step)
            },
            axis_name="batch")
        return new_state, metrics, new_dropout_rng

    p_train_step = jax.pmap(train_step,
                            axis_name="batch",
                            donate_argnums=(0, ))

    # endregion

    # region Define eval step functions
    def eval_step(state, batch):
        logits = state.apply_fn(**batch, params=state.params, train=False)
        return state.logits_fn(logits)

    p_eval_step = jax.pmap(eval_step, axis_name="batch")
    # endregion

    # region Define train and eval loop
    logger.info(f"===== Starting training ({num_epochs} epochs) =====")
    train_time = 0

    # make sure weights are replicated on each device
    state = replicate(state)

    train_time = 0
    step_per_epoch = len(train_dataset) // train_batch_size
    total_steps = step_per_epoch * num_epochs
    epochs = tqdm(range(num_epochs),
                  desc=f"Epoch ... (1/{num_epochs})",
                  position=0)
    for epoch in epochs:

        train_start = time.time()
        train_metrics = []

        # Create sampling rng
        rng, input_rng = jax.random.split(rng)

        # train
        for step, batch in enumerate(
                tqdm(
                    train_data_collator(input_rng, train_dataset,
                                        train_batch_size),
                    total=step_per_epoch,
                    desc="Training...",
                    position=1,
                ),
                1,
        ):
            state, train_metric, dropout_rngs = p_train_step(
                state, batch, dropout_rngs)
            train_metrics.append(train_metric)

            cur_step = epoch * step_per_epoch + step

            if cur_step % training_args.logging_steps == 0 and cur_step > 0:
                # Save metrics
                train_metric = unreplicate(train_metric)
                train_time += time.time() - train_start
                if jax.process_index() == 0:
                    write_train_metric(summary_writer, train_metrics,
                                       train_time, cur_step)

                epochs.write(
                    f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
                )

                train_metrics = []

            if (training_args.do_eval
                    and (cur_step % training_args.eval_steps == 0
                         or cur_step % step_per_epoch == 0) and cur_step > 0):

                eval_metrics = {}
                all_start_logits = []
                all_end_logits = []
                # evaluate
                for batch in tqdm(
                        eval_data_collator(eval_dataset, eval_batch_size),
                        total=len(eval_dataset) // eval_batch_size,
                        desc="Evaluating ...",
                        position=2,
                ):
                    _ = batch.pop("example_id")
                    _ = batch.pop("offset_mapping")
                    predictions = p_eval_step(state, batch)
                    start_logits = np.array(
                        [pred for pred in chain(*predictions[0])])
                    end_logits = np.array(
                        [pred for pred in chain(*predictions[1])])
                    all_start_logits.append(start_logits)
                    all_end_logits.append(end_logits)

                # evaluate also on leftover examples (not divisible by batch_size)
                num_leftover_samples = len(eval_dataset) % eval_batch_size

                # make sure leftover batch is evaluated on one device
                if num_leftover_samples > 0 and jax.process_index() == 0:
                    # take leftover samples
                    batch = eval_dataset[-num_leftover_samples:]
                    batch = {k: np.array(v) for k, v in batch.items()}
                    _ = batch.pop("example_id")
                    _ = batch.pop("offset_mapping")

                    predictions = eval_step(unreplicate(state), batch)
                    start_logits = np.array([pred for pred in predictions[0]])
                    end_logits = np.array([pred for pred in predictions[1]])
                    all_start_logits.append(start_logits)
                    all_end_logits.append(end_logits)

                max_len = max([x.shape[1] for x in all_start_logits
                               ])  # Get the max_length of the tensor

                # concatenate the numpy array
                start_logits_concat = create_and_fill_np_array(
                    all_start_logits, eval_dataset, max_len)
                end_logits_concat = create_and_fill_np_array(
                    all_end_logits, eval_dataset, max_len)

                # delete the list of numpy arrays
                del all_start_logits
                del all_end_logits
                outputs_numpy = (start_logits_concat, end_logits_concat)
                prediction = post_processing_function(eval_examples,
                                                      eval_dataset,
                                                      outputs_numpy)
                eval_metrics = compute_metrics(prediction)

                logger.info(
                    f"Step... ({cur_step}/{total_steps} | Evaluation metrics: {eval_metrics})"
                )

                if jax.process_index() == 0:
                    write_eval_metric(summary_writer, eval_metrics, cur_step)

            if (cur_step % training_args.save_steps == 0
                    and cur_step > 0) or (cur_step == total_steps):
                # save checkpoint after each epoch and push checkpoint to the hub
                if jax.process_index() == 0:
                    params = jax.device_get(unreplicate(state.params))
                    model.save_pretrained(training_args.output_dir,
                                          params=params)
                    tokenizer.save_pretrained(training_args.output_dir)
                    if training_args.push_to_hub:
                        repo.push_to_hub(
                            commit_message=
                            f"Saving weights and logs of step {cur_step}",
                            blocking=False)
        epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}"
示例#29
0
def device_broadcast(x, num_devices):
  """Broadcast a value to all devices."""
  return jax.pmap(lambda _: x)(jnp.arange(num_devices))
示例#30
0
def train_model():
  # Initialize training directory
  dirname, tf_writer = get_dirname_tfwriter(args)

  # Initialize data, model, losses and metrics
  (train_set, test_set, net_apply, params, net_state, key, log_likelihood_fn,
   log_prior_fn, _, predict_fn, ensemble_upd_fn, metrics_fns,
   tabulate_metrics) = script_utils.get_data_model_fns(args)

  # Initialize step-size schedule and optimizer
  num_batches, total_steps = script_utils.get_num_batches_total_steps(
      args, train_set)
  num_devices = len(jax.devices())
  lr_schedule = get_lr_schedule(num_batches, args)
  preconditioner = get_preconditioner(args)
  optimizer = sgmcmc.sgld_gradient_update(
      lr_schedule,
      momentum_decay=args.momentum_decay,
      seed=args.seed,
      preconditioner=preconditioner)

  # Initialize variables
  opt_state = optimizer.init(params)
  net_state = jax.pmap(lambda _: net_state)(jnp.arange(num_devices))
  key = jax.random.split(key, num_devices)
  init_dict = checkpoint_utils.make_sgmcmc_checkpoint_dict(
      -1, params, net_state, opt_state, key, 0, None, None)
  init_dict = script_utils.get_initialization_dict(dirname, args, init_dict)
  (start_iteration, params, net_state, opt_state, key, num_ensembled, _,
   ensemble_predictions) = (
       checkpoint_utils.parse_sgmcmc_checkpoint_dict(init_dict))
  start_iteration += 1

  # Define train epoch
  sgmcmc_train_epoch = script_utils.time_fn(
      train_utils.make_sgd_train_epoch(net_apply, log_likelihood_fn,
                                       log_prior_fn, optimizer, num_batches))

  # Train
  for iteration in range(start_iteration, args.num_epochs):

    (params, net_state, opt_state, logprob_avg, key), iteration_time = (
        sgmcmc_train_epoch(params, net_state, opt_state, train_set, key))

    is_evaluation_epoch, is_ensembling_epoch, is_save_epoch = (
        is_eval_ens_save_epoch(iteration, args))

    # Evaluate the model
    train_stats, test_stats = {"log_prob": logprob_avg}, {}
    if is_evaluation_epoch or is_ensembling_epoch:
      _, test_predictions, train_predictions, test_stats, train_stats_ = (
          script_utils.evaluate(net_apply, params, net_state, train_set,
                                test_set, predict_fn, metrics_fns,
                                log_prior_fn))
      train_stats.update(train_stats_)

    # Ensemble predictions
    if is_ensembling_epoch:
      ensemble_predictions = ensemble_upd_fn(ensemble_predictions,
                                             num_ensembled, test_predictions)
      ensemble_stats = train_utils.evaluate_metrics(ensemble_predictions,
                                                    test_set[1], metrics_fns)
      num_ensembled += 1
    else:
      ensemble_stats = {}
      test_predictions = None

    # Save checkpoint
    if is_save_epoch:
      checkpoint_name = checkpoint_utils.make_checkpoint_name(iteration)
      checkpoint_path = os.path.join(dirname, checkpoint_name)
      checkpoint_dict = checkpoint_utils.make_sgmcmc_checkpoint_dict(
          iteration, params, net_state, opt_state, key, num_ensembled,
          test_predictions, ensemble_predictions)
      checkpoint_utils.save_checkpoint(checkpoint_path, checkpoint_dict)

    # Log results
    other_logs = script_utils.get_common_logs(iteration, iteration_time, args)
    other_logs["hypers/step_size"] = lr_schedule(opt_state.count)
    other_logs["hypers/momentum"] = args.momentum_decay
    other_logs["telemetry/num_ensembled"] = num_ensembled
    logging_dict = logging_utils.make_logging_dict(train_stats, test_stats,
                                                   ensemble_stats)
    logging_dict.update(other_logs)
    script_utils.write_to_tensorboard(tf_writer, logging_dict, iteration)

    tabulate_dict = script_utils.get_tabulate_dict(tabulate_metrics,
                                                   logging_dict)
    tabulate_dict["lr"] = lr_schedule(opt_state.count)
    table = logging_utils.make_table(tabulate_dict, iteration - start_iteration,
                                     args.tabulate_freq)
    print(table)