def train_model(): subdirname = ( "model_{}_wd_{}_stepsize_{}_trajlen_{}_burnin_{}_{}_mh_{}_temp_{}_" "seed_{}".format( args.model_name, args.weight_decay, args.step_size, args.trajectory_len, args.num_burn_in_iterations, args.burn_in_step_size_factor, not args.no_mh, args.temperature, args.seed )) dirname = os.path.join(args.dir, subdirname) os.makedirs(dirname, exist_ok=True) tf_writer = tf.summary.create_file_writer(dirname) cmd_args_utils.save_cmd(dirname, tf_writer) num_devices = len(jax.devices()) dtype = jnp.float64 if args.use_float64 else jnp.float32 train_set, test_set, num_classes = data.make_ds_pmap_fullbatch( args.dataset_name, dtype) net_apply, net_init = models.get_model(args.model_name, num_classes) checkpoint_dict, status = checkpoint_utils.initialize( dirname, args.init_checkpoint) if status == checkpoint_utils.InitStatus.LOADED_PREEMPTED: print("Continuing the run from the last saved checkpoint") (start_iteration, params, net_state, key, step_size, _, num_ensembled, ensemble_predicted_probs) = ( checkpoint_utils.parse_hmc_checkpoint_dict(checkpoint_dict)) else: key, net_init_key = jax.random.split(jax.random.PRNGKey(args.seed), 2) start_iteration = 0 num_ensembled = 0 ensemble_predicted_probs = None step_size = args.step_size if status == checkpoint_utils.InitStatus.INIT_CKPT: print("Resuming the run from the provided init_checkpoint") _, params, net_state, _, _, _, _, _ = ( checkpoint_utils.parse_hmc_checkpoint_dict(checkpoint_dict)) elif status == checkpoint_utils.InitStatus.INIT_RANDOM: print("Starting from random initialization with provided seed") key, net_init_key = jax.random.split(jax.random.PRNGKey(args.seed), 2) init_data = jax.tree_map(lambda elem: elem[0][:1], train_set) params, net_state = net_init(net_init_key, init_data, True) net_state = jax.pmap(lambda _: net_state)(jnp.arange(num_devices)) else: raise ValueError("Unknown initialization status: {}".format(status)) # manually convert all params to dtype params = jax.tree_map(lambda p: p.astype(dtype), params) param_types = tree_utils._get_types(params) assert all([p_type == dtype for p_type in param_types]), ( "Params data types {} do not match specified data type {}".format( param_types, dtype)) trajectory_len = args.trajectory_len log_likelihood_fn = nn_loss.make_xent_log_likelihood( num_classes, args.temperature) log_prior_fn, log_prior_diff_fn = nn_loss.make_gaussian_log_prior( args.weight_decay, args.temperature) update, get_log_prob_and_grad, evaluate = train_utils.make_hmc_update( net_apply, log_likelihood_fn, log_prior_fn, log_prior_diff_fn, args.max_num_leapfrog_steps, args.target_accept_rate, args.step_size_adaptation_speed) log_prob, state_grad, log_likelihood, net_state = ( get_log_prob_and_grad(train_set, params, net_state)) assert log_prob.dtype == dtype, ( "log_prob data type {} does not match specified data type {}".format( log_prob.dtype, dtype)) grad_types = tree_utils._get_types(state_grad) assert all([g_type == dtype for g_type in grad_types]), ( "Gradient data types {} do not match specified data type {}".format( grad_types, dtype)) ensemble_acc = 0 for iteration in range(start_iteration, args.num_iterations): # do a linear ramp-down of the step-size in the burn-in phase if iteration < args.num_burn_in_iterations: alpha = iteration / (args.num_burn_in_iterations - 1) initial_step_size = args.step_size final_step_size = args.burn_in_step_size_factor * args.step_size step_size = final_step_size * alpha + initial_step_size * (1 - alpha) in_burnin = (iteration < args.num_burn_in_iterations) do_mh_correction = (not args.no_mh) and (not in_burnin) start_time = time.time() (params, net_state, log_likelihood, state_grad, step_size, key, accept_prob, accepted) = ( update(train_set, params, net_state, log_likelihood, state_grad, key, step_size, trajectory_len, do_mh_correction)) iteration_time = time.time() - start_time checkpoint_name = checkpoint_utils.make_checkpoint_name(iteration) checkpoint_path = os.path.join(dirname, checkpoint_name) checkpoint_dict = checkpoint_utils.make_hmc_checkpoint_dict( iteration, params, net_state, key, step_size, accepted, num_ensembled, ensemble_predicted_probs) checkpoint_utils.save_checkpoint(checkpoint_path, checkpoint_dict) if ((not in_burnin) and accepted) or args.no_mh: ensemble_predicted_probs, ensemble_acc, num_ensembled = ( train_utils.update_ensemble( net_apply, params, net_state, test_set, num_ensembled, ensemble_predicted_probs)) test_log_prob, test_acc, test_ce, _ = evaluate(params, net_state, test_set) train_log_prob, train_acc, train_ce, prior = ( evaluate(params, net_state, train_set)) tabulate_dict = OrderedDict() tabulate_dict["iteration"] = iteration tabulate_dict["step_size"] = step_size tabulate_dict["train_logprob"] = log_prob tabulate_dict["train_acc"] = train_acc tabulate_dict["test_acc"] = test_acc tabulate_dict["test_ce"] = test_ce tabulate_dict["accept_prob"] = accept_prob tabulate_dict["accepted"] = accepted tabulate_dict["ensemble_acc"] = ensemble_acc tabulate_dict["n_ens"] = num_ensembled tabulate_dict["time"] = iteration_time with tf_writer.as_default(): tf.summary.scalar("train/log_prob", train_log_prob, step=iteration) tf.summary.scalar("test/log_prob", test_log_prob, step=iteration) tf.summary.scalar("train/log_likelihood", train_ce, step=iteration) tf.summary.scalar("test/log_likelihood", test_ce, step=iteration) tf.summary.scalar("train/accuracy", train_acc, step=iteration) tf.summary.scalar("test/accuracy", test_acc, step=iteration) tf.summary.scalar("test/ens_accuracy", ensemble_acc, step=iteration) tf.summary.scalar("test/ens_accuracy", ensemble_acc, step=iteration) tf.summary.scalar("test/ens_accuracy", ensemble_acc, step=iteration) if num_ensembled > 0: test_labels = onp.asarray(test_set[1]) ensemble_nll = metrics.nll(ensemble_predicted_probs, test_labels) ensemble_calibration = metrics.calibration_curve( ensemble_predicted_probs, test_labels) tf.summary.scalar( "test/ens_ece", ensemble_calibration["ece"], step=iteration) tf.summary.scalar("test/ens_nll", ensemble_nll, step=iteration) tf.summary.scalar("telemetry/log_prior", prior, step=iteration) tf.summary.scalar("telemetry/accept_prob", accept_prob, step=iteration) tf.summary.scalar("telemetry/accepted", accepted, step=iteration) tf.summary.scalar("telemetry/n_ens", num_ensembled, step=iteration) tf.summary.scalar("telemetry/iteration_time", iteration_time, step=iteration) tf.summary.scalar("hypers/step_size", step_size, step=iteration) tf.summary.scalar("hypers/trajectory_len", trajectory_len, step=iteration) tf.summary.scalar("hypers/weight_decay", args.weight_decay, step=iteration) tf.summary.scalar("hypers/temperature", args.temperature, step=iteration) tf.summary.scalar("debug/do_mh_correction", float(do_mh_correction), step=iteration) tf.summary.scalar("debug/in_burnin", float(in_burnin), step=iteration) table = tabulate_utils.make_table( tabulate_dict, iteration - start_iteration, args.tabulate_freq) print(table)
def train_and_evaluate(config, workdir, vocab_filepath): """Runs a training and evaluation loop. Args: config: Model and training configuration. workdir: Working directory for checkpoints and TensorBoard summaries. If this contains a checkpoint, training will be resumed from the latest checkpoint. vocab_filepath: Absolute path to SentencePiece vocab model. Raises: ValueError: If training or eval batch sizes won't fit number of hosts and devices, or config is underspecified. """ # Update config before config validation. with config.unlocked(): # Numeric floating point type to use for model computations. config.dtype = jnp.float32 train_utils.validate_config(config) per_host_train_batch_size = config.train_batch_size // jax.process_count() per_host_eval_batch_size = config.eval_batch_size // jax.process_count() if jax.process_index() == 0: train_summary_writer = tensorboard.SummaryWriter( os.path.join(workdir, "train")) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(workdir, "eval")) else: train_summary_writer = None eval_summary_writer = None tokenizer = spm.SentencePieceProcessor() tokenizer.Load(vocab_filepath) ds_info = tfds.builder(config.dataset_name).info num_train_examples = ds_info.splits[tfds.Split.TRAIN].num_examples num_train_steps = int(num_train_examples * config.num_train_epochs // config.train_batch_size) num_warmup_steps = int(config.warmup_proportion * num_train_steps) # Round up evaluation frequency to power of 10. eval_frequency = int( math.ceil(config.eval_proportion * num_train_steps / 10)) * 10 # STSB is a regression task. COPA and ReCoRD are treated as scalar/regression # tasks during training. is_regression_task = (config.dataset_name == "glue/stsb" or config.dataset_name == "super_glue/copa" or config.dataset_name == "super_glue/record") if is_regression_task: num_classes = 1 else: num_classes = ds_info.features["label"].num_classes with config.unlocked(): config.vocab_size = tokenizer.GetPieceSize() config.pad_id = tokenizer.pad_id() config = ml_collections.FrozenConfigDict(config) model = models.SequenceClassificationModel(config, num_classes) rng = random.PRNGKey(config.seed) rng, init_rng = random.split(rng) params = _init_params(model, init_rng, config) learning_rate_fn = train_utils.create_learning_rate_scheduler( factors="constant * linear_warmup * linear_decay", base_learning_rate=config.learning_rate, warmup_steps=num_warmup_steps, decay_steps=num_train_steps - num_warmup_steps, ) tx = optax.adamw(learning_rate_fn, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.01) if config.clipped_grad_norm: tx = optax.chain(optax.clip_by_global_norm(config.clipped_grad_norm), tx) # jit state creation to ensure arrays are created on same device as input # (i.e. CPU). state_cpu = jax.jit( functools.partial(FlaxTrainState.create, apply_fn=model.apply, params=params, tx=tx))() # We access model params only via state.params del params if config.num_experts > 1: sharded_match_fn = core_utils.match_fn(r".*expert.*") not_sharded_match_fn = lambda name: not sharded_match_fn(name) else: sharded_match_fn = None not_sharded_match_fn = lambda name: True state, start_step = _restore_state_from_checkpoint(workdir, state_cpu, sharded_match_fn, not_sharded_match_fn, config) if is_regression_task: scoring_fn = lambda y: y[Ellipsis, 0] else: scoring_fn = lambda y: y.argmax(-1) compute_stats = functools.partial(_compute_stats, model=model, scoring_fn=scoring_fn) classification_inputs = functools.partial( input_pipeline.classification_inputs, dataset_name=config.dataset_name, max_seq_length=config.max_seq_length, tokenizer=tokenizer) train_ds = classification_inputs(split=tfds.Split.TRAIN, batch_size=per_host_train_batch_size, training=True) train_iter = iter(train_ds) if config.dataset_name == "glue/mnli": # MNLI contains two validation and test datasets. split_suffixes = ["_matched", "_mismatched"] else: split_suffixes = [""] # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. rngs = random.split(rng, jax.local_device_count()) loss_and_metrics_fn = functools.partial( _compute_loss_and_metrics, model=model, is_experts_model=config.num_experts > 1, auxiliary_loss_factor=config.auxiliary_loss_factor, router_z_loss_factor=config.router_z_loss_factor) train_step = functools.partial( train_utils.pmap_train_step, loss_and_metrics_fn=loss_and_metrics_fn, axis_name="batch", sharded_match_fn=sharded_match_fn, gradient_accum_steps=config.gradient_accum_steps) p_train_step = jax.pmap(train_step, axis_name="batch") p_eval_step = jax.pmap(compute_stats, axis_name="batch") eval_metrics_fn = _create_eval_metrics_fn(config.dataset_name) train_stats = [] logging.info("Starting training loop.") logging.info("====================") for step in range(start_step, num_train_steps): with jax.profiler.StepTraceContext("train", step_num=step): train_batch = next(train_iter) train_batch = common_utils.shard(train_batch) state, train_step_stats, rngs = p_train_step(state, train_batch, rng=rngs) train_stats.append(train_step_stats) if ((step > 0 and config.save_checkpoints_steps and step % config.save_checkpoints_steps == 0) or step == num_train_steps - 1): # We allow all hosts to potentially save checkpoints because some model # parameters are sharded across devices. Parameters replicated across # devices (i.e. not sharded) will only be checkpointed by host 0. unreplicated_train_state = jax.tree_map( np.array, core_utils.tree_unreplicate_by_name(state, not_sharded_match_fn)) checkpoints.save_checkpoint(workdir, unreplicated_train_state, sharded_match_fn, step, keep=config.checkpoints_to_keep) del unreplicated_train_state # Only used for checkpointing. # Periodic metric handling. if step % eval_frequency != 0 and step < num_train_steps - 1: continue logging.info("Gathering training metrics at step: %d", step) train_metrics = train_utils.collect_metrics(train_stats) train_summary = train_utils.compute_classification_metrics( train_metrics, is_regression_task) train_summary["learning_rate"] = learning_rate_fn(step) if jax.process_index() == 0: assert train_summary_writer for key, val in train_summary.items(): train_summary_writer.scalar(key, val, step) train_summary_writer.flush() # Reset metric accumulation for next training evaluation cycle. train_stats = [] logging.info("Gathering validation metrics at step: %d", step) for split_suffix in split_suffixes: eval_ds = classification_inputs( split=tfds.Split.VALIDATION + split_suffix, batch_size=per_host_eval_batch_size, training=False) eval_stats = [] for _, eval_batch in zip(range(config.max_num_eval_steps), eval_ds): eval_stats.append( _evaluate(p_eval_step, state.params, eval_batch)) eval_metrics = {} for k in eval_stats[ 0]: # All batches of output stats are the same size eval_metrics[k] = np.concatenate( [stat[k] for stat in eval_stats], axis=0) eval_summary = eval_metrics_fn(eval_metrics) if jax.process_index() == 0: assert eval_summary_writer for key, val in eval_summary.items(): eval_summary_writer.scalar(f"{key}{split_suffix}", val, step) eval_summary_writer.flush()
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ tf.io.gfile.makedirs(workdir) # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: tf.io.gfile.makedirs(workdir) train_summary_writer = tensorboard.SummaryWriter( os.path.join(workdir, "train")) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(workdir, "eval")) if config.batch_size % n_devices: raise ValueError("Batch size must be divisible by the number of devices") vocab_path = config.vocab_path if vocab_path is None: vocab_path = os.path.join(workdir, "sentencepiece_model") tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) # Load Dataset # --------------------------------------------------------------------------- logging.info("Initializing dataset.") train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets( n_devices=n_devices, dataset_name=config.dataset_name, eval_dataset_name=config.eval_dataset_name, shard_idx=jax.host_id(), shard_count=jax.host_count(), vocab_path=vocab_path, target_vocab_size=config.vocab_size, batch_size=config.batch_size, max_corpus_chars=config.max_corpus_chars, max_length=config.max_target_length, max_eval_length=config.max_eval_target_length) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_id = decode.EOS_ID # Default Sentencepiece EOS token. def decode_tokens(toks): valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode("utf-8") if config.num_predict_steps > 0: predict_ds = predict_ds.take(config.num_predict_steps) logging.info("Initializing model, optimizer, and step functions.") # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = models.TransformerConfig( vocab_size=vocab_size, output_vocab_size=vocab_size, share_embeddings=config.share_embeddings, logits_via_embedding=config.logits_via_embedding, dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, emb_dim=config.emb_dim, num_heads=config.num_heads, num_layers=config.num_layers, qkv_dim=config.qkv_dim, mlp_dim=config.mlp_dim, max_len=max(config.max_target_length, config.max_eval_target_length), dropout_rate=config.dropout_rate, attention_dropout_rate=config.attention_dropout_rate, deterministic=False, decode=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(deterministic=True, decode=True) start_step = 0 rng = random.PRNGKey(config.seed) rng, init_rng = random.split(rng) input_shape = (config.batch_size, config.max_target_length) target_shape = (config.batch_size, config.max_target_length) m = models.Transformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) # apply an optimizer to this tree optimizer_def = optim.Adam( config.learning_rate, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=config.weight_decay) optimizer = optimizer_def.create(initial_variables["params"]) # We access model params only from optimizer below via optimizer.target. del initial_variables if config.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint(workdir, optimizer) # Grab last step. start_step = int(optimizer.state.step) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps) # compile multidevice versions of train/eval/predict step and cache init fn. p_train_step = jax.pmap( functools.partial( train_step, config=train_config, learning_rate_fn=learning_rate_fn, label_smoothing=config.label_smoothing), axis_name="batch", donate_argnums=(0,)) # pytype: disable=wrong-arg-types p_eval_step = jax.pmap( functools.partial( eval_step, config=eval_config, label_smoothing=config.label_smoothing), axis_name="batch") p_init_cache = jax.pmap( functools.partial( initialize_cache, max_decode_len=config.max_predict_length, config=predict_config), axis_name="batch") p_pred_step = jax.pmap( functools.partial( predict_step, config=predict_config, beam_size=config.beam_size), axis_name="batch", static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap"d training update for performance. dropout_rngs = random.split(rng, n_devices) logging.info("Starting training loop.") metrics_all = [] t_loop_start = time.time() for step, batch in zip(range(start_step, config.num_train_steps), train_iter): # Shard data to devices and do a training step. batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access optimizer, metrics, dropout_rngs = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) metrics_all.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) # Save a checkpoint on one host after every checkpoint_freq steps. if (config.save_checkpoints and step % config.checkpoint_freq == 0 and step > 0 and jax.host_id() == 0): checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if step % config.eval_frequency != 0 and step > 0: continue # Training Metrics logging.info("Gathering training metrics.") metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop("learning_rate").mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop("denominator") summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary["learning_rate"] = lr steps_per_eval = config.eval_frequency if step != 0 else 1 steps_per_sec = steps_per_eval / (time.time() - t_loop_start) t_loop_start = time.time() if jax.host_id() == 0: train_summary_writer.scalar("steps per second", steps_per_sec, step) for key, val in summary.items(): train_summary_writer.scalar(key, val, step) train_summary_writer.flush() metrics_all = [] logging.info("train in step: %d, loss: %.4f", step, summary["loss"]) # Eval Metrics logging.info("Gathering evaluation metrics.") t_eval_start = time.time() eval_metrics = [] eval_iter = iter(eval_ds) for _, eval_batch in zip(range(config.num_eval_steps), eval_iter): eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access eval_batch = common_utils.shard(eval_batch) metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop("denominator") eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) if jax.host_id() == 0: for key, val in eval_summary.items(): eval_summary_writer.scalar(key, val, step) eval_summary_writer.flush() logging.info("eval in step: %d, loss: %.4f", step, eval_summary["loss"]) logging.info("eval time: %.4f s step %d", time.time() - t_eval_start, step) # Translation and BLEU Score. logging.info("Translating evaluation dataset.") t_inference_start = time.time() sources, references, predictions = [], [], [] for pred_batch in predict_ds: pred_batch = jax.tree_map(lambda x: x._numpy(), pred_batch) # pylint: disable=protected-access # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch["inputs"].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) pred_batch = jax.tree_map( lambda x: pad_examples(x, padded_size), pred_batch) # pylint: disable=cell-var-from-loop pred_batch = common_utils.shard(pred_batch) cache = p_init_cache(pred_batch["inputs"]) predicted = p_pred_step(pred_batch["inputs"], optimizer.target, cache, eos_id, config.max_predict_length) predicted = tohost(predicted) inputs = tohost(pred_batch["inputs"]) targets = tohost(pred_batch["targets"]) # Iterate through non-padding examples of batch. for i, s in enumerate(predicted[:cur_pred_batch_size]): sources.append(decode_tokens(inputs[i])) references.append(decode_tokens(targets[i])) predictions.append(decode_tokens(s)) logging.info("Translation: %d predictions %d references %d sources.", len(predictions), len(references), len(sources)) logging.info("Translation time: %.4f s step %d.", time.time() - t_inference_start, step) # Calculate BLEU score for translated eval corpus against reference. bleu_matches = bleu.bleu_partial(references, predictions) all_bleu_matches = per_host_sum_pmap(bleu_matches) bleu_score = bleu.complete_bleu(*all_bleu_matches) # Save translation samples for tensorboard. exemplars = "" for n in np.random.choice(np.arange(len(predictions)), 8): exemplars += f"{sources[n]}\n\n{references[n]}\n\n{predictions[n]}\n\n" if jax.host_id() == 0: eval_summary_writer.scalar("bleu", bleu_score, step) eval_summary_writer.text("samples", exemplars, step) eval_summary_writer.flush() logging.info("Translation BLEU Score %.4f", bleu_score)
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) -> TrainState: """Execute model training and evaluation loop. Args: config: Hyperparameter configuration for training and evaluation. workdir: Directory where the tensorboard summaries are written to. """ if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(workdir) summary_writer.hparams(dict(config)) rng = random.PRNGKey(0) image_size = 224 if config.batch_size % jax.device_count() > 0: raise ValueError( 'Batch size must be divisible by the number of devices') local_batch_size = config.batch_size // jax.host_count() platform = jax.local_devices()[0].platform if config.half_precision: if platform == 'tpu': input_dtype = tf.bfloat16 else: input_dtype = tf.float16 else: input_dtype = tf.float32 dataset_builder = tfds.builder(config.dataset) train_iter = create_input_iter(dataset_builder, local_batch_size, image_size, input_dtype, train=True, cache=config.cache) eval_iter = create_input_iter(dataset_builder, local_batch_size, image_size, input_dtype, train=False, cache=config.cache) steps_per_epoch = (dataset_builder.info.splits['train'].num_examples // config.batch_size) if config.num_train_steps == -1: num_steps = int(steps_per_epoch * config.num_epochs) else: num_steps = config.num_train_steps if config.steps_per_eval == -1: num_validation_examples = dataset_builder.info.splits[ 'validation'].num_examples steps_per_eval = num_validation_examples // config.batch_size else: steps_per_eval = config.steps_per_eval steps_per_checkpoint = steps_per_epoch * 10 base_learning_rate = config.learning_rate * config.batch_size / 256. model_cls = getattr(models, config.model) model = create_model(model_cls=model_cls, half_precision=config.half_precision) state = create_train_state(rng, config, model, image_size) state = restore_checkpoint(state, workdir) # step_offset > 0 if restarting from checkpoint step_offset = int(state.step) state = jax_utils.replicate(state) learning_rate_fn = create_learning_rate_fn(config, steps_per_epoch) p_train_step = jax.pmap(functools.partial( train_step, model.apply, learning_rate_fn=learning_rate_fn), axis_name='batch') p_eval_step = jax.pmap(functools.partial(eval_step, model.apply), axis_name='batch') epoch_metrics = [] hooks = [] if jax.host_id() == 0: hooks += [ periodic_actions.Profile(num_profile_steps=5, logdir=workdir) ] t_loop_start = time.time() logging.info('Initial compilation, this might take some minutes...') for step, batch in zip(range(step_offset, num_steps), train_iter): state, metrics = p_train_step(state, batch) for h in hooks: h(step) if step == step_offset: logging.info('Initial compilation completed.') epoch_metrics.append(metrics) if (step + 1) % steps_per_epoch == 0: epoch = step // steps_per_epoch epoch_metrics = common_utils.get_metrics(epoch_metrics) summary = jax.tree_map(lambda x: x.mean(), epoch_metrics) logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) steps_per_sec = steps_per_epoch / (time.time() - t_loop_start) t_loop_start = time.time() if jax.host_id() == 0: for key, vals in epoch_metrics.items(): tag = 'train_%s' % key for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) summary_writer.scalar('steps per second', steps_per_sec, step) epoch_metrics = [] eval_metrics = [] # sync batch statistics across replicas state = sync_batch_stats(state) for _ in range(steps_per_eval): eval_batch = next(eval_iter) metrics = p_eval_step(state, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) summary = jax.tree_map(lambda x: x.mean(), eval_metrics) logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) if jax.host_id() == 0: for key, val in eval_metrics.items(): tag = 'eval_%s' % key summary_writer.scalar(tag, val.mean(), step) summary_writer.flush() if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps: state = sync_batch_stats(state) save_checkpoint(state, workdir) # Wait until computations are done before exiting jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready() return state
def main(unused_argv): rng = random.PRNGKey(20200823) if FLAGS.config is not None: utils.update_flags(FLAGS) if FLAGS.train_dir is None: raise ValueError("train_dir must be set. None set now.") if FLAGS.data_dir is None: raise ValueError("data_dir must be set. None set now.") dataset = datasets.get_dataset("test", FLAGS) rng, key = random.split(rng) model, init_variables = models.get_model(key, dataset.peek(), FLAGS) optimizer = flax.optim.Adam(FLAGS.lr_init).create(init_variables) state = utils.TrainState(optimizer=optimizer) del optimizer, init_variables # Rendering is forced to be deterministic even if training was randomized, as # this eliminates "speckle" artifacts. def render_fn(variables, key_0, key_1, rays): return jax.lax.all_gather( model.apply(variables, key_0, key_1, rays, False), axis_name="batch") # pmap over only the data input. render_pfn = jax.pmap( render_fn, in_axes=(None, None, None, 0), donate_argnums=3, axis_name="batch", ) # Compiling to the CPU because it's faster and more accurate. ssim_fn = jax.jit( functools.partial(utils.compute_ssim, max_val=1.), backend="cpu") last_step = 0 out_dir = path.join(FLAGS.train_dir, "path_renders" if FLAGS.render_path else "test_preds") if not FLAGS.eval_once: summary_writer = tensorboard.SummaryWriter( path.join(FLAGS.train_dir, "eval")) while True: state = checkpoints.restore_checkpoint(FLAGS.train_dir, state) step = int(state.optimizer.state.step) if step <= last_step: continue if FLAGS.save_output and (not utils.isdir(out_dir)): utils.makedirs(out_dir) psnrs = [] ssims = [] if not FLAGS.eval_once: showcase_index = np.random.randint(0, dataset.size) for idx in range(dataset.size): print(f"Evaluating {idx+1}/{dataset.size}") batch = next(dataset) pred_color, pred_disp, pred_acc = utils.render_image( functools.partial(render_pfn, state.optimizer.target), batch["rays"], rng, FLAGS.dataset == "llff", chunk=FLAGS.chunk) if jax.host_id() != 0: # Only record via host 0. continue if not FLAGS.eval_once and idx == showcase_index: showcase_color = pred_color showcase_disp = pred_disp showcase_acc = pred_acc if not FLAGS.render_path: showcase_gt = batch["pixels"] if not FLAGS.render_path: psnr = utils.compute_psnr(((pred_color - batch["pixels"])**2).mean()) ssim = ssim_fn(pred_color, batch["pixels"]) print(f"PSNR = {psnr:.4f}, SSIM = {ssim:.4f}") psnrs.append(float(psnr)) ssims.append(float(ssim)) if FLAGS.save_output: utils.save_img(pred_color, path.join(out_dir, "{:03d}.png".format(idx))) utils.save_img(pred_disp[Ellipsis, 0], path.join(out_dir, "disp_{:03d}.png".format(idx))) if (not FLAGS.eval_once) and (jax.host_id() == 0): summary_writer.image("pred_color", showcase_color, step) summary_writer.image("pred_disp", showcase_disp, step) summary_writer.image("pred_acc", showcase_acc, step) if not FLAGS.render_path: summary_writer.scalar("psnr", np.mean(np.array(psnrs)), step) summary_writer.scalar("ssim", np.mean(np.array(ssims)), step) summary_writer.image("target", showcase_gt, step) if FLAGS.save_output and (not FLAGS.render_path) and (jax.host_id() == 0): with utils.open_file(path.join(out_dir, f"psnrs_{step}.txt"), "w") as f: f.write(" ".join([str(v) for v in psnrs])) with utils.open_file(path.join(out_dir, f"ssims_{step}.txt"), "w") as f: f.write(" ".join([str(v) for v in ssims])) with utils.open_file(path.join(out_dir, "psnr.txt"), "w") as f: f.write("{}".format(np.mean(np.array(psnrs)))) with utils.open_file(path.join(out_dir, "ssim.txt"), "w") as f: f.write("{}".format(np.mean(np.array(ssims)))) if FLAGS.eval_once: break if int(step) >= FLAGS.max_steps: break last_step = step
def compute_preconditioners_from_statistics(self, states, hps, step): """Compute preconditioners for statistics.""" statistics = [] num_statistics_per_state = [] original_shapes = [] exponents = [] max_size = 0 prev_preconditioners = [] for state in states: num_statistics = len(state.statistics) num_statistics_per_state.append(num_statistics) original_shapes_for_state = [] if num_statistics > 0: for statistic in state.statistics: exponents.append(2 * num_statistics if hps.exponent_override == 0 else hps.exponent_override) original_shapes_for_state.append(statistic.shape) max_size = max(max_size, statistic.shape[0]) statistics.extend(state.statistics) prev_preconditioners.extend(state.preconditioners) original_shapes.extend(original_shapes_for_state) num_statistics = len(statistics) def pack(mat, max_size): """Pack a matrix to a max_size for inverse on TPUs with static shapes. Args: mat: Matrix for computing inverse pth root. max_size: Matrix size to pack to. Returns: Given M returns [[M, 0], [0, I]] """ size = mat.shape[0] assert size <= max_size if size == max_size: return mat pad_size = max_size - size zs1 = jnp.zeros([size, pad_size], dtype=mat.dtype) zs2 = jnp.zeros([pad_size, size], dtype=mat.dtype) eye = jnp.eye(pad_size, dtype=mat.dtype) mat = jnp.concatenate([mat, zs1], 1) mat = jnp.concatenate([mat, jnp.concatenate([zs2, eye], 1)], 0) return mat if not hps.batch_axis_name: num_devices = jax.local_device_count() else: num_devices = lax.psum(1, hps.batch_axis_name) # Pad statistics and exponents to next multiple of num_devices. packed_statistics = [pack(stat, max_size) for stat in statistics] to_pad = -num_statistics % num_devices packed_statistics.extend([ jnp.eye(max_size, dtype=packed_statistics[0].dtype) for _ in range(to_pad) ]) exponents.extend([1 for _ in range(to_pad)]) # Batch statistics and exponents so that so that leading axis is # num_devices. def _batch(statistics, exponents, num_devices): assert len(statistics) == len(exponents) n = len(statistics) b = int(n / num_devices) batched_statistics = [ jnp.stack(statistics[idx:idx + b]) for idx in range(0, n, b) ] batched_exponents = [ jnp.stack(exponents[idx:idx + b]) for idx in range(0, n, b) ] return jnp.stack(batched_statistics), jnp.stack(batched_exponents) # Unbatch values across leading axis and return a list of elements. def _unbatch(batched_values): b1, b2 = batched_values.shape[0], batched_values.shape[1] results = [] for v_array in jnp.split(batched_values, b1, 0): for v in jnp.split(jnp.squeeze(v_array), b2, 0): results.append(jnp.squeeze(v)) return results all_statistics, all_exponents = _batch(packed_statistics, exponents, num_devices) def _matrix_inverse_pth_root(xs, ps): mi_pth_root = lambda x, y: matrix_inverse_pth_root( # pylint: disable=g-long-lambda x, y, ridge_epsilon=hps.matrix_eps) preconditioners, errors = jax.vmap(mi_pth_root)(xs, ps) return preconditioners, errors if not hps.batch_axis_name: preconditioners, errors = jax.pmap(_matrix_inverse_pth_root)( all_statistics, all_exponents) preconditioners_flat = _unbatch(preconditioners) errors_flat = _unbatch(errors) else: def _internal_inverse_pth_root_all(): preconditioners = jnp.array(all_statistics) current_replica = lax.axis_index(hps.batch_axis_name) preconditioners, errors = _matrix_inverse_pth_root( all_statistics[current_replica], all_exponents[current_replica]) preconditioners = jax.lax.all_gather(preconditioners, hps.batch_axis_name) errors = jax.lax.all_gather(errors, hps.batch_axis_name) preconditioners_flat = _unbatch(preconditioners) errors_flat = _unbatch(errors) return preconditioners_flat, errors_flat if hps.preconditioning_compute_steps == 1: preconditioners_flat, errors_flat = _internal_inverse_pth_root_all( ) else: # Passing statistics instead of preconditioners as they are similarly # shaped tensors, as error we are passing is the threshold these will # be ignored. preconditioners_init = packed_statistics errors_init = ([_INVERSE_PTH_ROOT_FAILURE_THRESHOLD] * len(packed_statistics)) init_state = [preconditioners_init, errors_init] perform_step = step % hps.preconditioning_compute_steps == 0 preconditioners_flat, errors_flat = self.fast_cond( perform_step, _internal_inverse_pth_root_all, init_state) def _skip(error): return jnp.logical_or( jnp.isnan(error), error >= _INVERSE_PTH_ROOT_FAILURE_THRESHOLD).astype( error.dtype) def _select_preconditioner(error, new_p, old_p): return lax.cond(_skip(error), lambda _: old_p, lambda _: new_p, operand=None) new_preconditioners_flat = [] for p, shape, prev_p, error in zip(preconditioners_flat, original_shapes, prev_preconditioners, errors_flat): new_preconditioners_flat.append( _select_preconditioner(error, p[:shape[0], :shape[1]], prev_p)) assert len(states) == len(num_statistics_per_state) assert len(new_preconditioners_flat) == num_statistics # Add back empty preconditioners so we that we can set the optimizer state. preconditioners_for_states = [] idx = 0 for num_statistics, state in zip(num_statistics_per_state, states): if num_statistics == 0: preconditioners_for_states.append([]) else: preconditioners_for_state = new_preconditioners_flat[ idx:idx + num_statistics] assert len(state.statistics) == len(preconditioners_for_state) preconditioners_for_states.append(preconditioners_for_state) idx += num_statistics new_states = [] for state, new_preconditioners in zip(states, preconditioners_for_states): new_states.append( _ShampooDefaultParamState(state.diagonal_statistics, state.statistics, new_preconditioners, state.diagonal_momentum, state.momentum)) return new_states
def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser( (ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file( json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses( ) if (os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir): raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome.") # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Setup logging, we only want one process per machine to log things on the screen. logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) if jax.process_index() == 0: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_info() else: datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() # Set the verbosity to info of the Transformers logger (on main process only): logger.info(f"Training/evaluation parameters {training_args}") # Set seed before initializing model. set_seed(training_args.seed) # Handle the repository creation if training_args.push_to_hub: if training_args.hub_model_id is None: repo_name = get_full_repo_name(Path( training_args.output_dir).absolute().name, token=training_args.hub_token) else: repo_name = training_args.hub_model_id repo = Repository(training_args.output_dir, clone_from=repo_name) # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # (the dataset will be downloaded automatically from the datasets Hub). # # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called # 'text' is found. You can easily tweak this behavior (see below). # # In distributed training, the load_dataset function guarantees that only one local process can concurrently # download the dataset. if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. dataset = load_dataset( data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False, use_auth_token=True if model_args.use_auth_token else None, ) if "validation" not in dataset.keys(): dataset["validation"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) dataset["train"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) else: data_files = {} dataset_args = {} if data_args.train_file is not None: data_files["train"] = data_args.train_file if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file extension = data_args.train_file.split(".")[-1] if extension == "txt": extension = "text" dataset_args["keep_linebreaks"] = data_args.keep_linebreaks dataset = load_dataset( extension, data_files=data_files, cache_dir=model_args.cache_dir, **dataset_args, use_auth_token=True if model_args.use_auth_token else None, ) if "validation" not in dataset.keys(): dataset["validation"] = load_dataset( extension, data_files=data_files, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, **dataset_args, use_auth_token=True if model_args.use_auth_token else None, ) dataset["train"] = load_dataset( extension, data_files=data_files, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, **dataset_args, use_auth_token=True if model_args.use_auth_token else None, ) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. # Load pretrained model and tokenizer # Distributed training: # The .from_pretrained methods guarantee that only one local process can concurrently # download model & vocab. if model_args.config_name: config = AutoConfig.from_pretrained( model_args.config_name, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) elif model_args.model_name_or_path: config = AutoConfig.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) else: config = CONFIG_MAPPING[model_args.model_type]() logger.warning( "You are instantiating a new config instance from scratch.") if model_args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer, use_auth_token=True if model_args.use_auth_token else None, ) elif model_args.model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer, use_auth_token=True if model_args.use_auth_token else None, ) else: raise ValueError( "You are instantiating a new tokenizer from scratch. This is not supported by this script." "You can do it from another script, save it, and load it from here, using --tokenizer_name." ) if model_args.model_name_or_path: model = FlaxAutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), use_auth_token=True if model_args.use_auth_token else None, ) else: model = FlaxAutoModelForCausalLM.from_config( config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), use_auth_token=True if model_args.use_auth_token else None, ) # Preprocessing the datasets. # First we tokenize all the texts. if training_args.do_train: column_names = dataset["train"].column_names else: column_names = dataset["validation"].column_names text_column_name = "text" if "text" in column_names else column_names[0] # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function tok_logger = transformers.utils.logging.get_logger( "transformers.tokenization_utils_base") def tokenize_function(examples): with CaptureLogger(tok_logger) as cl: output = tokenizer(examples[text_column_name]) # clm input could be much much longer than block_size if "Token indices sequence length is longer than the" in cl.out: tok_logger.warning( "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model." ) return output tokenized_datasets = dataset.map( tokenize_function, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, ) if data_args.block_size is None: block_size = tokenizer.model_max_length if block_size > config.max_position_embeddings: logger.warning( f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " "Picking 1024 instead. You can change that default value by passing --block_size xxx." ) block_size = 1024 else: if data_args.block_size > tokenizer.model_max_length: logger.warning( f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model" f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." ) block_size = min(data_args.block_size, tokenizer.model_max_length) # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. def group_texts(examples): # Concatenate all texts. concatenated_examples = { k: list(chain(*examples[k])) for k in examples.keys() } total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # customize this part to your needs. if total_length >= block_size: total_length = (total_length // block_size) * block_size # Split by chunks of max_len. result = { k: [t[i:i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items() } result["labels"] = result["input_ids"].copy() return result # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower # to preprocess. # # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map lm_datasets = tokenized_datasets.map( group_texts, batched=True, num_proc=data_args.preprocessing_num_workers, load_from_cache_file=not data_args.overwrite_cache, ) if training_args.do_train: if "train" not in tokenized_datasets: raise ValueError("--do_train requires a train dataset") train_dataset = lm_datasets["train"] if data_args.max_train_samples is not None: max_train_samples = min(len(train_dataset), data_args.max_train_samples) train_dataset = train_dataset.select(range(max_train_samples)) if training_args.do_eval: if "validation" not in tokenized_datasets: raise ValueError("--do_eval requires a validation dataset") eval_dataset = lm_datasets["validation"] if data_args.max_eval_samples is not None: max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) eval_dataset = eval_dataset.select(range(max_eval_samples)) # Enable tensorboard only on the master node has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter( log_dir=Path(training_args.output_dir)) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable.") # Initialize our training rng = jax.random.PRNGKey(training_args.seed) rng, dropout_rng = jax.random.split(rng) # Store some constant num_epochs = int(training_args.num_train_epochs) train_batch_size = int( training_args.per_device_train_batch_size) * jax.device_count() eval_batch_size = int( training_args.per_device_eval_batch_size) * jax.device_count() steps_per_epoch = len(train_dataset) // train_batch_size total_train_steps = steps_per_epoch * num_epochs # Create learning rate schedule linear_decay_lr_schedule_fn = create_learning_rate_fn( len(train_dataset), train_batch_size, training_args.num_train_epochs, training_args.warmup_steps, training_args.learning_rate, ) # We use Optax's "masking" functionality to not apply weight decay # to bias and LayerNorm scale parameters. decay_mask_fn returns a # mask boolean with the same structure as the parameters. # The mask is True for parameters that should be decayed. # Note that this mask is specifically adapted for FlaxGPT2. # For other models, one should correct the layer norm parameter naming # accordingly. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) flat_mask = { path: (path[-1] != "bias" and path[-2:] not in [("ln_1", "scale"), ("ln_2", "scale"), ("ln_f", "scale")]) for path in flat_params } return traverse_util.unflatten_dict(flat_mask) # create adam optimizer if training_args.adafactor: # We use the default parameters here to initialize adafactor, # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74 optimizer = optax.adafactor( learning_rate=linear_decay_lr_schedule_fn, ) else: optimizer = optax.adamw( learning_rate=linear_decay_lr_schedule_fn, b1=training_args.adam_beta1, b2=training_args.adam_beta2, eps=training_args.adam_epsilon, weight_decay=training_args.weight_decay, mask=decay_mask_fn, ) # Setup train state state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng) def loss_fn(logits, labels): shift_logits = logits[..., :-1, :] shift_labels = labels[..., 1:] loss = optax.softmax_cross_entropy( shift_logits, onehot(shift_labels, shift_logits.shape[-1])) return loss.mean() # Define gradient update step fn def train_step(state, batch): dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) def compute_loss(params): labels = batch.pop("labels") logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss = loss_fn(logits, labels) return loss grad_fn = jax.value_and_grad(compute_loss) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) metrics = { "loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step) } metrics = jax.lax.pmean(metrics, axis_name="batch") return new_state, metrics # Define eval fn def eval_step(params, batch): labels = batch.pop("labels") logits = model(**batch, params=params, train=False)[0] loss = loss_fn(logits, labels) # summarize metrics metrics = {"loss": loss} metrics = jax.lax.pmean(metrics, axis_name="batch") return metrics # Create parallel version of the train and eval step p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, )) p_eval_step = jax.pmap(eval_step, "batch") # Replicate the train state on each device state = state.replicate() logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {num_epochs}") logger.info( f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}" ) logger.info( f" Total train batch size (w. parallel & distributed) = {train_batch_size}" ) logger.info(f" Total optimization steps = {total_train_steps}") train_time = 0 train_metrics = [] epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0) for epoch in epochs: # ======================== Training ================================ train_start = time.time() # Create sampling rng rng, input_rng = jax.random.split(rng) # Generate an epoch by shuffling sampling indices from the train dataset train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True) steps_per_epoch = len(train_dataset) // train_batch_size # train for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): batch = next(train_loader) batch = shard(batch) state, train_metric = p_train_step(state, batch) train_metrics.append(train_metric) cur_step = epoch * (len(train_dataset) // train_batch_size) + step if cur_step % training_args.logging_steps == 0 and cur_step > 0: # Save metrics train_metric = unreplicate(train_metric) train_time += time.time() - train_start if has_tensorboard and jax.process_index() == 0: write_train_metric(summary_writer, train_metrics, train_time, cur_step) epochs.write( f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})" ) train_metrics = [] if cur_step % training_args.eval_steps == 0 and cur_step > 0: # ======================== Evaluating ============================== eval_metrics = [] eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size) eval_steps = len(eval_dataset) // eval_batch_size for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): # Model forward batch = next(eval_loader) batch = shard(batch) metrics = p_eval_step(state.params, batch) eval_metrics.append(metrics) # normalize eval metrics eval_metrics = get_metrics(eval_metrics) eval_metrics = jax.tree_map(jnp.mean, eval_metrics) try: eval_metrics["perplexity"] = math.exp(eval_metrics["loss"]) except OverflowError: eval_metrics["perplexity"] = float("inf") # Print metrics and update progress bar desc = f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})" epochs.write(desc) epochs.desc = desc # Save metrics if has_tensorboard and jax.process_index() == 0: write_eval_metric(summary_writer, eval_metrics, cur_step) if cur_step % training_args.save_steps == 0 and cur_step > 0: # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(unreplicate(state.params)) model.save_pretrained(training_args.output_dir, params=params) tokenizer.save_pretrained(training_args.output_dir) if training_args.push_to_hub: repo.push_to_hub( commit_message= f"Saving weights and logs of step {cur_step}", blocking=False) # Eval after training if training_args.do_eval: eval_metrics = [] eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size) eval_steps = len(eval_dataset) // eval_batch_size for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): # Model forward batch = shard(next(eval_loader)) metrics = p_eval_step(state.params, batch) eval_metrics.append(metrics) # normalize eval metrics eval_metrics = get_metrics(eval_metrics) eval_metrics = jax.tree_map(lambda x: jnp.mean(x).item(), eval_metrics) try: eval_metrics["perplexity"] = math.exp(eval_metrics["loss"]) except OverflowError: eval_metrics["perplexity"] = float("inf") if jax.process_index() == 0: eval_metrics = { f"eval_{metric_name}": value for metric_name, value in eval_metrics.items() } path = os.path.join(training_args.output_dir, "eval_results.json") with open(path, "w") as f: json.dump(eval_metrics, f, indent=4, sort_keys=True)
def f(x, y): z = jax.pmap(np.exp)(x) return x + 2., z + y
def main(): image_size = 384 # jax model jax_model = models.KNOWN_MODELS['ViT-B_16'].partial( num_classes=1000, representation_size=None) _, params = jax_model.init_by_shape( jax.random.PRNGKey(0), # Discard the "num_local_devices" dimension of the batch for initialization. [((4, image_size, image_size, 3), 'float32')]) params = checkpoint.load_pretrained( pretrained_path= '/home/hchen/Projects/vision_transformer/weights/jax/imagenet21k+imagenet2012_ViT-B_16.npz', init_params=params, model_config=models.CONFIGS['ViT-B_16'], logger=logger) params_repl = flax.jax_utils.replicate(params) # Then map the call to our model's forward pass onto all available devices. vit_apply_repl = jax.pmap(jax_model.call) # torch_model keys, values = load_jax( '/home/hchen/Projects/vision_transformer/weights/jax/imagenet21k+imagenet2012_ViT-B_16.npz' ) state_dict = convert_jax_pytorch(keys, values) torch_model = VisionTransformer(image_size=(image_size, image_size), patch_size=(16, 16), emb_dim=768, mlp_dim=3072, num_heads=12, num_layers=12, num_classes=1000, attn_dropout_rate=0.0, dropout_rate=0.1) torch_model.load_state_dict(state_dict) torch_model.eval() data_loader = ImageNetDataLoader( data_dir='/home/hchen/Projects/vat_contrast/data/ImageNet', split='val', image_size=image_size, batch_size=16, num_workers=0) for batch_idx, (data, target) in enumerate(data_loader): # jax prediction target_numpy = target.cpu().numpy() data_numpy = data.cpu().numpy().transpose(0, 2, 3, 1).reshape( 1, -1, image_size, image_size, 3) jax_predicted_logits = vit_apply_repl(params_repl, data_numpy)._value[0] jax_predicted = onp.argmax(jax_predicted_logits, axis=-1) # torch prediction with torch.no_grad(): torch_predicted = torch_model(data) torch_predicted_logits = torch_predicted.cpu().numpy() torch_predicted = onp.argmax(torch_predicted_logits, axis=-1) # check difference # diff = onp.abs(jax_predicted_logits - torch_predicted_logits) # assert onp.allclose(jax_predicted_logits, torch_predicted_logits, rtol=1e-1, atol=1e-1), "diff {}, max {}, sum {}".format(diff, onp.max(diff), onp.sum(diff)) diff = onp.abs(jax_predicted - torch_predicted) print(diff)
def __init__(self, network: networks_lib.FeedForwardNetwork, random_key: networks_lib.PRNGKey, loss_fn: losses.Loss, optimizer: optax.GradientTransformation, prefetching_iterator: Iterator[types.Transition], num_sgd_steps_per_step: int, loss_has_aux: bool = False, logger: Optional[loggers.Logger] = None, counter: Optional[counting.Counter] = None): """Behavior Cloning Learner. Args: network: Networks with signature for apply: (params, obs, is_training, key) -> jnp.ndarray and for init: (rng, is_training) -> params random_key: RNG key. loss_fn: BC loss to use. optimizer: Optax optimizer. prefetching_iterator: A sharded prefetching iterator as outputted from `acme.jax.utils.sharded_prefetch`. Please see the documentation for `sharded_prefetch` for more details. num_sgd_steps_per_step: Number of gradient updates per step. loss_has_aux: Whether the loss function returns auxiliary metrics as a second argument. logger: Logger. counter: Counter. """ def sgd_step( state: TrainingState, transitions: types.Transition, ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: loss_and_grad = jax.value_and_grad(loss_fn, argnums=1, has_aux=loss_has_aux) # Compute losses and their gradients. key, key_input = jax.random.split(state.key) loss_result, gradients = loss_and_grad(network.apply, state.policy_params, key_input, transitions) # Combine the gradient across all devices (by taking their mean). gradients = jax.lax.pmean(gradients, axis_name=_PMAP_AXIS_NAME) # Compute and combine metrics across all devices. metrics = _create_loss_metrics(loss_has_aux, loss_result, gradients) metrics = jax.lax.pmean(metrics, axis_name=_PMAP_AXIS_NAME) policy_update, optimizer_state = optimizer.update( gradients, state.optimizer_state, state.policy_params) policy_params = optax.apply_updates(state.policy_params, policy_update) new_state = TrainingState( optimizer_state=optimizer_state, policy_params=policy_params, key=key, steps=state.steps + 1, ) return new_state, metrics # General learner book-keeping and loggers. self._counter = counter or counting.Counter(prefix='learner') self._logger = logger or loggers.make_default_logger( 'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray, steps_key=self._counter.get_steps_key()) # Split the input batch to `num_sgd_steps_per_step` minibatches in order # to achieve better performance on accelerators. sgd_step = utils.process_multiple_batches(sgd_step, num_sgd_steps_per_step) self._sgd_step = jax.pmap(sgd_step, axis_name=_PMAP_AXIS_NAME) random_key, init_key = jax.random.split(random_key) policy_params = network.init(init_key) optimizer_state = optimizer.init(policy_params) # Create initial state. state = TrainingState( optimizer_state=optimizer_state, policy_params=policy_params, key=random_key, steps=0, ) self._state = utils.replicate_in_all_devices(state) self._timestamp = None self._prefetching_iterator = prefetching_iterator
def f(x): return jax.pmap(lambda x: np.exp(x) + 2.)(x)
("reserve_rng_keys", lambda: base.reserve_rng_keys(2)), ("with_rng", with_rng_example), ) # JAX transforms and control flow that need to be aware of Haiku internal # state to operate unsurprisingly. # pylint: disable=g-long-lambda JAX_PURE_EXPECTING_FNS = ( # Just-in-time compilation. ("jit", jax.jit), ("make_jaxpr", jax.make_jaxpr), ("eval_shape", lambda f: (lambda x: jax.eval_shape(f, x))), # Parallelization. # TODO(tomhennigan): Add missing features (e.g. pjit,xmap). ("pmap", lambda f: jax.pmap(f, "i")), # Vectorization. ("vmap", jax.vmap), # Control flow. # TODO(tomhennigan): Enable for associative_scan. # ("associative_scan", lambda f: # (lambda x: jax.lax.associative_scan( # lambda a, b: [f(a + b), a + b][-1], jnp.stack([x, x, x, x])))), ("cond", lambda f: (lambda x: jax.lax.cond(True, f, f, x))), ("fori_loop", lambda f: (lambda x: jax.lax.fori_loop(0, 1, ignore_index(f), x))), ("map", lambda f: (lambda x: jax.lax.map(f, x))), ("scan", lambda f: (lambda x: jax.lax.scan(identity_carry(f), None, x))), ("switch", lambda f: (lambda x: jax.lax.switch(0, [f, f], x))),
def test_model_shape( self, separate_memory_values=False, num_intermediate_layers=None, ): """Test loss function runs and produces expected values.""" config = copy.deepcopy(self.config) config['model_config']['encoder_config'][ 'separate_memory_values'] = separate_memory_values config['model_config']['encoder_config'][ 'num_intermediate_layers'] = num_intermediate_layers config = ml_collections.FrozenConfigDict(config) model_config = config.model_config encoder_config = model_config.encoder_config rows = encoder_config.rows preprocess_fn = mention_memory_task.MentionMemoryTask.make_preprocess_fn(config) # pylint: disable=line-too-long collater_fn = mention_memory_task.MentionMemoryTask.make_collater_fn( config) test_utils.force_multi_devices(self.n_devices) devices = jax.local_devices() model = mention_memory_encoder.MentionMemoryEncoder(**encoder_config) dummy_input = mention_memory_task.MentionMemoryTask.dummy_input(config) dummy_input = jax.device_put_replicated(dummy_input, devices) init_rng = jax.random.PRNGKey(0) split_rng = jax.random.split(init_rng, self.n_devices) memory_table = np.random.rand(rows, self.table_size // rows, encoder_config.memory_key_dim) memory_keys = jax.device_put_replicated(memory_table, devices) memory_values = memory_table.reshape(-1, encoder_config.memory_key_dim) memory_values = jax.device_put_replicated(memory_values, devices) memory_identifiers = np.arange(self.table_size) memory_identifiers = jax.device_put_replicated(memory_identifiers, devices) memory_entity_ids = memory_identifiers memory_text_entities = np.zeros( (self.table_size, encoder_config.n_memory_text_entities), dtype=np.int32) memory_text_entities = jax.device_put_replicated( memory_text_entities, devices) def model_init(*args, **kwargs): return model.init(*args, method=model.forward, **kwargs) initial_variables = jax.pmap(model_init, 'batch', static_broadcasted_argnums=2)( split_rng, dummy_input, True, ) initial_variables = {'params': initial_variables['params']} initial_variables['constants'] = { 'memory_keys': memory_keys, 'memory_values': memory_values, 'memory_identifiers': memory_identifiers, 'memory_entity_ids': memory_entity_ids, 'memory_text_entities': memory_text_entities, } raw_example = test_utils.gen_mention_pretraining_sample( self.text_length, self.n_mentions, self.n_linked_mentions, max_length=encoder_config.max_length) processed_example = preprocess_fn(raw_example) batch = { key: np.tile(value, (config.per_device_batch_size, 1)) for key, value in processed_example.items() } batch = collater_fn(batch) batch = { key: test_utils.tensor_to_numpy(value) for key, value in batch.items() } batch = { key: jax.device_put_replicated(value, devices) for key, value in batch.items() } def model_apply(*args, **kwargs): return model.apply(*args, method=model.forward, **kwargs) papply = jax.pmap(model_apply, 'batch', static_broadcasted_argnums=(2)) encoded_output, loss_helpers, _ = papply( { 'params': initial_variables['params'], 'constants': initial_variables['constants'], }, batch, True, ) self.assertEqual( encoded_output.shape, (self.n_devices, config.per_device_batch_size, encoder_config.max_length, encoder_config.hidden_size)) memory_value_dim = encoder_config.memory_value_dim memory_key_dim = encoder_config.memory_key_dim memory_size = memory_value_dim if memory_value_dim else memory_key_dim self.assertEqual(loss_helpers['target_mention_encodings'].shape, (self.n_devices, config.max_mention_targets * config.per_device_batch_size, memory_size))
def test_load_weights(self, separate_memory_values=False, memory_only=False): """Test saving and loading model recovers original parameters.""" config = copy.deepcopy(self.config) config['model_config']['encoder_config'][ 'separate_memory_values'] = separate_memory_values config = ml_collections.ConfigDict(config) model_config = config.model_config encoder_config = model_config.encoder_config rows = encoder_config.rows test_utils.force_multi_devices(self.n_devices) devices = jax.local_devices() model = mention_memory_encoder.MentionMemoryEncoder(**encoder_config) dummy_input = mention_memory_task.MentionMemoryTask.dummy_input(config) dummy_input = jax.device_put_replicated(dummy_input, devices) init_rng = jax.random.PRNGKey(0) split_rng = jax.random.split(init_rng, self.n_devices) memory_table = np.random.rand(rows, self.table_size // rows, encoder_config.memory_key_dim) memory_keys = jax.device_put_replicated(memory_table, devices) memory_values = memory_table.reshape(-1, encoder_config.memory_key_dim) memory_values = jax.device_put_replicated(memory_values, devices) memory_identifiers = np.arange(self.table_size) memory_identifiers = jax.device_put_replicated(memory_identifiers, devices) memory_entity_ids = memory_identifiers memory_text_entities = np.zeros( (self.table_size, encoder_config.n_memory_text_entities), dtype=np.int32) memory_text_entities = jax.device_put_replicated( memory_text_entities, devices) def model_init(*args, **kwargs): return model.init(*args, method=model.forward, **kwargs) initial_variables = jax.pmap(model_init, 'batch', static_broadcasted_argnums=2)( split_rng, dummy_input, True, ) initial_variables = {'params': initial_variables['params']} initial_variables['constants'] = { 'memory_keys': memory_keys, 'memory_values': memory_values, 'memory_identifiers': memory_identifiers, 'memory_entity_ids': memory_entity_ids, 'memory_text_entities': memory_text_entities, } n_shards = 4 tempdir_obj = self.create_tempdir() tempdir = tempdir_obj.full_path memory_key_base = os.path.join(tempdir, 'memory_keys') memory_value_base = os.path.join(tempdir, 'memory_values') memory_id_base = os.path.join(tempdir, 'memory_id') memory_entity_id_base = os.path.join(tempdir, 'memory_entity_id') memory_text_entities_base = os.path.join(tempdir, 'memory_text_entities') unreplicated_variables = jax_utils.unreplicate(initial_variables) unreplicated_variables['params'] = unreplicated_variables[ 'params'].unfreeze() if memory_only: load_weights = 'memory_only' else: load_weights = os.path.join(tempdir, 'weights') checkpoint_utils.save_weights(load_weights, unreplicated_variables['params']) memory_keys = initial_variables['constants']['memory_keys'] memory_keys = memory_keys.reshape(n_shards, -1, encoder_config.memory_key_dim) memory_values = initial_variables['constants']['memory_values'] memory_values = memory_values.reshape(n_shards, -1, encoder_config.memory_key_dim) memory_ids = initial_variables['constants'][ 'memory_identifiers'].reshape(n_shards, -1) memory_entity_ids = initial_variables['constants'][ 'memory_entity_ids'].reshape(n_shards, -1) memory_text_entities = initial_variables['constants'][ 'memory_text_entities'].reshape( n_shards, -1, encoder_config.n_memory_text_entities) for shard in range(n_shards): np.save(memory_key_base + str(shard), memory_keys[shard]) np.save(memory_value_base + str(shard), memory_values[shard]) np.save(memory_id_base + str(shard), memory_ids[shard]) np.save(memory_entity_id_base + str(shard), memory_entity_ids[shard]) np.save(memory_entity_id_base + str(shard), memory_entity_ids[shard]) np.save(memory_text_entities_base + str(shard), memory_text_entities[shard]) config.memory_key_pattern = memory_key_base + '*' config.memory_value_pattern = memory_value_base + '*' config.memory_id_pattern = memory_id_base + '*' config.memory_entity_id_pattern = memory_entity_id_base + '*' config.memory_text_entities_pattern = memory_text_entities_base + '*' config.load_weights = load_weights loaded_variables = mention_memory_encoder.MentionMemoryEncoder.load_weights( config) arrayeq = lambda x, y: jnp.all(x == y) constants = { key: value for key, value in initial_variables['constants'].items() if not (key == 'memory_values' and not separate_memory_values) } comparison_variables = {'constants': constants} if not memory_only: comparison_variables['params'] = initial_variables[ 'params'].unfreeze() self.assertTrue( jax.tree_map(arrayeq, loaded_variables, comparison_variables))
def sync_batch_stats(state): """Sync the batch statistics across replicas.""" avg = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') return state.replace(model_state=avg(state.model_state))
from timeit import default_timer as timer print(jax.local_device_count()) # 4 def random_walk(key, steps=1000): position = 0 for _ in range(steps): key, subkey = random.split(key) position += random.normal(subkey) return position jit_random_walk = jit(random_walk) p_random_walk = pmap(jit_random_walk) start = timer() jit_random_walk(random.PRNGKey(0)) end = timer() print("compile time serial:", end - start) start = timer() for i in range(4): jit_random_walk(random.PRNGKey(i)) end = timer() print("time elapsed serial:", end - start) keys = np.array([random.PRNGKey(i) for i in range(4)]) start = timer()
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') tf.enable_v2_behavior() # make sure tf does not allocate gpu memory tf.config.experimental.set_visible_devices([], 'GPU') if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(FLAGS.model_dir) rng = random.PRNGKey(0) image_size = 224 batch_size = FLAGS.batch_size if batch_size % jax.device_count() > 0: raise ValueError( 'Batch size must be divisible by the number of devices') local_batch_size = batch_size // jax.host_count() device_batch_size = batch_size // jax.device_count() platform = jax.local_devices()[0].platform dynamic_scale = None if FLAGS.half_precision: if platform == 'tpu': model_dtype = jnp.bfloat16 input_dtype = tf.bfloat16 else: model_dtype = jnp.float16 input_dtype = tf.float16 dynamic_scale = optim.DynamicScale() else: model_dtype = jnp.float32 input_dtype = tf.float32 train_iter = create_input_iter(local_batch_size, image_size, input_dtype, train=True, cache=FLAGS.cache) eval_iter = create_input_iter(local_batch_size, image_size, input_dtype, train=False, cache=FLAGS.cache) num_epochs = FLAGS.num_epochs steps_per_epoch = input_pipeline.TRAIN_IMAGES // batch_size steps_per_eval = input_pipeline.EVAL_IMAGES // batch_size steps_per_checkpoint = steps_per_epoch * 10 num_steps = steps_per_epoch * num_epochs base_learning_rate = FLAGS.learning_rate * batch_size / 256. model, model_state = create_model(rng, device_batch_size, image_size, model_dtype) optimizer = optim.Momentum(beta=FLAGS.momentum, nesterov=True).create(model) state = TrainState(step=0, optimizer=optimizer, model_state=model_state, dynamic_scale=dynamic_scale) del model, model_state # do not keep a copy of the initial model state = restore_checkpoint(state) step_offset = int( state.step) # step_offset > 0 if restarting from checkpoint state = jax_utils.replicate(state) learning_rate_fn = create_learning_rate_fn(base_learning_rate, steps_per_epoch, num_epochs) p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn), axis_name='batch') p_eval_step = jax.pmap(eval_step, axis_name='batch') epoch_metrics = [] t_loop_start = time.time() for step, batch in zip(range(step_offset, num_steps), train_iter): state, metrics = p_train_step(state, batch) epoch_metrics.append(metrics) if (step + 1) % steps_per_epoch == 0: epoch = step // steps_per_epoch epoch_metrics = common_utils.get_metrics(epoch_metrics) summary = jax.tree_map(lambda x: x.mean(), epoch_metrics) logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) steps_per_sec = steps_per_epoch / (time.time() - t_loop_start) t_loop_start = time.time() if jax.host_id() == 0: for key, vals in epoch_metrics.items(): tag = 'train_%s' % key for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) summary_writer.scalar('steps per second', steps_per_sec, step) epoch_metrics = [] eval_metrics = [] # sync batch statistics across replicas state = sync_batch_stats(state) for _ in range(steps_per_eval): eval_batch = next(eval_iter) metrics = p_eval_step(state, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) summary = jax.tree_map(lambda x: x.mean(), eval_metrics) logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) if jax.host_id() == 0: for key, val in eval_metrics.items(): tag = 'eval_%s' % key summary_writer.scalar(tag, val.mean(), step) summary_writer.flush() if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps: state = sync_batch_stats(state) save_checkpoint(state) # Wait until computations are done before exiting jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser( (ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file( json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses( ) if (os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir): raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome.") # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", level=logging.INFO, datefmt="[%X]", ) # Log on each process the small summary: logger = logging.getLogger(__name__) # Set the verbosity to info of the Transformers logger (on main process only): logger.info(f"Training/evaluation parameters {training_args}") # Set seed before initializing model. set_seed(training_args.seed) # Handle the repository creation if training_args.push_to_hub: if training_args.hub_model_id is None: repo_name = get_full_repo_name(Path( training_args.output_dir).absolute().name, token=training_args.hub_token) else: repo_name = training_args.hub_model_id repo = Repository(training_args.output_dir, clone_from=repo_name) # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # (the dataset will be downloaded automatically from the datasets Hub). # # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called # 'text' is found. You can easily tweak this behavior (see below). if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir) if "validation" not in datasets.keys(): datasets["validation"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, ) datasets["train"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, ) else: data_files = {} if data_args.train_file is not None: data_files["train"] = data_args.train_file if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file extension = data_args.train_file.split(".")[-1] if extension == "txt": extension = "text" datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) if "validation" not in datasets.keys(): datasets["validation"] = load_dataset( extension, data_files=data_files, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, ) datasets["train"] = load_dataset( extension, data_files=data_files, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, ) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. # Load pretrained model and tokenizer if model_args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer) elif model_args.model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer) else: raise ValueError( "You are instantiating a new tokenizer from scratch. This is not supported by this script." "You can do it from another script, save it, and load it from here, using --tokenizer_name." ) if model_args.config_name: config = T5Config.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir, vocab_size=len(tokenizer)) elif model_args.model_name_or_path: config = T5Config.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) else: config = CONFIG_MAPPING[model_args.model_type]() logger.warning( "You are instantiating a new config instance from scratch.") # Preprocessing the datasets. # First we tokenize all the texts. if training_args.do_train: column_names = datasets["train"].column_names else: column_names = datasets["validation"].column_names text_column_name = "text" if "text" in column_names else column_names[0] max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts. # Since we make sure that all sequences are of the same length, no attention_mask is needed. def tokenize_function(examples): return tokenizer(examples[text_column_name], return_attention_mask=False) tokenized_datasets = datasets.map( tokenize_function, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, ) # T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token. # To ensure that the input length is `max_seq_length`, we need to increase the maximum length # according to `mlm_probability` and `mean_noise_span_length`. We can also define the label length accordingly. expanded_inputs_length, targets_length = compute_input_and_target_lengths( inputs_length=max_seq_length, noise_density=data_args.mlm_probability, mean_noise_span_length=data_args.mean_noise_span_length, ) # Main data processing function that will concatenate all texts from our dataset and generate chunks of expanded_inputs_length. def group_texts(examples): # Concatenate all texts. concatenated_examples = { k: list(chain(*examples[k])) for k in examples.keys() } total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # customize this part to your needs. if total_length >= expanded_inputs_length: total_length = (total_length // expanded_inputs_length) * expanded_inputs_length # Split by chunks of max_len. result = { k: [ t[i:i + expanded_inputs_length] for i in range(0, total_length, expanded_inputs_length) ] for k, t in concatenated_examples.items() } return result # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value # might be slower to preprocess. # # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map tokenized_datasets = tokenized_datasets.map( group_texts, batched=True, num_proc=data_args.preprocessing_num_workers, load_from_cache_file=not data_args.overwrite_cache, ) # Enable tensorboard only on the master node has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter( log_dir=Path(training_args.output_dir)) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable.") # Initialize our training rng = jax.random.PRNGKey(training_args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) if model_args.model_name_or_path: model = FlaxT5ForConditionalGeneration.from_pretrained( model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)) else: config.vocab_size = len(tokenizer) model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr( jnp, model_args.dtype)) # Data collator # This one will take care of randomly masking the tokens. data_collator = FlaxDataCollatorForT5MLM( tokenizer=tokenizer, noise_density=data_args.mlm_probability, mean_noise_span_length=data_args.mean_noise_span_length, input_length=max_seq_length, target_length=targets_length, pad_token_id=model.config.pad_token_id, decoder_start_token_id=model.config.decoder_start_token_id, ) # Store some constant num_epochs = int(training_args.num_train_epochs) train_batch_size = int( training_args.per_device_train_batch_size) * jax.device_count() eval_batch_size = int( training_args.per_device_eval_batch_size) * jax.device_count() num_train_steps = len( tokenized_datasets["train"]) // train_batch_size * num_epochs # Create learning rate schedule warmup_fn = optax.linear_schedule( init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps) decay_fn = optax.linear_schedule( init_value=training_args.learning_rate, end_value=0, transition_steps=num_train_steps - training_args.warmup_steps, ) linear_decay_lr_schedule_fn = optax.join_schedules( schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]) # We use Optax's "masking" functionality to not apply weight decay # to bias and LayerNorm scale parameters. decay_mask_fn returns a # mask boolean with the same structure as the parameters. # The mask is True for parameters that should be decayed. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) flat_mask = { path: (path[-1] != "bias" and path[-2:] not in [("layer_norm", "scale"), ("final_layer_norm", "scale")]) for path in flat_params } return traverse_util.unflatten_dict(flat_mask) # create adam optimizer if training_args.adafactor: # We use the default parameters here to initialize adafactor, # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74 optimizer = optax.adafactor( learning_rate=linear_decay_lr_schedule_fn, ) else: optimizer = optax.adamw( learning_rate=linear_decay_lr_schedule_fn, b1=training_args.adam_beta1, b2=training_args.adam_beta2, weight_decay=training_args.weight_decay, mask=decay_mask_fn, ) # Setup train state state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer) # Define gradient update step fn def train_step(state, batch, dropout_rng): dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) def loss_fn(params): labels = batch.pop("labels") logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] # compute loss loss = optax.softmax_cross_entropy( logits, onehot(labels, logits.shape[-1])).mean() return loss grad_fn = jax.value_and_grad(loss_fn) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) metrics = jax.lax.pmean( { "loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step) }, axis_name="batch") return new_state, metrics, new_dropout_rng # Create parallel version of the train step p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, )) # Define eval fn def eval_step(params, batch): labels = batch.pop("labels") logits = model(**batch, params=params, train=False)[0] # compute loss loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) # compute accuracy accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) # summarize metrics metrics = {"loss": loss.mean(), "accuracy": accuracy.mean()} metrics = jax.lax.pmean(metrics, axis_name="batch") return metrics p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0, )) # Replicate the train state on each device state = jax_utils.replicate(state) train_time = 0 epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0) for epoch in epochs: # ======================== Training ================================ train_start = time.time() train_metrics = [] # Create sampling rng rng, input_rng = jax.random.split(rng) # Generate an epoch by shuffling sampling indices from the train dataset num_train_samples = len(tokenized_datasets["train"]) train_samples_idx = np.random.permutation(np.arange(num_train_samples)) train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) # Gather the indexes for creating the batch and do a training step for step, batch_idx in enumerate( tqdm(train_batch_idx, desc="Training...", position=1)): samples = [ tokenized_datasets["train"][int(idx)] for idx in batch_idx ] model_inputs = data_collator(samples) # Model forward model_inputs = shard(model_inputs.data) state, train_metric, dropout_rngs = p_train_step( state, model_inputs, dropout_rngs) train_metrics.append(train_metric) cur_step = epoch * (num_train_samples // train_batch_size) + step if cur_step % training_args.logging_steps == 0 and cur_step > 0: # Save metrics train_metric = jax_utils.unreplicate(train_metric) train_time += time.time() - train_start if has_tensorboard and jax.process_index() == 0: write_train_metric(summary_writer, train_metrics, train_time, cur_step) epochs.write( f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})" ) train_metrics = [] if cur_step % training_args.eval_steps == 0 and cur_step > 0: # ======================== Evaluating ============================== num_eval_samples = len(tokenized_datasets["validation"]) eval_samples_idx = jnp.arange(num_eval_samples) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) eval_metrics = [] for i, batch_idx in enumerate( tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): samples = [ tokenized_datasets["validation"][int(idx)] for idx in batch_idx ] model_inputs = data_collator(samples) # Model forward model_inputs = shard(model_inputs.data) metrics = p_eval_step(state.params, model_inputs) eval_metrics.append(metrics) # get eval metrics eval_metrics = get_metrics(eval_metrics) eval_metrics = jax.tree_map(jnp.mean, eval_metrics) # Update progress bar epochs.write( f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})" ) # Save metrics if has_tensorboard and jax.process_index() == 0: write_eval_metric(summary_writer, eval_metrics, cur_step) if cur_step % training_args.save_steps == 0 and cur_step > 0: # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get( jax.tree_map(lambda x: x[0], state.params)) model.save_pretrained(training_args.output_dir, params=params) tokenizer.save_pretrained(training_args.output_dir) if training_args.push_to_hub: repo.push_to_hub( commit_message= f"Saving weights and logs of step {cur_step}", blocking=False) # Eval after training if training_args.do_eval: num_eval_samples = len(tokenized_datasets["validation"]) eval_samples_idx = jnp.arange(num_eval_samples) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) eval_metrics = [] for i, batch_idx in enumerate( tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): samples = [ tokenized_datasets["validation"][int(idx)] for idx in batch_idx ] model_inputs = data_collator(samples) # Model forward model_inputs = shard(model_inputs.data) metrics = p_eval_step(state.params, model_inputs) eval_metrics.append(metrics) # get eval metrics eval_metrics = get_metrics(eval_metrics) eval_metrics = jax.tree_map(lambda metric: jnp.mean(metric).item(), eval_metrics) if jax.process_index() == 0: eval_metrics = { f"eval_{metric_name}": value for metric_name, value in eval_metrics.items() } path = os.path.join(training_args.output_dir, "eval_results.json") with open(path, "w") as f: json.dump(eval_metrics, f, indent=4, sort_keys=True)
def main(): args = parse_args() # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Setup logging, we only want one process per machine to log things on the screen. logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) if jax.process_index() == 0: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_info() else: datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named # label if at least two columns are provided. # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this # single column. You can easily tweak this behavior (see below) # In distributed training, the load_dataset function guarantee that only one local process can concurrently # download the dataset. if args.task_name is not None: # Downloading and loading a dataset from the hub. raw_datasets = load_dataset("glue", args.task_name) else: # Loading the dataset from local csv or json file. data_files = {} if args.train_file is not None: data_files["train"] = args.train_file if args.validation_file is not None: data_files["validation"] = args.validation_file extension = (args.train_file if args.train_file is not None else args.valid_file).split(".")[-1] raw_datasets = load_dataset(extension, data_files=data_files) # See more about loading any type of standard or custom dataset at # https://huggingface.co/docs/datasets/loading_datasets.html. # Labels if args.task_name is not None: is_regression = args.task_name == "stsb" if not is_regression: label_list = raw_datasets["train"].features["label"].names num_labels = len(label_list) else: num_labels = 1 else: # Trying to have good defaults here, don't hesitate to tweak to your needs. is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"] if is_regression: num_labels = 1 else: # A useful fast method: # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique label_list = raw_datasets["train"].unique("label") label_list.sort() # Let's sort it for determinism num_labels = len(label_list) # Load pretrained model and tokenizer config = AutoConfig.from_pretrained(args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name) tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer) model = FlaxAutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, config=config) # Preprocessing the datasets if args.task_name is not None: sentence1_key, sentence2_key = task_to_keys[args.task_name] else: # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"] if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: sentence1_key, sentence2_key = "sentence1", "sentence2" else: if len(non_label_column_names) >= 2: sentence1_key, sentence2_key = non_label_column_names[:2] else: sentence1_key, sentence2_key = non_label_column_names[0], None # Some models have set the order of the labels to use, so let's make sure we do use it. label_to_id = None if ( model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id and args.task_name is not None and not is_regression ): # Some have all caps in their config, some don't. label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): logger.info( f"The configuration of the model provided the following label correspondence: {label_name_to_id}. " "Using it!" ) label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)} else: logger.warning( "Your model seems to have been trained with labels, but they don't match the dataset: ", f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." "\nIgnoring the model labels as a result.", ) elif args.task_name is None: label_to_id = {v: i for i, v in enumerate(label_list)} def preprocess_function(examples): # Tokenize the texts texts = ( (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) ) result = tokenizer(*texts, padding="max_length", max_length=args.max_length, truncation=True) if "label" in examples: if label_to_id is not None: # Map labels to IDs (not necessary for GLUE tasks) result["labels"] = [label_to_id[l] for l in examples["label"]] else: # In all cases, rename the column to labels because the model will expect that. result["labels"] = examples["label"] return result processed_datasets = raw_datasets.map( preprocess_function, batched=True, remove_columns=raw_datasets["train"].column_names ) train_dataset = processed_datasets["train"] eval_dataset = processed_datasets["validation_matched" if args.task_name == "mnli" else "validation"] # Log a few random samples from the training set: for index in random.sample(range(len(train_dataset)), 3): logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") # Define a summary writer summary_writer = tensorboard.SummaryWriter(args.output_dir) summary_writer.hparams(vars(args)) def write_metric(train_metrics, eval_metrics, train_time, step): summary_writer.scalar("train_time", train_time, step) train_metrics = get_metrics(train_metrics) for key, vals in train_metrics.items(): tag = f"train_{key}" for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) for metric_name, value in eval_metrics.items(): summary_writer.scalar(f"eval_{metric_name}", value, step) num_epochs = int(args.num_train_epochs) rng = jax.random.PRNGKey(args.seed) train_batch_size = args.per_device_train_batch_size * jax.local_device_count() eval_batch_size = args.per_device_eval_batch_size * jax.local_device_count() learning_rate_fn = create_learning_rate_fn( len(train_dataset), train_batch_size, args.num_train_epochs, args.num_warmup_steps, args.learning_rate ) state = create_train_state(model, learning_rate_fn, is_regression, num_labels=num_labels) # define step functions def train_step( state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey ) -> Tuple[train_state.TrainState, float]: """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`.""" targets = batch.pop("labels") def loss_fn(params): logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss = state.loss_fn(logits, targets) return loss, logits grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (loss, logits), grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch") return new_state, metrics p_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,)) def eval_step(state, batch): logits = state.apply_fn(**batch, params=state.params, train=False)[0] return state.logits_fn(logits) p_eval_step = jax.pmap(eval_step, axis_name="batch") if args.task_name is not None: metric = load_metric("glue", args.task_name) else: metric = load_metric("accuracy") logger.info(f"===== Starting training ({num_epochs} epochs) =====") train_time = 0 for epoch in range(1, num_epochs + 1): logger.info(f"Epoch {epoch}") logger.info(" Training...") # make sure weights are replicated on each device state = replicate(state) train_start = time.time() train_metrics = [] rng, input_rng, dropout_rng = jax.random.split(rng, 3) # train for batch in glue_train_data_collator(input_rng, train_dataset, train_batch_size): dropout_rngs = shard_prng_key(dropout_rng) state, metrics = p_train_step(state, batch, dropout_rngs) train_metrics.append(metrics) train_time += time.time() - train_start logger.info(f" Done! Training metrics: {unreplicate(metrics)}") logger.info(" Evaluating...") rng, input_rng = jax.random.split(rng) # evaluate for batch in glue_eval_data_collator(eval_dataset, eval_batch_size): labels = batch.pop("labels") predictions = p_eval_step(state, batch) metric.add_batch(predictions=chain(*predictions), references=chain(*labels)) # evaluate also on leftover examples (not divisible by batch_size) num_leftover_samples = len(eval_dataset) % eval_batch_size # make sure leftover batch is evaluated on one device if num_leftover_samples > 0 and jax.process_index() == 0: # put weights on single device state = unreplicate(state) # take leftover samples batch = eval_dataset[-num_leftover_samples:] batch = {k: jnp.array(v) for k, v in batch.items()} labels = batch.pop("labels") predictions = eval_step(state, batch) metric.add_batch(predictions=predictions, references=labels) eval_metric = metric.compute() logger.info(f" Done! Eval metrics: {eval_metric}") cur_step = epoch * (len(train_dataset) // train_batch_size) write_metric(train_metrics, eval_metric, train_time, cur_step) # save last checkpoint if jax.process_index() == 0: params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) model.save_pretrained(args.output_dir, params=params)
def f(x): return jax.pmap(lambda x: variable(x, name='x'))(x)
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ tf.io.gfile.makedirs(workdir) vocab_path = config.vocab_path if vocab_path is None: vocab_path = os.path.join(workdir, "sentencepiece_model") tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) # Load Dataset # --------------------------------------------------------------------------- logging.info("Initializing dataset.") train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets( n_devices=jax.local_device_count(), config=config, shard_idx=jax.host_id(), shard_count=jax.host_count(), vocab_path=vocab_path) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_id = decode.EOS_ID # Default Sentencepiece EOS token. def decode_tokens(toks): valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode("utf-8") if config.num_predict_steps > 0: predict_ds = predict_ds.take(config.num_predict_steps) logging.info("Initializing model, optimizer, and step functions.") # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = models.TransformerConfig( vocab_size=vocab_size, output_vocab_size=vocab_size, share_embeddings=config.share_embeddings, logits_via_embedding=config.logits_via_embedding, dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, emb_dim=config.emb_dim, num_heads=config.num_heads, num_layers=config.num_layers, qkv_dim=config.qkv_dim, mlp_dim=config.mlp_dim, max_len=max(config.max_target_length, config.max_eval_target_length), dropout_rate=config.dropout_rate, attention_dropout_rate=config.attention_dropout_rate, deterministic=False, decode=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(deterministic=True, decode=True) start_step = 0 rng = jax.random.PRNGKey(config.seed) rng, init_rng = jax.random.split(rng) input_shape = (config.per_device_batch_size, config.max_target_length) target_shape = (config.per_device_batch_size, config.max_target_length) m = models.Transformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) # apply an optimizer to this tree optimizer_def = optim.Adam(config.learning_rate, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=config.weight_decay) optimizer = optimizer_def.create(initial_variables["params"]) # We access model params only from optimizer below via optimizer.target. del initial_variables if config.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint(workdir, optimizer) # Grab last step. start_step = int(optimizer.state.step) writer = metric_writers.create_default_writer( workdir, just_logging=jax.host_id() > 0) if start_step == 1: writer.write_hparams(dict(config)) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps) # compile multidevice versions of train/eval/predict step and cache init fn. p_train_step = jax.pmap(functools.partial( train_step, config=train_config, learning_rate_fn=learning_rate_fn, label_smoothing=config.label_smoothing), axis_name="batch", donate_argnums=(0, )) # pytype: disable=wrong-arg-types p_eval_step = jax.pmap(functools.partial( eval_step, config=eval_config, label_smoothing=config.label_smoothing), axis_name="batch") p_init_cache = jax.pmap(functools.partial( initialize_cache, max_decode_len=config.max_predict_length, config=predict_config), axis_name="batch") p_pred_step = jax.pmap( functools.partial(predict_step, config=predict_config, beam_size=config.beam_size), axis_name="batch", static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap"d training update for performance. dropout_rngs = jax.random.split(rng, jax.local_device_count()) del rng logging.info("Starting training loop.") hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) if jax.host_id() == 0: hooks += [ report_progress, periodic_actions.Profile(num_profile_steps=5) ] train_metrics = [] with metric_writers.ensure_flushes(writer): for step in range(start_step, config.num_train_steps): is_last_step = step == config.num_train_steps - 1 # Shard data to devices and do a training step. with jax.profiler.StepTraceContext("train", step_num=step): batch = common_utils.shard( jax.tree_map(np.asarray, next(train_iter))) optimizer, metrics = p_train_step(optimizer, batch, dropout_rng=dropout_rngs) train_metrics.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) for h in hooks: h(step) # Periodic metric handling. if step % config.eval_every_steps == 0 or is_last_step: with report_progress.timed("training_metrics"): logging.info("Gathering training metrics.") train_metrics = common_utils.get_metrics(train_metrics) lr = train_metrics.pop("learning_rate").mean() metrics_sums = jax.tree_map(jnp.sum, train_metrics) denominator = metrics_sums.pop("denominator") summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary["learning_rate"] = lr summary = {"train_" + k: v for k, v in summary.items()} writer.write_scalars(step, summary) train_metrics = [] with report_progress.timed("eval"): eval_results = evaluate( p_eval_step=p_eval_step, target=optimizer.target, eval_ds=eval_ds, num_eval_steps=config.num_eval_steps) writer.write_scalars( step, {"eval_" + k: v for k, v in eval_results.items()}) with report_progress.timed("translate_and_bleu"): exemplars, bleu_score = translate_and_calculate_bleu( p_pred_step=p_pred_step, p_init_cache=p_init_cache, target=optimizer.target, predict_ds=predict_ds, decode_tokens=decode_tokens, max_predict_length=config.max_predict_length) writer.write_scalars(step, {"bleu": bleu_score}) writer.write_texts(step, {"samples": exemplars}) # Save a checkpoint on one host after every checkpoint_freq steps. save_checkpoint = step % config.checkpoint_every_steps or is_last_step if config.save_checkpoints and save_checkpoint and jax.host_id(): with report_progress.timed("checkpoint"): checkpoints.save_checkpoint( workdir, jax_utils.unreplicate(optimizer), step)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Make sure tf does not allocate gpu memory. tf.config.experimental.set_visible_devices([], 'GPU') if FLAGS.jax_backend_target: jax.config.FLAGS.jax_xla_backend = 'tpu_driver' jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: tf.io.gfile.makedirs(FLAGS.model_dir) if FLAGS.batch_size % n_devices: raise ValueError('Batch size must be divisible by the number of devices') vocab_path = FLAGS.vocab_path if vocab_path is None: vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model') tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') if FLAGS.dynamic: train_ds_mgr, eval_ds, predict_ds, encoder = input_pipeline.get_dynamic_datasets( dataset_name=FLAGS.dataset_name, eval_dataset_name=FLAGS.eval_dataset_name, shard_idx=jax.host_id(), shard_count=jax.host_count(), data_dir=FLAGS.data_dir, vocab_path=FLAGS.vocab_path, target_vocab_size=FLAGS.vocab_size, batch_size=FLAGS.batch_size, max_length=FLAGS.max_target_length, max_eval_length=FLAGS.max_eval_target_length, paracrawl_size=FLAGS.paracrawl_size, is_scores_path=FLAGS.is_scores_path, num_buckets=FLAGS.num_data_buckets) if FLAGS.static: weights = np.array([float(w) for w in FLAGS.static.split(',')]) assert len(weights) == FLAGS.num_data_buckets train_ds = train_ds_mgr.sampled_dataset(weights) FLAGS.dynamic = False else: init_dist = np.zeros(FLAGS.num_data_buckets) if FLAGS.data_selection_size < FLAGS.num_data_buckets: init_dist[range(FLAGS.data_selection_size)] = 1.0 train_ds = train_ds_mgr.sampled_dataset(init_dist) else: train_ds = build_split(train_ds_mgr, 1.0) else: train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets( dataset_name=FLAGS.dataset_name, eval_dataset_name=FLAGS.eval_dataset_name, shard_idx=jax.host_id(), shard_count=jax.host_count(), data_dir=FLAGS.data_dir, vocab_path=vocab_path, target_vocab_size=FLAGS.vocab_size, batch_size=FLAGS.batch_size, max_length=FLAGS.max_target_length, max_eval_length=FLAGS.max_eval_target_length, paracrawl_size=FLAGS.paracrawl_size, is_scores_path=FLAGS.is_scores_path, num_to_keep=FLAGS.data_selection_size, pseudo_path=FLAGS.pseudo_path, repeat_count=FLAGS.repeat_count, newscommentary_size=FLAGS.newscommentary_size) if FLAGS.aux_eval_dataset: aux_datasets = [] aux_names = FLAGS.aux_eval_dataset.split(',') for name in aux_names: _, aux_eval_ds, _, _ = input_pipeline.get_wmt_datasets( dataset_name=name, eval_dataset_name=None, shard_idx=jax.host_id(), shard_count=jax.host_count(), data_dir=FLAGS.data_dir, vocab_path=vocab_path, target_vocab_size=FLAGS.vocab_size, batch_size=FLAGS.batch_size, max_length=FLAGS.max_target_length, max_eval_length=FLAGS.max_eval_target_length, paracrawl_size=FLAGS.paracrawl_size, is_scores_path=FLAGS.is_scores_path, num_to_keep=FLAGS.data_selection_size, pseudo_path=FLAGS.pseudo_path, repeat_count=FLAGS.repeat_count, newscommentary_size=FLAGS.newscommentary_size) aux_datasets.append(aux_eval_ds) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_id = decode.EOS_ID # Default Sentencepiece EOS token. def decode_tokens(toks): valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode('utf-8') logging.info('Initializing model, optimizer, and step functions.') # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = models.TransformerConfig( vocab_size=vocab_size, output_vocab_size=vocab_size, share_embeddings=FLAGS.share_embeddings, logits_via_embedding=FLAGS.logits_via_embedding, dtype=jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32, emb_dim=FLAGS.emb_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.qkv_dim, mlp_dim=FLAGS.mlp_dim, max_len=max(FLAGS.max_target_length, FLAGS.max_eval_target_length), dropout_rate=FLAGS.dropout_rate, attention_dropout_rate=FLAGS.attention_dropout_rate, deterministic=False, decode=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(deterministic=True, decode=True) start_step = 0 rng = jax.random.PRNGKey(FLAGS.random_seed) rng, init_rng = jax.random.split(rng) # It's possible that is supposed to be per device batch size input_shape = (FLAGS.batch_size, FLAGS.max_target_length) target_shape = (FLAGS.batch_size, FLAGS.max_target_length) m = models.Transformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) # apply an optimizer to this tree optimizer_def = optim.Adam( FLAGS.learning_rate, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(initial_variables['params']) # We access model params only from optimizer below via optimizer.target. del initial_variables if FLAGS.restore_checkpoints: logging.info('Restoring checkpoint.') # If we have a pretrained model, use that. Else, just continue where leftoff model_path = FLAGS.pretrained_model_dir if FLAGS.pretrained_model_dir else FLAGS.model_dir optimizer = checkpoints.restore_checkpoint(model_path, optimizer) # Grab last step. start_step = int(optimizer.state.step) if FLAGS.adapter != NONE: adapter = optim.ModelParamTraversal(lambda path, _: FLAGS.adapter in path) optimizer = optimizer_def.create(optimizer.target, focus=adapter) writer = metric_writers.create_default_writer( FLAGS.model_dir, just_logging=jax.process_index() > 0) flag_key = [k for k in FLAGS.flags_by_module_dict().keys() if 'wmt.par' in k ] if flag_key: flag_key = flag_key[0] local_flags = { f.name: f.value for f in FLAGS.flags_by_module_dict()[flag_key] } writer.write_hparams(local_flags) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) if FLAGS.adapter != NONE: learning_rate_fn = common.create_learning_rate_scheduler( factors='constant', base_learning_rate=FLAGS.learning_rate, warmup_steps=FLAGS.warmup_steps) else: learning_rate_fn = common.create_learning_rate_scheduler( base_learning_rate=FLAGS.learning_rate, warmup_steps=FLAGS.warmup_steps, steps_per_cycle=FLAGS.steps_per_cycle, init_step=start_step, finetune_lr=FLAGS.finetune_lr) # compile multidevice versions of train/eval/predict step and cache init fn. p_train_step = jax.pmap( functools.partial( train_step, config=train_config, learning_rate_fn=learning_rate_fn, label_smoothing=FLAGS.label_smoothing), axis_name='batch', donate_argnums=(0,)) # pytype: disable=wrong-arg-types p_eval_step = jax.pmap( functools.partial(eval_step, config=eval_config), axis_name='batch') p_init_cache = jax.pmap( functools.partial( initialize_cache, max_decode_len=FLAGS.max_predict_length, config=predict_config), axis_name='batch') p_pred_step = jax.pmap( functools.partial( predict_step, config=predict_config, beam_size=FLAGS.beam_size), axis_name='batch', static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant p_get_diag_grads = jax.pmap( functools.partial( get_diag_grads, config=eval_config), axis_name='batch') p_get_bucket_score = jax.pmap( functools.partial( get_diag_score, strategy=FLAGS.strategy), axis_name='batch') # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap"d training update for performance. dropout_rngs = jax.random.split(rng, jax.local_device_count()) del rng logging.info('Starting training loop.') hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=FLAGS.num_train_steps, writer=writer) if jax.process_index() == 0: hooks += [ report_progress, periodic_actions.Profile(logdir=FLAGS.model_dir, num_profile_steps=5) ] train_metrics = [] total_steps = start_step + FLAGS.num_train_steps best_eval_loss = 1000 curr_eval_loss = 1000 with metric_writers.ensure_flushes(writer): for step in range(start_step, total_steps): is_last_step = step == total_steps - 1 if FLAGS.dynamic and ((step - start_step) % FLAGS.resample_freq == 0): # Dynamic macro: use gradient alignment to score different ratios # of top k vs bottom N-k bins if FLAGS.macro: train_iter = get_macro_distribution(p_get_diag_grads, p_get_bucket_score, aux_eval_ds, train_ds_mgr, optimizer, eval_ds) else: # Use gradient alignment to score bins # take the top k bins and sample uniformly from them. raw_distribution = get_new_distribution(p_get_diag_grads, p_get_bucket_score, aux_eval_ds, train_ds_mgr, optimizer, eval_ds) logging.info(raw_distribution) selected = np.argsort( raw_distribution)[::-1][:FLAGS.data_selection_size] new_distribution = np.zeros(100) new_distribution[selected] = 1.0 logging.info(new_distribution) train_ds = train_ds_mgr.sampled_dataset(new_distribution) train_iter = iter(train_ds) # Shard data to devices and do a training step. with jax.profiler.StepTraceAnnotation('train', step_num=step): try: batch = common_utils.shard(jax.tree_map(np.asarray, next(train_iter))) optimizer, metrics = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) train_metrics.append(metrics) except StopIteration: is_last_step = True # Quick indication that training is happening. logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step) for h in hooks: h(step) # Periodic metric handling. if (step - start_step) % FLAGS.eval_frequency == 0 or is_last_step: with report_progress.timed('training_metrics'): logging.info('Gathering training metrics.') train_metrics = common_utils.get_metrics(train_metrics) lr = train_metrics.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, train_metrics) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr summary = {'train_' + k: v for k, v in summary.items()} writer.write_scalars(step, summary) train_metrics = [] with report_progress.timed('eval'): eval_results = evaluate( p_eval_step=p_eval_step, target=optimizer.target, eval_ds=eval_ds, num_eval_steps=FLAGS.num_eval_steps) curr_eval_loss = eval_results['loss'] writer.write_scalars( step, {'eval_' + k: v for k, v in eval_results.items()}) if FLAGS.aux_eval_dataset: for aux_i, aux_eval_ds in enumerate(aux_datasets): with report_progress.timed('aux_eval'): eval_results = evaluate( p_eval_step=p_eval_step, target=optimizer.target, eval_ds=aux_eval_ds, num_eval_steps=FLAGS.num_eval_steps) writer.write_scalars( step, { 'aux' + str(aux_i) + '_eval_' + k: v for k, v in eval_results.items() }) if FLAGS.compute_bleu: with report_progress.timed('translate_and_bleu'): exemplars, bleu_score = translate_and_calculate_bleu( p_pred_step=p_pred_step, p_init_cache=p_init_cache, target=optimizer.target, predict_ds=predict_ds, decode_tokens=decode_tokens, max_predict_length=FLAGS.max_predict_length) writer.write_scalars(step, {'bleu': bleu_score}) writer.write_texts(step, {'samples': exemplars}) # Save a checkpoint on one host after every checkpoint_freq steps. save_checkpoint = ((step - start_step) % FLAGS.checkpoint_freq == 0 or is_last_step) if FLAGS.save_checkpoints and save_checkpoint and jax.host_id() == 0: if curr_eval_loss < best_eval_loss: # only save better checkpoints best_eval_loss = curr_eval_loss with report_progress.timed('checkpoint'): checkpoints.save_checkpoint( FLAGS.model_dir, jax_utils.unreplicate(optimizer), step, keep=FLAGS.chkpts_to_keep, overwrite=True) if is_last_step: break
def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs): """ Run the MCMC samplers and collect samples. :param random.PRNGKey rng_key: Random number generator key to be used for the sampling. For multi-chains, a batch of `num_chains` keys can be supplied. If `rng_key` does not have batch_size, it will be split in to a batch of `num_chains` keys. :param args: Arguments to be provided to the :meth:`numpyro.infer.mcmc.MCMCKernel.init` method. These are typically the arguments needed by the `model`. :param extra_fields: Extra fields (aside from `z`, `diverging`) from :data:`numpyro.infer.mcmc.HMCState` to collect during the MCMC run. :type extra_fields: tuple or list :param init_params: Initial parameters to begin sampling. The type must be consistent with the input type to `potential_fn`. :param kwargs: Keyword arguments to be provided to the :meth:`numpyro.infer.mcmc.MCMCKernel.init` method. These are typically the keyword arguments needed by the `model`. .. note:: jax allows python code to continue even when the compiled code has not finished yet. This can cause troubles when trying to profile the code for speed. See https://jax.readthedocs.io/en/latest/async_dispatch.html and https://jax.readthedocs.io/en/latest/profiling.html for pointers on profiling jax programs. """ self._args = args self._kwargs = kwargs init_state = self._get_cached_init_state(rng_key, args, kwargs) if self.num_chains > 1 and rng_key.ndim == 1: rng_key = random.split(rng_key, self.num_chains) if self._warmup_state is not None: self._set_collection_params(0, self.num_samples, self.num_samples) init_state = self._warmup_state._replace(rng_key=rng_key) chain_method = self.chain_method if chain_method == 'parallel' and xla_bridge.device_count( ) < self.num_chains: chain_method = 'sequential' warnings.warn( 'There are not enough devices to run parallel chains: expected {} but got {}.' ' Chains will be drawn sequentially. If you are running MCMC in CPU,' ' consider to use `numpyro.set_host_device_count({})` at the beginning' ' of your program.'.format(self.num_chains, xla_bridge.device_count(), self.num_chains)) if init_params is not None and self.num_chains > 1: prototype_init_val = tree_flatten(init_params)[0][0] if jnp.shape(prototype_init_val)[0] != self.num_chains: raise ValueError( '`init_params` must have the same leading dimension' ' as `num_chains`.') assert isinstance(extra_fields, (tuple, list)) collect_fields = tuple( set((self._sample_field, ) + tuple(self._default_fields) + tuple(extra_fields))) partial_map_fn = partial(self._single_chain_mcmc, args=args, kwargs=kwargs, collect_fields=collect_fields) map_args = (rng_key, init_state, init_params) if self.num_chains == 1: states_flat, last_state = partial_map_fn(map_args) states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat) else: if chain_method == 'sequential': if self.progress_bar: states, last_state = _laxmap(partial_map_fn, map_args) else: states, last_state = lax.map(partial_map_fn, map_args) elif chain_method == 'parallel': states, last_state = pmap(partial_map_fn)(map_args) # TODO: remove when https://github.com/google/jax/issues/3597 is resolved states = device_put(states) else: assert chain_method == 'vectorized' states, last_state = partial_map_fn(map_args) # swap num_samples x num_chains to num_chains x num_samples states = tree_map(lambda x: jnp.swapaxes(x, 0, 1), states) states_flat = tree_map( lambda x: jnp.reshape(x, (-1, ) + x.shape[2:]), states) self._last_state = last_state self._states = states self._states_flat = states_flat self._set_collection_params()
def train(optimizer: flax.optim.Optimizer, state: flax.nn.Collection, dataset_source: dataset_source_lib.DatasetSource, training_dir: str, num_epochs: int): """Trains the model. Args: optimizer: The optimizer targeting the model to train. state: Current state associated with the model (contains the batch norm MA). dataset_source: Container for the training dataset. training_dir: Parent directory where the tensorboard logs and model checkpoints should be saved. num_epochs: Number of epochs for which we want to train the model. """ checkpoint_dir = os.path.join(training_dir, 'checkpoints') summary_writer = tensorboard.SummaryWriter(training_dir) if jax.host_id() != 0: # Don't log if not first host. summary_writer.scalar = lambda *args: None prng_key = jax.random.PRNGKey(FLAGS.run_seed) if FLAGS.ema_decay: end_warmup_step = 1560 moving_averages = efficientnet_optim.ExponentialMovingAverage( (optimizer.target, state), FLAGS.ema_decay, end_warmup_step) # pytype:disable=wrong-arg-count def update_ema(optimizer, state, ema): step = optimizer.state.step return ema.update_moving_average((optimizer.target, state), step) pmapped_update_ema = jax.pmap(update_ema, axis_name='batch') else: pmapped_update_ema = moving_averages = None # Log initial results: if gfile.exists(checkpoint_dir): if FLAGS.ema_decay: optimizer, (state, moving_averages), epoch_last_checkpoint = restore_checkpoint( optimizer, (state, moving_averages), checkpoint_dir) else: optimizer, state, epoch_last_checkpoint = restore_checkpoint( optimizer, state, checkpoint_dir) # If last checkpoint was saved at the end of epoch n, then the first # training epochs to do when we resume training is n+1. initial_epoch = epoch_last_checkpoint + 1 info = 'Resuming training from epoch {}'.format(initial_epoch) logging.info(info) else: initial_epoch = jnp.array(0, dtype=jnp.int32) logging.info('Starting training from scratch.') optimizer = jax_utils.replicate(optimizer) state = jax_utils.replicate(state) if FLAGS.ema_decay: moving_averages = jax_utils.replicate(moving_averages) if FLAGS.use_learning_rate_schedule: if FLAGS.lr_schedule == 'cosine': learning_rate_fn = get_cosine_schedule(num_epochs, FLAGS.learning_rate, dataset_source.num_training_obs, dataset_source.batch_size) elif FLAGS.lr_schedule == 'exponential': learning_rate_fn = get_exponential_schedule( num_epochs, FLAGS.learning_rate, dataset_source.num_training_obs, dataset_source.batch_size) else: raise ValueError('Wrong schedule: ' + FLAGS.lr_schedule) else: learning_rate_fn = lambda step: FLAGS.learning_rate # pmap the training and evaluation functions. pmapped_train_step = jax.pmap( functools.partial( train_step, learning_rate_fn=learning_rate_fn, l2_reg=FLAGS.weight_decay), axis_name='batch', donate_argnums=(0, 1)) pmapped_eval_step = jax.pmap(eval_step, axis_name='batch') time_at_last_checkpoint = time.time() for epochs_id in range(initial_epoch, num_epochs): if epochs_id in FLAGS.additional_checkpoints_at_epochs: # To save additional checkpoints that will not be erase by later version, # we save them in a new directory. c_path = os.path.join(checkpoint_dir, 'additional_ckpt_' + str(epochs_id)) save_checkpoint(optimizer, state, c_path, epochs_id) tick = time.time() optimizer, state, moving_averages = train_for_one_epoch( dataset_source, optimizer, state, prng_key, pmapped_train_step, pmapped_update_ema, moving_averages, summary_writer) tock = time.time() info = 'Epoch {} finished in {:.2f}s.'.format(epochs_id, tock - tick) logging.info(info) # Evaluate the model on the test set, and optionally the training set. if (epochs_id + 1) % FLAGS.evaluate_every == 0: info = 'Evaluating at end of epoch {} (0-indexed)'.format(epochs_id) logging.info(info) tick = time.time() current_step = int(optimizer.state.step[0]) if FLAGS.also_eval_on_training_set: train_ds = dataset_source.get_train(use_augmentations=False) train_metrics = eval_on_dataset( optimizer.target, state, train_ds, pmapped_eval_step) for metric_name, metric_value in train_metrics.items(): summary_writer.scalar('eval_on_train_' + metric_name, metric_value, current_step) summary_writer.flush() if FLAGS.ema_decay: logging.info('Evaluating with EMA.') ema_model, ema_state = moving_averages.param_ema # pytype:disable=attribute-error test_ds = dataset_source.get_test() test_metrics = eval_on_dataset( ema_model, ema_state, test_ds, pmapped_eval_step) for metric_name, metric_value in test_metrics.items(): summary_writer.scalar('ema_test_' + metric_name, metric_value, current_step) summary_writer.flush() else: test_ds = dataset_source.get_test() test_metrics = eval_on_dataset( optimizer.target, state, test_ds, pmapped_eval_step) for metric_name, metric_value in test_metrics.items(): summary_writer.scalar('test_' + metric_name, metric_value, current_step) summary_writer.flush() tock = time.time() info = 'Evaluated model in {:.2f}.'.format(tock - tick) logging.info(info) # Save new checkpoint if the last one was saved more than # `save_progress_seconds` seconds ago. sec_from_last_ckpt = time.time() - time_at_last_checkpoint if sec_from_last_ckpt > FLAGS.save_progress_seconds: if FLAGS.ema_decay: save_checkpoint( optimizer, (state, moving_averages), checkpoint_dir, epochs_id) else: save_checkpoint(optimizer, state, checkpoint_dir, epochs_id) time_at_last_checkpoint = time.time() logging.info('Saved checkpoint.') # Always save final checkpoint if FLAGS.ema_decay: save_checkpoint( optimizer, (state, moving_averages), checkpoint_dir, epochs_id) else: save_checkpoint(optimizer, state, checkpoint_dir, epochs_id)
def pmap(func, dyn_vars=None, axis_name=None, in_axes=0, out_axes=0, static_broadcasted_argnums=(), devices=None, backend=None, axis_size=None, donate_argnums=(), global_arg_shapes=None, reduce_func=None): """Parallel compilation for class objects. Parallel compile a function or a module to run on multiple devices in parallel. Parameters ---------- func axis_name in_axes out_axes static_broadcasted_argnums devices backend axis_size donate_argnums global_arg_shapes Returns ------- Examples -------- """ from brainpy.building.brainobjects import DynamicalSystem if isinstance(func, DynamicalSystem): if len(func.steps): # DynamicalSystem has step functions # dynamical variables all_vars = (dyn_vars or func.vars().unique()) dyn_vars = TensorCollector() rand_vars = TensorCollector() for key, val in all_vars.items(): if isinstance(val, RandomState): rand_vars[key] = val else: dyn_vars[key] = val # reduce function if reduce_func is None: reduce_func = jnp.concatenate # static broadcast-ed arguments if static_broadcasted_argnums is None: static_broadcasted_argnums = () elif isinstance(static_broadcasted_argnums, int): static_broadcasted_argnums = (static_broadcasted_argnums + 2, ) elif isinstance(static_broadcasted_argnums, (tuple, list)): static_broadcasted_argnums = tuple( argnum + 2 for argnum in static_broadcasted_argnums) assert isinstance(static_broadcasted_argnums, (tuple, list)) # jit functions for key in func.steps.keys(): step = func.steps[key] func.steps[key] = _make_pmap( dyn_vars=dyn_vars, rand_vars=rand_vars, func=step, axis_name=axis_name, in_axes=in_axes, out_axes=out_axes, static_broadcasted_argnums=static_broadcasted_argnums, devices=devices, backend=backend, axis_size=axis_size, donate_argnums=donate_argnums, global_arg_shapes=global_arg_shapes, reduce_func=reduce_func, f_name=key) return func if callable(func): if dyn_vars is not None: dyn_vars = dyn_vars elif isinstance(func, Base): # Base has '__call__()' implementation dyn_vars = func.vars().unique() elif hasattr(func, '__self__'): if isinstance(func.__self__, Base): dyn_vars = func.__self__.vars().unique() if dyn_vars is None: return jax.pmap( func, axis_name=axis_name, in_axes=in_axes, out_axes=out_axes, static_broadcasted_argnums=static_broadcasted_argnums, devices=devices, backend=backend, axis_size=axis_size, donate_argnums=donate_argnums, global_arg_shapes=global_arg_shapes) else: # dynamical variables dyn_vars = TensorCollector() rand_vars = TensorCollector() for key, val in dyn_vars.items(): if isinstance(val, RandomState): rand_vars[key] = val else: dyn_vars[key] = val # static broadcast-ed arguments if static_broadcasted_argnums is None: static_broadcasted_argnums = () elif isinstance(static_broadcasted_argnums, int): static_broadcasted_argnums = (static_broadcasted_argnums + 2, ) elif isinstance(static_broadcasted_argnums, (tuple, list)): static_broadcasted_argnums = tuple( argnum + 2 for argnum in static_broadcasted_argnums) assert isinstance(static_broadcasted_argnums, (tuple, list)) # reduce function if reduce_func is None: reduce_func = jnp.concatenate # jit function func.__call__ = _make_pmap( dyn_vars=dyn_vars, rand_vars=rand_vars, func=func, axis_name=axis_name, in_axes=in_axes, out_axes=out_axes, static_broadcasted_argnums=static_broadcasted_argnums, devices=devices, backend=backend, axis_size=axis_size, donate_argnums=donate_argnums, global_arg_shapes=global_arg_shapes, reduce_func=reduce_func) return func else: raise errors.BrainPyError( f'Only support instance of {Base.__name__}, or a callable function, ' f'but we got {type(func)}.')
def train(): """Train model.""" batch_size = FLAGS.batch_size n_devices = jax.device_count() if jax.host_count() > 1: raise ValueError( 'PixelCNN++ example should not be run on more than 1 host' ' (for now)') if batch_size % n_devices > 0: raise ValueError( 'Batch size must be divisible by the number of devices') train_summary_writer, eval_summary_writer = get_summary_writers() # Load dataset data_source = input_pipeline.DataSource(train_batch_size=batch_size, eval_batch_size=batch_size) train_ds = data_source.train_ds eval_ds = data_source.eval_ds # Create dataset batch iterators train_iter = iter(train_ds) eval_iter = iter(eval_ds) # Compute steps per epoch and nb of eval steps steps_per_epoch = data_source.TRAIN_IMAGES // batch_size steps_per_eval = data_source.EVAL_IMAGES // batch_size steps_per_checkpoint = steps_per_epoch * 10 num_steps = steps_per_epoch * FLAGS.num_epochs # Create the model using data-dependent initialization. Don't shard the init # batch. assert FLAGS.init_batch_size <= batch_size init_batch = next(train_iter)['image']._numpy()[:FLAGS.init_batch_size] rng = random.PRNGKey(FLAGS.rng) rng, init_rng = random.split(rng) rng, dropout_rng = random.split(rng) initial_variables = model().init( { 'params': init_rng, 'dropout': dropout_rng }, init_batch)['params'] optimizer_def = optim.Adam(learning_rate=FLAGS.learning_rate, beta1=0.95, beta2=0.9995) optimizer = optimizer_def.create(initial_variables) optimizer, ema = restore_checkpoint(optimizer, initial_variables) ema = initial_variables step_offset = int(optimizer.state.step) optimizer, ema = jax_utils.replicate((optimizer, ema)) # Learning rate schedule learning_rate_fn = lambda step: FLAGS.learning_rate * FLAGS.lr_decay**step # pmap the train and eval functions p_train_step = jax.pmap(partial(train_step, learning_rate_fn=learning_rate_fn), axis_name='batch') p_eval_step = jax.pmap(eval_step, axis_name='batch') # Gather metrics train_metrics = [] for step, batch in zip(range(step_offset, num_steps), train_iter): # Load and shard the TF batch batch = load_and_shard_tf_batch(batch) # Generate a PRNG key that will be rolled into the batch. rng, step_rng = random.split(rng) sharded_rngs = common_utils.shard_prng_key(step_rng) # Train step optimizer, ema, metrics = p_train_step(optimizer, ema, batch, sharded_rngs) train_metrics.append(metrics) if (step + 1) % steps_per_epoch == 0: epoch = step // steps_per_epoch # We've finished an epoch train_metrics = common_utils.get_metrics(train_metrics) # Get training epoch summary for logging train_summary = jax.tree_map(lambda x: x.mean(), train_metrics) # Send stats to Tensorboard for key, vals in train_metrics.items(): for i, val in enumerate(vals): train_summary_writer.scalar(key, val, step - len(vals) + i + 1) # Reset train metrics train_metrics = [] # Evaluation eval_metrics = [] for _ in range(steps_per_eval): eval_batch = next(eval_iter) # Load and shard the TF batch eval_batch = load_and_shard_tf_batch(eval_batch) # Step metrics = p_eval_step(ema, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) # Get eval epoch summary for logging eval_summary = jax.tree_map(lambda x: x.mean(), eval_metrics) # Log epoch summary logging.info('Epoch %d: TRAIN loss=%.6f, EVAL loss=%.6f', epoch, train_summary['loss'], eval_summary['loss']) eval_summary_writer.scalar('loss', eval_summary['loss'], step) train_summary_writer.flush() eval_summary_writer.flush() if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps: save_checkpoint(optimizer, ema, step)
# Setup optimizer optimizer = Adam( learning_rate=training_args.learning_rate, weight_decay=training_args.weight_decay, beta1=training_args.adam_beta1, beta2=training_args.adam_beta2, ).create(model.params) # Create learning rate scheduler # warmup_steps = 0 causes the Flax optimizer to return NaNs; warmup_steps = 1 is functionally equivalent. lr_scheduler_fn = create_learning_rate_scheduler( base_learning_rate=training_args.learning_rate, warmup_steps=min(training_args.warmup_steps, 1) ) # Create parallel version of the training and evaluation steps p_training_step = jax.pmap(training_step, "batch", donate_argnums=(0,)) p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,)) # Replicate the optimizer on each device optimizer = jax_utils.replicate(optimizer) # Store some constant nb_epochs = int(training_args.num_train_epochs) batch_size = int(training_args.train_batch_size) eval_batch_size = int(training_args.eval_batch_size) epochs = tqdm(range(nb_epochs), desc=f"Epoch ... (1/{nb_epochs})", position=0) for epoch in epochs: # ======================== Training ================================ # Create sampling rng
def main(): # region Argument parsing # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser( (ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file( json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses( ) # endregion # region Logging # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Setup logging, we only want one process per machine to log things on the screen. logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) if jax.process_index() == 0: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_info() else: datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() # endregion # Handle the repository creation if training_args.push_to_hub: if training_args.hub_model_id is None: repo_name = get_full_repo_name(Path( training_args.output_dir).absolute().name, token=training_args.hub_token) else: repo_name = training_args.hub_model_id repo = Repository(training_args.output_dir, clone_from=repo_name) # region Load Data # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # (the dataset will be downloaded automatically from the datasets Hub). # # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called # 'text' is found. You can easily tweak this behavior (see below). # # In distributed training, the load_dataset function guarantee that only one local process can concurrently # download the dataset. if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. raw_datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir) else: # Loading the dataset from local csv or json file. data_files = {} if data_args.train_file is not None: data_files["train"] = data_args.train_file extension = data_args.train_file.split(".")[-1] if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file extension = data_args.validation_file.split(".")[-1] if data_args.test_file is not None: data_files["test"] = data_args.test_file extension = data_args.test_file.split(".")[-1] raw_datasets = load_dataset(extension, data_files=data_files, field="data", cache_dir=model_args.cache_dir) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. # endregion # region Load pretrained model and tokenizer # # Load pretrained model and tokenizer config = AutoConfig.from_pretrained( model_args.config_name if model_args.config_name else model_args.model_name_or_path, cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, ) tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=True, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, ) # endregion # region Tokenizer check: this script requires a fast tokenizer. if not isinstance(tokenizer, PreTrainedTokenizerFast): raise ValueError( "This example script only works for models that have a fast tokenizer. Checkout the big table of models " "at https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet this " "requirement") # endregion # region Preprocessing the datasets # Preprocessing is slightly different for training and evaluation. if training_args.do_train: column_names = raw_datasets["train"].column_names elif training_args.do_eval: column_names = raw_datasets["validation"].column_names else: column_names = raw_datasets["test"].column_names question_column_name = "question" if "question" in column_names else column_names[ 0] context_column_name = "context" if "context" in column_names else column_names[ 1] answer_column_name = "answers" if "answers" in column_names else column_names[ 2] # Padding side determines if we do (question|context) or (context|question). pad_on_right = tokenizer.padding_side == "right" if data_args.max_seq_length > tokenizer.model_max_length: logger.warning( f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." ) max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) # Training preprocessing def prepare_train_features(examples): # Some of the questions have lots of whitespace on the left, which is not useful and will make the # truncation of the context fail (the tokenized question will take a lots of space). So we remove that # left whitespace examples[question_column_name] = [ q.lstrip() for q in examples[question_column_name] ] # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results # in one example possible giving several features when a context is long, each of those features having a # context that overlaps a bit the context of the previous feature. tokenized_examples = tokenizer( examples[ question_column_name if pad_on_right else context_column_name], examples[ context_column_name if pad_on_right else question_column_name], truncation="only_second" if pad_on_right else "only_first", max_length=max_seq_length, stride=data_args.doc_stride, return_overflowing_tokens=True, return_offsets_mapping=True, padding="max_length", ) # Since one example might give us several features if it has a long context, we need a map from a feature to # its corresponding example. This key gives us just that. sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") # The offset mappings will give us a map from token to character position in the original context. This will # help us compute the start_positions and end_positions. offset_mapping = tokenized_examples.pop("offset_mapping") # Let's label those examples! tokenized_examples["start_positions"] = [] tokenized_examples["end_positions"] = [] for i, offsets in enumerate(offset_mapping): # We will label impossible answers with the index of the CLS token. input_ids = tokenized_examples["input_ids"][i] cls_index = input_ids.index(tokenizer.cls_token_id) # Grab the sequence corresponding to that example (to know what is the context and what is the question). sequence_ids = tokenized_examples.sequence_ids(i) # One example can give several spans, this is the index of the example containing this span of text. sample_index = sample_mapping[i] answers = examples[answer_column_name][sample_index] # If no answers are given, set the cls_index as answer. if len(answers["answer_start"]) == 0: tokenized_examples["start_positions"].append(cls_index) tokenized_examples["end_positions"].append(cls_index) else: # Start/end character index of the answer in the text. start_char = answers["answer_start"][0] end_char = start_char + len(answers["text"][0]) # Start token index of the current span in the text. token_start_index = 0 while sequence_ids[token_start_index] != (1 if pad_on_right else 0): token_start_index += 1 # End token index of the current span in the text. token_end_index = len(input_ids) - 1 while sequence_ids[token_end_index] != (1 if pad_on_right else 0): token_end_index -= 1 # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index). if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char): tokenized_examples["start_positions"].append(cls_index) tokenized_examples["end_positions"].append(cls_index) else: # Otherwise move the token_start_index and token_end_index to the two ends of the answer. # Note: we could go after the last offset if the answer is the last word (edge case). while token_start_index < len(offsets) and offsets[ token_start_index][0] <= start_char: token_start_index += 1 tokenized_examples["start_positions"].append( token_start_index - 1) while offsets[token_end_index][1] >= end_char: token_end_index -= 1 tokenized_examples["end_positions"].append( token_end_index + 1) return tokenized_examples processed_raw_datasets = dict() if training_args.do_train: if "train" not in raw_datasets: raise ValueError("--do_train requires a train dataset") train_dataset = raw_datasets["train"] if data_args.max_train_samples is not None: # We will select sample from whole data if agument is specified train_dataset = train_dataset.select( range(data_args.max_train_samples)) # Create train feature from dataset train_dataset = train_dataset.map( prepare_train_features, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, ) if data_args.max_train_samples is not None: # Number of samples might increase during Feature Creation, We select only specified max samples train_dataset = train_dataset.select( range(data_args.max_train_samples)) processed_raw_datasets["train"] = train_dataset # Validation preprocessing def prepare_validation_features(examples): # Some of the questions have lots of whitespace on the left, which is not useful and will make the # truncation of the context fail (the tokenized question will take a lots of space). So we remove that # left whitespace examples[question_column_name] = [ q.lstrip() for q in examples[question_column_name] ] # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results # in one example possible giving several features when a context is long, each of those features having a # context that overlaps a bit the context of the previous feature. tokenized_examples = tokenizer( examples[ question_column_name if pad_on_right else context_column_name], examples[ context_column_name if pad_on_right else question_column_name], truncation="only_second" if pad_on_right else "only_first", max_length=max_seq_length, stride=data_args.doc_stride, return_overflowing_tokens=True, return_offsets_mapping=True, padding="max_length", ) # Since one example might give us several features if it has a long context, we need a map from a feature to # its corresponding example. This key gives us just that. sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the # corresponding example_id and we will store the offset mappings. tokenized_examples["example_id"] = [] for i in range(len(tokenized_examples["input_ids"])): # Grab the sequence corresponding to that example (to know what is the context and what is the question). sequence_ids = tokenized_examples.sequence_ids(i) context_index = 1 if pad_on_right else 0 # One example can give several spans, this is the index of the example containing this span of text. sample_index = sample_mapping[i] tokenized_examples["example_id"].append( examples["id"][sample_index]) # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token # position is part of the context or not. tokenized_examples["offset_mapping"][i] = [ (o if sequence_ids[k] == context_index else None) for k, o in enumerate(tokenized_examples["offset_mapping"][i]) ] return tokenized_examples if training_args.do_eval: if "validation" not in raw_datasets: raise ValueError("--do_eval requires a validation dataset") eval_examples = raw_datasets["validation"] if data_args.max_eval_samples is not None: # We will select sample from whole data eval_examples = eval_examples.select( range(data_args.max_eval_samples)) # Validation Feature Creation eval_dataset = eval_examples.map( prepare_validation_features, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, ) if data_args.max_eval_samples is not None: # During Feature creation dataset samples might increase, we will select required samples again eval_dataset = eval_dataset.select( range(data_args.max_eval_samples)) processed_raw_datasets["validation"] = eval_dataset if training_args.do_predict: if "test" not in raw_datasets: raise ValueError("--do_predict requires a test dataset") predict_examples = raw_datasets["test"] if data_args.max_predict_samples is not None: # We will select sample from whole data predict_examples = predict_examples.select( range(data_args.max_predict_samples)) # Predict Feature Creation predict_dataset = predict_examples.map( prepare_validation_features, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, ) if data_args.max_predict_samples is not None: # During Feature creation dataset samples might increase, we will select required samples again predict_dataset = predict_dataset.select( range(data_args.max_predict_samples)) processed_raw_datasets["test"] = predict_dataset # endregion # region Metrics and Post-processing: def post_processing_function(examples, features, predictions, stage="eval"): # Post-processing: we match the start logits and end logits to answers in the original context. predictions = postprocess_qa_predictions( examples=examples, features=features, predictions=predictions, version_2_with_negative=data_args.version_2_with_negative, n_best_size=data_args.n_best_size, max_answer_length=data_args.max_answer_length, null_score_diff_threshold=data_args.null_score_diff_threshold, output_dir=training_args.output_dir, prefix=stage, ) # Format the result to the format the metric expects. if data_args.version_2_with_negative: formatted_predictions = [{ "id": k, "prediction_text": v, "no_answer_probability": 0.0 } for k, v in predictions.items()] else: formatted_predictions = [{ "id": k, "prediction_text": v } for k, v in predictions.items()] references = [{ "id": ex["id"], "answers": ex[answer_column_name] } for ex in examples] return EvalPrediction(predictions=formatted_predictions, label_ids=references) metric = load_metric( "squad_v2" if data_args.version_2_with_negative else "squad") def compute_metrics(p: EvalPrediction): return metric.compute(predictions=p.predictions, references=p.label_ids) # Create and fill numpy array of size len_of_validation_data * max_length_of_output_tensor def create_and_fill_np_array(start_or_end_logits, dataset, max_len): """ Create and fill numpy array of size len_of_validation_data * max_length_of_output_tensor Args: start_or_end_logits(:obj:`tensor`): This is the output predictions of the model. We can only enter either start or end logits. eval_dataset: Evaluation dataset max_len(:obj:`int`): The maximum length of the output tensor. ( See the model.eval() part for more details ) """ step = 0 # create a numpy array and fill it with -100. logits_concat = np.full((len(dataset), max_len), -100, dtype=np.float64) # Now since we have create an array now we will populate it with the outputs of the model. for i, output_logit in enumerate( start_or_end_logits): # populate columns # We have to fill it such that we have to take the whole tensor and replace it on the newly created array # And after every iteration we have to change the step batch_size = output_logit.shape[0] cols = output_logit.shape[1] if step + batch_size < len(dataset): logits_concat[step:step + batch_size, :cols] = output_logit else: logits_concat[step:, :cols] = output_logit[:len(dataset) - step] step += batch_size return logits_concat # endregion # region Training steps and logging init train_dataset = processed_raw_datasets["train"] eval_dataset = processed_raw_datasets["validation"] # Log a few random samples from the training set: for index in random.sample(range(len(train_dataset)), 3): logger.info( f"Sample {index} of the training set: {train_dataset[index]}.") # Define a summary writer summary_writer = tensorboard.SummaryWriter(training_args.output_dir) summary_writer.hparams({ **training_args.to_dict(), **vars(model_args), **vars(data_args) }) def write_train_metric(summary_writer, train_metrics, train_time, step): summary_writer.scalar("train_time", train_time, step) train_metrics = get_metrics(train_metrics) for key, vals in train_metrics.items(): tag = f"train_{key}" for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) def write_eval_metric(summary_writer, eval_metrics, step): for metric_name, value in eval_metrics.items(): summary_writer.scalar(f"eval_{metric_name}", value, step) num_epochs = int(training_args.num_train_epochs) rng = jax.random.PRNGKey(training_args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) train_batch_size = training_args.per_device_train_batch_size * jax.local_device_count( ) eval_batch_size = training_args.per_device_eval_batch_size * jax.local_device_count( ) # endregion # region Load model model = FlaxAutoModelForQuestionAnswering.from_pretrained( model_args.model_name_or_path, config=config, cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), ) learning_rate_fn = create_learning_rate_fn( len(train_dataset), train_batch_size, training_args.num_train_epochs, training_args.warmup_steps, training_args.learning_rate, ) state = create_train_state(model, learning_rate_fn, num_labels=max_seq_length, training_args=training_args) # endregion # region Define train step functions def train_step( state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey) -> Tuple[train_state.TrainState, float]: """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`.""" dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) start_positions = batch.pop("start_positions") end_positions = batch.pop("end_positions") targets = (start_positions, end_positions) def loss_fn(params): logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True) loss = state.loss_fn(logits, targets) return loss grad_fn = jax.value_and_grad(loss_fn) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) metrics = jax.lax.pmean( { "loss": loss, "learning_rate": learning_rate_fn(state.step) }, axis_name="batch") return new_state, metrics, new_dropout_rng p_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0, )) # endregion # region Define eval step functions def eval_step(state, batch): logits = state.apply_fn(**batch, params=state.params, train=False) return state.logits_fn(logits) p_eval_step = jax.pmap(eval_step, axis_name="batch") # endregion # region Define train and eval loop logger.info(f"===== Starting training ({num_epochs} epochs) =====") train_time = 0 # make sure weights are replicated on each device state = replicate(state) train_time = 0 step_per_epoch = len(train_dataset) // train_batch_size total_steps = step_per_epoch * num_epochs epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) for epoch in epochs: train_start = time.time() train_metrics = [] # Create sampling rng rng, input_rng = jax.random.split(rng) # train for step, batch in enumerate( tqdm( train_data_collator(input_rng, train_dataset, train_batch_size), total=step_per_epoch, desc="Training...", position=1, ), 1, ): state, train_metric, dropout_rngs = p_train_step( state, batch, dropout_rngs) train_metrics.append(train_metric) cur_step = epoch * step_per_epoch + step if cur_step % training_args.logging_steps == 0 and cur_step > 0: # Save metrics train_metric = unreplicate(train_metric) train_time += time.time() - train_start if jax.process_index() == 0: write_train_metric(summary_writer, train_metrics, train_time, cur_step) epochs.write( f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})" ) train_metrics = [] if (training_args.do_eval and (cur_step % training_args.eval_steps == 0 or cur_step % step_per_epoch == 0) and cur_step > 0): eval_metrics = {} all_start_logits = [] all_end_logits = [] # evaluate for batch in tqdm( eval_data_collator(eval_dataset, eval_batch_size), total=len(eval_dataset) // eval_batch_size, desc="Evaluating ...", position=2, ): _ = batch.pop("example_id") _ = batch.pop("offset_mapping") predictions = p_eval_step(state, batch) start_logits = np.array( [pred for pred in chain(*predictions[0])]) end_logits = np.array( [pred for pred in chain(*predictions[1])]) all_start_logits.append(start_logits) all_end_logits.append(end_logits) # evaluate also on leftover examples (not divisible by batch_size) num_leftover_samples = len(eval_dataset) % eval_batch_size # make sure leftover batch is evaluated on one device if num_leftover_samples > 0 and jax.process_index() == 0: # take leftover samples batch = eval_dataset[-num_leftover_samples:] batch = {k: np.array(v) for k, v in batch.items()} _ = batch.pop("example_id") _ = batch.pop("offset_mapping") predictions = eval_step(unreplicate(state), batch) start_logits = np.array([pred for pred in predictions[0]]) end_logits = np.array([pred for pred in predictions[1]]) all_start_logits.append(start_logits) all_end_logits.append(end_logits) max_len = max([x.shape[1] for x in all_start_logits ]) # Get the max_length of the tensor # concatenate the numpy array start_logits_concat = create_and_fill_np_array( all_start_logits, eval_dataset, max_len) end_logits_concat = create_and_fill_np_array( all_end_logits, eval_dataset, max_len) # delete the list of numpy arrays del all_start_logits del all_end_logits outputs_numpy = (start_logits_concat, end_logits_concat) prediction = post_processing_function(eval_examples, eval_dataset, outputs_numpy) eval_metrics = compute_metrics(prediction) logger.info( f"Step... ({cur_step}/{total_steps} | Evaluation metrics: {eval_metrics})" ) if jax.process_index() == 0: write_eval_metric(summary_writer, eval_metrics, cur_step) if (cur_step % training_args.save_steps == 0 and cur_step > 0) or (cur_step == total_steps): # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(unreplicate(state.params)) model.save_pretrained(training_args.output_dir, params=params) tokenizer.save_pretrained(training_args.output_dir) if training_args.push_to_hub: repo.push_to_hub( commit_message= f"Saving weights and logs of step {cur_step}", blocking=False) epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}"
def device_broadcast(x, num_devices): """Broadcast a value to all devices.""" return jax.pmap(lambda _: x)(jnp.arange(num_devices))
def train_model(): # Initialize training directory dirname, tf_writer = get_dirname_tfwriter(args) # Initialize data, model, losses and metrics (train_set, test_set, net_apply, params, net_state, key, log_likelihood_fn, log_prior_fn, _, predict_fn, ensemble_upd_fn, metrics_fns, tabulate_metrics) = script_utils.get_data_model_fns(args) # Initialize step-size schedule and optimizer num_batches, total_steps = script_utils.get_num_batches_total_steps( args, train_set) num_devices = len(jax.devices()) lr_schedule = get_lr_schedule(num_batches, args) preconditioner = get_preconditioner(args) optimizer = sgmcmc.sgld_gradient_update( lr_schedule, momentum_decay=args.momentum_decay, seed=args.seed, preconditioner=preconditioner) # Initialize variables opt_state = optimizer.init(params) net_state = jax.pmap(lambda _: net_state)(jnp.arange(num_devices)) key = jax.random.split(key, num_devices) init_dict = checkpoint_utils.make_sgmcmc_checkpoint_dict( -1, params, net_state, opt_state, key, 0, None, None) init_dict = script_utils.get_initialization_dict(dirname, args, init_dict) (start_iteration, params, net_state, opt_state, key, num_ensembled, _, ensemble_predictions) = ( checkpoint_utils.parse_sgmcmc_checkpoint_dict(init_dict)) start_iteration += 1 # Define train epoch sgmcmc_train_epoch = script_utils.time_fn( train_utils.make_sgd_train_epoch(net_apply, log_likelihood_fn, log_prior_fn, optimizer, num_batches)) # Train for iteration in range(start_iteration, args.num_epochs): (params, net_state, opt_state, logprob_avg, key), iteration_time = ( sgmcmc_train_epoch(params, net_state, opt_state, train_set, key)) is_evaluation_epoch, is_ensembling_epoch, is_save_epoch = ( is_eval_ens_save_epoch(iteration, args)) # Evaluate the model train_stats, test_stats = {"log_prob": logprob_avg}, {} if is_evaluation_epoch or is_ensembling_epoch: _, test_predictions, train_predictions, test_stats, train_stats_ = ( script_utils.evaluate(net_apply, params, net_state, train_set, test_set, predict_fn, metrics_fns, log_prior_fn)) train_stats.update(train_stats_) # Ensemble predictions if is_ensembling_epoch: ensemble_predictions = ensemble_upd_fn(ensemble_predictions, num_ensembled, test_predictions) ensemble_stats = train_utils.evaluate_metrics(ensemble_predictions, test_set[1], metrics_fns) num_ensembled += 1 else: ensemble_stats = {} test_predictions = None # Save checkpoint if is_save_epoch: checkpoint_name = checkpoint_utils.make_checkpoint_name(iteration) checkpoint_path = os.path.join(dirname, checkpoint_name) checkpoint_dict = checkpoint_utils.make_sgmcmc_checkpoint_dict( iteration, params, net_state, opt_state, key, num_ensembled, test_predictions, ensemble_predictions) checkpoint_utils.save_checkpoint(checkpoint_path, checkpoint_dict) # Log results other_logs = script_utils.get_common_logs(iteration, iteration_time, args) other_logs["hypers/step_size"] = lr_schedule(opt_state.count) other_logs["hypers/momentum"] = args.momentum_decay other_logs["telemetry/num_ensembled"] = num_ensembled logging_dict = logging_utils.make_logging_dict(train_stats, test_stats, ensemble_stats) logging_dict.update(other_logs) script_utils.write_to_tensorboard(tf_writer, logging_dict, iteration) tabulate_dict = script_utils.get_tabulate_dict(tabulate_metrics, logging_dict) tabulate_dict["lr"] = lr_schedule(opt_state.count) table = logging_utils.make_table(tabulate_dict, iteration - start_iteration, args.tabulate_freq) print(table)