def save_unreplicated_checkpoint_background(train_dir, optimizer_state, params, batch_stats, training_metrics_state, global_step, preemption_count, sum_train_cost, max_to_keep=1): """Saves pytree, step, preemption_count, and sum_train_cost to train_dir.""" logging.info('Saving checkpoint to ckpt_%d', global_step) unreplicated_optimizer_state = jax.device_get( jax_utils.unreplicate(optimizer_state)) unreplicated_params = jax.device_get(jax_utils.unreplicate(params)) unreplicated_batch_stats = jax.device_get( jax_utils.unreplicate(batch_stats)) unreplicated_training_metrics_state = jax.device_get( jax_utils.unreplicate(training_metrics_state)) state = dict(global_step=global_step, preemption_count=preemption_count, sum_train_cost=sum_train_cost, optimizer_state=unreplicated_optimizer_state, params=unreplicated_params, batch_stats=unreplicated_batch_stats, training_metrics_grabber=unreplicated_training_metrics_state) save_checkpoint_background(train_dir, global_step, state, max_to_keep=max_to_keep) logging.info('Done saving checkpoint.')
def write_summaries(self, example, logits, summary_writer, info, step, state): example = jax_utils.unreplicate(example) outputs = jnp.argmax(jax_utils.unreplicate(logits), axis=-1) text = summary_utils.human_readable_texts(example, outputs, info) summary_writer.text('predictions', '<pre>{}</pre>'.format(text), step) self.generate_plots(state, summary_writer, step)
def maybe_copy_model_from_pretraining(optimizer, pretrain_optimizer, step, adam_opt_def): """Copy model parameters from pretraining.""" if step < FLAGS.num_pretrain_steps: optimizer = jax_utils.unreplicate(optimizer) state_dict = adam_opt_def.state_dict( target=jax_utils.unreplicate(pretrain_optimizer).target, state=optim.OptimizerState(jnp.asarray(step, dtype=jnp.int32), optimizer.state.param_states)) optimizer = optimizer.restore_state(state_dict) optimizer = jax_utils.replicate(optimizer) return optimizer
def load_checkpoint(self, ckpt_dir): """Loads optimizer from ckpt_dir.""" target = jax_utils.unreplicate(self._optimizer) optimizer = checkpoints.restore_checkpoint(ckpt_dir, target=target) if optimizer is target: raise ValueError('Unable to load checkpoint from %s' % ckpt_dir) self.set_weights(optimizer)
def test_train_one_step(self): """Tests training loop over one step.""" iterator = self._dataset.get_train() batch = next(iterator) state = jax_utils.replicate(self._state) optimizer = jax_utils.replicate(self._optimizer.create(self._model)) self._rng, step_key = jax.random.split(self._rng) batch = training._shard_batch(batch) sharded_keys = common_utils.shard_prng_key(step_key) p_train_step = jax.pmap(functools.partial( training.train_step, learning_rate_fn=self._learning_rate_fn), axis_name='batch') _, _, loss, gradient_norm = p_train_step(optimizer, batch, sharded_keys, state) loss = jnp.mean(loss) gradient_norm = jax_utils.unreplicate(gradient_norm) with self.subTest(name='test_loss_range'): self.assertBetween(loss, self._min_loss, self._max_loss) with self.subTest(name='test_gradient_norm'): self.assertGreaterEqual(gradient_norm, 0)
def maybe_log_training_metrics(metrics_state, metrics_summary_fn, metrics_logger): """If appropriate, send a summary tree of training metrics to the logger.""" if metrics_state: unreplicated_metrics_state = jax_utils.unreplicate(metrics_state) summary_tree = metrics_summary_fn(unreplicated_metrics_state) metrics_logger.append_pytree(summary_tree)
def write_images(H, ema_params, viz_batch_original, viz_batch_processed, fname, logprint): rng = random.PRNGKey(H.seed_sample) ema_apply = partial( VAE(H).apply, {'params': jax_utils.unreplicate(ema_params)}) forward_get_latents = partial(ema_apply, method=VAE(H).forward_get_latents) forward_samples_set_latents = partial( ema_apply, method=VAE(H).forward_samples_set_latents) forward_uncond_samples = partial(ema_apply, method=VAE(H).forward_uncond_samples) zs = [s['z'] for s in forward_get_latents(viz_batch_processed, rng)] batches = [viz_batch_original.numpy()] mb = viz_batch_processed.shape[0] lv_points = np.floor( np.linspace(0, 1, H.num_variables_visualize + 2) * len(zs)).astype(int)[1:-1] for i in lv_points: batches.append(forward_samples_set_latents(mb, zs[:i], rng, t=0.1)) for t in [1.0, 0.9, 0.8, 0.7][:H.num_temperatures_visualize]: batches.append(forward_uncond_samples(mb, rng, t=t)) n_rows = len(batches) im = np.concatenate(batches, axis=0).reshape( (n_rows, mb, *viz_batch_processed.shape[1:])).transpose([0, 2, 1, 3, 4]).reshape([ n_rows * viz_batch_processed.shape[1], mb * viz_batch_processed.shape[2], 3 ]) logprint(f'printing samples to {fname}') Image.fromarray(im).save(fname)
def save_model(path, optimizer, ema, epoch, H): optimizer, ema = jax_utils.unreplicate((optimizer, ema)) checkpoints.save_checkpoint(path, (optimizer, epoch), optimizer.state.step) checkpoints.save_checkpoint(path + '_ema', ema, optimizer.state.step) from_log = os.path.join(H.save_dir, 'log.jsonl') to_log = f'{os.path.dirname(path)}/{os.path.basename(path)}-log.jsonl' subprocess.check_output(['cp', from_log, to_log])
def run_eval(self, flax_module, batch_stats, optimizer_state, global_step): """Computes the loss hessian and returns the max eigenvalue. Note, the full lanczos tridiagonal matrix is saved via the logger to train_dir/checkpoints/config['name']. Args: flax_module: Replicated flax module. batch_stats: Replicated batch_stats from the trainer. optimizer_state: Replicated optimizer state from the trainer. global_step: Current training step. Returns: Max eigenvalue of the loss (full tridiag is saved to disk). """ del batch_stats if self.callback_config.get('precondition'): precondition_config = self.callback_config.get( 'precondition_config', default=FrozenConfigDict()) diag_preconditioner = precondition.make_diag_preconditioner( self.hps.optimizer, self.hps.opt_hparams, jax_utils.unreplicate(optimizer_state), precondition_config) else: diag_preconditioner = None hessian_metrics, _, _ = self.hessian_evaluator.evaluate_spectrum( flax_module, global_step, diag_preconditioner=diag_preconditioner) if jax.host_id() == 0: self.logger.append_pytree(hessian_metrics) max_eig_key = self.name + '/max_eig' return {max_eig_key: hessian_metrics['max_eig_hess']}
def save_checkpoint(self, ckpt_dir): """Saves unreplicated optimizer to ckpt_dir.""" optimizer = jax_utils.unreplicate(self._optimizer) checkpoints.save_checkpoint( ckpt_dir, target=optimizer, step=self._train_step, )
def evaluate(self, state, dataset): dataloader = get_batched_dataset(dataset, self.args.batch_size) total = len(dataset) // self.args.batch_size running_loss = jnp.array(0, dtype=jnp.float32) i = 0 for batch in tqdm(dataloader, total=total, desc="Evaluating ... "): batch = self.data_collator(batch) metrics = self.val_step_fn(state, **batch) running_loss += jax_utils.unreplicate(metrics["loss"]) i += 1 return running_loss / i
def save_checkpoint(self, save_dir, state): state = jax_utils.unreplicate(state) print(f"SAVING CHECKPOINT IN {save_dir}", end=" ... ") self.model_save_fn(save_dir, params=state.params) with open(os.path.join(save_dir, "opt_state.msgpack"), "wb") as f: f.write(to_bytes(state.opt_state)) joblib.dump(self.args, os.path.join(save_dir, "args.joblib")) joblib.dump(self.data_collator, os.path.join(save_dir, "data_collator.joblib")) with open(os.path.join(save_dir, "training_state.json"), "w") as f: json.dump({"step": state.step.item()}, f) print("DONE")
def update_params( workload: spec.Workload, current_param_container: spec.ParameterContainer, current_params_types: spec.ParameterTypeTree, model_state: spec.ModelAuxiliaryState, hyperparameters: spec.Hyperparamters, input_batch: spec.Tensor, label_batch: spec.Tensor, # This will define the output activation via `output_activation_fn`. loss_type: spec.LossType, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type del eval_results del global_step num_devices = jax.local_device_count() input_shape = input_batch.shape reshaped_input_batch = jnp.reshape( input_batch, (num_devices, input_shape[0] // num_devices, *input_shape[1:])) reshaped_label_batch = jnp.reshape(label_batch, (num_devices, label_batch.shape[0] // num_devices, *label_batch.shape[1:])) # TODO(znado) we should be more efficient than replicating state each step. new_optimizer_state, updated_params, new_model_state = pmapped_update_params( workload, jax_utils.replicate(current_param_container), jax_utils.replicate(model_state), hyperparameters, reshaped_input_batch, reshaped_label_batch, jax_utils.replicate(optimizer_state), rng, jnp.arange(num_devices)) return (jax_utils.unreplicate(new_optimizer_state), jax_utils.unreplicate(updated_params), jax_utils.unreplicate(new_model_state))
def maybe_restore_checkpoint(params, train_dir, external_checkpoint_path): """Helper function to replicate_and_maybe_restore a checkpoint.""" (_, ret_params, _, _, ret_global_step, ret_sum_train_cost, ret_preemption_count, ret_is_restored ) = checkpoint.replicate_and_maybe_restore_checkpoint( {}, params, {}, {}, train_dir, external_checkpoint_path) ret_params_unrep = jax.device_get( jax_utils.unreplicate(ret_params)) return (ret_params_unrep, ret_global_step, ret_sum_train_cost, ret_preemption_count, ret_is_restored)
def evaluate(model, state, eval_ds, num_eval_steps=-1): """Evaluate the model on the given dataset.""" logging.info("Starting evaluation.") eval_metrics = None with StepTraceContextHelper("eval", 0) as trace_context: for step, batch in enumerate(eval_ds): # pytype: disable=wrong-arg-types batch = jax.tree_map(np.asarray, batch) metrics_update = flax_utils.unreplicate( eval_step(model, state, batch)) eval_metrics = (metrics_update if eval_metrics is None else eval_metrics.merge(metrics_update)) if num_eval_steps > 0 and step + 1 == num_eval_steps: break trace_context.next_step() return eval_metrics
def train(self, state, tr_dataset, val_dataset): args = self.args total = len(tr_dataset) // args.batch_size rng = jax.random.PRNGKey(0) drp_rng = jax.random.split(rng, jax.device_count()) for epoch in range(args.max_epochs): running_loss = jnp.array(0, dtype=jnp.float32) tr_dataloader = get_batched_dataset(tr_dataset, args.batch_size, seed=epoch) i = 0 for batch in tqdm(tr_dataloader, total=total, desc=f"Running EPOCH-{epoch}"): batch = self.data_collator(batch) state, metrics, drp_rng = self.train_step_fn( state, drp_rng, **batch) running_loss += jax_utils.unreplicate(metrics["loss"]) i += 1 if i % args.logging_steps == 0: state_step = jax_utils.unreplicate(state.step) tr_loss = running_loss.item() / i lr = self.scheduler_fn(state_step - 1) eval_loss = self.evaluate(state, val_dataset) logging_dict = dict(step=state_step.item(), eval_loss=eval_loss.item(), tr_loss=tr_loss, lr=lr.item()) tqdm.write(str(logging_dict)) self.logger.log(logging_dict, commit=True) if i % args.save_steps == 0: self.save_checkpoint(args.save_dir + f"-e{epoch}-s{i}", state=state)
def maybe_reset_train_state(self): optimizer = jax_utils.unreplicate(self.train_state.optimizer) if self.hparams.get('reinitilize_params_at_each_step', False): del optimizer.target (flax_model, _, _) = pipeline_utils.create_flax_module( optimizer.target.module, self.task.dataset.meta_data['input_shape'], self.hparams, nn.make_rng(), self.task.dataset.meta_data.get('input_dtype', jnp.float32)) else: flax_model = optimizer.target # Reset optimizer if self.hparams.get('reinitialize_optimizer_at_each_step', False): self.lr_start_step = jax_utils.unreplicate( self.train_state.global_step) optimizer = optimizers.get_optimizer( self.hparams).create(flax_model) else: optimizer = optimizer.replace(target=flax_model) optimizer = jax_utils.replicate(optimizer) self.train_state = self.train_state.replace(optimizer=optimizer)
def evaluate(self, workdir, dir_name='eval', ckpt_name=None): """Perform one evaluation.""" checkpoint_dir = os.path.join(workdir, 'checkpoints-0') ckpt = checkpoint.Checkpoint(checkpoint_dir) state_dict = ckpt.restore_dict(os.path.join(checkpoint_dir, ckpt_name)) ema_params = flax.core.FrozenDict(state_dict['ema_params']) step = int(state_dict['step']) # Distribute training. ema_params = flax_utils.replicate(ema_params) eval_logdir = os.path.join(workdir, dir_name) tf.io.gfile.makedirs(eval_logdir) writer = metric_writers.create_default_writer( eval_logdir, just_logging=jax.process_index() > 0) outputs = self._eval_epoch(params=ema_params) outputs = flax_utils.unreplicate(outputs) scalars, images = outputs['scalars'], outputs['images'] writer.write_scalars(step, scalars) writer.write_images(step, images)
def save_checkpoint(experiment_dir, train_state, keep=3): """Saves a checkpoint. First syncs the model state across replicas, then it unreplicates it by taking the train state of the first replica and saves it as a checkpoint. Args: experiment_dir: str; Experiment directory for saving the checkpoint. train_state: Dataclass; An instance of TrainState that holds the state of training. keep: int; Number of checkpoints to keep. """ if jax.host_id() == 0: # get train state from the first replica checkpoint_state = jax.device_get(jax_utils.unreplicate(train_state)) ckpt_path = checkpoint_path(experiment_dir, int(checkpoint_state.global_step)) if not tf.io.gfile.exists(ckpt_path): checkpoints.save_checkpoint(experiment_dir, checkpoint_state, int(checkpoint_state.global_step), keep=keep)
def save_checkpoint( model, save_dir, state, cur_step: int, with_opt: bool = True, push_to_hub: bool = False, ): state = jax_utils.unreplicate(state) if with_opt: logger.info(f"Saving optimizer and training state in {save_dir}...") with open(os.path.join(save_dir, "opt_state.msgpack"), "wb") as f: f.write(to_bytes(state.opt_state)) with open(os.path.join(save_dir, "training_state.json"), "w") as f: json.dump({"step": state.step.item()}, f) logger.info( f'Saving model in {save_dir} {"and pushing it to HF Hub" if push_to_hub else ""}' ) model.save_pretrained( save_dir, params=state.params, push_to_hub=push_to_hub, commit_message=f"Saving weights and logs of step {cur_step}", )
def get_env_aligned_pairs_idx(self, env_reps, env_batches, env_ids): """Computes alignments between all environment pairs. Args: env_reps: jnp array; Reps for different environments (sharded). env_batches: list of dict; Batches of different environments (sharded). env_ids: jnp array; Environment ids. Returns: alignment between batches of environment pairs (sharded). """ # TODO(riannevdberg, samiraabnar): aligning is done on the total # unsharded batch, but that requires access between local batches # when computing the loss. Unsure why this works! To be compatible # with random alignment and sinkhorn soft alignment we should do # alignment only within local batches. env_reps = shard_util.unshard_env_batch(env_reps) env_batches = shard_util.unshard(env_batches) with nn.stochastic(jax_utils.unreplicate(self.train_state.rng)): alignments = self.task.get_env_aligned_pairs_idx(env_reps, env_batches, env_ids) alignments = dataset_utils.shard(alignments) return alignments
def run_train(run_configuration): """Runs the training workflow.""" config = run_configuration.config run_dir = run_configuration.run_dir adapter = run_configuration.adapter log_dir = os.path.join(run_dir, 'train') checkpoint_path = run_configuration.original_checkpoint_path dataset = run_configuration.dataset_info.dataset info = run_configuration.dataset_info.info random_seed = 0 rng = jax.random.PRNGKey(random_seed) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = jax.random.split(rng) dropout_rngs = jax.random.split(rng, jax.local_device_count()) # Set up optimizer. optimizer = adapter.create_optimizer(run_configuration, rng=init_rng) # Set up train step. train_step = adapter.make_train_step() if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(log_dir) # Set up checkpointing. # TODO(dbieber): Set up phoenix. checkpoint_dir = checkpoint_utils.build_checkpoint_dir(run_dir) if checkpoint_path is None: checkpoint_path = checkpoint_utils.latest_checkpoint(checkpoint_dir) optimizer = checkpoint_utils.handle_restart_behavior( checkpoint_path, optimizer, config) start_step = int(optimizer.state.step) num_train_steps = config.train.total_steps # Replicate optimizer. optimizer = flax.jax_utils.replicate(optimizer) # Begin training loop. dataset_iter_raw = iter(dataset) dataset_iter = adapter.preprocess(dataset_iter_raw) summary_freq = config.logging.summary_freq metrics_all = [] tick = time.time() for step, example in zip(range(start_step, num_train_steps), dataset_iter): train_inputs = adapter.get_train_inputs(example) optimizer, metrics, dropout_rngs, logits, state = train_step( optimizer, train_inputs, dropout_rngs) metrics_all.append(metrics) # Save a Checkpoint if ((step % config.logging.save_freq == 0 and step > 0) or step == num_train_steps - 1): if jax.host_id() == 0 and config.logging.save_freq: # Save unreplicated optimizer + model state. checkpoint_utils.save_checkpoint( checkpoint_dir, jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if summary_freq and step % summary_freq == 0 and step > 0: metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) logging.info('train step: %d, loss: %.4f', step, summary['loss']) if jax.host_id() == 0: tock = time.time() steps_per_sec = summary_freq / (tock - tick) examples_per_sec = denominator / (tock - tick) tick = tock summary_writer.scalar('per-second/steps', steps_per_sec, step) summary_writer.scalar('per-second/examples', examples_per_sec, step) for key, val in summary.items(): summary_writer.scalar(key, val, step) adapter.write_summaries(example, logits, summary_writer, info, step, state) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = []
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') if FLAGS.jax_backend_target: jax.config.FLAGS.jax_xla_backend = "tpu_driver" jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target # This seems to be necessary even when importing TF2? tf.enable_v2_behavior() # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: train_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'train')) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'eval')) if FLAGS.batch_size % n_devices: raise ValueError( 'Batch size must be divisible by the number of devices') vocab_path = FLAGS.vocab_path if vocab_path is None: vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model') # Load Dataset logging.info('Initializing dataset.') train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets( n_devices=n_devices, dataset_name=FLAGS.dataset_name, eval_dataset_name=FLAGS.eval_dataset_name, shard_idx=jax.host_id(), shard_count=jax.host_count(), data_dir=FLAGS.data_dir, vocab_path=vocab_path, batch_size=FLAGS.batch_size, max_length=FLAGS.max_target_length, max_eval_length=FLAGS.max_eval_target_length) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_token = 2 # Default Sentencepiece EOS token. def decode_tokens(toks): valid_toks = toks[:np.argmax(toks == eos_token) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode('utf-8') logging.info('Initializing model, optimizer, and step functions.') # Build Model and Optimizer transformer_kwargs = { 'vocab_size': vocab_size, 'output_vocab_size': vocab_size, 'emb_dim': 1024, 'num_heads': 16, 'num_layers': 6, 'qkv_dim': 1024, 'mlp_dim': 4096, 'max_len': max(FLAGS.max_target_length, FLAGS.max_eval_target_length), 'share_embeddings': FLAGS.share_embeddings, 'logits_via_embedding': FLAGS.logits_via_embedding, } start_step = 0 rng = random.PRNGKey(FLAGS.random_seed) rng, init_rng = random.split(rng) input_shape = (FLAGS.batch_size, FLAGS.max_target_length) target_shape = (FLAGS.batch_size, FLAGS.max_target_length) model, cache_def = create_model(init_rng, input_shape, target_shape, transformer_kwargs) optimizer = create_optimizer(model, FLAGS.learning_rate, FLAGS.weight_decay) # We access model only from optimizer below via optimizer.target. del model if FLAGS.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer) # Grab last step. start_step = int(optimizer.state.step) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=FLAGS.learning_rate, warmup_steps=FLAGS.warmup_steps) p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn, label_smoothing=FLAGS.label_smoothing, use_bfloat16=FLAGS.use_bfloat16), axis_name='batch') p_eval_step = jax.pmap(functools.partial( eval_step, label_smoothing=FLAGS.label_smoothing, use_bfloat16=FLAGS.use_bfloat16), axis_name='batch') p_pred_step = jax.pmap( functools.partial(predict_step, use_bfloat16=FLAGS.use_bfloat16), axis_name='batch', static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. dropout_rngs = random.split(rng, n_devices) logging.info('Starting training loop.') metrics_all = [] t_loop_start = time.time() for step, batch in zip(range(start_step, FLAGS.num_train_steps), train_iter): # Shard data to devices and do a training step. batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access optimizer, metrics, dropout_rngs = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) metrics_all.append(metrics) # Save a checkpoint on one host after every checkpoint_freq steps. if (FLAGS.save_checkpoints and step % FLAGS.checkpoint_freq == 0 and step > 0 and jax.host_id() == 0): checkpoints.save_checkpoint(FLAGS.model_dir, jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if step % FLAGS.eval_frequency != 0: continue logging.info('Gathering training metrics.') # Training Metrics metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr steps_per_eval = FLAGS.eval_frequency if step != 0 else 1 steps_per_sec = steps_per_eval / (time.time() - t_loop_start) t_loop_start = time.time() if jax.host_id() == 0: train_summary_writer.scalar('steps per second', steps_per_sec, step) for key, val in summary.items(): train_summary_writer.scalar(key, val, step) train_summary_writer.flush() metrics_all = [] logging.info('train in step: %d, loss: %.4f', step, summary['loss']) # Eval Metrics logging.info('Gathering evaluation metrics.') t_eval_start = time.time() eval_metrics = [] eval_iter = iter(eval_ds) for _, eval_batch in zip(range(FLAGS.num_eval_steps), eval_iter): eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access eval_batch = common_utils.shard(eval_batch) metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) if jax.host_id() == 0: for key, val in eval_summary.items(): eval_summary_writer.scalar(key, val, step) eval_summary_writer.flush() logging.info('eval in step: %d, loss: %.4f', step, eval_summary['loss']) logging.info('eval time: %.4f s step %d', time.time() - t_eval_start, step) # Translation and BLEU Score. logging.info('Translating evaluation dataset.') t_inference_start = time.time() predict_iter = iter(predict_ds) sources, references, predictions = [], [], [] for _, pred_batch in enumerate(predict_iter): pred_batch = jax.tree_map(lambda x: x._numpy(), pred_batch) # pylint: disable=protected-access # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch['inputs'].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) pred_batch = jax.tree_map( lambda x: pad_examples(x, padded_size), pred_batch) # pylint: disable=cell-var-from-loop pred_batch = common_utils.shard(pred_batch) per_device_batchsize = pred_batch['inputs'].shape[1] cache_dtype = jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32 cache = jax_utils.replicate( cache_def.initialize_cache( (per_device_batchsize, FLAGS.max_predict_length), dtype=cache_dtype)) predicted = p_pred_step(pred_batch['inputs'], optimizer.target, cache, eos_token, FLAGS.max_predict_length) predicted = tohost(predicted) inputs = tohost(pred_batch['inputs']) targets = tohost(pred_batch['targets']) # Iterate through non-padding examples of batch. for i, s in enumerate(predicted[:cur_pred_batch_size]): sources.append(decode_tokens(inputs[i])) references.append(decode_tokens(targets[i])) predictions.append(decode_tokens(s)) logging.info('Translation: %d predictions %d references %d sources.', len(predictions), len(references), len(sources)) logging.info('Translation time: %.4f s step %d.', time.time() - t_inference_start, step) # Calculate BLEU score for translated eval corpus against reference. bleu_matches = bleu.bleu_partial(references, predictions) all_bleu_matches = per_host_sum_pmap(bleu_matches) bleu_score = bleu.complete_bleu(*all_bleu_matches) # Save translation samples for tensorboard. exemplars = '' for n in np.random.choice(np.arange(len(predictions)), 8): exemplars += f'{sources[n]}\n\n{references[n]}\n\n{predictions[n]}\n\n' if jax.host_id() == 0: eval_summary_writer.scalar('bleu', bleu_score, step) eval_summary_writer.text('samples', exemplars, step) eval_summary_writer.flush() logging.info('Translation BLEU Score %.4f', bleu_score)
def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser( (ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file( json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses( ) # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The # information sent is the one passed as arguments along with your Python/PyTorch versions. send_example_telemetry("run_t5_mlm", model_args, data_args, framework="flax") if (os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir): raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome.") # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", level=logging.INFO, datefmt="[%X]", ) # Log on each process the small summary: logger = logging.getLogger(__name__) # Set the verbosity to info of the Transformers logger (on main process only): logger.info(f"Training/evaluation parameters {training_args}") # Set seed before initializing model. set_seed(training_args.seed) # Handle the repository creation if training_args.push_to_hub: if training_args.hub_model_id is None: repo_name = get_full_repo_name(Path( training_args.output_dir).absolute().name, token=training_args.hub_token) else: repo_name = training_args.hub_model_id repo = Repository(training_args.output_dir, clone_from=repo_name) # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # (the dataset will be downloaded automatically from the datasets Hub). # # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called # 'text' is found. You can easily tweak this behavior (see below). if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. datasets = load_dataset( data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) if "validation" not in datasets.keys(): datasets["validation"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) datasets["train"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) else: data_files = {} if data_args.train_file is not None: data_files["train"] = data_args.train_file if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file extension = data_args.train_file.split(".")[-1] if extension == "txt": extension = "text" datasets = load_dataset( extension, data_files=data_files, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) if "validation" not in datasets.keys(): datasets["validation"] = load_dataset( extension, data_files=data_files, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) datasets["train"] = load_dataset( extension, data_files=data_files, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. # Load pretrained model and tokenizer if model_args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer, use_auth_token=True if model_args.use_auth_token else None, ) elif model_args.model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer, use_auth_token=True if model_args.use_auth_token else None, ) else: raise ValueError( "You are instantiating a new tokenizer from scratch. This is not supported by this script." "You can do it from another script, save it, and load it from here, using --tokenizer_name." ) if model_args.config_name: config = T5Config.from_pretrained( model_args.config_name, cache_dir=model_args.cache_dir, vocab_size=len(tokenizer), use_auth_token=True if model_args.use_auth_token else None, ) elif model_args.model_name_or_path: config = T5Config.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) else: config = CONFIG_MAPPING[model_args.model_type]() logger.warning( "You are instantiating a new config instance from scratch.") # Preprocessing the datasets. # First we tokenize all the texts. if training_args.do_train: column_names = datasets["train"].column_names else: column_names = datasets["validation"].column_names text_column_name = "text" if "text" in column_names else column_names[0] max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts. # Since we make sure that all sequences are of the same length, no attention_mask is needed. def tokenize_function(examples): return tokenizer(examples[text_column_name], return_attention_mask=False) tokenized_datasets = datasets.map( tokenize_function, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, ) # T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token. # To ensure that the input length is `max_seq_length`, we need to increase the maximum length # according to `mlm_probability` and `mean_noise_span_length`. We can also define the label length accordingly. expanded_inputs_length, targets_length = compute_input_and_target_lengths( inputs_length=max_seq_length, noise_density=data_args.mlm_probability, mean_noise_span_length=data_args.mean_noise_span_length, ) # Main data processing function that will concatenate all texts from our dataset and generate chunks of expanded_inputs_length. def group_texts(examples): # Concatenate all texts. concatenated_examples = { k: list(chain(*examples[k])) for k in examples.keys() } total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # customize this part to your needs. if total_length >= expanded_inputs_length: total_length = (total_length // expanded_inputs_length) * expanded_inputs_length # Split by chunks of max_len. result = { k: [ t[i:i + expanded_inputs_length] for i in range(0, total_length, expanded_inputs_length) ] for k, t in concatenated_examples.items() } return result # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value # might be slower to preprocess. # # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map tokenized_datasets = tokenized_datasets.map( group_texts, batched=True, num_proc=data_args.preprocessing_num_workers, load_from_cache_file=not data_args.overwrite_cache, ) # Enable tensorboard only on the master node has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter( log_dir=Path(training_args.output_dir)) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable.") # Initialize our training rng = jax.random.PRNGKey(training_args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) if model_args.model_name_or_path: model = FlaxT5ForConditionalGeneration.from_pretrained( model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), use_auth_token=True if model_args.use_auth_token else None, ) else: config.vocab_size = len(tokenizer) model = FlaxT5ForConditionalGeneration( config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), ) # Data collator # This one will take care of randomly masking the tokens. data_collator = FlaxDataCollatorForT5MLM( tokenizer=tokenizer, noise_density=data_args.mlm_probability, mean_noise_span_length=data_args.mean_noise_span_length, input_length=max_seq_length, target_length=targets_length, pad_token_id=model.config.pad_token_id, decoder_start_token_id=model.config.decoder_start_token_id, ) # Store some constant num_epochs = int(training_args.num_train_epochs) train_batch_size = int( training_args.per_device_train_batch_size) * jax.device_count() per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) eval_batch_size = per_device_eval_batch_size * jax.device_count() num_train_steps = len( tokenized_datasets["train"]) // train_batch_size * num_epochs num_of_hosts = jax.process_count() current_host_idx = jax.process_index() # Create learning rate schedule warmup_fn = optax.linear_schedule( init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps) decay_fn = optax.linear_schedule( init_value=training_args.learning_rate, end_value=0, transition_steps=num_train_steps - training_args.warmup_steps, ) linear_decay_lr_schedule_fn = optax.join_schedules( schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]) # We use Optax's "masking" functionality to not apply weight decay # to bias and LayerNorm scale parameters. decay_mask_fn returns a # mask boolean with the same structure as the parameters. # The mask is True for parameters that should be decayed. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) # find out all LayerNorm parameters layer_norm_candidates = ["layernorm", "layer_norm", "ln"] layer_norm_named_params = set([ layer[-2:] for layer_norm_name in layer_norm_candidates for layer in flat_params.keys() if layer_norm_name in "".join(layer).lower() ]) flat_mask = { path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params } return traverse_util.unflatten_dict(flat_mask) # create adam optimizer if training_args.adafactor: # We use the default parameters here to initialize adafactor, # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74 optimizer = optax.adafactor( learning_rate=linear_decay_lr_schedule_fn, ) else: optimizer = optax.adamw( learning_rate=linear_decay_lr_schedule_fn, b1=training_args.adam_beta1, b2=training_args.adam_beta2, weight_decay=training_args.weight_decay, mask=decay_mask_fn, ) # Setup train state state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer) # Define gradient update step fn def train_step(state, batch, dropout_rng): dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) def loss_fn(params): labels = batch.pop("labels") logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] # compute loss loss = optax.softmax_cross_entropy( logits, onehot(labels, logits.shape[-1])).mean() return loss grad_fn = jax.value_and_grad(loss_fn) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) metrics = jax.lax.pmean( { "loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step) }, axis_name="batch") return new_state, metrics, new_dropout_rng # Create parallel version of the train step p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, )) # Define eval fn def eval_step(params, batch): labels = batch.pop("labels") logits = model(**batch, params=params, train=False)[0] # compute loss loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) # compute accuracy accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) # summarize metrics metrics = {"loss": loss.mean(), "accuracy": accuracy.mean()} metrics = jax.lax.pmean(metrics, axis_name="batch") return metrics p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0, )) # Replicate the train state on each device state = jax_utils.replicate(state) train_time = 0 epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0) for epoch in epochs: # ======================== Training ================================ train_start = time.time() train_metrics = [] # Create sampling rng rng, input_rng = jax.random.split(rng) # Generate an epoch by shuffling sampling indices from the train dataset num_train_samples = len(tokenized_datasets["train"]) # Avoid using jax.numpy here in case of TPU training train_samples_idx = np.random.permutation(np.arange(num_train_samples)) train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) # Gather the indexes for creating the batch and do a training step for step, batch_idx in enumerate( tqdm(train_batch_idx, desc="Training...", position=1)): samples = [ tokenized_datasets["train"][int(idx)] for idx in batch_idx ] model_inputs = data_collator(samples) local_host_model_inputs = { key: np.split(model_inputs.data[key], num_of_hosts, axis=0)[current_host_idx] for key, value in model_inputs.data.items() } # Model forward model_inputs = shard(local_host_model_inputs) state, train_metric, dropout_rngs = p_train_step( state, model_inputs, dropout_rngs) train_metrics.append(train_metric) cur_step = epoch * (num_train_samples // train_batch_size) + step if cur_step % training_args.logging_steps == 0 and cur_step > 0: # Save metrics train_metric = jax_utils.unreplicate(train_metric) train_time += time.time() - train_start if has_tensorboard and jax.process_index() == 0: write_train_metric(summary_writer, train_metrics, train_time, cur_step) epochs.write( f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate:" f" {train_metric['learning_rate'].mean()})") train_metrics = [] if cur_step % training_args.eval_steps == 0 and cur_step > 0: # ======================== Evaluating ============================== num_eval_samples = len(tokenized_datasets["validation"]) # Avoid using jax.numpy here in case of TPU training eval_samples_idx = np.arange(num_eval_samples) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) eval_metrics = [] for i, batch_idx in enumerate( tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): samples = [ tokenized_datasets["validation"][int(idx)] for idx in batch_idx ] model_inputs = data_collator(samples) # Model forward metrics = pad_shard_unpad(p_eval_step, static_return=True)( state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size) eval_metrics.append(metrics) # get eval metrics eval_metrics = get_metrics(eval_metrics) eval_metrics = jax.tree_map(jnp.mean, eval_metrics) # Update progress bar epochs.write( f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})" ) # Save metrics if has_tensorboard and jax.process_index() == 0: write_eval_metric(summary_writer, eval_metrics, cur_step) if cur_step % training_args.save_steps == 0 and cur_step > 0: # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get( jax.tree_map(lambda x: x[0], state.params)) model.save_pretrained(training_args.output_dir, params=params) tokenizer.save_pretrained(training_args.output_dir) if training_args.push_to_hub: repo.push_to_hub( commit_message= f"Saving weights and logs of step {cur_step}", blocking=False) # Eval after training if training_args.do_eval: num_eval_samples = len(tokenized_datasets["validation"]) # Avoid using jax.numpy here in case of TPU training eval_samples_idx = np.arange(num_eval_samples) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) eval_metrics = [] for i, batch_idx in enumerate( tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): samples = [ tokenized_datasets["validation"][int(idx)] for idx in batch_idx ] model_inputs = data_collator(samples) # Model forward metrics = pad_shard_unpad(p_eval_step, static_return=True)( state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size) eval_metrics.append(metrics) # get eval metrics eval_metrics = get_metrics(eval_metrics) eval_metrics = jax.tree_map(lambda metric: jnp.mean(metric).item(), eval_metrics) if jax.process_index() == 0: eval_metrics = { f"eval_{metric_name}": value for metric_name, value in eval_metrics.items() } path = os.path.join(training_args.output_dir, "eval_results.json") with open(path, "w") as f: json.dump(eval_metrics, f, indent=4, sort_keys=True)
def train(config, workdir): """Runs a training loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ # Create directories for experimental logs tf.io.gfile.makedirs(workdir) sample_dir = os.path.join(workdir, "samples") tf.io.gfile.makedirs(sample_dir) rng = jax.random.PRNGKey(config.seed) tb_dir = os.path.join(workdir, "tensorboard") tf.io.gfile.makedirs(tb_dir) if jax.host_id() == 0: writer = tensorboard.SummaryWriter(tb_dir) # Initialize model. rng, model_rng = jax.random.split(rng) model_name = config.model.name ncsn_def = mutils.get_model(model_name).partial(config=config) rng, run_rng = jax.random.split(rng) # Whether the generative model is conditioned on class labels class_conditional = "conditional" in config.training.loss.lower() with nn.stateful() as init_model_state: with nn.stochastic(run_rng): input_shape = (jax.local_device_count(), config.data.image_size, config.data.image_size, 3) input_list = [(input_shape, jnp.float32), (input_shape[:1], jnp.int32)] if class_conditional: input_list.append(input_list[-1]) _, initial_params = ncsn_def.init_by_shape( model_rng, input_list, train=True) ncsn = nn.Model(ncsn_def, initial_params) optimizer = losses.get_optimizer(config).create(ncsn) state = mutils.State(step=0, optimizer=optimizer, lr=config.optim.lr, model_state=init_model_state, ema_rate=config.model.ema_rate, params_ema=initial_params, rng=rng) # pytype: disable=wrong-keyword-args del ncsn, init_model_state # Do not keep a copy of the initial model. # Create checkpoints directory and the initial checkpoint checkpoint_dir = os.path.join(workdir, "checkpoints") ckpt = utils.Checkpoint( checkpoint_dir, max_to_keep=None) ckpt.restore_or_initialize(state) # Save intermediate checkpoints to resume training automatically checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta") ckpt_meta = utils.Checkpoint( checkpoint_meta_dir, max_to_keep=1) state = ckpt_meta.restore_or_initialize(state) initial_step = int(state.step) rng = state.rng # Build input pipeline. rng, ds_rng = jax.random.split(rng) train_ds, eval_ds, _ = datasets.get_dataset(ds_rng, config) train_iter = iter(train_ds) # pytype: disable=wrong-arg-types eval_iter = iter(eval_ds) # pytype: disable=wrong-arg-types scaler = datasets.get_data_scaler(config) # data normalizer inverse_scaler = datasets.get_data_inverse_scaler(config) # Distribute training. optimize_fn = losses.optimization_manager(config) if config.training.loss.lower() == "ddpm": # Use score matching loss with DDPM-type perturbation. ddpm_params = mutils.get_ddpm_params() train_step = functools.partial(losses.ddpm_loss, ddpm_params=ddpm_params, train=True, optimize_fn=optimize_fn) eval_step = functools.partial(losses.ddpm_loss, ddpm_params=ddpm_params, train=False) else: # Use score matching loss with NCSN-type perturbation. sigmas = mutils.get_sigmas(config) # Whether to use a continuous distribution of noise levels continuous = "continuous" in config.training.loss.lower() train_step = functools.partial( losses.ncsn_loss, sigmas=sigmas, class_conditional=class_conditional, continuous=continuous, train=True, optimize_fn=optimize_fn, anneal_power=config.training.anneal_power) eval_step = functools.partial( losses.ncsn_loss, sigmas=sigmas, class_conditional=class_conditional, continuous=continuous, train=False, anneal_power=config.training.anneal_power) p_train_step = jax.pmap(train_step, axis_name="batch") p_eval_step = jax.pmap(eval_step, axis_name="batch") state = flax_utils.replicate(state) num_train_steps = config.training.n_iters logging.info("Starting training loop at step %d.", initial_step) rng = jax.random.fold_in(rng, jax.host_id()) for step in range(initial_step, num_train_steps + 1): # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU # devices. # Convert data to JAX arrays. Use ._numpy() to avoid copy. batch = jax.tree_map(lambda x: scaler(x._numpy()), next(train_iter)) # pylint: disable=protected-access rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1) next_rng = jnp.asarray(next_rng) loss, state = p_train_step(next_rng, state, batch) loss = flax.jax_utils.unreplicate(loss) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) if jax.host_id() == 0 and step % 50 == 0: logging.info("step: %d, training_loss: %.5e", step, loss) writer.scalar("training_loss", loss, step) # Save a temporary checkpoint to resume training after pre-emption. if step % config.training.snapshot_freq_for_preemption == 0 and jax.host_id( ) == 0: saved_state = flax_utils.unreplicate(state) saved_state = saved_state.replace(rng=rng) ckpt_meta.save(saved_state) # Report the loss on an evaluation dataset. if step % 100 == 0: rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1) next_rng = jnp.asarray(next_rng) eval_batch = jax.tree_map(lambda x: scaler(x._numpy()), next(eval_iter)) # pylint: disable=protected-access eval_loss, _ = p_eval_step(next_rng, state, eval_batch) eval_loss = flax.jax_utils.unreplicate(eval_loss) if jax.host_id() == 0: logging.info("step: %d, eval_loss: %.5e", step, eval_loss) writer.scalar("eval_loss", eval_loss, step) # Save a checkpoint periodically and generate samples. if (step + 1) % config.training.snapshot_freq == 0 or step == num_train_steps: # Save the checkpoint. if jax.host_id() == 0: saved_state = flax_utils.unreplicate(state) saved_state = saved_state.replace(rng=rng) ckpt.save(saved_state) # Generate and save samples if config.training.snapshot_sampling: rng, sample_rng = jax.random.split(rng) init_shape = tuple(train_ds.element_spec["image"].shape) samples = sampling.get_samples(sample_rng, config, flax_utils.unreplicate(state), init_shape, scaler, inverse_scaler, class_conditional=class_conditional) this_sample_dir = os.path.join( sample_dir, "iter_{}_host_{}".format(step, jax.host_id())) tf.io.gfile.makedirs(this_sample_dir) if config.sampling.final_only: # Do not save intermediate samples sample = samples[-1] image_grid = sample.reshape((-1, *sample.shape[2:])) nrow = int(np.sqrt(image_grid.shape[0])) sample = np.clip(sample * 255, 0, 255).astype(np.uint8) with tf.io.gfile.GFile( os.path.join(this_sample_dir, "sample.np"), "wb") as fout: np.save(fout, sample) with tf.io.gfile.GFile( os.path.join(this_sample_dir, "sample.png"), "wb") as fout: utils.save_image(image_grid, fout, nrow=nrow, padding=2) else: # Save all intermediate samples produced during sampling. for i, sample in enumerate(samples): image_grid = sample.reshape((-1, *sample.shape[2:])) nrow = int(np.sqrt(image_grid.shape[0])) sample = np.clip(sample * 255, 0, 255).astype(np.uint8) with tf.io.gfile.GFile( os.path.join(this_sample_dir, "sample_{}.np".format(i)), "wb") as fout: np.save(fout, sample) with tf.io.gfile.GFile( os.path.join(this_sample_dir, "sample_{}.png".format(i)), "wb") as fout: utils.save_image(image_grid, fout, nrow=nrow, padding=2)
# Gather the indexes for creating the batch and do a training step for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)): samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx] model_inputs = data_collator(samples, pad_to_multiple_of=16) # Model forward model_inputs = shard(model_inputs.data) state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs) train_metrics.append(train_metric) cur_step = epoch * (num_train_samples // train_batch_size) + step if cur_step % training_args.logging_steps == 0 and cur_step > 0: # Save metrics train_metric = jax_utils.unreplicate(train_metric) train_time += time.time() - train_start if has_tensorboard and jax.process_index() == 0: write_train_metric(summary_writer, train_metrics, train_time, cur_step) epochs.write( f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})" ) train_metrics = [] if cur_step % training_args.eval_steps == 0 and cur_step > 0: # ======================== Evaluating ============================== num_eval_samples = len(tokenized_datasets["validation"]) eval_samples_idx = jnp.arange(num_eval_samples) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
def main(): args = parse_args() # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Setup logging, we only want one process per machine to log things on the screen. logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) if jax.process_index() == 0: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_info() else: datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() # Handle the repository creation if args.push_to_hub: if args.hub_model_id is None: repo_name = get_full_repo_name(Path(args.output_dir).absolute().name, token=args.hub_token) else: repo_name = args.hub_model_id repo = Repository(args.output_dir, clone_from=repo_name) # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named # label if at least two columns are provided. # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this # single column. You can easily tweak this behavior (see below) # In distributed training, the load_dataset function guarantee that only one local process can concurrently # download the dataset. if args.task_name is not None: # Downloading and loading a dataset from the hub. raw_datasets = load_dataset("glue", args.task_name) else: # Loading the dataset from local csv or json file. data_files = {} if args.train_file is not None: data_files["train"] = args.train_file if args.validation_file is not None: data_files["validation"] = args.validation_file extension = (args.train_file if args.train_file is not None else args.valid_file).split(".")[-1] raw_datasets = load_dataset(extension, data_files=data_files) # See more about loading any type of standard or custom dataset at # https://huggingface.co/docs/datasets/loading_datasets.html. # Labels if args.task_name is not None: is_regression = args.task_name == "stsb" if not is_regression: label_list = raw_datasets["train"].features["label"].names num_labels = len(label_list) else: num_labels = 1 else: # Trying to have good defaults here, don't hesitate to tweak to your needs. is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"] if is_regression: num_labels = 1 else: # A useful fast method: # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique label_list = raw_datasets["train"].unique("label") label_list.sort() # Let's sort it for determinism num_labels = len(label_list) # Load pretrained model and tokenizer config = AutoConfig.from_pretrained(args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name) tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer) model = FlaxAutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, config=config) # Preprocessing the datasets if args.task_name is not None: sentence1_key, sentence2_key = task_to_keys[args.task_name] else: # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"] if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: sentence1_key, sentence2_key = "sentence1", "sentence2" else: if len(non_label_column_names) >= 2: sentence1_key, sentence2_key = non_label_column_names[:2] else: sentence1_key, sentence2_key = non_label_column_names[0], None # Some models have set the order of the labels to use, so let's make sure we do use it. label_to_id = None if ( model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id and args.task_name is not None and not is_regression ): # Some have all caps in their config, some don't. label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): logger.info( f"The configuration of the model provided the following label correspondence: {label_name_to_id}. " "Using it!" ) label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)} else: logger.warning( "Your model seems to have been trained with labels, but they don't match the dataset: ", f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." "\nIgnoring the model labels as a result.", ) elif args.task_name is None: label_to_id = {v: i for i, v in enumerate(label_list)} def preprocess_function(examples): # Tokenize the texts texts = ( (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) ) result = tokenizer(*texts, padding="max_length", max_length=args.max_length, truncation=True) if "label" in examples: if label_to_id is not None: # Map labels to IDs (not necessary for GLUE tasks) result["labels"] = [label_to_id[l] for l in examples["label"]] else: # In all cases, rename the column to labels because the model will expect that. result["labels"] = examples["label"] return result processed_datasets = raw_datasets.map( preprocess_function, batched=True, remove_columns=raw_datasets["train"].column_names ) train_dataset = processed_datasets["train"] eval_dataset = processed_datasets["validation_matched" if args.task_name == "mnli" else "validation"] # Log a few random samples from the training set: for index in random.sample(range(len(train_dataset)), 3): logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") # Define a summary writer has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter(args.output_dir) summary_writer.hparams(vars(args)) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable." ) def write_metric(train_metrics, eval_metrics, train_time, step): summary_writer.scalar("train_time", train_time, step) train_metrics = get_metrics(train_metrics) for key, vals in train_metrics.items(): tag = f"train_{key}" for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) for metric_name, value in eval_metrics.items(): summary_writer.scalar(f"eval_{metric_name}", value, step) num_epochs = int(args.num_train_epochs) rng = jax.random.PRNGKey(args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) train_batch_size = args.per_device_train_batch_size * jax.local_device_count() eval_batch_size = args.per_device_eval_batch_size * jax.local_device_count() learning_rate_fn = create_learning_rate_fn( len(train_dataset), train_batch_size, args.num_train_epochs, args.num_warmup_steps, args.learning_rate ) state = create_train_state( model, learning_rate_fn, is_regression, num_labels=num_labels, weight_decay=args.weight_decay ) # define step functions def train_step( state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey ) -> Tuple[train_state.TrainState, float]: """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`.""" dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) targets = batch.pop("labels") def loss_fn(params): logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss = state.loss_fn(logits, targets) return loss grad_fn = jax.value_and_grad(loss_fn) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch") return new_state, metrics, new_dropout_rng p_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,)) def eval_step(state, batch): logits = state.apply_fn(**batch, params=state.params, train=False)[0] return state.logits_fn(logits) p_eval_step = jax.pmap(eval_step, axis_name="batch") if args.task_name is not None: metric = load_metric("glue", args.task_name) else: metric = load_metric("accuracy") logger.info(f"===== Starting training ({num_epochs} epochs) =====") train_time = 0 # make sure weights are replicated on each device state = replicate(state) for epoch in range(1, num_epochs + 1): logger.info(f"Epoch {epoch}") logger.info(" Training...") train_start = time.time() train_metrics = [] rng, input_rng = jax.random.split(rng) # train for batch in glue_train_data_collator(input_rng, train_dataset, train_batch_size): state, metrics, dropout_rngs = p_train_step(state, batch, dropout_rngs) train_metrics.append(metrics) train_time += time.time() - train_start logger.info(f" Done! Training metrics: {unreplicate(metrics)}") logger.info(" Evaluating...") # evaluate for batch in glue_eval_data_collator(eval_dataset, eval_batch_size): labels = batch.pop("labels") predictions = p_eval_step(state, batch) metric.add_batch(predictions=chain(*predictions), references=chain(*labels)) # evaluate also on leftover examples (not divisible by batch_size) num_leftover_samples = len(eval_dataset) % eval_batch_size # make sure leftover batch is evaluated on one device if num_leftover_samples > 0 and jax.process_index() == 0: # take leftover samples batch = eval_dataset[-num_leftover_samples:] batch = {k: jnp.array(v) for k, v in batch.items()} labels = batch.pop("labels") predictions = eval_step(unreplicate(state), batch) metric.add_batch(predictions=predictions, references=labels) eval_metric = metric.compute() logger.info(f" Done! Eval metrics: {eval_metric}") cur_step = epoch * (len(train_dataset) // train_batch_size) # Save metrics if has_tensorboard and jax.process_index() == 0: write_metric(train_metrics, eval_metric, train_time, cur_step) # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) model.save_pretrained(args.output_dir, params=params) tokenizer.save_pretrained(args.output_dir) if args.push_to_hub: repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False) # save the eval metrics in json if jax.process_index() == 0: eval_metric = {f"eval_{metric_name}": value for metric_name, value in eval_metric.items()} path = os.path.join(args.output_dir, "eval_results.json") with open(path, "w") as f: json.dump(eval_metric, f, indent=4, sort_keys=True)
def train_loop(config, dropout_rngs, eval_ds, eval_freq, num_eval_steps, num_train_steps, optimizer, state, p_eval_step, p_train_step, start_step, train_iter, summary_writer): """Training loop. Args: config: experiment config. dropout_rngs: float array; Jax PRNG key. eval_ds: tf.dataset; Evaluation dataset. eval_freq: int; Evaluation frequency; num_eval_steps: int; Number of evaluation steps. num_train_steps: int; Number of training steps. optimizer: flax optimizer. state: model state, e.g. batch statistics. p_eval_step: fn; Pmapped evaluation step function. p_train_step: fn; Pmapped train step function. start_step: int; global training step. train_iter: iter(tf.dataset); Training data iterator. summary_writer: tensorflow summary writer. Returns: optimizer, global training step """ metrics_all = [] tick = time.time() logging.info('Starting training') logging.info('====================') step = 0 for step, batch in zip(range(start_step, num_train_steps), train_iter): batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access optimizer, state, metrics, dropout_rngs = p_train_step( optimizer, state, batch, dropout_rng=dropout_rngs) metrics_all.append(metrics) # Save a Checkpoint if ((step % config.checkpoint_freq == 0 and step > 0) or step == num_train_steps - 1): if jax.host_id() == 0 and config.save_checkpoints: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint(FLAGS.model_dir, (jax_utils.unreplicate(optimizer), jax_utils.unreplicate(state)), step) # Periodic metric handling. if step % eval_freq == 0 and step > 0: metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: logging.info('train in step: %d, loss: %.4f, acc: %.4f', step, summary['loss'], summary['accuracy']) if jax.host_id() == 0: tock = time.time() steps_per_sec = eval_freq / (tock - tick) tick = tock summary_writer.scalar('examples_per_second', steps_per_sec * config.batch_size, step) for key, val in summary.items(): summary_writer.scalar(f'train_{key}', val, step) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] # Eval Metrics eval_metrics = [] eval_iter = iter(eval_ds) if num_eval_steps == -1: num_iter = itertools.repeat(1) else: num_iter = range(num_eval_steps) for _, eval_batch in zip(num_iter, eval_iter): # pylint: disable=protected-access eval_batch = common_utils.shard( jax.tree_map(lambda x: x._numpy(), eval_batch)) # pylint: enable=protected-access metrics = p_eval_step(optimizer.target, state, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) logging.info('eval in step: %d, loss: %.4f, acc: %.4f', step, eval_summary['loss'], eval_summary['accuracy']) if jax.host_id() == 0: for key, val in eval_summary.items(): summary_writer.scalar(f'val_{key}', val, step) summary_writer.flush() return optimizer, state, step
def main(_): tf.enable_v2_behavior() tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) # BOS special attention only makes sense if we are using relative attention # and it's not the baseline. if FLAGS.bos_special_attention and (not FLAGS.use_relative_attention or FLAGS.attention_mask_type == 'baseline'): raise ValueError( "bos_special_attention doesn't work when use_relative_attention={} and " 'attention_mask_type={}'.format(FLAGS.use_relative_attention, FLAGS.attention_mask_type)) if not gfile.isdir(FLAGS.save_dir): gfile.makedirs(FLAGS.save_dir) hparam_str_dict = json.loads(FLAGS.xm_parameters) hparam_str = ','.join([ '%s=%s' % (shorten(k), str(hparam_str_dict[k])) for k in hparam_str_dict.keys() ]) # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) batch_size = FLAGS.per_device_batch_size * n_devices io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.max_characters) predict_io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.predict_max_characters) program_shape = (FLAGS.per_device_batch_size, FLAGS.max_program_length) # Setup DSL # --------------------------------------------------------------------------- # Build token tables. id_char_table = {i + 1: char for (i, char) in enumerate(dsl.CHARACTER)} char_id_table = {char: id for id, char in id_char_table.items()} id_token_table, token_id_table = dsl_tokens.build_token_tables() io_vocab_size = len(char_id_table) + 1 # For padding. program_vocab_size = len(token_id_table) + 1 bos_token = token_id_table[dsl.BOS] eos_token = token_id_table[dsl.EOS] # Parse io and program token sequences (for eval). def decode_io(inputs, outputs): """Decode io examples tokens.""" def decode_str(s): """Decode string tokens.""" return ''.join([id_char_table[c_id] for c_id in s if c_id > 0]) inps, outs = [], [] for inp, out in zip(inputs, outputs): inps.append(decode_str(inp)) outs.append(decode_str(out)) return inps, outs def decode_program(program): """Decode program tokens.""" program = program[:np.argmax(program == eos_token) + 1].astype( np.int32) program = program[program != bos_token] try: return dsl.decode_program(program.tolist(), id_token_table) except: # pylint: disable=bare-except return None # Program does not compile. # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') if not FLAGS.dataset_filepattern: raise ValueError('Must specify filepattern to dataset.') # Training dataset. logging.info('Loading dataset from %s', FLAGS.dataset_filepattern) padded_shapes = (io_shape[1:], io_shape[1:], program_shape[1:]) logging.info('padded_shapes: %s', padded_shapes) dataset = input_pipeline.create_dataset_from_tf_record( FLAGS.dataset_filepattern, token_id_table, char_id_table) dataset = dataset.padded_batch(batch_size, padded_shapes=padded_shapes, drop_remainder=True) # Split evaluation and training. eval_ds = dataset.take(FLAGS.num_eval_steps) # Decrease batch of predict dataset to handle beam search. predict_padded_shapes = (predict_io_shape[1:], predict_io_shape[1:], program_shape[1:]) logging.info('predict_padded_shapes: %s', predict_padded_shapes) predict_ds = eval_ds.unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=predict_padded_shapes) train_ds = dataset.skip(FLAGS.num_eval_steps) if FLAGS.train_set_batches > 0: train_ds = train_ds.take(FLAGS.train_set_batches) train_ds = train_ds.repeat() test_dataset = input_pipeline.create_dataset_from_tf_record( FLAGS.test_dataset_filepattern, token_id_table, char_id_table) test_dataset = test_dataset.padded_batch( batch_size, padded_shapes=predict_padded_shapes, drop_remainder=False) quick_test_dataset = (test_dataset.take( FLAGS.num_quick_test_steps).unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=predict_padded_shapes)) final_test_dataset = (test_dataset.take( FLAGS.num_final_test_steps).unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=predict_padded_shapes)) # Build Model and Optimizer # --------------------------------------------------------------------------- default_config = base_models.TransformerConfig( vocab_size=io_vocab_size, output_vocab_size=program_vocab_size) base_config = base_models.TransformerConfig( vocab_size=io_vocab_size, output_vocab_size=program_vocab_size, shift=True, emb_dim=FLAGS.embedding_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.embedding_dim, mlp_dim=FLAGS.hidden_dim, max_len=max(FLAGS.max_characters, FLAGS.max_program_length), dropout_rate=FLAGS.dropout_rate, attention_dropout_rate=FLAGS.attention_dropout_rate, use_relative_attention=FLAGS.use_relative_attention, deterministic=False, decode=False, bos_token=bos_token, num_input_relative_position_buckets=FLAGS.num_position_buckets, max_input_distance=min(FLAGS.max_distance, default_config.max_input_distance), num_output_relative_position_buckets=FLAGS.num_position_buckets, max_output_distance=min(FLAGS.max_distance, default_config.max_output_distance), num_input_cross_output_relative_position_buckets=( FLAGS.num_position_buckets), max_input_cross_output_distance=min( FLAGS.max_distance, default_config.max_input_cross_output_distance), num_program_relative_position_buckets=FLAGS.num_position_buckets, max_program_distance=min(FLAGS.max_distance, default_config.max_program_distance), num_program_cross_embed_relative_position_buckets=( FLAGS.num_position_buckets), max_program_cross_embed_distance=min( FLAGS.max_distance, default_config.max_program_cross_embed_distance), bidirectional_program_attention=FLAGS.bidirectional_program_attention) train_config = models.DecomposeAttentionTransformerConfig( base_config=base_config, attention_mask_type=FLAGS.attention_mask_type, bos_special_attention=FLAGS.bos_special_attention) eval_config = models.DecomposeAttentionTransformerConfig( base_config=base_config.replace(deterministic=True), attention_mask_type=FLAGS.attention_mask_type, bos_special_attention=FLAGS.bos_special_attention) predict_config = models.DecomposeAttentionTransformerConfig( base_config=base_config.replace( shift=False, deterministic=True, decode=not FLAGS.slow_decode, max_len=max(FLAGS.max_characters, FLAGS.max_program_length, FLAGS.predict_max_characters)), attention_mask_type=FLAGS.attention_mask_type, bos_special_attention=FLAGS.bos_special_attention) rng = jax.random.PRNGKey(FLAGS.seed) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = jax.random.split(rng) dropout_rng = jax.random.split(rng, jax.local_device_count()) del rng m = models.DecomposeAttentionTransformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(program_shape, jnp.float32)) optimizer_def = optim.Adam(FLAGS.lr, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(initial_variables['params']) del initial_variables # Don't keep a copy of the initial model. start_step = 0 if FLAGS.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer) # Grab last step. start_step = int(optimizer.state.step) logging.info('Found model checkpointed at step %d.', start_step) if FLAGS.finetune_start_step > 0: logging.info( 'Checking that start_step (%s) == finetune_start_step (%s)', start_step, FLAGS.finetune_start_step) assert start_step >= FLAGS.finetune_start_step steps_to_skip = start_step - FLAGS.finetune_start_step else: steps_to_skip = start_step # TODO(kshi): It is likely that this code can lead to the job stalling for # 10+ hours when restarting from a checkpoint that had been trained a long # time, possibly because dataset skipping is slow. logging.info('Skipping %s steps...', steps_to_skip) train_ds = train_ds.skip(steps_to_skip) dummy_p_train_step = jax.pmap( lambda dropout_rng: jax.random.split(dropout_rng)[1]) for _ in range(steps_to_skip): dropout_rng = dummy_p_train_step(dropout_rng) logging.info('Finished skipping steps') logging.info('Host %s has dropout_rng = %s', jax.host_id(), dropout_rng) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) # TODO(jxihong): Implement fast decoding. assert FLAGS.slow_decode, 'Fast decoding is not implemented yet.' if FLAGS.finetune_start_step <= 0: learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=FLAGS.lr) else: # Constant LR for finetuning. learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=FLAGS.lr, factors='constant') p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn, config=train_config), axis_name='batch') p_eval_step = jax.pmap(functools.partial(eval_step, eos_token=eos_token, config=eval_config), axis_name='batch') p_init_cache = jax.pmap(functools.partial( initialize_cache, max_decode_len=FLAGS.max_program_length, config=predict_config), axis_name='batch') p_pred_step = jax.pmap(functools.partial( predict_step, eos_token=eos_token, max_decode_len=FLAGS.max_program_length, config=predict_config, slow_decode=FLAGS.slow_decode), axis_name='batch', static_broadcasted_argnums=(4, )) # Main Train Loop # --------------------------------------------------------------------------- logging.info('Starting training!') metrics_all = [] tick = time.time() train_iter = train_ds.as_numpy_iterator() for step in range(start_step, FLAGS.num_train_steps): inputs, outputs, programs = common_utils.shard(next(train_iter)) optimizer, metrics, dropout_rng = p_train_step(optimizer, inputs, outputs, programs, dropout_rng=dropout_rng) metrics_all.append(metrics) is_last_step = step == FLAGS.num_train_steps - 1 # Save a Checkpoint if (step % FLAGS.checkpoint_freq == 0 and step > 0) or is_last_step: if jax.host_id() == 0: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), jax_utils.unreplicate(optimizer), step) # Periodic metric handling. # Training Metrics if (step and step % FLAGS.log_freq == 0) or is_last_step: logging.info('Gathering training metrics.') metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map( lambda x: x / denominator, # pylint: disable=cell-var-from-loop metrics_sums) summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) if jax.host_id() == 0: logging.info('Train in step: %d, loss: %.4f', step, summary['loss']) tock = time.time() steps_per_sec = FLAGS.log_freq / (tock - tick) tick = tock summary_writer.scalar('train/steps per second', steps_per_sec, step) for key, val in summary.items(): summary_writer.scalar('train/' + key, val, step) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] # Evaluation Metrics if (step and step % FLAGS.eval_freq == 0) or is_last_step: logging.info('Gathering evaluation metrics.') t_evaluation_start = time.time() eval_metrics = [] for batches in eval_ds.as_numpy_iterator(): inputs, outputs, programs = common_utils.shard(batches) metrics = p_eval_step(optimizer.target, inputs, outputs, programs) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) if jax.host_id() == 0: logging.info('Evaluation time: %.4f s step %d, loss: %.4f.', time.time() - t_evaluation_start, step, eval_summary['loss']) for key, val in eval_summary.items(): summary_writer.scalar('eval/' + key, val, step) summary_writer.flush() # Beam search metrics. if (step and step % FLAGS.predict_freq == 0) or is_last_step: logging.info('Gathering beam search metrics.') test_ds = final_test_dataset if is_last_step else quick_test_dataset for dataset, predict_or_test in [(predict_ds, 'predict'), (test_ds, 'test')]: for beam_size in [1, 10]: t_inference_start = time.time() total_successes = 0 total_denominator = 0 pred_successes = collections.defaultdict(int) pred_denominators = collections.defaultdict(int) ios, targets, predictions, top_of_beams = [], [], [], [] for batches in dataset.as_numpy_iterator(): pred_batch = batches # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch[0].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) # pylint: disable=cell-var-from-loop pred_batch = jax.tree_map( lambda x: pad_examples(x, padded_size), pred_batch) inputs, outputs, programs = common_utils.shard( pred_batch) cache = (p_init_cache(inputs, outputs, programs) if not FLAGS.slow_decode else None) predicted = p_pred_step(optimizer.target, inputs, outputs, cache, beam_size) predicted = tohost(predicted) inputs, outputs, programs = map( tohost, (inputs, outputs, programs)) for i, beams in enumerate(predicted): inps, outs = decode_io(inputs[i], outputs[i]) p, p_score = eval_predicted( beams, inps, outs, parse_beam_fn=decode_program) # Split by length of program. program = programs[i] num_expressions = len( decode_program(program).expressions) pred_denominators[num_expressions] += 1 total_denominator += 1 if p_score >= len(inps): pred_successes[num_expressions] += 1 total_successes += 1 ios.append(' ; '.join(map(str, zip(inps, outs)))) targets.append( decode_program(programs[i]).to_string()) try: predictions.append(p.to_string()) except: # pylint: disable=bare-except predictions.append('Did not compile') logging.info('ios: %s', ios[-1]) logging.info('target: %s', targets[-1]) beams_log = [] for beam in beams: try: beams_log.append( decode_program(beam).to_string()) except: # pylint: disable=bare-except beams_log.append('Did not compile') logging.info('predicted beam: %s', '\n'.join(beams_log)) top_of_beam = [] for index, beam in enumerate(beams[:-5:-1]): try: decoded_program = decode_program( beam).to_string() except: # pylint: disable=bare-except decoded_program = 'Did not compile' top_of_beam.append( 'index: {}, decoded: {}, tokens: {}'. format(index, decoded_program, beam)) top_of_beams.append('\n\n'.join(top_of_beam)) all_total_successes, all_total_denominator = per_host_sum_pmap( jax.tree_map(np.array, (total_successes, total_denominator))) all_pred_successes, all_pred_denominators = per_host_sum_pmap( jax.tree_map(np.array, (pred_successes, pred_denominators))) # Record beam search results as text summaries. message = [] for n in np.random.choice(np.arange(len(predictions)), 8): text = (f'ios: {ios[n]}\n\ntarget: {targets[n]}\n\n' f'predicted: {predictions[n]}\n\n' f'top of beam:\n\n{top_of_beams[n]}\n\n') message.append(text) # Write to tensorboard. if jax.host_id() == 0: accuracy = 100 * all_total_successes / all_total_denominator logging.info( '%s results, step %d, beam size %d: %s / %s = %.2f%% (%.2f s)', predict_or_test, step, beam_size, all_total_successes, all_total_denominator, accuracy, time.time() - t_inference_start) summary_writer.scalar( '{}/beam-size-{}'.format(predict_or_test, beam_size), accuracy, step) for length in sorted(all_pred_successes.keys()): this_length_accuracy = ( 100 * all_pred_successes[length] / all_pred_denominators[length]) logging.info( ' accuracy for length %s: %s / %s = %.2f%%', length, all_pred_successes[length], all_pred_denominators[length], this_length_accuracy) summary_writer.scalar( '{}-by-length/beam-size-{}-length-{}'.format( predict_or_test, beam_size, length), this_length_accuracy, step) summary_writer.text( '{}-samples-beam-{}'.format( predict_or_test, beam_size), '\n------\n'.join(message), step) summary_writer.flush()
def main(_): tf.enable_v2_behavior() tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) if not gfile.isdir(FLAGS.save_dir): gfile.mkdir(FLAGS.save_dir) hparam_str_dict = dict(seed=FLAGS.seed, lr=FLAGS.lr) # Get hyperparmaters if FLAGS.xm_parameters: for key, value in json.loads(FLAGS.xm_parameters).items(): if key not in hparam_str_dict: hparam_str_dict[key] = value hparam_str = ','.join([ '%s=%s' % (k, str(hparam_str_dict[k])) for k in sorted(hparam_str_dict.keys()) ]) # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) batch_size = FLAGS.per_device_batch_size * n_devices io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.max_characters) program_shape = (FLAGS.per_device_batch_size, FLAGS.max_program_length) # Setup DSL # --------------------------------------------------------------------------- # Build token tables. id_char_table = {i + 1: char for (i, char) in enumerate(dsl.CHARACTER)} char_id_table = {char: id for id, char in id_char_table.items()} id_token_table, token_id_table = dsl_tokens.build_token_tables() io_vocab_size = len(char_id_table) + 1 # For padding. program_vocab_size = len(token_id_table) + 1 bos_token = token_id_table[dsl.BOS] eos_token = token_id_table[dsl.EOS] def decode_io(inputs, outputs): """Decode io examples tokens.""" def decode_str(s): """Decode string tokens.""" return ''.join([id_char_table[c_id] for c_id in s if c_id > 0]) io_string = '' inps, outs = [], [] for inp, out in zip(inputs, outputs): inps.append(decode_str(inp)) outs.append(decode_str(out)) io_string += inps[-1] + ' < ' + outs[-1] + ' > ' return inps, outs, io_string[:-3] # Remove last separator. def decode_program(program): """Decode program tokens.""" program = program[:np.argmax(program == eos_token) + 1].astype( np.int32) try: p = dsl.decode_program(program, id_token_table) return p, p.to_string() except: # pylint: disable=bare-except return None, '' # Program does not compile. # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') if not FLAGS.dataset_filepattern: raise ValueError('Must specify filepattern to dataset.') # Training dataset. dataset = input_pipeline.create_dataset_from_tf_record( FLAGS.dataset_filepattern, token_id_table, char_id_table) dataset = dataset.padded_batch(batch_size, padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:]), drop_remainder=True) # Split evaluation and training. eval_ds = dataset.take(FLAGS.num_eval_steps) # Decrease batch of predict dataset to handle beam search. predict_ds = eval_ds.unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:])) train_ds = dataset.skip(FLAGS.num_eval_steps).repeat() train_iter = train_ds.as_numpy_iterator() # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = models.TransformerConfig( vocab_size=io_vocab_size, output_vocab_size=program_vocab_size, shift=True, emb_dim=FLAGS.embedding_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.embedding_dim, mlp_dim=FLAGS.hidden_dim, max_len=max(FLAGS.max_characters, FLAGS.max_program_length), use_relative_attention=FLAGS.use_relative_attention, deterministic=False, decode=False, bos_token=bos_token) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(shift=False, deterministic=True, decode=True) rng = jax.random.PRNGKey(FLAGS.seed) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = jax.random.split(rng) m = models.ProgramTransformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(program_shape, jnp.float32)) optimizer_def = optim.Adam(FLAGS.lr, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(initial_variables['params']) del initial_variables # Don't keep a copy of the initial model. start_step = 0 if FLAGS.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer) # Grab last step. start_step = int(optimizer.state.step) logging.info('Found model checkpointed at step %d.', start_step) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = train_lib.create_learning_rate_scheduler( base_learning_rate=FLAGS.lr) p_train_step = jax.pmap(functools.partial( train_lib.train_step, learning_rate_fn=learning_rate_fn, config=train_config), axis_name='batch') p_eval_step = jax.pmap(functools.partial(train_lib.eval_step, config=eval_config), axis_name='batch') p_init_cache = jax.pmap(functools.partial( train_lib.initialize_cache, max_decode_len=FLAGS.max_program_length, config=predict_config), axis_name='batch') p_pred_step = jax.pmap(functools.partial(train_lib.predict_step, config=predict_config), axis_name='batch', static_broadcasted_argnums=(4, 5, 6)) # Main Train Loop # --------------------------------------------------------------------------- train_rngs = jax.random.split(rng, jax.local_device_count()) del rng metrics_all = [] tick = time.time() for step in range(start_step, FLAGS.num_train_steps): inputs, outputs, programs = common_utils.shard(next(train_iter)) optimizer, metrics, train_rngs = p_train_step(optimizer, inputs, outputs, programs, train_rng=train_rngs) metrics_all.append(metrics) # Save a Checkpoint if ((step % FLAGS.checkpoint_freq == 0 and step > 0) or step == FLAGS.num_train_steps - 1): if jax.host_id() == 0: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if not step or step % FLAGS.log_freq != 0: continue logging.info('Gathering training metrics.') # Training Metrics metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map( lambda x: x / denominator, # pylint: disable=cell-var-from-loop metrics_sums) summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) if jax.host_id() == 0: logging.info('Train in step: %d, loss: %.4f', step, summary['loss']) tock = time.time() steps_per_sec = FLAGS.log_freq / (tock - tick) tick = tock summary_writer.scalar('train/steps per second', steps_per_sec, step) for key, val in summary.items(): summary_writer.scalar('train/' + key, val, step) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] # Evaluation Metrics logging.info('Gathering evaluation metrics.') t_evaluation_start = time.time() eval_metrics = [] for batches in eval_ds.as_numpy_iterator(): inputs, outputs, programs = common_utils.shard(batches) metrics = p_eval_step(optimizer.target, inputs, outputs, programs) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) if jax.host_id() == 0: logging.info('Evaluation time: %.4f s step %d, loss: %.4f.', time.time() - t_evaluation_start, step, eval_summary['loss']) for key, val in eval_summary.items(): summary_writer.scalar('eval/' + key, val, step) summary_writer.flush() # Beam search metrics. logging.info('Gathering beam search metrics.') for beam_size in [10, 100]: t_inference_start = time.time() pred_acc = 0 pred_denominator = 0 ios, targets, predictions = [], [], [] for batches in predict_ds.as_numpy_iterator(): pred_batch = batches # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch[0].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) # pylint: disable=cell-var-from-loop pred_batch = jax.tree_map( lambda x: train_lib.pad_examples(x, padded_size), pred_batch) inputs, outputs, programs = common_utils.shard(pred_batch) cache = p_init_cache(inputs, outputs, programs) predicted = p_pred_step(optimizer.target, inputs, outputs, cache, eos_token, programs.shape[-1], beam_size) predicted = train_lib.tohost(predicted) inputs, outputs, programs = map(train_lib.tohost, (inputs, outputs, programs)) pred_denominator += programs.shape[0] for i, beams in enumerate(predicted): inps, outs, io_string = decode_io(inputs[i], outputs[i]) p, p_score = train_lib.eval_predicted( beams, inps, outs, parse_beam_fn=lambda x: decode_program(x)[0]) if p_score >= len(inps): pred_acc += 1 ios.append(io_string) targets.append(decode_program(programs[i])[1]) predictions.append(p.to_string() if p else '') all_pred_acc, all_pred_denominator = train_lib.per_host_sum_pmap( jax.tree_map(np.array, (pred_acc, pred_denominator))) # Record beam search results as text summaries. message = [] for n in np.random.choice(np.arange(len(predictions)), 8): text = (f'ios: {ios[n]}\n\ntarget: {targets[n]}\n\n' f'predicted: {predictions[n]}\n\n') message.append(text) # Write to tensorboard. if jax.host_id() == 0: logging.info( 'Prediction time (beam %d): %.4f s step %d, score %.4f.', beam_size, time.time() - t_inference_start, step, all_pred_acc / all_pred_denominator) summary_writer.scalar('predict/score-{}'.format(beam_size), all_pred_acc / all_pred_denominator, step) summary_writer.text('samples-{}'.format(beam_size), '\n------\n'.join(message), step) summary_writer.flush()