def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ tf.io.gfile.makedirs(workdir) vocab_path = config.vocab_path if vocab_path is None: vocab_path = os.path.join(workdir, "sentencepiece_model") config.vocab_path = vocab_path tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) # Load Dataset # --------------------------------------------------------------------------- logging.info("Initializing dataset.") train_ds, eval_ds, _, encoder = input_pipeline.get_datasets( n_devices=jax.local_device_count(), config=config, vocab_path=vocab_path) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_id = temperature_sampler.EOS_ID # Default Sentencepiece EOS token. def decode_tokens(toks): valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode("utf-8") def encode_strings(strs, max_len): tokenized_batch = np.zeros((len(strs), max_len), np.int32) for i, s in enumerate(strs): toks = encoder.tokenize(s).numpy() # Remove EOS token in prompt. tokenized_batch[i, :toks.shape[0] - 1] = toks[:-1] return tokenized_batch tokenized_prompts = encode_strings([config.prompts], config.max_predict_length) logging.info("Initializing model, optimizer, and step functions.") # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = models.TransformerConfig( vocab_size=vocab_size, output_vocab_size=vocab_size, logits_via_embedding=config.logits_via_embedding, dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, emb_dim=config.emb_dim, num_heads=config.num_heads, num_layers=config.num_layers, qkv_dim=config.qkv_dim, mlp_dim=config.mlp_dim, max_len=max(config.max_target_length, config.max_eval_target_length), dropout_rate=config.dropout_rate, attention_dropout_rate=config.attention_dropout_rate, deterministic=False, decode=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(deterministic=True, decode=True) start_step = 0 rng = jax.random.PRNGKey(config.seed) rng, init_rng = jax.random.split(rng) rng, inference_rng = random.split(rng) input_shape = (config.per_device_batch_size, config.max_target_length) m = models.TransformerLM(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(input_shape, jnp.float32)) # apply an optimizer to this tree optimizer_def = optim.Adam(config.learning_rate, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=config.weight_decay) optimizer = optimizer_def.create(initial_variables["params"]) # We access model params only from optimizer below via optimizer.target. del initial_variables if config.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint(workdir, optimizer) # Grab last step. start_step = int(optimizer.state.step) writer = metric_writers.create_default_writer( workdir, just_logging=jax.host_id() > 0) if start_step == 0: writer.write_hparams(dict(config)) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps) # compile multidevice versions of train/eval/predict step fn. p_train_step = jax.pmap(functools.partial( train_step, config=train_config, learning_rate_fn=learning_rate_fn), axis_name="batch", donate_argnums=(0, )) # pytype: disable=wrong-arg-types p_eval_step = jax.pmap(functools.partial(eval_step, config=eval_config), axis_name="batch") p_pred_step = jax.pmap( functools.partial(predict_step, config=predict_config, temperature=config.sampling_temperature, top_k=config.sampling_top_k), axis_name="batch", static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap"d training update for performance. dropout_rngs = jax.random.split(rng, jax.local_device_count()) del rng logging.info("Starting training loop.") hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) if jax.host_id() == 0: hooks += [ report_progress, periodic_actions.Profile(logdir=workdir, num_profile_steps=5) ] train_metrics = [] with metric_writers.ensure_flushes(writer): for step in range(start_step, config.num_train_steps): is_last_step = step == config.num_train_steps - 1 # Shard data to devices and do a training step. with jax.profiler.StepTraceContext("train", step_num=step): batch = common_utils.shard( jax.tree_map(np.asarray, next(train_iter))) optimizer, metrics = p_train_step(optimizer, batch, dropout_rng=dropout_rngs) train_metrics.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) for h in hooks: h(step) # Periodic metric handling. if step % config.eval_every_steps == 0 or is_last_step: with report_progress.timed("training_metrics"): logging.info("Gathering training metrics.") train_metrics = common_utils.get_metrics(train_metrics) lr = train_metrics.pop("learning_rate").mean() metrics_sums = jax.tree_map(jnp.sum, train_metrics) denominator = metrics_sums.pop("denominator") summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary["learning_rate"] = lr summary["perplexity"] = jnp.clip(jnp.exp(summary["loss"]), a_max=1.0e4) summary = {"train_" + k: v for k, v in summary.items()} writer.write_scalars(step, summary) train_metrics = [] with report_progress.timed("eval"): eval_results = evaluate( p_eval_step=p_eval_step, target=optimizer.target, eval_ds=eval_ds, num_eval_steps=config.num_eval_steps) # (clipped) perplexity after averaging log-perplexitie eval_results["perplexity"] = jnp.clip(jnp.exp( eval_results["loss"]), a_max=1.0e4) writer.write_scalars( step, {"eval_" + k: v for k, v in eval_results.items()}) with report_progress.timed("generate_text"): exemplars = generate_prediction( p_pred_step=p_pred_step, target=optimizer.target, tokenized_prompts=tokenized_prompts, eos_id=eos_id, inference_rng=inference_rng, decode_tokens=decode_tokens, max_predict_length=config.max_predict_length) writer.write_texts(step, {"samples": exemplars}) # Save a checkpoint on one host after every checkpoint_freq steps. save_checkpoint = (step % config.checkpoint_every_steps == 0 or is_last_step) if config.save_checkpoints and save_checkpoint and jax.host_id( ) == 0: with report_progress.timed("checkpoint"): checkpoints.save_checkpoint( workdir, jax_utils.unreplicate(optimizer), step)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') tf.enable_v2_behavior() config = FLAGS.config logging.info('===========Config Dict============') logging.info(config) batch_size = config.batch_size learning_rate = config.learning_rate num_train_steps = config.num_train_steps num_eval_steps = config.num_eval_steps eval_freq = config.eval_frequency random_seed = config.random_seed model_type = config.model_type max_length = config.max_length if jax.process_index() == 0: summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'summary')) if batch_size % jax.device_count() > 0: raise ValueError( 'Batch size must be divisible by the number of devices') train_ds, eval_ds, test_ds, encoder = input_pipeline.get_matching_datasets( n_devices=jax.local_device_count(), task_name=FLAGS.task_name, data_dir=FLAGS.data_dir, batch_size=batch_size, fixed_vocab=None, max_length=max_length, tokenizer=config.tokenizer, vocab_file_path=FLAGS.vocab_file_path) vocab_size = encoder.vocab_size logging.info('Vocab Size: %d', vocab_size) train_ds = train_ds.repeat() train_iter = iter(train_ds) input_shape = (batch_size, max_length) model_kwargs = { 'vocab_size': vocab_size, 'emb_dim': config.emb_dim, 'num_heads': config.num_heads, 'num_layers': config.num_layers, 'qkv_dim': config.qkv_dim, 'mlp_dim': config.mlp_dim, 'max_len': max_length, 'classifier': True, 'num_classes': 2, 'classifier_pool': config.pooling_mode } rng = random.PRNGKey(random_seed) rng = jax.random.fold_in(rng, jax.process_index()) rng, init_rng = random.split(rng) # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. dropout_rngs = random.split(rng, jax.local_device_count()) model = train_utils.get_model(model_type, create_model, model_kwargs, init_rng, input_shape) optimizer = create_optimizer(model, learning_rate, weight_decay=FLAGS.config.weight_decay) del model # Don't keep a copy of the initial model. start_step = 0 if config.restore_checkpoints or FLAGS.test_only: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer) # Grab last step. start_step = int(optimizer.state.step) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = train_utils.create_learning_rate_scheduler( factors=config.factors, base_learning_rate=learning_rate, warmup_steps=config.warmup) p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn), axis_name='batch') p_eval_step = jax.pmap(eval_step, axis_name='batch') # p_pred_step = jax.pmap(predict_step, axis_name='batch') def run_eval(eval_ds, num_eval_steps=-1): eval_metrics = [] eval_iter = iter(eval_ds) if num_eval_steps == -1: num_iter = itertools.count() else: num_iter = range(num_eval_steps) for _, eval_batch in zip(num_iter, eval_iter): # pylint: disable=protected-access eval_batch = common_utils.shard( jax.tree_map(lambda x: x._numpy(), eval_batch)) # pylint: enable=protected-access metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) # Calculate (clipped) perplexity after averaging log-perplexities: eval_summary['perplexity'] = jnp.clip(jnp.exp(eval_summary['loss']), a_max=1.0e4) return eval_summary if FLAGS.test_only: with tf.io.gfile.GFile(os.path.join(FLAGS.model_dir, 'results.json'), 'w') as f: test_summary = run_eval(test_ds) json.dump(jax.tree_map(lambda x: x.tolist(), test_summary), f) return metrics_all = [] tick = time.time() logging.info('Starting training') logging.info('====================') for step, batch in zip(range(start_step, num_train_steps), train_iter): batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access # logging.info(batch) optimizer, metrics, dropout_rngs = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) metrics_all.append(metrics) logging.info('train in step: %d', step) # Save a Checkpoint if ((step % config.checkpoint_freq == 0 and step > 0) or step == num_train_steps - 1): if jax.process_index() == 0 and config.save_checkpoints: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint(FLAGS.model_dir, jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if step % eval_freq == 0 and step > 0: metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) logging.info('train in step: %d, loss: %.4f, acc: %.4f', step, summary['loss'], summary['accuracy']) if jax.process_index() == 0: tock = time.time() steps_per_sec = eval_freq / (tock - tick) tick = tock summary_writer.scalar('steps per second', steps_per_sec, step) for key, val in summary.items(): summary_writer.scalar(f'train_{key}', val, step) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] # Eval Metrics eval_summary = run_eval(eval_ds, num_eval_steps) logging.info('eval in step: %d, loss: %.4f, acc: %.4f', step, eval_summary['loss'], eval_summary['accuracy']) if jax.process_index() == 0: for key, val in eval_summary.items(): summary_writer.scalar(f'eval_{key}', val, step) summary_writer.flush() # Test eval # Eval Metrics logging.info('Testing...') test_summary = run_eval(test_ds, num_eval_steps) logging.info('test in step: %d, loss: %.4f, acc: %.4f', step, test_summary['loss'], test_summary['accuracy']) if jax.process_index() == 0: for key, val in test_summary.items(): summary_writer.scalar(f'test_{key}', val, step) summary_writer.flush()
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') tf.enable_v2_behavior() # make sure tf does not allocate gpu memory tf.config.experimental.set_visible_devices([], 'GPU') if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(FLAGS.model_dir) rng = random.PRNGKey(0) image_size = 224 batch_size = FLAGS.batch_size if batch_size % jax.device_count() > 0: raise ValueError('Batch size must be divisible by the number of devices') local_batch_size = batch_size // jax.host_count() device_batch_size = batch_size // jax.device_count() platform = jax.local_devices()[0].platform if FLAGS.half_precision: if platform == 'tpu': model_dtype = jnp.bfloat16 input_dtype = tf.bfloat16 else: model_dtype = jnp.float16 input_dtype = tf.float16 else: model_dtype = jnp.float32 input_dtype = tf.float32 train_iter = create_input_iter( local_batch_size, image_size, input_dtype, train=True, cache=FLAGS.cache) eval_iter = create_input_iter( local_batch_size, image_size, input_dtype, train=False, cache=FLAGS.cache) num_epochs = FLAGS.num_epochs steps_per_epoch = input_pipeline.TRAIN_IMAGES // batch_size steps_per_eval = input_pipeline.EVAL_IMAGES // batch_size steps_per_checkpoint = steps_per_epoch * 10 num_steps = steps_per_epoch * num_epochs base_learning_rate = FLAGS.learning_rate * batch_size / 256. base_learning_rate = base_learning_rate / FLAGS.loss_scaling model, model_state = create_model( rng, device_batch_size, image_size, model_dtype) optimizer = optim.Momentum(beta=FLAGS.momentum, nesterov=True).create(model) state = TrainState(step=0, optimizer=optimizer, model_state=model_state) del model, model_state # do not keep a copy of the initial model state = restore_checkpoint(state) step_offset = int(state.step) # step_offset > 0 if restarting from checkpoint state = jax_utils.replicate(state) learning_rate_fn = create_learning_rate_fn( base_learning_rate, steps_per_epoch, num_epochs) p_train_step = jax.pmap( functools.partial(train_step, learning_rate_fn=learning_rate_fn), axis_name='batch') p_eval_step = jax.pmap(eval_step, axis_name='batch') epoch_metrics = [] t_loop_start = time.time() for step, batch in zip(range(step_offset, num_steps), train_iter): state, metrics = p_train_step(state, batch) epoch_metrics.append(metrics) if (step + 1) % steps_per_epoch == 0: epoch = step // steps_per_epoch epoch_metrics = common_utils.get_metrics(epoch_metrics) summary = jax.tree_map(lambda x: x.mean(), epoch_metrics) logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) steps_per_sec = steps_per_epoch / (time.time() - t_loop_start) t_loop_start = time.time() if jax.host_id() == 0: for key, vals in epoch_metrics.items(): tag = 'train_%s' % key for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) summary_writer.scalar('steps per second', steps_per_sec, step) epoch_metrics = [] eval_metrics = [] # sync batch statistics across replicas state = sync_batch_stats(state) for _ in range(steps_per_eval): eval_batch = next(eval_iter) metrics = p_eval_step(state, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) summary = jax.tree_map(lambda x: x.mean(), eval_metrics) logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) if jax.host_id() == 0: for key, val in eval_metrics.items(): tag = 'eval_%s' % key summary_writer.scalar(tag, val.mean(), step) summary_writer.flush() if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps: state = sync_batch_stats(state) save_checkpoint(state) # Wait until computations are done before exiting jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
def main(_): tf.enable_v2_behavior() tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) if not gfile.isdir(FLAGS.save_dir): gfile.mkdir(FLAGS.save_dir) hparam_str_dict = dict(seed=FLAGS.seed, lr=FLAGS.lr) # Get hyperparmaters if FLAGS.xm_parameters: for key, value in json.loads(FLAGS.xm_parameters).items(): if key not in hparam_str_dict: hparam_str_dict[key] = value hparam_str = ','.join([ '%s=%s' % (k, str(hparam_str_dict[k])) for k in sorted(hparam_str_dict.keys()) ]) # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) batch_size = FLAGS.per_device_batch_size * n_devices io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.max_characters) program_shape = (FLAGS.per_device_batch_size, FLAGS.max_program_length) # Setup DSL # --------------------------------------------------------------------------- # Build token tables. id_char_table = {i + 1: char for (i, char) in enumerate(dsl.CHARACTER)} char_id_table = {char: id for id, char in id_char_table.items()} id_token_table, token_id_table = dsl_tokens.build_token_tables() io_vocab_size = len(char_id_table) + 1 # For padding. program_vocab_size = len(token_id_table) + 1 bos_token = token_id_table[dsl.BOS] eos_token = token_id_table[dsl.EOS] def decode_io(inputs, outputs): """Decode io examples tokens.""" def decode_str(s): """Decode string tokens.""" return ''.join([id_char_table[c_id] for c_id in s if c_id > 0]) io_string = '' inps, outs = [], [] for inp, out in zip(inputs, outputs): inps.append(decode_str(inp)) outs.append(decode_str(out)) io_string += inps[-1] + ' < ' + outs[-1] + ' > ' return inps, outs, io_string[:-3] # Remove last separator. def decode_program(program): """Decode program tokens.""" program = program[:np.argmax(program == eos_token) + 1].astype( np.int32) try: p = dsl.decode_program(program, id_token_table) return p, p.to_string() except: # pylint: disable=bare-except return None, '' # Program does not compile. # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') if not FLAGS.dataset_filepattern: raise ValueError('Must specify filepattern to dataset.') # Training dataset. dataset = input_pipeline.create_dataset_from_tf_record( FLAGS.dataset_filepattern, token_id_table, char_id_table) dataset = dataset.padded_batch(batch_size, padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:]), drop_remainder=True) # Split evaluation and training. eval_ds = dataset.take(FLAGS.num_eval_steps) # Decrease batch of predict dataset to handle beam search. predict_ds = eval_ds.unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:])) train_ds = dataset.skip(FLAGS.num_eval_steps).repeat() train_iter = train_ds.as_numpy_iterator() # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = models.LatentTransformerConfig( vocab_size=io_vocab_size, output_vocab_size=program_vocab_size, latent_vocab_size=FLAGS.latent_vocab_size, shift=True, emb_dim=FLAGS.embedding_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.embedding_dim, mlp_dim=FLAGS.hidden_dim, max_len=max(FLAGS.max_characters, FLAGS.max_program_length), deterministic=False, decode=False, c=FLAGS.c, train_vq=True, commitment_cost_vq=FLAGS.commitment_cost_vq, bos_token=bos_token) eval_config = train_config.replace(deterministic=True, train_vq=False) predict_config = train_config.replace(shift=False, deterministic=True, train_vq=False, decode=True) # Latent Predictor. lp_train_config = models.TransformerConfig( vocab_size=io_vocab_size, output_vocab_size=FLAGS.latent_vocab_size, shift=True, emb_dim=FLAGS.embedding_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.embedding_dim, mlp_dim=FLAGS.hidden_dim, max_len=max(FLAGS.max_characters, FLAGS.max_program_length), deterministic=False, decode=False, bos_token=bos_token) lp_eval_config = lp_train_config.replace(deterministic=True) lp_predict_config = lp_train_config.replace(shift=False, deterministic=True, decode=True) rng = jax.random.PRNGKey(0) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = jax.random.split(rng) m = models.LatentProgramTransformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(program_shape, jnp.float32)) lp_m = models.ProgramTransformer(lp_eval_config) lp_initial_variables = jax.jit(lp_m.init)(init_rng, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(program_shape, jnp.float32)) optimizer_def = optim.Adam(FLAGS.lr, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(initial_variables['params']) lp_optimizer = optimizer_def.create(lp_initial_variables['params']) state = TrainState(step=0, optimizer=optimizer, model_state=initial_variables['vqvae'], lp_optimizer=lp_optimizer) # Don't keep a copy of the initial model. del initial_variables, lp_initial_variables train_rngs = jax.random.split(rng, jax.local_device_count()) start_step = 0 if FLAGS.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. state = checkpoints.restore_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), state) # Grab last step. start_step = int(state.step) logging.info('Found model checkpointed at step %d.', start_step) state = jax_utils.replicate(state) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=FLAGS.lr) p_train_step = jax.pmap(functools.partial( train_step, bos_token=bos_token, eos_token=eos_token, learning_rate_fn=learning_rate_fn, config=train_config, lp_config=lp_train_config), axis_name='batch', static_broadcasted_argnums=(4, )) p_eval_step = jax.pmap(functools.partial(eval_step, bos_token=bos_token, eos_token=eos_token, config=eval_config, lp_config=lp_eval_config), axis_name='batch') p_init_cache = jax.pmap(functools.partial( initialize_cache, max_decode_len=FLAGS.max_program_length, config=predict_config, lp_config=lp_predict_config), axis_name='batch') p_pred_step = jax.pmap(functools.partial( predict_step, bos_token=bos_token, eos_token=eos_token, max_decode_len=FLAGS.max_program_length, config=predict_config, lp_config=lp_predict_config), axis_name='batch', static_broadcasted_argnums=(5, )) metrics_all = [] latent_metrics_all = [] tick = time.time() for step in range(start_step, FLAGS.num_train_steps): inputs, outputs, programs = common_utils.shard(next(train_iter)) state, metrics, latent_metrics, train_rngs = p_train_step( state, inputs, outputs, programs, step <= FLAGS.num_pretrain_steps, train_rng=train_rngs) metrics, latent_metrics = jax.tree_map(np.array, (metrics, latent_metrics)) metrics_all.append(metrics) latent_metrics_all.append(latent_metrics) # Save a Checkpoint if ((step % FLAGS.checkpoint_freq == 0 and step > 0) or step == FLAGS.num_train_steps - 1): if jax.host_id() == 0: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), jax_utils.unreplicate(state), step) # Periodic metric handling. if not step or step % FLAGS.log_freq != 0: continue logging.info('Gathering training metrics.') # Training Metrics metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map( lambda x: x / denominator, # pylint: disable=cell-var-from-loop metrics_sums) summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) latent_metrics_all = common_utils.get_metrics(latent_metrics_all) metrics_sums = jax.tree_map(jnp.sum, latent_metrics_all) denominator = metrics_sums.pop('denominator') summary.update( jax.tree_map( lambda x: x / denominator, # pylint: disable=cell-var-from-loop metrics_sums)) if jax.host_id() == 0: logging.info('Train in step: %d, loss: %.4f, acc: %.4f', step, summary['loss'], summary['accuracy']) tock = time.time() steps_per_sec = FLAGS.log_freq / (tock - tick) tick = tock summary_writer.scalar('train/steps per second', steps_per_sec, step) for key, val in summary.items(): summary_writer.scalar('train/' + key, val, step) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] latent_metrics_all = [] # Evaluation Metrics logging.info('Gathering evaluation metrics.') t_evaluation_start = time.time() eval_metrics = [] latent_eval_metrics = [] for batches in eval_ds.as_numpy_iterator(): inputs, outputs, programs = common_utils.shard(batches) all_metrics = p_eval_step(state, inputs, outputs, programs) metrics, latent_metrics = jax.tree_map(np.array, all_metrics) eval_metrics.append(metrics) latent_eval_metrics.append(latent_metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) latent_eval_metrics = common_utils.get_metrics(latent_eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, latent_eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary.update( jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums)) if jax.host_id() == 0: logging.info('Evaluation time: %.4f s step %d, loss: %.4f', time.time() - t_evaluation_start, step, eval_summary['loss']) for key, val in eval_summary.items(): summary_writer.scalar('eval/' + key, val, step) summary_writer.flush() # Beam search metrics. logging.info('Gathering beam search metrics.') for beam_size in [10, 50, 100]: t_inference_start = time.time() pred_acc = 0 pred_denominator = 0 ios, targets, predictions, latent_predictions = [], [], [], [] for batches in predict_ds.as_numpy_iterator(): pred_batch = batches # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch[0].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) pred_batch = jax.tree_map( lambda x: pad_examples(x, padded_size), pred_batch) # pylint: disable=cell-var-from-loop inputs, outputs, programs = common_utils.shard(pred_batch) cache, lp_cache = p_init_cache(inputs, outputs, programs) predicted, latent_predicted = p_pred_step( state, inputs, outputs, cache, lp_cache, beam_size) predicted, latent_predicted = map( tohost, (predicted, latent_predicted)) inputs, outputs, programs = map(tohost, (inputs, outputs, programs)) pred_denominator += programs.shape[0] for i, beams in enumerate(predicted): inps, outs, io_string = decode_io(inputs[i], outputs[i]) p, p_idx, p_score = eval_predicted( beams, inps, outs, parse_beam_fn=lambda x: decode_program(x)[0]) if p_score >= len(inps): pred_acc += 1 ios.append(io_string) targets.append(decode_program(programs[i])[1]) predictions.append(p.to_string() if p else '') latent_predictions.append(' '.join( list(np.array(latent_predicted[i, p_idx]).astype(str)))) all_pred_acc, all_pred_denominator = per_host_sum_pmap( jax.tree_map(np.array, (pred_acc, pred_denominator))) # Record beam search results as text summaries. message = [] for n in np.random.choice(np.arange(len(predictions)), 8): text = (f'ios: {ios[n]}\n\ntarget: {targets[n]}\n\n' f'predicted: {predictions[n]}\n\n' f'latent_predicted: {latent_predictions[n]}\n\n') message.append(text) # Write to tensorboard. if jax.host_id() == 0: logging.info( 'Prediction time (beam %d): %.4f s step %d, score %.4f.', beam_size, time.time() - t_inference_start, step, all_pred_acc / all_pred_denominator) summary_writer.scalar('predict/score-{}'.format(beam_size), all_pred_acc / all_pred_denominator, step) summary_writer.text('samples-{}'.format(beam_size), '\n------\n'.join(message), step) summary_writer.flush()
def replicate(self): return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
beta1=training_args.adam_beta1, beta2=training_args.adam_beta2, ).create(model.params) # Create learning rate scheduler # warmup_steps = 0 causes the Flax optimizer to return NaNs; warmup_steps = 1 is functionally equivalent. lr_scheduler_fn = create_learning_rate_scheduler( base_learning_rate=training_args.learning_rate, warmup_steps=max(training_args.warmup_steps, 1)) # Create parallel version of the training and evaluation steps p_training_step = jax.pmap(training_step, "batch", donate_argnums=(0, )) p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0, )) # Replicate the optimizer on each device optimizer = jax_utils.replicate(optimizer) # Store some constant nb_epochs = int(training_args.num_train_epochs) batch_size = int( training_args.per_device_train_batch_size) * jax.device_count() eval_batch_size = int( training_args.per_device_eval_batch_size) * jax.device_count() epochs = tqdm(range(nb_epochs), desc=f"Epoch ... (1/{nb_epochs})", position=0) for epoch in epochs: # ======================== Training ================================ # Create sampling rng
def main(config, output_dir): seed = config.get('seed', 0) rng = jax.random.PRNGKey(seed) tf.random.set_seed(seed) if config.get('data_dir'): logging.info('data_dir=%s', config.data_dir) logging.info('Output dir: %s', output_dir) save_checkpoint_path = None if config.get('checkpoint_steps'): gfile.makedirs(output_dir) save_checkpoint_path = os.path.join(output_dir, 'checkpoint.npz') # Create an asynchronous multi-metric writer. writer = metric_writers.create_default_writer( output_dir, just_logging=jax.process_index() > 0) # The pool is used to perform misc operations such as logging in async way. pool = multiprocessing.pool.ThreadPool() def write_note(note): if jax.process_index() == 0: logging.info('NOTE: %s', note) write_note('Initializing...') # Verify settings to make sure no checkpoints are accidentally missed. if config.get('keep_checkpoint_steps'): assert config.get('checkpoint_steps'), 'Specify `checkpoint_steps`.' assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, ( f'`keep_checkpoint_steps` ({config.checkpoint_steps}) should be' f'divisible by `checkpoint_steps ({config.checkpoint_steps}).`') batch_size = config.batch_size batch_size_eval = config.get('batch_size_eval', batch_size) if (batch_size % jax.device_count() != 0 or batch_size_eval % jax.device_count() != 0): raise ValueError( f'Batch sizes ({batch_size} and {batch_size_eval}) must ' f'be divisible by device number ({jax.device_count()})') local_batch_size = batch_size // jax.process_count() local_batch_size_eval = batch_size_eval // jax.process_count() logging.info( 'Global batch size %d on %d hosts results in %d local batch size. ' 'With %d dev per host (%d dev total), that is a %d per-device batch size.', batch_size, jax.process_count(), local_batch_size, jax.local_device_count(), jax.device_count(), local_batch_size // jax.local_device_count()) write_note('Initializing train dataset...') rng, train_ds_rng = jax.random.split(rng) train_ds_rng = jax.random.fold_in(train_ds_rng, jax.process_index()) train_ds = input_utils.get_data( dataset=config.dataset, split=config.train_split, rng=train_ds_rng, process_batch_size=local_batch_size, preprocess_fn=preprocess_spec.parse( spec=config.pp_train, available_ops=preprocess_utils.all_ops()), shuffle_buffer_size=config.shuffle_buffer_size, prefetch_size=config.get('prefetch_to_host', 2), data_dir=config.get('data_dir')) logging.info('image_size = %s', train_ds.element_spec['image'].shape[1:]) # Start prefetching already. train_iter = input_utils.start_input_pipeline( train_ds, config.get('prefetch_to_device', 1)) write_note('Initializing val dataset(s)...') def _get_val_split(dataset, split, pp_eval, data_dir=None): # We do ceil rounding such that we include the last incomplete batch. nval_img = input_utils.get_num_examples( dataset, split=split, process_batch_size=local_batch_size_eval, drop_remainder=False, data_dir=data_dir) val_steps = int(np.ceil(nval_img / batch_size_eval)) logging.info('Running validation for %d steps for %s, %s', val_steps, dataset, split) if isinstance(pp_eval, str): pp_eval = preprocess_spec.parse( spec=pp_eval, available_ops=preprocess_utils.all_ops()) val_ds = input_utils.get_data(dataset=dataset, split=split, rng=None, process_batch_size=local_batch_size_eval, preprocess_fn=pp_eval, cache=config.get('val_cache', 'batched'), num_epochs=1, repeat_after_batching=True, shuffle=False, prefetch_size=config.get( 'prefetch_to_host', 2), drop_remainder=False, data_dir=data_dir) return val_ds val_ds_splits = { 'val': _get_val_split(config.dataset, config.val_split, config.pp_eval, config.get('data_dir')) } if config.get('test_split'): val_ds_splits.update({ 'test': _get_val_split(config.dataset, split=config.test_split, pp_eval=config.pp_eval, data_dir=config.get('data_dir')) }) if config.get('eval_on_cifar_10h'): cifar10_to_cifar10h_fn = data_uncertainty_utils.create_cifar10_to_cifar10h_fn( config.get('data_dir', None)) preprocess_fn = preprocess_spec.parse( spec=config.pp_eval_cifar_10h, available_ops=preprocess_utils.all_ops()) pp_eval = lambda ex: preprocess_fn(cifar10_to_cifar10h_fn(ex)) val_ds_splits['cifar_10h'] = _get_val_split( 'cifar10', split=config.get('cifar_10h_split') or 'test', pp_eval=pp_eval, data_dir=config.get('data_dir')) elif config.get('eval_on_imagenet_real'): imagenet_to_real_fn = data_uncertainty_utils.create_imagenet_to_real_fn( ) preprocess_fn = preprocess_spec.parse( spec=config.pp_eval_imagenet_real, available_ops=preprocess_utils.all_ops()) pp_eval = lambda ex: preprocess_fn(imagenet_to_real_fn(ex)) val_ds_splits['imagenet_real'] = _get_val_split( 'imagenet2012_real', split=config.get('imagenet_real_split') or 'validation', pp_eval=pp_eval, data_dir=config.get('data_dir')) ood_ds = {} if config.get('ood_datasets') and config.get('ood_methods'): if config.get( 'ood_methods'): # config.ood_methods is not a empty list logging.info('loading OOD dataset = %s', config.get('ood_datasets')) ood_ds, ood_ds_names = ood_utils.load_ood_datasets( config.dataset, config.ood_datasets, config.ood_split, config.pp_eval, config.pp_eval_ood, config.ood_methods, config.train_split, config.get('data_dir'), _get_val_split, ) ntrain_img = input_utils.get_num_examples( config.dataset, split=config.train_split, process_batch_size=local_batch_size, data_dir=config.get('data_dir')) steps_per_epoch = int(ntrain_img / batch_size) if config.get('num_epochs'): total_steps = int(config.num_epochs * steps_per_epoch) assert not config.get( 'total_steps'), 'Set either num_epochs or total_steps' else: total_steps = config.total_steps logging.info('Total train data points: %d', ntrain_img) logging.info( 'Running for %d steps, that means %f epochs and %d steps per epoch', total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch) write_note('Initializing model...') logging.info('config.model = %s', config.get('model')) # Specify Gaussian process layer configs. use_gp_layer = config.get('use_gp_layer', True) gp_config = config.get('gp_layer', {}) gp_layer_kwargs = get_gp_kwargs(gp_config) # Process ViT backbone model configs. vit_kwargs = config.get('model') model = ub.models.vision_transformer_gp(num_classes=config.num_classes, use_gp_layer=use_gp_layer, vit_kwargs=vit_kwargs, gp_layer_kwargs=gp_layer_kwargs) # We want all parameters to be created in host RAM, not on any device, they'll # be sent there later as needed, otherwise we already encountered two # situations where we allocate them twice. @partial(jax.jit, backend='cpu') def init(rng): image_size = tuple(train_ds.element_spec['image'].shape[2:]) logging.info('image_size = %s', image_size) dummy_input = jnp.zeros((local_batch_size, ) + image_size, jnp.float32) variables = model.init(rng, dummy_input, train=False) # Split model parameters into trainable and untrainable collections. states, params = variables.pop('params') del variables # Set bias in the head to a low value, such that loss is small initially. params = flax.core.unfreeze(params) if use_gp_layer: # Modify the head parameter in the GP head. params['head']['output_layer']['bias'] = jnp.full_like( params['head']['output_layer']['bias'], config.get('init_head_bias', 0)) else: params['head']['bias'] = jnp.full_like( params['head']['bias'], config.get('init_head_bias', 0)) return params, states rng, rng_init = jax.random.split(rng) params_cpu, states_cpu = init(rng_init) if jax.process_index() == 0: num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) writer.write_scalars(step=0, scalars={'num_params': num_params}) @partial(jax.pmap, axis_name='batch') def evaluation_fn(params, states, images, labels, mask): # Ignore the entries with all zero labels for evaluation. mask *= labels.max(axis=1) variable_dict = {'params': flax.core.freeze(params), **states} logits, out = model.apply(variable_dict, images, train=False, mean_field_factor=gp_config.get( 'mean_field_factor', -1.)) # Note that logits and labels are usually of the shape [batch,num_classes]. # But for OOD data, when num_classes_ood > num_classes_ind, we need to # adjust labels to labels[:, :config.num_classes] to match the shape of # logits. That is just to avoid shape mismatch. The output losses does not # have any meaning for OOD data, because OOD not belong to any IND class. losses = getattr(train_utils, config.get('loss', 'sigmoid_xent'))( logits=logits, labels=labels[:, :config.num_classes], reduction=False) loss = jax.lax.psum(losses * mask, axis_name='batch') top1_idx = jnp.argmax(logits, axis=1) # Extracts the label at the highest logit index for each image. top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0] ncorrect = jax.lax.psum(top1_correct * mask, axis_name='batch') n = jax.lax.psum(mask, axis_name='batch') metric_args = jax.lax.all_gather( [logits, labels, out['pre_logits'], mask], axis_name='batch') return ncorrect, loss, n, metric_args @partial(jax.pmap, axis_name='batch') def cifar_10h_evaluation_fn(params, states, images, labels, mask): variable_dict = {'params': flax.core.freeze(params), **states} logits, out = model.apply(variable_dict, images, train=False, mean_field_factor=gp_config.get( 'mean_field_factor', -1.)) losses = getattr(train_utils, config.get('loss', 'softmax_xent'))(logits=logits, labels=labels, reduction=False) loss = jax.lax.psum(losses, axis_name='batch') top1_idx = jnp.argmax(logits, axis=1) # Extracts the label at the highest logit index for each image. one_hot_labels = jnp.eye(10)[jnp.argmax(labels, axis=1)] top1_correct = jnp.take_along_axis(one_hot_labels, top1_idx[:, None], axis=1)[:, 0] ncorrect = jax.lax.psum(top1_correct, axis_name='batch') n = jax.lax.psum(one_hot_labels, axis_name='batch') metric_args = jax.lax.all_gather( [logits, labels, out['pre_logits'], mask], axis_name='batch') return ncorrect, loss, n, metric_args # Setup function for computing representation. @partial(jax.pmap, axis_name='batch') def representation_fn(params, images, labels, mask, states): variable_dict = {'params': flax.core.freeze(params), **states} _, outputs = model.apply(variable_dict, images, train=False, mean_field_factor=gp_config.get( 'mean_field_factor', -1.)) representation = outputs[config.fewshot.representation_layer] representation = jax.lax.all_gather(representation, 'batch') labels = jax.lax.all_gather(labels, 'batch') mask = jax.lax.all_gather(mask, 'batch') return representation, labels, mask # Load the optimizer from flax. opt_name = config.get('optim_name') write_note(f'Initializing {opt_name} optimizer...') opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {})) # We jit this, such that the arrays that are created are created on the same # device as the input is, in this case the CPU. Else they'd be on device[0]. opt_cpu = jax.jit(opt_def.create)(params_cpu) weight_decay_rules = config.get('weight_decay', []) or [] rescale_value = config.lr.base if config.get( 'weight_decay_decouple') else 1. weight_decay_fn = train_utils.get_weight_decay_fn( weight_decay_rules=weight_decay_rules, rescale_value=rescale_value) @partial(jax.pmap, axis_name='batch', donate_argnums=(0, )) def update_fn(opt, states, lr, reset_covmat, images, labels, rng): """Update step.""" measurements = {} # Get device-specific loss rng. rng, rng_model = jax.random.split(rng, 2) rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index('batch')) def loss_fn(params, states, images, labels): # Specify mutable collection to update untrainable GP parameters. variable_dict = {'params': flax.core.freeze(params), **states} model_results, updated_states = model.apply( variable_dict, images, train=True, rngs={'dropout': rng_model_local}, mutable=list(states.keys()), mean_field_factor=gp_config.get('mean_field_factor', -1.)) logits, _ = model_results loss = getattr(train_utils, config.get('loss', 'sigmoid_xent'))(logits=logits, labels=labels) return loss, updated_states # Performs exact covariance update (i.e., reset precision matrix resetting # at begining of new epoch) if covmat_momentum is a null value. if use_gp_layer and gp_config.get('covmat_momentum', -1.) < 0: # Resets precision matrix to Identity * ridge_penalty if at the begining # of a new epoch. This should be done before accumulate gradient. ridge_penalty = gp_config.get('ridge_penalty', 1.) prec_mat_old = states['laplace_covariance']['head'][ 'covmat_layer']['precision_matrix'] prec_mat_new = ( (1. - reset_covmat) * prec_mat_old + reset_covmat * jnp.eye(prec_mat_old.shape[0]) * ridge_penalty) states = flax.core.unfreeze(states) states['laplace_covariance']['head']['covmat_layer'][ 'precision_matrix'] = prec_mat_new states = flax.core.freeze(states) # Implementation considerations compared and summarized at # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en# (l, s), g = train_utils.accumulate_gradient_with_states( jax.value_and_grad(loss_fn, has_aux=True), opt.target, states, images, labels, config.get('grad_accum_steps')) l, g = jax.lax.pmean((l, g), axis_name='batch') # Log the gradient norm only if we need to compute it anyways (clipping) # or if we don't use grad_accum_steps, as they interact badly. do_grad_clip = config.get('grad_clip_norm', -1.) > 0. if config.get('grad_accum_steps', 1) == 1 or do_grad_clip: grads, _ = jax.tree_flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements['l2_grads'] = l2_g # Optionally resize the global gradient to a maximum norm. We found this # useful in some cases across optimizers, hence it's in the main loop. if do_grad_clip: g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g) g = jax.tree_map(lambda p: g_factor * p, g) opt = opt.apply_gradient(g, learning_rate=lr) opt = opt.replace(target=weight_decay_fn(opt.target, lr)) params, _ = jax.tree_flatten(opt.target) measurements['l2_params'] = jnp.sqrt( sum([jnp.vdot(p, p) for p in params])) measurements['reset_covmat'] = reset_covmat return opt, s, l, rng, measurements default_reinit_params = ('head/output_layer/kernel', 'head/output_layer/bias', 'head/kernel', 'head/bias') rng, train_loop_rngs = jax.random.split(rng) checkpoint_data = checkpoint_utils.maybe_load_checkpoint( train_loop_rngs=train_loop_rngs, save_checkpoint_path=save_checkpoint_path, init_optimizer=opt_cpu, init_params=params_cpu, init_fixed_model_states=states_cpu, default_reinit_params=default_reinit_params, config=config) train_loop_rngs = checkpoint_data.train_loop_rngs opt_cpu = checkpoint_data.optimizer states_cpu = checkpoint_data.fixed_model_states accumulated_train_time = checkpoint_data.accumulated_train_time write_note('Adapting the checkpoint model...') adapted_params = checkpoint_utils.adapt_upstream_architecture( init_params=params_cpu, loaded_params=opt_cpu.target) opt_cpu = opt_cpu.replace(target=adapted_params) write_note('Kicking off misc stuff...') first_step = int(opt_cpu.state.step) # Might be a DeviceArray type. if first_step == 0 and jax.process_index() == 0: writer.write_hparams(dict(config)) chrono = train_utils.Chrono(first_step, total_steps, batch_size, accumulated_train_time) # Note: switch to ProfileAllHosts() if you need to profile all hosts. # (Xprof data become much larger and take longer to load for analysis) profiler = periodic_actions.Profile( # Create profile after every restart to analyze pre-emption related # problems and assure we get similar performance in every run. logdir=output_dir, first_profile=first_step + 10) # Prepare the learning-rate and pre-fetch it to device to avoid delays. lr_fn = train_utils.create_learning_rate_schedule(total_steps, **config.get('lr', {})) # TODO(dusenberrymw): According to flax docs, prefetching shouldn't be # necessary for TPUs. lr_iter = train_utils.prefetch_scalar(map(lr_fn, range(total_steps)), config.get('prefetch_to_device', 1)) # Prepare the precision matrix resetting schedule, and pre-fetch it to device. reset_covmat_fn = lambda step: float(step % steps_per_epoch == 0) reset_covmat_iter = train_utils.prefetch_scalar( map(reset_covmat_fn, range(first_step, total_steps)), nprefetch=config.get('prefetch_to_device', 1)) write_note(f'Replicating...\n{chrono.note}') opt_repl = flax_utils.replicate(opt_cpu) states_repl = flax_utils.replicate(states_cpu) write_note(f'Initializing few-shotters...\n{chrono.note}') fewshotter = None if 'fewshot' in config and fewshot is not None: fewshotter = fewshot.FewShotEvaluator( representation_fn, config.fewshot, config.fewshot.get('batch_size') or batch_size_eval) checkpoint_writer = None # Note: we return the train loss, val loss, and fewshot best l2s for use in # reproducibility unit tests. train_loss = -jnp.inf val_loss = {val_name: -jnp.inf for val_name, _ in val_ds_splits.items()} fewshot_results = {'dummy': {(0, 1): -jnp.inf}} write_note(f'First step compilations...\n{chrono.note}') logging.info('first_step = %s', first_step) # Advance the iterators if we are restarting from an earlier checkpoint. # TODO(dusenberrymw): Look into checkpointing dataset state instead. # Makes sure log_eval_steps is same as steps_per_epoch. This is because # the precision matrix needs to be updated fully (at the end of each epoch) # when eval takes place. log_eval_steps = steps_per_epoch if first_step > 0: write_note('Advancing iterators after resuming from a checkpoint...') lr_iter = itertools.islice(lr_iter, first_step, None) train_iter = itertools.islice(train_iter, first_step, None) # Using a python integer for step here, because opt.state.step is allocated # on TPU during replication. for step, train_batch, lr_repl, reset_covmat_repl in zip( range(first_step + 1, total_steps + 1), train_iter, lr_iter, reset_covmat_iter): with jax.profiler.TraceAnnotation('train_step', step_num=step, _r=1): # TODO(jereliu): Expand to allow precision matrix resetting. (opt_repl, states_repl, loss_value, train_loop_rngs, extra_measurements) = update_fn(opt_repl, states_repl, lr_repl, reset_covmat_repl, train_batch['image'], train_batch['labels'], rng=train_loop_rngs) if jax.process_index() == 0: profiler(step) # Checkpoint saving if train_utils.itstime(step, config.get('checkpoint_steps'), total_steps, process=0): write_note('Checkpointing...') chrono.pause() train_utils.checkpointing_timeout( checkpoint_writer, config.get('checkpoint_timeout', 1)) accumulated_train_time = chrono.accum_train_time # We need to transfer the weights over now or else we risk keeping them # alive while they'll be updated in a future step, creating hard to debug # memory errors (see b/160593526). Also, takes device 0's params only. # For GP layer, we will also do the same for untrainable parameters # (`states`). This is ok since `random features` are frozen throughout # pre-training, and `precision matrix` is a finetuning-specific parameters # that will be re-learned in the finetuning task. opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl) states_cpu = jax.tree_map(lambda x: np.array(x[0]), states_repl) # Check whether we want to keep a copy of the current checkpoint. copy_step = None if train_utils.itstime(step, config.get('keep_checkpoint_steps'), total_steps): write_note('Keeping a checkpoint copy...') copy_step = step # Checkpoint should be a nested dictionary or FLAX datataclasses from # `flax.struct`. Both can be present in a checkpoint. checkpoint_data = checkpoint_utils.CheckpointData( optimizer=opt_cpu, fixed_model_states=states_cpu, train_loop_rngs=train_loop_rngs, accumulated_train_time=accumulated_train_time) checkpoint_writer = pool.apply_async( checkpoint_utils.checkpoint_trained_model, (checkpoint_data, save_checkpoint_path, copy_step)) chrono.resume() # Report training progress if train_utils.itstime(step, config.log_training_steps, total_steps, process=0): write_note('Reporting training progress...') train_loss = loss_value[ 0] # Keep to return for reproducibility tests. timing_measurements, note = chrono.tick(step) write_note(note) train_measurements = {} train_measurements.update({ 'learning_rate': lr_repl[0], 'training_loss': train_loss, }) train_measurements.update( flax.jax_utils.unreplicate(extra_measurements)) train_measurements.update(timing_measurements) writer.write_scalars(step, train_measurements) # Report validation performance if train_utils.itstime(step, log_eval_steps, total_steps): write_note('Evaluating on the validation set...') chrono.pause() for val_name, val_ds in val_ds_splits.items(): # Sets up evaluation metrics. ece_num_bins = config.get('ece_num_bins', 15) auc_num_bins = config.get('auc_num_bins', 1000) ece = rm.metrics.ExpectedCalibrationError( num_bins=ece_num_bins) calib_auc = rm.metrics.CalibrationAUC( correct_pred_as_pos_label=False) # TODO(jereliu): Extend to support soft multi-class probabilities. oc_auc_0_5 = rm.metrics.OracleCollaborativeAUC( oracle_fraction=0.005, num_bins=auc_num_bins) oc_auc_1 = rm.metrics.OracleCollaborativeAUC( oracle_fraction=0.01, num_bins=auc_num_bins) oc_auc_2 = rm.metrics.OracleCollaborativeAUC( oracle_fraction=0.02, num_bins=auc_num_bins) oc_auc_5 = rm.metrics.OracleCollaborativeAUC( oracle_fraction=0.05, num_bins=auc_num_bins) label_diversity = tf.keras.metrics.Mean() sample_diversity = tf.keras.metrics.Mean() ged = tf.keras.metrics.Mean() # Runs evaluation loop. val_iter = input_utils.start_input_pipeline( val_ds, config.get('prefetch_to_device', 1)) ncorrect, loss, nseen = 0, 0, 0 for batch in val_iter: if val_name == 'cifar_10h': batch_ncorrect, batch_losses, batch_n, batch_metric_args = ( cifar_10h_evaluation_fn(opt_repl.target, states_repl, batch['image'], batch['labels'], batch['mask'])) else: batch_ncorrect, batch_losses, batch_n, batch_metric_args = ( evaluation_fn(opt_repl.target, states_repl, batch['image'], batch['labels'], batch['mask'])) # All results are a replicated array shaped as follows: # (local_devices, per_device_batch_size, elem_shape...) # with each local device's entry being identical as they got psum'd. # So let's just take the first one to the host as numpy. ncorrect += np.sum(np.array(batch_ncorrect[0])) loss += np.sum(np.array(batch_losses[0])) nseen += np.sum(np.array(batch_n[0])) if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent': # Here we parse batch_metric_args to compute uncertainty metrics. # (e.g., ECE or Calibration AUC). logits, labels, _, masks = batch_metric_args masks = np.array(masks[0], dtype=np.bool) logits = np.array(logits[0]) probs = jax.nn.softmax(logits) # From one-hot to integer labels, as required by ECE. int_labels = np.argmax(np.array(labels[0]), axis=-1) int_preds = np.argmax(logits, axis=-1) confidence = np.max(probs, axis=-1) for p, c, l, d, m, label in zip( probs, confidence, int_labels, int_preds, masks, labels[0]): ece.add_batch(p[m, :], label=l[m]) calib_auc.add_batch(d[m], label=l[m], confidence=c[m]) oc_auc_0_5.add_batch(d[m], label=l[m], custom_binning_score=c[m]) oc_auc_1.add_batch(d[m], label=l[m], custom_binning_score=c[m]) oc_auc_2.add_batch(d[m], label=l[m], custom_binning_score=c[m]) oc_auc_5.add_batch(d[m], label=l[m], custom_binning_score=c[m]) if val_name == 'cifar_10h' or val_name == 'imagenet_real': batch_label_diversity, batch_sample_diversity, batch_ged = data_uncertainty_utils.generalized_energy_distance( label[m], p[m, :], config.num_classes) label_diversity.update_state( batch_label_diversity) sample_diversity.update_state( batch_sample_diversity) ged.update_state(batch_ged) val_loss[ val_name] = loss / nseen # Keep for reproducibility tests. val_measurements = { f'{val_name}_prec@1': ncorrect / nseen, f'{val_name}_loss': val_loss[val_name] } if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent': val_measurements[f'{val_name}_ece'] = ece.result()['ece'] val_measurements[ f'{val_name}_calib_auc'] = calib_auc.result( )['calibration_auc'] val_measurements[ f'{val_name}_oc_auc_0.5%'] = oc_auc_0_5.result( )['collaborative_auc'] val_measurements[ f'{val_name}_oc_auc_1%'] = oc_auc_1.result( )['collaborative_auc'] val_measurements[ f'{val_name}_oc_auc_2%'] = oc_auc_2.result( )['collaborative_auc'] val_measurements[ f'{val_name}_oc_auc_5%'] = oc_auc_5.result( )['collaborative_auc'] writer.write_scalars(step, val_measurements) if val_name == 'cifar_10h' or val_name == 'imagenet_real': cifar_10h_measurements = { f'{val_name}_label_diversity': label_diversity.result(), f'{val_name}_sample_diversity': sample_diversity.result(), f'{val_name}_ged': ged.result(), } writer.write_scalars(step, cifar_10h_measurements) # OOD eval # There are two entries in the ood_ds dict (in-dist, ood), and that this # section computes metrics using both pieces. This is in contrast to # normal validation eval above where we eval metrics separately for each # val split in val_ds. if ood_ds and config.ood_methods: def make_sngp_eval_fn(states): def sngp_eval_fn(params, images, labels, mask): return evaluation_fn(params=params, states=states, images=images, labels=labels, mask=mask) return sngp_eval_fn ood_measurements = ood_utils.eval_ood_metrics( ood_ds, ood_ds_names, config.ood_methods, make_sngp_eval_fn(states_repl), opt_repl.target, n_prefetch=config.get('prefetch_to_device', 1)) writer.write_scalars(step, ood_measurements) chrono.resume() if 'fewshot' in config and fewshotter is not None: # Compute few-shot on-the-fly evaluation. if train_utils.itstime(step, config.fewshot.log_steps, total_steps): chrono.pause() write_note(f'Few-shot evaluation...\n{chrono.note}') # Keep `results` to return for reproducibility tests. fewshot_results, best_l2 = fewshotter.run_all( opt_repl.target, datasets=config.fewshot.datasets, states=states_repl) # TODO(dusenberrymw): Remove this once fewshot.py is updated. def make_writer_measure_fn(step): def writer_measure(name, value): writer.write_scalars(step, {name: value}) return writer_measure fewshotter.walk_results(make_writer_measure_fn(step), fewshot_results, best_l2) chrono.resume() # End of step. if config.get('testing_failure_step'): # Break early to simulate infra failures in test cases. if config.testing_failure_step == step: break write_note(f'Done!\n{chrono.note}') pool.close() pool.join() writer.close() # Return final training loss, validation loss, and fewshot results for # reproducibility test cases. return train_loss, val_loss, fewshot_results
def create_synchronized_rng_seed(): rng_seed = np.int64(struct.unpack('q', os.urandom(8))[0]) rng_seed = _sum_seeds_pmapped(jax_utils.replicate(rng_seed)) rng_seed = np.sum(rng_seed) return rng_seed
def train_and_evaluate(self, workdir): """Runs a training and evaluation loop. Args: workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ tf.io.gfile.makedirs(workdir) config = self.config substeps = config.training.substeps # Learning rate schedule. num_train_steps = config.training.num_train_steps logging.info('num_train_steps=%d', num_train_steps) # Get train state state = self._train_state # Set up checkpointing of the model and the input pipeline. checkpoint_dir = os.path.join(workdir, 'checkpoints') ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir, max_to_keep=5) state = ckpt.restore_or_initialize(state) initial_step = int(state.step) # Distribute training. state = flax_utils.replicate(state) writer = metric_writers.create_default_writer( workdir, just_logging=jax.process_index() > 0) if initial_step == 0: writer.write_hparams(dict(config)) logging.info('Starting training loop at step %d.', initial_step) hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=num_train_steps, writer=writer) if jax.process_index() == 0: hooks += [ report_progress, periodic_actions.Profile(num_profile_steps=5, logdir=workdir) ] step = initial_step with metric_writers.ensure_flushes(writer): while step < num_train_steps: # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU # devices. is_last_step = step + substeps >= num_train_steps with jax.profiler.StepTraceAnnotation('train', step_num=step): inputs = jax.tree_map(np.asarray, next(self._train_iter)) state, outputs = self._update_func(state, inputs) # Quick indication that training is happening. logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step) for h in hooks: h(step) new_step = int(state.step[0]) assert new_step == step + substeps step = new_step is_eval = step % config.logs.eval_full_every_steps == 0 or is_last_step if step % config.logs.log_loss_every_steps == 0 and not is_eval: def avg_over_substeps(x): assert x.shape[0] == substeps return float(x.mean(axis=0)) # Extract scalars and images. outputs = flax_utils.unreplicate(outputs) outputs = jax.tree_map(avg_over_substeps, outputs) scalars = outputs['scalars'] writer.write_scalars(step, scalars) if is_eval: with report_progress.timed('eval_full'): outputs = self._eval_epoch(params=state.ema_params) outputs = flax_utils.unreplicate(outputs) scalars = outputs['scalars'] writer.write_scalars(step, scalars) if step % config.logs.checkpoint_every_steps == 0 or is_last_step: with report_progress.timed('checkpoint'): ckpt.save(flax_utils.unreplicate(state)) logging.info('Finishing training at step %d', num_train_steps)
def eval_checkpoints( checkpoint_dir, hps, rng, eval_num_batches, model_cls, dataset_builder, dataset_meta_data, hessian_eval_config, min_global_step=None, max_global_step=None, use_deprecated_checkpointing=True, ): """Evaluate the Hessian of the given checkpoints. Iterates over all checkpoints in the specified directory, loads the checkpoint then evaluates the Hessian on the given checkpoint. A list of dicts will be saved to cns at checkpoint_dir/hessian_eval_config['name']. Args: checkpoint_dir: Directory of checkpoints to load. hps: (tf.HParams) Model, initialization and training hparams. rng: (jax.random.PRNGKey) Rng seed used in model initialization and data shuffling. eval_num_batches: (int) The batch size used for evaluating on validation, and test sets. Set to None to evaluate on the whole test set. model_cls: One of the model classes (not an instance) defined in model_lib. dataset_builder: dataset builder returned by datasets.get_dataset. dataset_meta_data: dict of meta_data about the dataset. hessian_eval_config: a dict specifying the configuration of the Hessian eval. min_global_step: Lower bound on what steps to filter checkpoints. Set to None to evaluate all checkpoints in the directory. max_global_step: Upper bound on what steps to filter checkpoints. use_deprecated_checkpointing: Whether to use deprecated checkpointing. """ rng, init_rng = jax.random.split(rng) rng = jax.random.fold_in(rng, jax.host_id()) rng, data_rng = jax.random.split(rng) initializer = initializers.get_initializer('noop') loss_name = 'cross_entropy' metrics_name = 'classification_metrics' model = model_cls(hps, dataset_meta_data, loss_name, metrics_name) # Maybe run the initializer. flax_module, batch_stats = trainer.initialize(model.flax_module_def, initializer, model.loss_fn, hps.input_shape, hps.output_shape, hps, init_rng, None) # Fold in a the unreplicated batch_stats and rng into the loss used by # hessian eval. def batch_loss(module, batch_rng): batch, rng = batch_rng return model.training_cost(module, batch_stats, batch, rng)[0] batch_stats = jax_utils.replicate(batch_stats) if jax.host_id() == 0: utils.log_pytree_shape_and_statistics(flax_module.params) logging.info('train_size: %d,', hps.train_size) logging.info(hps) # Save the hessian computation hps to the experiment directory exp_dir = os.path.join(checkpoint_dir, hessian_eval_config['name']) if not gfile.exists(exp_dir): gfile.mkdir(exp_dir) if min_global_step == 0: hparams_fname = os.path.join(exp_dir, 'hparams.json') with gfile.GFile(hparams_fname, 'w') as f: f.write(hps.to_json()) config_fname = os.path.join(exp_dir, 'hconfig.json') with gfile.GFile(config_fname, 'w') as f: f.write(json.dumps(hessian_eval_config)) optimizer = trainer.get_optimizer(hps).create(flax_module) optimizer = jax_utils.replicate(optimizer) data_rng = jax.random.fold_in(data_rng, 0) assert hps.batch_size % (jax.device_count()) == 0 dataset = dataset_builder( data_rng, hps.batch_size, eval_batch_size=hps.batch_size, # eval iterators not used. hps=hps, ) # pmap functions for the training loop evaluate_batch_pmapped = jax.pmap(model.evaluate_batch, axis_name='batch') if jax.host_id() == 0: logging.info('Starting eval!') logging.info('Number of hosts: %d', jax.host_count()) hessian_evaluator = hessian_eval.CurvatureEvaluator( optimizer.target, hessian_eval_config, dataset, batch_loss) if min_global_step is None: suffix = '' else: suffix = '{}_{}'.format(min_global_step, max_global_step) pytree_path = os.path.join(checkpoint_dir, hessian_eval_config['name'], suffix) logger = utils.MetricLogger(pytree_path=pytree_path) for checkpoint_path, step in iterate_checkpoints(checkpoint_dir, min_global_step, max_global_step): ckpt = checkpoint.load_checkpoint( checkpoint_path, target=(optimizer, batch_stats), use_deprecated_checkpointing=use_deprecated_checkpointing) results = trainer.restore_checkpoint( ckpt, (optimizer, batch_stats), use_deprecated_checkpointing=use_deprecated_checkpointing) optimizer, batch_stats = results[0] # pylint: disable=protected-access batch_stats = trainer._maybe_sync_batchnorm_stats(batch_stats) # pylint: enable=protected-access report, _ = trainer.eval_metrics(optimizer.target, batch_stats, dataset, eval_num_batches, eval_num_batches, evaluate_batch_pmapped) if jax.host_id() == 0: logging.info('Global Step: %d', step) logging.info(report) row = {} grads, updates = [], [] hess_evecs, cov_evecs = [], [] stats, hess_evecs, cov_evecs = hessian_evaluator.evaluate_spectrum( optimizer.target, step) row.update(stats) if hessian_eval_config[ 'compute_stats'] or hessian_eval_config['compute_interps']: grads, updates = hessian_evaluator.compute_dirs(optimizer) row.update(hessian_evaluator.evaluate_stats(optimizer.target, grads, updates, hess_evecs, cov_evecs, step)) row.update(hessian_evaluator.compute_interpolations(optimizer.target, grads, updates, hess_evecs, cov_evecs, step)) if jax.host_id() == 0: logger.append_pytree(row)
def compute_is_scores(filename): """Compute IS scores for training data.""" # Make sure tf does not allocate gpu memory. tf.config.experimental.set_visible_devices([], 'GPU') if FLAGS.jax_backend_target: jax.config.FLAGS.jax_xla_backend = 'tpu_driver' jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: tf.io.gfile.makedirs(FLAGS.model_dir) if FLAGS.batch_size % n_devices: raise ValueError('Batch size must be divisible by the number of devices') vocab_path = FLAGS.vocab_path if vocab_path is None: vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model') tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) # Load Dataset print('Loading data') logging.info('Initializing dataset.') train_ds, encoder = input_pipeline.get_wmt_is_datasets( n_devices=n_devices, dataset_name=FLAGS.dataset_name, shard_idx=jax.host_id(), shard_count=jax.host_count(), data_dir=FLAGS.data_dir, vocab_path=vocab_path, target_vocab_size=FLAGS.vocab_size, batch_size=FLAGS.batch_size, max_length=FLAGS.max_target_length, paracrawl_size=FLAGS.paracrawl_size) print('Datasets created') train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_id = decode.EOS_ID # Default Sentencepiece EOS token. print('data iterators created') logging.info('Initializing model, optimizer, and step functions.') # Build Model and Optimizer # --------------------------------------------------------------------------- eval_config = models.TransformerConfig( vocab_size=vocab_size, output_vocab_size=vocab_size, share_embeddings=FLAGS.share_embeddings, logits_via_embedding=FLAGS.logits_via_embedding, dtype=jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32, emb_dim=FLAGS.emb_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.qkv_dim, mlp_dim=FLAGS.mlp_dim, max_len=max(FLAGS.max_target_length, FLAGS.max_eval_target_length), dropout_rate=FLAGS.dropout_rate, attention_dropout_rate=FLAGS.attention_dropout_rate, deterministic=True, decode=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) start_step = 0 rng = jax.random.PRNGKey(FLAGS.random_seed) rng, init_rng = jax.random.split(rng) # It's possible that is supposed to be per device batch size input_shape = (FLAGS.batch_size, FLAGS.max_target_length) target_shape = (FLAGS.batch_size, FLAGS.max_target_length) m = models.Transformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) # apply an optimizer to this tree optimizer_def = optim.Adam( FLAGS.learning_rate, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(initial_variables['params']) # We access model params only from optimizer below via optimizer.target. del initial_variables if FLAGS.restore_checkpoints: logging.info('Restoring checkpoint.') # If we have a pretrained model, use that. Else, just continue where leftoff model_path = FLAGS.pretrained_model_dir if FLAGS.pretrained_model_dir else FLAGS.model_dir # When loading a checkpoint trained with adapters (ie. frozen weights) # restoring from the base optimizer fails. We catch this error and create # the optimizer with frozen weights. try: optimizer = checkpoints.restore_checkpoint(model_path, optimizer) # Grab last step. start_step = int(optimizer.state.step) except ValueError: adapter = optim.ModelParamTraversal(lambda path, _: FLAGS.adapter in path) optimizer = optimizer_def.create(optimizer.target, focus=adapter) optimizer = checkpoints.restore_checkpoint(model_path, optimizer) start_step = optimizer.state[0].step else: raise RuntimeError('Must restore checkpoint for IS') if FLAGS.adapter != NONE and not isinstance(optimizer, optim.MultiOptimizer): adapter = optim.ModelParamTraversal(lambda path, _: FLAGS.adapter in path) optimizer = optimizer_def.create(optimizer.target, focus=adapter) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) p_eval_step = jax.pmap( functools.partial( eval_for_is_step, config=eval_config), axis_name='batch') logging.info('Start scoring loop.') metrics_all = [] t_loop_start = time.time() # Eval Metrics logging.info('Gathering evaluation metrics.') t_eval_start = time.time() save_file = FLAGS.is_save_path + '/' + filename + '-lengths.txt' length_fp = tf.io.gfile.GFile(save_file, 'w') lengths_writer = csv.writer(length_fp) save_file = FLAGS.is_save_path + '/' + filename + '.txt' with tf.io.gfile.GFile(save_file, 'w') as fp: writer = csv.writer(fp) for batch_idx, eval_batch in enumerate(train_iter): eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access cur_pred_batch_size = eval_batch['inputs'].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) eval_batch = jax.tree_map( lambda x: common.pad_examples(x, padded_size), eval_batch) # pylint: disable=cell-var-from-loop eval_batch = common_utils.shard(eval_batch) losses, lengths = p_eval_step(optimizer.target, eval_batch) if jax.host_id() == 0: losses = common.tohost(losses) lengths = common.tohost(lengths) if cur_pred_batch_size % n_devices: writer.writerow(losses[:cur_pred_batch_size]) lengths_writer.writerow(lengths[:cur_pred_batch_size]) else: writer.writerow(losses) lengths_writer.writerow(lengths) if batch_idx % 500 == 0: print('Batch', batch_idx) print(time.time() - t_loop_start) length_fp.close()
def inference_time(config: ml_collections.ConfigDict, workdir: str): """Runs a number of steps and measures inference time.""" assert config.batch, f'Expected --config.batch={config.batch} > 0' assert config.num_classes, ( f'Expected --config.num_classes={config.num_classes} > 0') assert config.image_size, ( f'Expected --config.image_size={config.image_size} > 0') # Build VisionTransformer architecture model_config = config_lib.MODEL_CONFIGS[config.model_name] model = models.VisionTransformer( num_classes=config.num_classes, **model_config) # Make sure initial model parameters (before replication) are on CPU only. @functools.partial(jax.jit, backend='cpu') def init(rng): return model.init( rng, # Discard the "num_local_devices" dimension for initialization. inputs=jnp.ones([1, config.image_size, config.image_size, 3], jnp.float32), train=False) variables = init(jax.random.PRNGKey(0)) params_repl = flax_utils.replicate(variables['params']) # pmap replicates the models over all TPUs/GPUs vit_fn_repl = jax.pmap(functools.partial(model.apply, train=False)) images = jnp.ones([ jax.local_device_count(), config.batch // jax.local_device_count(), config.image_size, config.image_size, 3 ], jnp.float32) writer = metric_writers.create_default_writer(workdir, asynchronous=False) writer.write_hparams(config.to_dict()) logging.info('Starting training loop; initial compile can take a while...') logits = vit_fn_repl(flax.core.FrozenDict(params=params_repl), images) logits.block_until_ready() logging.info('Done.') logging.info('Going to run %d inferences WITHOUT measuring...', config.initial_steps) for _ in range(config.initial_steps): logits = vit_fn_repl(flax.core.FrozenDict(params=params_repl), images) logits.block_until_ready() logging.info('Going to run %d inferences measuring...', config.steps) times = [] for _ in range(config.initial_steps): t0 = time.time() logits = vit_fn_repl(flax.core.FrozenDict(params=params_repl), images) logits.block_until_ready() times.append(time.time() - t0) logging.info('times=%s', times) imgs_sec_core = config.batch / jax.local_device_count() / np.array(times) logging.info('imgs_sec_core_min=%f', imgs_sec_core.min()) logging.info('imgs_sec_core_max=%f', imgs_sec_core.max()) logging.info('imgs_sec_core_mean=%f', imgs_sec_core.mean()) logging.info('imgs_sec_core_std=%f', imgs_sec_core.std()) writer.write_scalars( 0, dict( imgs_sec_core_min=imgs_sec_core.min(), imgs_sec_core_max=imgs_sec_core.max(), imgs_sec_core_mean=imgs_sec_core.mean(), imgs_sec_core_std=imgs_sec_core.std(), ))
def predict_and_evaluate(config, workdir, ckpt_path=None): """Runs a testing and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. ckpt_path: The checkpoint to evaluate. If not specified, use the latest checkpoint. """ logging.info('Starting testing at %s', workdir) tf.io.gfile.makedirs(workdir) rng = jax.random.PRNGKey(config.seed) # Build input pipeline. rng, data_rng = jax.random.split(rng) data_rng = jax.random.fold_in(data_rng, jax.process_index()) test_ds = [] for split in config.dataset.test_splits: ds = input_pipeline.create_val_dataset( config.dataset, split, config.dataset.test_per_device_batch_size, config.dataset.test_pad_last_batch) test_ds.append(ds) # Initialize model. inputs = train_utils.get_init_inputs(test_ds[0]) rng, model_rng = jax.random.split(rng) predict_config = models.TransformerConfig(**config.model.to_dict()) predict_config = predict_config.replace(decode=True) model = models.Model(predict_config) state = train_utils.create_train_state(model, config, model_rng, inputs=inputs) writer = metric_writers.create_default_writer( workdir, just_logging=jax.process_index() > 0) # Set up checkpointing of the model and the input pipeline. checkpoint_dir = os.path.join(workdir, 'checkpoints') ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir, max_to_keep=3) logging.info('Testing and evaluating checkpoint %s', ckpt_path) try: state = ckpt.restore(state, ckpt_path) except FileNotFoundError: state = ckpt.restore_or_initialize(state) step = int(state.step) p_pred_step = jax.pmap(functools.partial(predict_step, config=predict_config), axis_name='batch', static_broadcasted_argnums=(3, )) p_init_cache = jax.pmap(functools.partial(init_cache, config=predict_config), axis_name='batch') # Distribute testing. state = flax_utils.replicate(state) with metric_writers.ensure_flushes(writer): test_metrics = {} for ds, split in zip(test_ds, config.dataset.test_splits): ds_metrics = evaluate_sequence_accuracy(p_pred_step, p_init_cache, state, ds, config, split, workdir, config.num_test_steps) ds_metrics = {f'{k}_{split}': v for k, v in ds_metrics.items()} test_metrics.update(ds_metrics) writer.write_scalars(step, test_metrics)
def main(argv): del argv config = FLAGS.config workdir = FLAGS.workdir logging.info("Workdir: %s", workdir) save_checkpoint_path = None if config.get("checkpoint_steps"): tf.io.gfile.makedirs(workdir) save_checkpoint_path = os.path.join(workdir, "checkpoint.npz") # The pool is used to perform misc operations such as logging in async way. pool = multiprocessing.pool.ThreadPool() # This seed makes the Jax part of things (like model init) deterministic. # However, full training still won't be deterministic, for example due to the # tf.data pipeline not being deterministic even if we would set TF seed. rng = jax.random.PRNGKey(config.get("seed", 0)) def write_note(note): if jax.host_id() == 0: logging.info("NOTE: %s", note) write_note("Initializing...") # Verify settings to make sure no checkpoints are accidentally missed. if config.get("keep_checkpoint_steps"): assert config.get("checkpoint_steps"), "Specify `checkpoint_steps`." assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, ( f"`keep_checkpoint_steps` ({config.checkpoint_steps}) should be" f"divisible by `checkpoint_steps ({config.checkpoint_steps}).`") batch_size = config.batch_size batch_size_eval = config.get("batch_size_eval", batch_size) if (batch_size % jax.device_count() != 0 or batch_size_eval % jax.device_count() != 0): raise ValueError(f"Batch sizes ({batch_size} and {batch_size_eval}) must " f"be divisible by device number ({jax.device_count()})") local_batch_size = batch_size // jax.host_count() local_batch_size_eval = batch_size_eval // jax.host_count() logging.info( "Global batch size %d on %d hosts results in %d local batch size. " "With %d dev per host (%d dev total), that's a %d per-device batch size.", batch_size, jax.host_count(), local_batch_size, jax.local_device_count(), jax.device_count(), local_batch_size // jax.local_device_count()) write_note("Initializing train dataset...") train_ds = input_pipeline.get_data( dataset=config.dataset, split=config.train_split, data_dir=fillin(config.get("dataset_dir")), batch_size=local_batch_size, preprocess_fn=pp_builder.get_preprocess_fn(config.pp_train), shuffle_buffer_size=config.shuffle_buffer_size, prefetch=config.get("prefetch_to_host", 2), cache=False) # Start prefetching already. train_iter = u.start_input_pipeline( train_ds, config.get("prefetch_to_device", 1), pad=local_batch_size) # We always pad to local_batch_size_eval even when less would be enough in # order to minimize memory fragmentation. write_note("Initializing val dataset(s)...") def _get_val_split(dataset, split, pp_eval, data_dir=None): # We do ceil rounding such that we include the last incomplete batch. nval_img = input_pipeline.get_num_examples( dataset, split, data_dir=fillin(data_dir)) val_steps = int(np.ceil(nval_img / batch_size_eval)) logging.info("Running validation for %d steps for %s, %s", val_steps, dataset, split) val_it = input_pipeline.get_data( dataset=dataset, split=split, data_dir=fillin(data_dir), batch_size=local_batch_size_eval, preprocess_fn=pp_builder.get_preprocess_fn(pp_eval), cache=config.get("val_cache", "batched"), repeat_after_batching=True, prefetch=0, # Save memory since we cache. drop_remainder=False, shuffle_files=False) val_it = u.start_input_pipeline( val_it, config.get("prefetch_to_device", 1), pad=local_batch_size_eval) return (val_it, val_steps) if isinstance(config.val_split, str): val_ds = {"val": _get_val_split(config.dataset, config.val_split, config.pp_eval, config.get("dataset_dir"))} else: val_ds = {t[0]: _get_val_split(*t[1:]) for t in config.val_split} ntrain_img = input_pipeline.get_num_examples( config.dataset, config.train_split, data_dir=fillin(config.get("dataset_dir"))) steps_per_epoch = ntrain_img / batch_size if config.get("num_epochs"): total_steps = int(config.num_epochs * steps_per_epoch) assert not config.get("total_steps"), "Set either num_epochs or total_steps" else: total_steps = config.total_steps logging.info( "Running for %d steps, that means %f epochs and %f steps per epoch", total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch) mw = u.BigVisionMetricWriter(xm_xp.id, xm_wu.id, steps_per_epoch) write_note(f"Initializing {config.model_name} model...") model_mod = importlib.import_module(f"{BASEDIR}.models.{config.model_name}") model = model_mod.Model( num_classes=config.num_classes, **config.get("model", {})) # We want all parameters to be created in host RAM, not on any device, they'll # be sent there later as needed, otherwise we already encountered two # situations where we allocate them twice. @partial(jax.jit, backend="cpu") def init(rng): image_size = tuple(train_ds.element_spec["image"].shape[1:]) dummy_input = jnp.zeros((local_batch_size,) + image_size, jnp.float32) params = flax.core.unfreeze(model.init(rng, dummy_input))["params"] # Set bias in the head to a low value, such that loss is small initially. params["head"]["bias"] = jnp.full_like( params["head"]["bias"], config.get("init_head_bias", 0)) return params rng, rng_init = jax.random.split(rng) params_cpu = init(rng_init) if jax.host_id() == 0: num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) mw.measure("num_params", num_params) @partial(jax.pmap, axis_name="batch") def evaluation_fn(params, images, labels, mask): # Ignore the entries with all zero labels for evaluation. mask *= labels.max(axis=1) logits, _ = model.apply({"params": flax.core.freeze(params)}, images) losses = getattr(u, config.get("loss", "sigmoid_xent"))( logits=logits, labels=labels, reduction=False) loss = jax.lax.psum(losses * mask, axis_name="batch") top1_idx = jnp.argmax(logits, axis=1) # Extracts the label at the highest logit index for each image. top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0] ncorrect = jax.lax.psum(top1_correct * mask, axis_name="batch") n = jax.lax.psum(mask, axis_name="batch") return ncorrect, loss, n # Setup function for computing representation. @partial(jax.pmap, axis_name="batch") def representation_fn(params, images, labels, mask): _, outputs = model.apply({"params": flax.core.freeze(params)}, images) representation = outputs[config.fewshot.representation_layer] representation = jax.lax.all_gather(representation, "batch") labels = jax.lax.all_gather(labels, "batch") mask = jax.lax.all_gather(mask, "batch") return representation, labels, mask # Load the optimizer either from our folder or from flax. opt_name = config.get("optim_name", "momentum_hp") write_note(f"Initializing {opt_name} optimizer...") try: opt_mod = importlib.import_module(f"{BASEDIR}.optims.{opt_name}") opt_def = opt_mod.Optimizer(**config.get("optim", {})) except ModuleNotFoundError: opt_def = getattr(flax.optim, opt_name)(**config.get("optim", {})) # We jit this, such that the arrays that are created are created on the same # device as the input is, in this case the CPU. Else they'd be on device[0]. opt_cpu = jax.jit(opt_def.create)(params_cpu) @partial(jax.pmap, axis_name="batch", donate_argnums=(0,)) def update_fn(opt, lr, images, labels, rng): """Update step.""" measurements = {} if config.get("mixup") and config.mixup.p: rng, (images, labels), _ = u.mixup(rng, images, labels, **config.mixup) # Get device-specific loss rng. rng, rng_model = jax.random.split(rng, 2) rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index("batch")) def loss_fn(params, images, labels): logits, _ = model.apply( {"params": flax.core.freeze(params)}, images, train=True, rngs={"dropout": rng_model_local}) return getattr(u, config.get("loss", "sigmoid_xent"))( logits=logits, labels=labels) # Implementation considerations compared and summarized at # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en# l, g = u.accumulate_gradient(jax.value_and_grad(loss_fn), opt.target, images, labels, config.get("grad_accum_steps")) l, g = jax.lax.pmean((l, g), axis_name="batch") # Log the gradient norm only if we need to compute it anyways (clipping) # or if we don't use grad_accum_steps, as they interact badly. if config.get("grad_accum_steps", 1) == 1 or config.get("grad_clip_norm"): grads, _ = jax.tree_flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements["l2_grads"] = l2_g # Optionally resize the global gradient to a maximum norm. We found this # useful in some cases across optimizers, hence it's in the main loop. if config.get("grad_clip_norm"): g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g) g = jax.tree_map(lambda p: g_factor * p, g) opt = opt.apply_gradient(g, learning_rate=lr) decay_rules = config.get("weight_decay", []) or [] if isinstance(decay_rules, numbers.Number): decay_rules = [(".*kernel.*", decay_rules)] sched_m = lr/config.lr.base if config.get("weight_decay_decouple") else lr def decay_fn(v, wd): return (1.0 - sched_m * wd) * v opt = opt.replace(target=u.tree_map_with_regex( decay_fn, opt.target, decay_rules, name="weight decay")) params, _ = jax.tree_flatten(opt.target) measurements["l2_params"] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params])) return opt, l, rng, measurements # Other things besides optimizer state to be stored. checkpoint_extra = dict(accum_train_time=0.0) # Decide how to initialize training. The order is important. # 1. Always resumes from the existing checkpoint, e.g. resumes a finetune job. # 2. Resume from a previous checkpoint, e.g. start a cooldown training job. # 3. Initialize model from something, e,g, start a fine-tuning job. # 4. Train from scratch. resume_checkpoint_path = None if save_checkpoint_path and tf.io.gfile.exists(save_checkpoint_path): resume_checkpoint_path = save_checkpoint_path elif config.get("resume"): resume_checkpoint_path = fillin(config.resume) if resume_checkpoint_path: write_note("Resume training from checkpoint...") checkpoint = {"opt": opt_cpu, "extra": checkpoint_extra} _, checkpoint_tree = jax.tree_flatten(checkpoint) loaded = u.load_checkpoint(checkpoint_tree, resume_checkpoint_path) # bfloat16 type gets lost when data is saved to disk, so we recover it. checkpoint = jax.tree_map(u.recover_dtype, loaded) opt_cpu, checkpoint_extra = checkpoint["opt"], checkpoint["extra"] elif config.get("model_init"): write_note(f"Initialize model from {config.model_init}...") loaded = model_mod.load(params_cpu, config.model_init, config.get("model")) opt_cpu = opt_cpu.replace(target=loaded) if jax.host_id() == 0: logging.info("Restored parameter overview:") parameter_overview.log_parameter_overview(loaded) write_note("Kicking off misc stuff...") first_step = int(opt_cpu.state.step) # Might be a DeviceArray type. chrono = u.Chrono(first_step, total_steps, batch_size, checkpoint_extra["accum_train_time"]) # Note: switch to ProfileAllHosts() if you need to profile all hosts. # (Xprof data become much larger and take longer to load for analysis) profiler = periodic_actions.Profile( # Create profile after every restart to analyze pre-emption related # problems and assure we get similar performance in every run. logdir=workdir, first_profile=first_step + 10) # Prepare the learning-rate and pre-fetch it to device to avoid delays. lr_fn = u.create_learning_rate_schedule( batch_size, total_steps, steps_per_epoch, **config.get("lr", {})) lr_iter = u.prefetch_scalar(map(lr_fn, range(first_step, total_steps)), config.get("prefetch_to_device", 1)) write_note(f"Replicating...\n{chrono.note}") opt_repl = flax_utils.replicate(opt_cpu) write_note(f"Initializing few-shotters...\n{chrono.note}") if "fewshot" in config: fewshotter = fewshot.FewShotEvaluator( representation_fn, config.fewshot, config.fewshot.get("batch_size") or batch_size_eval) rng, rng_loop = jax.random.split(rng, 2) rngs_loop = flax_utils.replicate(rng_loop) checkpoint_writer = None write_note(f"First step compilations...\n{chrono.note}") # Using a python integer for step here, because opt.state.step is allocated # on TPU during replication. for step, train_batch, lr_repl in zip( range(first_step + 1, total_steps + 1), train_iter, lr_iter): mw.step_start(step) with jax.profiler.TraceContext("train_step", step_num=step, _r=1): opt_repl, loss_value, rngs_loop, extra_measurements = update_fn( opt_repl, lr_repl, train_batch["image"], train_batch["labels"], rng=rngs_loop) if jax.host_id() == 0: profiler(step) # Checkpoint saving if u.itstime(step, config.get("checkpoint_steps"), total_steps, host=0): chrono.pause() u.checkpointing_timeout(checkpoint_writer, config.get("checkpoint_timeout", 1)) checkpoint_extra["accum_train_time"] = chrono.accum_train_time # We need to transfer the weights over now or else we risk keeping them # alive while they'll be updated in a future step, creating hard to debug # memory errors (see b/160593526). Also, takes device 0's params only. opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl) # Check whether we want to keep a copy of the current checkpoint. copy_step = None if u.itstime(step, config.get("keep_checkpoint_steps"), total_steps): copy_step = step # Checkpoint should be a nested dictionary or FLAX datataclasses from # `flax.struct`. Both can be present in a checkpoint. checkpoint = {"opt": opt_cpu, "extra": checkpoint_extra} checkpoint_writer = pool.apply_async( u.save_checkpoint, (checkpoint, save_checkpoint_path, copy_step)) chrono.resume() # Report training progress if u.itstime(step, config.log_training_steps, total_steps, host=0): mw.measure("learning_rate", lr_repl[0]) mw.measure("training_loss", loss_value[0]) for name, value in extra_measurements.items(): mw.measure(name, value[0]) chrono.tick(step, mw.measure, write_note) # Report validation performance if u.itstime(step, config.log_eval_steps, total_steps): chrono.pause() for val_name, (val_iter, val_steps) in val_ds.items(): ncorrect, loss, nseen = 0, 0, 0 for _, batch in zip(range(val_steps), val_iter): batch_ncorrect, batch_losses, batch_n = evaluation_fn( opt_repl.target, batch["image"], batch["labels"], batch["mask"]) # All results are a replicated array shaped as follows: # (local_devices, per_device_batch_size, elem_shape...) # with each local device's entry being identical as they got psum'd. # So let's just take the first one to the host as numpy. ncorrect += np.sum(np.array(batch_ncorrect[0])) loss += np.sum(np.array(batch_losses[0])) nseen += np.sum(np.array(batch_n[0])) mw.measure(f"{val_name}_prec@1", ncorrect / nseen) mw.measure(f"{val_name}_loss", loss / nseen) chrono.resume() if "fewshot" in config: # Compute few-shot on-the-fly evaluation. if u.itstime(step, config.fewshot.log_steps, total_steps): chrono.pause() write_note(f"Few-shot evaluation...\n{chrono.note}") r = fewshotter.run_all(opt_repl.target, config.fewshot.datasets) fewshotter.walk_results(mw.measure, *r) chrono.resume() mw.step_end() write_note(f"Done!\n{chrono.note}") pool.close() pool.join() mw.close()
def test_train(self): tf.enable_v2_behavior() tf.random.set_seed(0) np.random.seed(0) random.seed(0) dataset_filepattern = os.path.join( os.path.dirname(__file__), 'tasks/robust_fill/dataset/test_dataset/program_tasks.tf_records-*' ) print('dataset_filepattern = {}'.format(dataset_filepattern)) batch_size = 4 num_strings_per_task = 4 max_characters = 10 max_program_length = 15 # Build token tables. id_char_table = {i + 1: char for (i, char) in enumerate(dsl.CHARACTER)} char_id_table = {char: id for id, char in id_char_table.items()} _, token_id_table = dsl_tokens.build_token_tables() io_vocab_size = len(char_id_table) + 1 # For padding. program_vocab_size = len(token_id_table) + 1 bos_token = token_id_table[dsl.BOS] # Load dataset. dataset = input_pipeline.create_dataset_from_tf_record( dataset_filepattern, token_id_table, char_id_table) dataset = dataset.padded_batch(batch_size, padded_shapes=((num_strings_per_task, max_characters), (num_strings_per_task, max_characters), (max_program_length, )), drop_remainder=True) dataset_iter = dataset.repeat().as_numpy_iterator() train_config = models.TransformerConfig( vocab_size=io_vocab_size, output_vocab_size=program_vocab_size, shift=True, emb_dim=32, num_heads=4, num_layers=2, qkv_dim=32, mlp_dim=32, max_len=max(max_characters, max_program_length), deterministic=False, decode=False, bos_token=bos_token) eval_config = train_config.replace(deterministic=True) rng = jax.random.PRNGKey(0) rng, init_rng = jax.random.split(rng) m = models.ProgramTransformer(eval_config) initial_variables = jax.jit(m.init)( init_rng, jnp.ones((batch_size, num_strings_per_task, max_characters), jnp.float32), jnp.ones((batch_size, num_strings_per_task, max_characters), jnp.float32), jnp.ones((batch_size, max_program_length), jnp.float32)) optimizer_def = optim.Adam(1e-2, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=0.1) optimizer = optimizer_def.create(initial_variables['params']) del initial_variables # Don't keep a copy of the initial model. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = train_lib.create_learning_rate_scheduler( base_learning_rate=1e-2) p_train_step = jax.pmap(functools.partial( train_lib.train_step, learning_rate_fn=learning_rate_fn, config=train_config), axis_name='batch') p_eval_step = jax.pmap(functools.partial(train_lib.eval_step, config=eval_config), axis_name='batch') # Training loop. start_step = 0 rngs = jax.random.split(rng, jax.local_device_count()) del rng for _ in range(start_step, 1000): inputs, outputs, programs = common_utils.shard(next(dataset_iter)) optimizer, _, rngs = p_train_step(optimizer, inputs, outputs, programs, train_rng=rngs) # Evaluation. eval_metrics = [] for batches in dataset.as_numpy_iterator(): inputs, outputs, programs = common_utils.shard(batches) metrics = p_eval_step(optimizer.target, inputs, outputs, programs) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) if jax.host_id() == 0: self.assertGreater(eval_summary['accuracy'], 0.1)
def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() configure_logger(model_args, training_args) # Downloading and loading a dataset from the hub. datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir) if "validation" not in datasets.keys(): # make sure only "validation" and "train" keys remain" datasets = DatasetDict() datasets["validation"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"{data_args.train_split_name}[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, ) datasets["train"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"{data_args.train_split_name}[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, ) else: # make sure only "validation" and "train" keys remain" datasets = DatasetDict() datasets["validation"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split="validation", cache_dir=model_args.cache_dir, ) datasets["train"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"{data_args.train_split_name}", cache_dir=model_args.cache_dir, ) # only normalized-inputs-training is supported feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, do_normalize=True ) def prepare_dataset(batch): # check that all files have the correct sampling rate batch["speech"], _ = librosa.load(batch[data_args.speech_file_column], sr=feature_extractor.sampling_rate) return batch # load audio files into numpy arrays vectorized_datasets = datasets.map( prepare_dataset, num_proc=data_args.preprocessing_num_workers, remove_columns=datasets["train"].column_names ) # filter audio files that are too long vectorized_datasets = vectorized_datasets.filter( lambda data: len(data["speech"]) < int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate) ) def normalize(batch): return feature_extractor(batch["speech"], sampling_rate=feature_extractor.sampling_rate) # normalize and transform to `BatchFeatures` vectorized_datasets = vectorized_datasets.map( normalize, batched=True, num_proc=data_args.preprocessing_num_workers, load_from_cache_file=not data_args.overwrite_cache, remove_columns=vectorized_datasets["train"].column_names, ) # pretraining is only supported for "newer" stable layer norm architecture # apply_spec_augment has to be True, mask_feature_prob has to be 0.0 config = Wav2Vec2Config.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, ) if not config.do_stable_layer_norm or config.feat_extract_norm != "layer": raise ValueError( "PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'" ) model = FlaxWav2Vec2ForPreTraining(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)) # Activate gradient checkpointing if needed if training_args.gradient_checkpointing: model.gradient_checkpointing_enable() data_collator = FlaxDataCollatorForWav2Vec2Pretraining( model=model, feature_extractor=feature_extractor, pad_to_multiple_of=data_args.pad_to_multiple_of ) # Enable tensorboard only on the master node has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable." ) # Initialize our training rng = jax.random.PRNGKey(training_args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) gumbel_rngs = jax.random.split(rng, jax.local_device_count()) num_epochs = int(training_args.num_train_epochs) train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() num_train_steps = len(vectorized_datasets["train"]) // train_batch_size * num_epochs # Create learning rate schedule warmup_fn = optax.linear_schedule( init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps ) decay_fn = optax.linear_schedule( init_value=training_args.learning_rate, end_value=0, transition_steps=num_train_steps - training_args.warmup_steps, ) linear_decay_lr_schedule_fn = optax.join_schedules( schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps] ) # We use Optax's "masking" functionality to not apply weight decay # to bias and LayerNorm scale parameters. decay_mask_fn returns a # mask boolean with the same structure as the parameters. # The mask is True for parameters that should be decayed. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) flat_mask = { path: (path[-1] != "bias" and path[-2:] not in [("layer_norm", "scale"), ("final_layer_norm", "scale")]) for path in flat_params } return traverse_util.unflatten_dict(flat_mask) # create adam optimizer adamw = optax.adamw( learning_rate=linear_decay_lr_schedule_fn, b1=training_args.adam_beta1, b2=training_args.adam_beta2, eps=training_args.adam_epsilon, weight_decay=training_args.weight_decay, mask=decay_mask_fn, ) # Setup train state and define training hyper-parameters state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw) num_negatives = model.config.num_negatives contrastive_logits_temperature = model.config.contrastive_logits_temperature num_codevectors = model.config.num_codevectors_per_group * model.config.num_codevector_groups diversity_loss_weight = model.config.diversity_loss_weight # Define gradient update step fn def train_step(state, batch, dropout_rng, gumbel_rng): dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) gumbel_rng, new_gumbel_rng = jax.random.split(gumbel_rng) def loss_fn(params): negative_indices = batch.pop("sampled_negative_indices") gumbel_temperature = jnp.clip( model_args.max_gumbel_temperature * model_args.gumbel_temperature_decay ** state.step, a_min=model_args.min_gumbel_temperature, ) outputs = state.apply_fn( **batch, gumbel_temperature=gumbel_temperature, params=params, dropout_rng=dropout_rng, gumbel_rng=gumbel_rng, train=True, ) contrastive_loss = compute_contrastive_loss( outputs.projected_quantized_states, outputs.projected_states, negative_indices, batch["mask_time_indices"], contrastive_logits_temperature, num_negatives, ) diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors loss = contrastive_loss + diversity_loss_weight * diversity_loss return loss grad_fn = jax.value_and_grad(loss_fn) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) metrics = jax.lax.pmean( {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch" ) return new_state, metrics, new_dropout_rng, new_gumbel_rng # Create parallel version of the train step p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,)) # Define eval fn def eval_step(params, batch): negative_indices = batch.pop("sampled_negative_indices") outputs = model(**batch, params=params, train=False) contrastive_loss = compute_contrastive_loss( outputs.projected_quantized_states, outputs.projected_states, negative_indices, batch["mask_time_indices"], contrastive_logits_temperature, num_negatives, ) diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors loss = contrastive_loss + diversity_loss_weight * diversity_loss # summarize metrics metrics = {"loss": loss.mean(), "codevector_perplexity": outputs.codevector_perplexity} metrics = jax.lax.pmean(metrics, axis_name="batch") return metrics p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,)) # Replicate the train state on each device state = jax_utils.replicate(state) train_time = 0 train_metrics = [] epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) for epoch in epochs: # ======================== Training ================================ train_start = time.time() # Create sampling rng rng, input_rng = jax.random.split(rng) # Generate an epoch by shuffling sampling indices from the train dataset num_train_samples = len(vectorized_datasets["train"]) train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples)) train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) # Gather the indexes for creating the batch and do a training step for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)): samples = [vectorized_datasets["train"][int(idx)] for idx in batch_idx] model_inputs = data_collator(samples) model_inputs = shard(model_inputs.data) # Model forward state, train_metric, dropout_rngs, gumbel_rngs = p_train_step( state, model_inputs, dropout_rngs, gumbel_rngs ) train_metrics.append(train_metric) cur_step = epoch * (num_train_samples // train_batch_size) + step if cur_step % training_args.logging_steps == 0 and cur_step > 0: # Save metrics train_metric = jax_utils.unreplicate(train_metric) train_time += time.time() - train_start if has_tensorboard and jax.process_index() == 0: write_train_metric(summary_writer, train_metrics, train_time, cur_step) epochs.write( f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})" ) train_metrics = [] # ======================== Evaluating ============================== num_eval_samples = len(vectorized_datasets["validation"]) eval_samples_idx = jnp.arange(num_eval_samples) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) eval_metrics = [] for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): samples = [vectorized_datasets["validation"][int(idx)] for idx in batch_idx] model_inputs = data_collator(samples) # Model forward model_inputs = shard(model_inputs.data) metrics = p_eval_step(state.params, model_inputs) eval_metrics.append(metrics) # get eval metrics eval_metrics = get_metrics(eval_metrics) eval_metrics = jax.tree_map(jnp.mean, eval_metrics) # Update progress bar epochs.write( f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Perplexity: {eval_metrics['codevector_perplexity']})" ) # Save metrics if has_tensorboard and jax.process_index() == 0: cur_step = epoch * (len(vectorized_datasets["train"]) // train_batch_size) write_eval_metric(summary_writer, eval_metrics, cur_step) # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) model.save_pretrained(training_args.output_dir, params=params, push_to_hub=training_args.push_to_hub)
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ tf.io.gfile.makedirs(workdir) batch_size = config.batch_size n_devices = jax.device_count() if jax.process_count() > 1: raise ValueError( 'PixelCNN++ example should not be run on more than 1 host' ' (for now)') if batch_size % n_devices > 0: raise ValueError( 'Batch size must be divisible by the number of devices') train_summary_writer, eval_summary_writer = get_summary_writers(workdir) # Load dataset data_source = input_pipeline.DataSource(config) train_ds = data_source.train_ds eval_ds = data_source.eval_ds steps_per_epoch = data_source.ds_info.splits[ 'train'].num_examples // config.batch_size # Create dataset batch iterators train_iter = iter(train_ds) num_train_steps = train_ds.cardinality().numpy() steps_per_checkpoint = 1000 # Create the model using data-dependent initialization. Don't shard the init # batch. assert config.init_batch_size <= batch_size init_batch = next(train_iter)['image']._numpy()[:config.init_batch_size] rng = jax.random.PRNGKey(config.seed) rng, init_rng, dropout_rng = jax.random.split(rng, 3) initial_variables = model(config).init( { 'params': init_rng, 'dropout': dropout_rng }, init_batch, train=False)['params'] optimizer_def = optim.Adam(beta1=0.95, beta2=0.9995) optimizer = optimizer_def.create(initial_variables) optimizer, ema = restore_checkpoint(workdir, optimizer, initial_variables) ema = initial_variables step_offset = int(optimizer.state.step) optimizer, ema = jax_utils.replicate((optimizer, ema)) # Learning rate schedule learning_rate_fn = lambda step: config.learning_rate * config.lr_decay**step # pmap the train and eval functions p_train_step = jax.pmap(functools.partial(train_step, config, learning_rate_fn), axis_name='batch') p_eval_step = jax.pmap(functools.partial(eval_step, config=config), axis_name='batch') # Gather metrics train_metrics = [] for step, batch in zip(range(step_offset, num_train_steps), train_iter): # Load and shard the TF batch batch = load_and_shard_tf_batch(batch) # Generate a PRNG key that will be rolled into the batch. rng, step_rng = jax.random.split(rng) sharded_rngs = common_utils.shard_prng_key(step_rng) # Train step optimizer, ema, metrics = p_train_step(optimizer, ema, batch, sharded_rngs) train_metrics.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step) if (step + 1) % steps_per_epoch == 0: epoch = step // steps_per_epoch # We've finished an epoch train_metrics = common_utils.get_metrics(train_metrics) # Get training epoch summary for logging train_summary = jax.tree_map(lambda x: x.mean(), train_metrics) # Send stats to Tensorboard for key, vals in train_metrics.items(): for i, val in enumerate(vals): train_summary_writer.scalar(key, val, step - len(vals) + i + 1) # Reset train metrics train_metrics = [] # Evaluation eval_metrics = [] for eval_batch in eval_ds: # Load and shard the TF batch eval_batch = load_and_shard_tf_batch(eval_batch) # Step metrics = p_eval_step(ema, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) # Get eval epoch summary for logging eval_summary = jax.tree_map(lambda x: x.mean(), eval_metrics) # Log epoch summary logging.info('Epoch %d: TRAIN loss=%.6f, EVAL loss=%.6f', epoch, train_summary['loss'], eval_summary['loss']) eval_summary_writer.scalar('loss', eval_summary['loss'], step) train_summary_writer.flush() eval_summary_writer.flush() if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_train_steps: save_checkpoint(workdir, optimizer, ema, step)
def main(unused_argv): rng = random.PRNGKey(20200823) # Shift the numpy random seed by host_id() to shuffle data loaded by different # hosts. np.random.seed(20201473 + jax.host_id()) if FLAGS.config is not None: utils.update_flags(FLAGS) if FLAGS.batch_size % jax.device_count() != 0: raise ValueError( "Batch size must be divisible by the number of devices.") if FLAGS.train_dir is None: raise ValueError("train_dir must be set. None set now.") if FLAGS.data_dir is None: raise ValueError("data_dir must be set. None set now.") dataset = datasets.get_dataset("train", FLAGS) test_dataset = datasets.get_dataset("test", FLAGS) test_render_fn = jax.pmap( # Note rng_keys are useless in eval mode since there's no randomness. # pylint: disable=g-long-lambda lambda key_0, key_1, model, rays: jax.lax.all_gather( model(key_0, key_1, rays.origins, rays.directions, rays.viewdirs), axis_name="batch"), in_axes=(None, None, None, 0), # Only distribute the data input. donate_argnums=3, axis_name="batch", ) rng, key = random.split(rng) init_model, init_state = models.get_model(key, dataset.peek(), FLAGS) optimizer_def = optim.Adam(FLAGS.lr) optimizer = optimizer_def.create(init_model) state = model_utils.TrainState(step=0, optimizer=optimizer, model_state=init_state) if not utils.isdir(FLAGS.train_dir): utils.makedirs(FLAGS.train_dir) state = checkpoints.restore_checkpoint(FLAGS.train_dir, state) offset = state.step + 1 state = jax_utils.replicate(state) del init_model, init_state if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(FLAGS.train_dir) t_loop_start = time.time() learning_rate_fn = functools.partial(utils.learning_rate_decay, init_lr=FLAGS.lr, decay_steps=FLAGS.lr_decay * 1000, decay_rate=0.1) ptrain_step = jax.pmap(train_step, axis_name="batch", in_axes=(0, 0, 0, None), donate_argnums=2) # Prefetch_buffer_size = 3 x batch_size pdataset = jax_utils.prefetch_to_device(dataset, 3) n_local_deices = jax.local_device_count() rng = rng + jax.host_id() # Make random seed separate across hosts. keys = random.split(rng, n_local_deices) # For pmapping RNG keys. gc.disable() # Disable automatic garbage collection for efficiency. for step, batch in zip(range(offset, FLAGS.max_steps + 1), pdataset): lr = learning_rate_fn(step) state, stats, keys = ptrain_step(keys, state, batch, lr) if step % FLAGS.gc_every == 0: gc.collect() # --- Train logs start --- # Put the training time visualization before the host_id check as in # multi-host evaluation, all hosts need to run inference even though we # only use host 0 to record results. if FLAGS.render_every > 0 and step % FLAGS.render_every == 0: # We reuse the same random number generator from the optimization step # here on purpose so that the visualization matches what happened in # training. state_to_eval = jax.device_get(jax.tree_map(lambda x: x[0], state)) test_case = next(test_dataset) pred_color, pred_disp, pred_acc = utils.render_image( state_to_eval, test_case["rays"], test_render_fn, keys[0], FLAGS.dataset == "llff", chunk=FLAGS.chunk) if jax.host_id() == 0: summary_writer.image("pred_color", pred_color, step) summary_writer.image("pred_disp", pred_disp, step) summary_writer.image("pred_acc", pred_acc, step) summary_writer.image("target", test_case["pixels"], step) if jax.host_id() != 0: # Only log via host 0. continue if step % FLAGS.print_every == 0: steps_per_sec = FLAGS.print_every / (time.time() - t_loop_start) t_loop_start = time.time() rays_per_sec = FLAGS.batch_size * steps_per_sec summary_writer.scalar("loss", stats[0].loss[0], step) summary_writer.scalar("psnr", stats[0].psnr[0], step) summary_writer.scalar("learning_rate", lr, step) if len(stats) > 1: summary_writer.scalar("loss_coarse", stats[1].loss[0], step) summary_writer.scalar("psnr_coarse", stats[1].psnr[0], step) summary_writer.scalar("steps_per_sec", steps_per_sec, step) summary_writer.scalar("rays_per_sec", rays_per_sec, step) precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1 print(("{:" + "{:d}".format(precision) + "d}").format(step) + f"/{FLAGS.max_steps:d}: " + f"loss={stats[0].loss[0]:0.5f}, " + f"{rays_per_sec:0.3f} rays/sec") if step % FLAGS.save_every == 0: state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state)) checkpoints.save_checkpoint(FLAGS.train_dir, state_to_save, state_to_save.step, keep=100) # --- Train logs end --- if FLAGS.max_steps % FLAGS.save_every != 0: state = jax.device_get(jax.tree_map(lambda x: x[0], state)) checkpoints.save_checkpoint(FLAGS.train_dir, state, int(state.step), keep=100)
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): """Execute model training and evaluation loop. Args: config: Hyperparameter configuration for training and evaluation. workdir: Directory where the tensorboard summaries are written to. """ if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(workdir) summary_writer.hparams(dict(config)) rng = random.PRNGKey(0) image_size = 224 if config.batch_size % jax.device_count() > 0: raise ValueError( 'Batch size must be divisible by the number of devices') local_batch_size = config.batch_size // jax.host_count() platform = jax.local_devices()[0].platform if config.half_precision: if platform == 'tpu': input_dtype = tf.bfloat16 else: input_dtype = tf.float16 else: input_dtype = tf.float32 dataset_builder = tfds.builder('imagenet2012:5.*.*') train_iter = create_input_iter(dataset_builder, local_batch_size, image_size, input_dtype, train=True, cache=config.cache) eval_iter = create_input_iter(dataset_builder, local_batch_size, image_size, input_dtype, train=False, cache=config.cache) steps_per_epoch = (dataset_builder.info.splits['train'].num_examples // config.batch_size) if config.num_train_steps == -1: num_steps = steps_per_epoch * config.num_epochs else: num_steps = config.num_train_steps if config.steps_per_eval == -1: num_validation_examples = dataset_builder.info.splits[ 'validation'].num_examples steps_per_eval = num_validation_examples // config.batch_size else: steps_per_eval = config.steps_per_eval steps_per_checkpoint = steps_per_epoch * 10 base_learning_rate = config.learning_rate * config.batch_size / 256. model_cls = getattr(models, config.model) model = create_model(model_cls=model_cls, half_precision=config.half_precision) state = create_train_state(rng, config, model, image_size) state = restore_checkpoint(state, workdir) # step_offset > 0 if restarting from checkpoint step_offset = int(state.step) state = jax_utils.replicate(state) learning_rate_fn = create_learning_rate_fn(base_learning_rate, steps_per_epoch, config.num_epochs) p_train_step = jax.pmap(functools.partial( train_step, model.apply, learning_rate_fn=learning_rate_fn), axis_name='batch') p_eval_step = jax.pmap(functools.partial(eval_step, model.apply), axis_name='batch') epoch_metrics = [] t_loop_start = time.time() logging.info('Initial compilation, this might take some minutes...') for step, batch in zip(range(step_offset, num_steps), train_iter): state, metrics = p_train_step(state, batch) epoch_metrics.append(metrics) if (step + 1) % steps_per_epoch == 0: epoch = step // steps_per_epoch epoch_metrics = common_utils.get_metrics(epoch_metrics) summary = jax.tree_map(lambda x: x.mean(), epoch_metrics) logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) steps_per_sec = steps_per_epoch / (time.time() - t_loop_start) t_loop_start = time.time() if jax.host_id() == 0: for key, vals in epoch_metrics.items(): tag = 'train_%s' % key for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) summary_writer.scalar('steps per second', steps_per_sec, step) epoch_metrics = [] eval_metrics = [] # sync batch statistics across replicas state = sync_batch_stats(state) for _ in range(steps_per_eval): eval_batch = next(eval_iter) metrics = p_eval_step(state, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) summary = jax.tree_map(lambda x: x.mean(), eval_metrics) logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) if jax.host_id() == 0: for key, val in eval_metrics.items(): tag = 'eval_%s' % key summary_writer.scalar(tag, val.mean(), step) summary_writer.flush() if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps: state = sync_batch_stats(state) save_checkpoint(state, workdir) # Wait until computations are done before exiting jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
def main(): args = parse_args() # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Setup logging, we only want one process per machine to log things on the screen. logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) if jax.process_index() == 0: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_info() else: datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named # label if at least two columns are provided. # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this # single column. You can easily tweak this behavior (see below) # In distributed training, the load_dataset function guarantee that only one local process can concurrently # download the dataset. if args.task_name is not None: # Downloading and loading a dataset from the hub. raw_datasets = load_dataset("glue", args.task_name) else: # Loading the dataset from local csv or json file. data_files = {} if args.train_file is not None: data_files["train"] = args.train_file if args.validation_file is not None: data_files["validation"] = args.validation_file extension = (args.train_file if args.train_file is not None else args.valid_file).split(".")[-1] raw_datasets = load_dataset(extension, data_files=data_files) # See more about loading any type of standard or custom dataset at # https://huggingface.co/docs/datasets/loading_datasets.html. # Labels if args.task_name is not None: is_regression = args.task_name == "stsb" if not is_regression: label_list = raw_datasets["train"].features["label"].names num_labels = len(label_list) else: num_labels = 1 else: # Trying to have good defaults here, don't hesitate to tweak to your needs. is_regression = raw_datasets["train"].features["label"].dtype in [ "float32", "float64" ] if is_regression: num_labels = 1 else: # A useful fast method: # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique label_list = raw_datasets["train"].unique("label") label_list.sort() # Let's sort it for determinism num_labels = len(label_list) # Load pretrained model and tokenizer config = AutoConfig.from_pretrained(args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name) tokenizer = AutoTokenizer.from_pretrained( args.model_name_or_path, use_fast=not args.use_slow_tokenizer) model = FlaxAutoModelForSequenceClassification.from_pretrained( args.model_name_or_path, config=config) # Preprocessing the datasets if args.task_name is not None: sentence1_key, sentence2_key = task_to_keys[args.task_name] else: # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. non_label_column_names = [ name for name in raw_datasets["train"].column_names if name != "label" ] if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: sentence1_key, sentence2_key = "sentence1", "sentence2" else: if len(non_label_column_names) >= 2: sentence1_key, sentence2_key = non_label_column_names[:2] else: sentence1_key, sentence2_key = non_label_column_names[0], None # Some models have set the order of the labels to use, so let's make sure we do use it. label_to_id = None if (model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id and args.task_name is not None and not is_regression): # Some have all caps in their config, some don't. label_name_to_id = { k.lower(): v for k, v in model.config.label2id.items() } if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): logger.info( f"The configuration of the model provided the following label correspondence: {label_name_to_id}. " "Using it!") label_to_id = { i: label_name_to_id[label_list[i]] for i in range(num_labels) } else: logger.warning( "Your model seems to have been trained with labels, but they don't match the dataset: ", f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." "\nIgnoring the model labels as a result.", ) elif args.task_name is None: label_to_id = {v: i for i, v in enumerate(label_list)} def preprocess_function(examples): # Tokenize the texts texts = ((examples[sentence1_key], ) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])) result = tokenizer(*texts, padding="max_length", max_length=args.max_length, truncation=True) if "label" in examples: if label_to_id is not None: # Map labels to IDs (not necessary for GLUE tasks) result["labels"] = [label_to_id[l] for l in examples["label"]] else: # In all cases, rename the column to labels because the model will expect that. result["labels"] = examples["label"] return result processed_datasets = raw_datasets.map( preprocess_function, batched=True, remove_columns=raw_datasets["train"].column_names) train_dataset = processed_datasets["train"] eval_dataset = processed_datasets["validation_matched" if args.task_name == "mnli" else "validation"] # Log a few random samples from the training set: for index in random.sample(range(len(train_dataset)), 3): logger.info( f"Sample {index} of the training set: {train_dataset[index]}.") # Define a summary writer summary_writer = tensorboard.SummaryWriter(args.output_dir) summary_writer.hparams(vars(args)) def write_metric(train_metrics, eval_metrics, train_time, step): summary_writer.scalar("train_time", train_time, step) train_metrics = get_metrics(train_metrics) for key, vals in train_metrics.items(): tag = f"train_{key}" for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) for metric_name, value in eval_metrics.items(): summary_writer.scalar(f"eval_{metric_name}", value, step) num_epochs = int(args.num_train_epochs) rng = jax.random.PRNGKey(args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) train_batch_size = args.per_device_train_batch_size * jax.local_device_count( ) eval_batch_size = args.per_device_eval_batch_size * jax.local_device_count( ) learning_rate_fn = create_learning_rate_fn(len(train_dataset), train_batch_size, args.num_train_epochs, args.num_warmup_steps, args.learning_rate) state = create_train_state(model, learning_rate_fn, is_regression, num_labels=num_labels, weight_decay=args.weight_decay) # define step functions def train_step( state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey) -> Tuple[train_state.TrainState, float]: """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`.""" dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) targets = batch.pop("labels") def loss_fn(params): logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss = state.loss_fn(logits, targets) return loss grad_fn = jax.value_and_grad(loss_fn) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) metrics = jax.lax.pmean( { "loss": loss, "learning_rate": learning_rate_fn(state.step) }, axis_name="batch") return new_state, metrics, new_dropout_rng p_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0, )) def eval_step(state, batch): logits = state.apply_fn(**batch, params=state.params, train=False)[0] return state.logits_fn(logits) p_eval_step = jax.pmap(eval_step, axis_name="batch") if args.task_name is not None: metric = load_metric("glue", args.task_name) else: metric = load_metric("accuracy") logger.info(f"===== Starting training ({num_epochs} epochs) =====") train_time = 0 # make sure weights are replicated on each device state = replicate(state) for epoch in range(1, num_epochs + 1): logger.info(f"Epoch {epoch}") logger.info(" Training...") train_start = time.time() train_metrics = [] rng, input_rng = jax.random.split(rng) # train for batch in glue_train_data_collator(input_rng, train_dataset, train_batch_size): state, metrics, dropout_rngs = p_train_step( state, batch, dropout_rngs) train_metrics.append(metrics) train_time += time.time() - train_start logger.info(f" Done! Training metrics: {unreplicate(metrics)}") logger.info(" Evaluating...") # evaluate for batch in glue_eval_data_collator(eval_dataset, eval_batch_size): labels = batch.pop("labels") predictions = p_eval_step(state, batch) metric.add_batch(predictions=chain(*predictions), references=chain(*labels)) # evaluate also on leftover examples (not divisible by batch_size) num_leftover_samples = len(eval_dataset) % eval_batch_size # make sure leftover batch is evaluated on one device if num_leftover_samples > 0 and jax.process_index() == 0: # take leftover samples batch = eval_dataset[-num_leftover_samples:] batch = {k: jnp.array(v) for k, v in batch.items()} labels = batch.pop("labels") predictions = eval_step(unreplicate(state), batch) metric.add_batch(predictions=predictions, references=labels) eval_metric = metric.compute() logger.info(f" Done! Eval metrics: {eval_metric}") cur_step = epoch * (len(train_dataset) // train_batch_size) write_metric(train_metrics, eval_metric, train_time, cur_step) # save last checkpoint if jax.process_index() == 0: params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) model.save_pretrained(args.output_dir, params=params)
def main(): # region Argument parsing # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser( (ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file( json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses( ) # endregion # region Logging # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Setup logging, we only want one process per machine to log things on the screen. logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) if jax.process_index() == 0: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_info() else: datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() # endregion # Handle the repository creation if training_args.push_to_hub: if training_args.hub_model_id is None: repo_name = get_full_repo_name(Path( training_args.output_dir).absolute().name, token=training_args.hub_token) else: repo_name = training_args.hub_model_id repo = Repository(training_args.output_dir, clone_from=repo_name) # region Load Data # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # (the dataset will be downloaded automatically from the datasets Hub). # # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called # 'text' is found. You can easily tweak this behavior (see below). # # In distributed training, the load_dataset function guarantee that only one local process can concurrently # download the dataset. if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. raw_datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir) else: # Loading the dataset from local csv or json file. data_files = {} if data_args.train_file is not None: data_files["train"] = data_args.train_file extension = data_args.train_file.split(".")[-1] if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file extension = data_args.validation_file.split(".")[-1] if data_args.test_file is not None: data_files["test"] = data_args.test_file extension = data_args.test_file.split(".")[-1] raw_datasets = load_dataset(extension, data_files=data_files, field="data", cache_dir=model_args.cache_dir) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. # endregion # region Load pretrained model and tokenizer # # Load pretrained model and tokenizer config = AutoConfig.from_pretrained( model_args.config_name if model_args.config_name else model_args.model_name_or_path, cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, ) tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=True, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, ) # endregion # region Tokenizer check: this script requires a fast tokenizer. if not isinstance(tokenizer, PreTrainedTokenizerFast): raise ValueError( "This example script only works for models that have a fast tokenizer. Checkout the big table of models " "at https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet this " "requirement") # endregion # region Preprocessing the datasets # Preprocessing is slightly different for training and evaluation. if training_args.do_train: column_names = raw_datasets["train"].column_names elif training_args.do_eval: column_names = raw_datasets["validation"].column_names else: column_names = raw_datasets["test"].column_names question_column_name = "question" if "question" in column_names else column_names[ 0] context_column_name = "context" if "context" in column_names else column_names[ 1] answer_column_name = "answers" if "answers" in column_names else column_names[ 2] # Padding side determines if we do (question|context) or (context|question). pad_on_right = tokenizer.padding_side == "right" if data_args.max_seq_length > tokenizer.model_max_length: logger.warning( f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." ) max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) # Training preprocessing def prepare_train_features(examples): # Some of the questions have lots of whitespace on the left, which is not useful and will make the # truncation of the context fail (the tokenized question will take a lots of space). So we remove that # left whitespace examples[question_column_name] = [ q.lstrip() for q in examples[question_column_name] ] # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results # in one example possible giving several features when a context is long, each of those features having a # context that overlaps a bit the context of the previous feature. tokenized_examples = tokenizer( examples[ question_column_name if pad_on_right else context_column_name], examples[ context_column_name if pad_on_right else question_column_name], truncation="only_second" if pad_on_right else "only_first", max_length=max_seq_length, stride=data_args.doc_stride, return_overflowing_tokens=True, return_offsets_mapping=True, padding="max_length", ) # Since one example might give us several features if it has a long context, we need a map from a feature to # its corresponding example. This key gives us just that. sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") # The offset mappings will give us a map from token to character position in the original context. This will # help us compute the start_positions and end_positions. offset_mapping = tokenized_examples.pop("offset_mapping") # Let's label those examples! tokenized_examples["start_positions"] = [] tokenized_examples["end_positions"] = [] for i, offsets in enumerate(offset_mapping): # We will label impossible answers with the index of the CLS token. input_ids = tokenized_examples["input_ids"][i] cls_index = input_ids.index(tokenizer.cls_token_id) # Grab the sequence corresponding to that example (to know what is the context and what is the question). sequence_ids = tokenized_examples.sequence_ids(i) # One example can give several spans, this is the index of the example containing this span of text. sample_index = sample_mapping[i] answers = examples[answer_column_name][sample_index] # If no answers are given, set the cls_index as answer. if len(answers["answer_start"]) == 0: tokenized_examples["start_positions"].append(cls_index) tokenized_examples["end_positions"].append(cls_index) else: # Start/end character index of the answer in the text. start_char = answers["answer_start"][0] end_char = start_char + len(answers["text"][0]) # Start token index of the current span in the text. token_start_index = 0 while sequence_ids[token_start_index] != (1 if pad_on_right else 0): token_start_index += 1 # End token index of the current span in the text. token_end_index = len(input_ids) - 1 while sequence_ids[token_end_index] != (1 if pad_on_right else 0): token_end_index -= 1 # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index). if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char): tokenized_examples["start_positions"].append(cls_index) tokenized_examples["end_positions"].append(cls_index) else: # Otherwise move the token_start_index and token_end_index to the two ends of the answer. # Note: we could go after the last offset if the answer is the last word (edge case). while token_start_index < len(offsets) and offsets[ token_start_index][0] <= start_char: token_start_index += 1 tokenized_examples["start_positions"].append( token_start_index - 1) while offsets[token_end_index][1] >= end_char: token_end_index -= 1 tokenized_examples["end_positions"].append( token_end_index + 1) return tokenized_examples processed_raw_datasets = dict() if training_args.do_train: if "train" not in raw_datasets: raise ValueError("--do_train requires a train dataset") train_dataset = raw_datasets["train"] if data_args.max_train_samples is not None: # We will select sample from whole data if agument is specified train_dataset = train_dataset.select( range(data_args.max_train_samples)) # Create train feature from dataset train_dataset = train_dataset.map( prepare_train_features, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, ) if data_args.max_train_samples is not None: # Number of samples might increase during Feature Creation, We select only specified max samples train_dataset = train_dataset.select( range(data_args.max_train_samples)) processed_raw_datasets["train"] = train_dataset # Validation preprocessing def prepare_validation_features(examples): # Some of the questions have lots of whitespace on the left, which is not useful and will make the # truncation of the context fail (the tokenized question will take a lots of space). So we remove that # left whitespace examples[question_column_name] = [ q.lstrip() for q in examples[question_column_name] ] # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results # in one example possible giving several features when a context is long, each of those features having a # context that overlaps a bit the context of the previous feature. tokenized_examples = tokenizer( examples[ question_column_name if pad_on_right else context_column_name], examples[ context_column_name if pad_on_right else question_column_name], truncation="only_second" if pad_on_right else "only_first", max_length=max_seq_length, stride=data_args.doc_stride, return_overflowing_tokens=True, return_offsets_mapping=True, padding="max_length", ) # Since one example might give us several features if it has a long context, we need a map from a feature to # its corresponding example. This key gives us just that. sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the # corresponding example_id and we will store the offset mappings. tokenized_examples["example_id"] = [] for i in range(len(tokenized_examples["input_ids"])): # Grab the sequence corresponding to that example (to know what is the context and what is the question). sequence_ids = tokenized_examples.sequence_ids(i) context_index = 1 if pad_on_right else 0 # One example can give several spans, this is the index of the example containing this span of text. sample_index = sample_mapping[i] tokenized_examples["example_id"].append( examples["id"][sample_index]) # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token # position is part of the context or not. tokenized_examples["offset_mapping"][i] = [ (o if sequence_ids[k] == context_index else None) for k, o in enumerate(tokenized_examples["offset_mapping"][i]) ] return tokenized_examples if training_args.do_eval: if "validation" not in raw_datasets: raise ValueError("--do_eval requires a validation dataset") eval_examples = raw_datasets["validation"] if data_args.max_eval_samples is not None: # We will select sample from whole data eval_examples = eval_examples.select( range(data_args.max_eval_samples)) # Validation Feature Creation eval_dataset = eval_examples.map( prepare_validation_features, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, ) if data_args.max_eval_samples is not None: # During Feature creation dataset samples might increase, we will select required samples again eval_dataset = eval_dataset.select( range(data_args.max_eval_samples)) processed_raw_datasets["validation"] = eval_dataset if training_args.do_predict: if "test" not in raw_datasets: raise ValueError("--do_predict requires a test dataset") predict_examples = raw_datasets["test"] if data_args.max_predict_samples is not None: # We will select sample from whole data predict_examples = predict_examples.select( range(data_args.max_predict_samples)) # Predict Feature Creation predict_dataset = predict_examples.map( prepare_validation_features, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, ) if data_args.max_predict_samples is not None: # During Feature creation dataset samples might increase, we will select required samples again predict_dataset = predict_dataset.select( range(data_args.max_predict_samples)) processed_raw_datasets["test"] = predict_dataset # endregion # region Metrics and Post-processing: def post_processing_function(examples, features, predictions, stage="eval"): # Post-processing: we match the start logits and end logits to answers in the original context. predictions = postprocess_qa_predictions( examples=examples, features=features, predictions=predictions, version_2_with_negative=data_args.version_2_with_negative, n_best_size=data_args.n_best_size, max_answer_length=data_args.max_answer_length, null_score_diff_threshold=data_args.null_score_diff_threshold, output_dir=training_args.output_dir, prefix=stage, ) # Format the result to the format the metric expects. if data_args.version_2_with_negative: formatted_predictions = [{ "id": k, "prediction_text": v, "no_answer_probability": 0.0 } for k, v in predictions.items()] else: formatted_predictions = [{ "id": k, "prediction_text": v } for k, v in predictions.items()] references = [{ "id": ex["id"], "answers": ex[answer_column_name] } for ex in examples] return EvalPrediction(predictions=formatted_predictions, label_ids=references) metric = load_metric( "squad_v2" if data_args.version_2_with_negative else "squad") def compute_metrics(p: EvalPrediction): return metric.compute(predictions=p.predictions, references=p.label_ids) # Create and fill numpy array of size len_of_validation_data * max_length_of_output_tensor def create_and_fill_np_array(start_or_end_logits, dataset, max_len): """ Create and fill numpy array of size len_of_validation_data * max_length_of_output_tensor Args: start_or_end_logits(:obj:`tensor`): This is the output predictions of the model. We can only enter either start or end logits. eval_dataset: Evaluation dataset max_len(:obj:`int`): The maximum length of the output tensor. ( See the model.eval() part for more details ) """ step = 0 # create a numpy array and fill it with -100. logits_concat = np.full((len(dataset), max_len), -100, dtype=np.float64) # Now since we have create an array now we will populate it with the outputs of the model. for i, output_logit in enumerate( start_or_end_logits): # populate columns # We have to fill it such that we have to take the whole tensor and replace it on the newly created array # And after every iteration we have to change the step batch_size = output_logit.shape[0] cols = output_logit.shape[1] if step + batch_size < len(dataset): logits_concat[step:step + batch_size, :cols] = output_logit else: logits_concat[step:, :cols] = output_logit[:len(dataset) - step] step += batch_size return logits_concat # endregion # region Training steps and logging init train_dataset = processed_raw_datasets["train"] eval_dataset = processed_raw_datasets["validation"] # Log a few random samples from the training set: for index in random.sample(range(len(train_dataset)), 3): logger.info( f"Sample {index} of the training set: {train_dataset[index]}.") # Define a summary writer has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter(training_args.output_dir) summary_writer.hparams({ **training_args.to_dict(), **vars(model_args), **vars(data_args) }) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable.") def write_train_metric(summary_writer, train_metrics, train_time, step): summary_writer.scalar("train_time", train_time, step) train_metrics = get_metrics(train_metrics) for key, vals in train_metrics.items(): tag = f"train_{key}" for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) def write_eval_metric(summary_writer, eval_metrics, step): for metric_name, value in eval_metrics.items(): summary_writer.scalar(f"eval_{metric_name}", value, step) num_epochs = int(training_args.num_train_epochs) rng = jax.random.PRNGKey(training_args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) train_batch_size = training_args.per_device_train_batch_size * jax.local_device_count( ) eval_batch_size = training_args.per_device_eval_batch_size * jax.local_device_count( ) # endregion # region Load model model = FlaxAutoModelForQuestionAnswering.from_pretrained( model_args.model_name_or_path, config=config, cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), ) learning_rate_fn = create_learning_rate_fn( len(train_dataset), train_batch_size, training_args.num_train_epochs, training_args.warmup_steps, training_args.learning_rate, ) state = create_train_state(model, learning_rate_fn, num_labels=max_seq_length, training_args=training_args) # endregion # region Define train step functions def train_step( state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey) -> Tuple[train_state.TrainState, float]: """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`.""" dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) start_positions = batch.pop("start_positions") end_positions = batch.pop("end_positions") targets = (start_positions, end_positions) def loss_fn(params): logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True) loss = state.loss_fn(logits, targets) return loss grad_fn = jax.value_and_grad(loss_fn) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) metrics = jax.lax.pmean( { "loss": loss, "learning_rate": learning_rate_fn(state.step) }, axis_name="batch") return new_state, metrics, new_dropout_rng p_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0, )) # endregion # region Define eval step functions def eval_step(state, batch): logits = state.apply_fn(**batch, params=state.params, train=False) return state.logits_fn(logits) p_eval_step = jax.pmap(eval_step, axis_name="batch") # endregion # region Define train and eval loop logger.info(f"===== Starting training ({num_epochs} epochs) =====") train_time = 0 # make sure weights are replicated on each device state = replicate(state) train_time = 0 step_per_epoch = len(train_dataset) // train_batch_size total_steps = step_per_epoch * num_epochs epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) for epoch in epochs: train_start = time.time() train_metrics = [] # Create sampling rng rng, input_rng = jax.random.split(rng) # train for step, batch in enumerate( tqdm( train_data_collator(input_rng, train_dataset, train_batch_size), total=step_per_epoch, desc="Training...", position=1, ), 1, ): state, train_metric, dropout_rngs = p_train_step( state, batch, dropout_rngs) train_metrics.append(train_metric) cur_step = epoch * step_per_epoch + step if cur_step % training_args.logging_steps == 0 and cur_step > 0: # Save metrics train_metric = unreplicate(train_metric) train_time += time.time() - train_start if has_tensorboard and jax.process_index() == 0: write_train_metric(summary_writer, train_metrics, train_time, cur_step) epochs.write( f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})" ) train_metrics = [] if (training_args.do_eval and (cur_step % training_args.eval_steps == 0 or cur_step % step_per_epoch == 0) and cur_step > 0): eval_metrics = {} all_start_logits = [] all_end_logits = [] # evaluate for batch in tqdm( eval_data_collator(eval_dataset, eval_batch_size), total=len(eval_dataset) // eval_batch_size, desc="Evaluating ...", position=2, ): _ = batch.pop("example_id") _ = batch.pop("offset_mapping") predictions = p_eval_step(state, batch) start_logits = np.array( [pred for pred in chain(*predictions[0])]) end_logits = np.array( [pred for pred in chain(*predictions[1])]) all_start_logits.append(start_logits) all_end_logits.append(end_logits) # evaluate also on leftover examples (not divisible by batch_size) num_leftover_samples = len(eval_dataset) % eval_batch_size # make sure leftover batch is evaluated on one device if num_leftover_samples > 0 and jax.process_index() == 0: # take leftover samples batch = eval_dataset[-num_leftover_samples:] batch = {k: np.array(v) for k, v in batch.items()} _ = batch.pop("example_id") _ = batch.pop("offset_mapping") predictions = eval_step(unreplicate(state), batch) start_logits = np.array([pred for pred in predictions[0]]) end_logits = np.array([pred for pred in predictions[1]]) all_start_logits.append(start_logits) all_end_logits.append(end_logits) max_len = max([x.shape[1] for x in all_start_logits ]) # Get the max_length of the tensor # concatenate the numpy array start_logits_concat = create_and_fill_np_array( all_start_logits, eval_dataset, max_len) end_logits_concat = create_and_fill_np_array( all_end_logits, eval_dataset, max_len) # delete the list of numpy arrays del all_start_logits del all_end_logits outputs_numpy = (start_logits_concat, end_logits_concat) prediction = post_processing_function(eval_examples, eval_dataset, outputs_numpy) eval_metrics = compute_metrics(prediction) logger.info( f"Step... ({cur_step}/{total_steps} | Evaluation metrics: {eval_metrics})" ) if has_tensorboard and jax.process_index() == 0: write_eval_metric(summary_writer, eval_metrics, cur_step) if (cur_step % training_args.save_steps == 0 and cur_step > 0) or (cur_step == total_steps): # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(unreplicate(state.params)) model.save_pretrained(training_args.output_dir, params=params) tokenizer.save_pretrained(training_args.output_dir) if training_args.push_to_hub: repo.push_to_hub( commit_message= f"Saving weights and logs of step {cur_step}", blocking=False) epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}" # endregion # Eval after training if training_args.do_eval: eval_metrics = {} all_start_logits = [] all_end_logits = [] eva_loader = eval_data_collator(eval_dataset, eval_batch_size) for batch in tqdm(eva_loader, total=len(eval_dataset) // eval_batch_size, desc="Evaluating ...", position=2): _ = batch.pop("example_id") _ = batch.pop("offset_mapping") predictions = p_eval_step(state, batch) start_logits = np.array([pred for pred in chain(*predictions[0])]) end_logits = np.array([pred for pred in chain(*predictions[1])]) all_start_logits.append(start_logits) all_end_logits.append(end_logits) # evaluate also on leftover examples (not divisible by batch_size) num_leftover_samples = len(eval_dataset) % eval_batch_size # make sure leftover batch is evaluated on one device if num_leftover_samples > 0 and jax.process_index() == 0: # take leftover samples batch = eval_dataset[-num_leftover_samples:] batch = {k: np.array(v) for k, v in batch.items()} _ = batch.pop("example_id") _ = batch.pop("offset_mapping") predictions = eval_step(unreplicate(state), batch) start_logits = np.array([pred for pred in predictions[0]]) end_logits = np.array([pred for pred in predictions[1]]) all_start_logits.append(start_logits) all_end_logits.append(end_logits) max_len = max([x.shape[1] for x in all_start_logits ]) # Get the max_length of the tensor # concatenate the numpy array start_logits_concat = create_and_fill_np_array(all_start_logits, eval_dataset, max_len) end_logits_concat = create_and_fill_np_array(all_end_logits, eval_dataset, max_len) # delete the list of numpy arrays del all_start_logits del all_end_logits outputs_numpy = (start_logits_concat, end_logits_concat) prediction = post_processing_function(eval_examples, eval_dataset, outputs_numpy) eval_metrics = compute_metrics(prediction) if jax.process_index() == 0: eval_metrics = { f"eval_{metric_name}": value for metric_name, value in eval_metrics.items() } path = os.path.join(training_args.output_dir, "eval_results.json") with open(path, "w") as f: json.dump(eval_metrics, f, indent=4, sort_keys=True)
def train_and_evaluate(config, workdir): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ if config.dataset.batch_size % jax.device_count() != 0: raise ValueError( "Batch size must be divisible by the number of devices.") tf.io.gfile.makedirs(workdir) # Deterministic training. rng = jax.random.PRNGKey(config.seed) # Shift the numpy random seed by process_index() to shuffle data loaded # by different hosts np.random.seed(20201473 + jax.process_index()) #---------------------------------------------------------------------------- # Build input pipeline. rng, data_rng = jax.random.split(rng) data_rng = jax.random.fold_in(data_rng, jax.process_index()) scene_path_list = train_utils.get_train_scene_list(config) train_ds = datasets.create_train_dataset(config, scene_path_list[0]) _, eval_ds_dict = datasets.create_eval_dataset(config) _, eval_ds = eval_ds_dict.popitem() example_batch = train_ds.peek() #---------------------------------------------------------------------------- # Learning rate schedule. num_train_steps = config.train.max_steps if num_train_steps == -1: num_train_steps = train_ds.size() steps_per_epoch = num_train_steps // config.train.num_epochs logging.info("num_train_steps=%d, steps_per_epoch=%d", num_train_steps, steps_per_epoch) learning_rate_fn = train_utils.create_learning_rate_fn(config) #---------------------------------------------------------------------------- # Initialize model. rng, model_rng = jax.random.split(rng) model, state = models.create_train_state( config, model_rng, learning_rate_fn=learning_rate_fn, example_batch=example_batch, ) #---------------------------------------------------------------------------- # Set up checkpointing of the model and the input pipeline. # check if the job was stopped and relaunced latest_ckpt = checkpoints.latest_checkpoint(workdir) if latest_ckpt is None: # No previous checkpoint. Then check for pretrained weights. if config.train.pretrain_dir: state = checkpoints.restore_checkpoint(config.train.pretrain_dir, state) else: state = checkpoints.restore_checkpoint(workdir, state) initial_step = int(state.step) + 1 step_per_scene = config.train.switch_scene_iter if config.dev_run: jnp.set_printoptions(precision=2) np.set_printoptions(precision=2) step_per_scene = 3 #---------------------------------------------------------------------------- # Distribute training. state = flax_utils.replicate(state) p_train_step = jax.pmap( functools.partial( train_step, model=model, learning_rate_fn=learning_rate_fn, weight_decay=config.train.weight_decay, config=config, ), axis_name="batch", ) # Get distributed rendering function render_pfn = render_utils.get_render_function( model=model, config=config, randomized=False, # No randomization for evaluation. ) #---------------------------------------------------------------------------- # Prepare Metric Writers writer = metric_writers.create_default_writer( workdir, just_logging=jax.process_index() > 0) if initial_step == 1: writer.write_hparams(dict(config)) logging.info("Starting training loop at step %d.", initial_step) hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=num_train_steps, writer=writer) if jax.process_index() == 0: hooks += [ report_progress, ] train_metrics = None # Prefetch_buffer_size = 6 x batch_size ptrain_ds = flax.jax_utils.prefetch_to_device(train_ds, 6) n_local_devices = jax.local_device_count() rng = rng + jax.process_index() # Make random seed separate across hosts. keys = jax.random.split(rng, n_local_devices) # For pmapping RNG keys. with metric_writers.ensure_flushes(writer): for step in range(initial_step, num_train_steps + 1): # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU # devices. if step % step_per_scene == 0: scene_idx = np.random.randint(len(scene_path_list)) logging.info("Loading scene {}".format( scene_path_list[scene_idx])) # pylint: disable=logging-format-interpolation curr_scene = scene_path_list[scene_idx] if config.dataset.name == "dtu": # lighting can take values between 0 and 6 (both included) config.dataset.dtu_light_idx = np.random.randint(low=0, high=7) train_ds = datasets.create_train_dataset(config, curr_scene) ptrain_ds = flax.jax_utils.prefetch_to_device(train_ds, 6) is_last_step = step == num_train_steps with jax.profiler.StepTraceAnnotation("train", step_num=step): batch = next(ptrain_ds) state, metrics_update, keys = p_train_step(rng=keys, state=state, batch=batch) metric_update = flax_utils.unreplicate(metrics_update) train_metrics = (metric_update if train_metrics is None else train_metrics.merge(metric_update)) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) for h in hooks: h(step) if step % config.train.log_loss_every_steps == 0 or is_last_step: writer.write_scalars(step, train_metrics.compute()) train_metrics = None if step % config.train.render_every_steps == 0 or is_last_step: test_batch = next(eval_ds) test_pixels = model_utils.uint2float( test_batch.target_view.rgb) # extract for evaluation with report_progress.timed("eval"): pred_color, pred_disp, pred_acc = eval_step( state, keys[0], test_batch, render_pfn, config) #------------------------------------------------------------------ # Log metrics and images for host 0 #------------------------------------------------------------------ if jax.process_index() == 0: psnr = model_utils.compute_psnr( ((pred_color - test_pixels)**2).mean()) ssim = 0. writer.write_scalars( step, { "train_eval/test_psnr": psnr, "train_eval/test_ssim": ssim, }) writer.write_images( step, { "test_pred_color": pred_color[None, :], "test_target": test_pixels[None, :] }) if pred_disp is not None: writer.write_images( step, {"test_pred_disp": pred_disp[None, :]}) if pred_acc is not None: writer.write_images( step, {"test_pred_acc": pred_acc[None, :]}) #------------------------------------------------------------------ if (jax.process_index() == 0) and (step % config.train.checkpoint_every_steps == 0 or is_last_step): # Write final metrics to file with file_utils.open_file( os.path.join(workdir, "train_logs.json"), "w") as f: log_dict = metric_update.compute() for k, v in log_dict.items(): log_dict[k] = v.item() f.write(json.dumps(log_dict)) with report_progress.timed("checkpoint"): state_to_save = jax.device_get( jax.tree_map(lambda x: x[0], state)) checkpoints.save_checkpoint(workdir, state_to_save, step, keep=100) logging.info("Finishing training at step %d", num_train_steps)
def experiment( model_dir='.', # pylint: disable=dangerous-default-value imagenet_subset_dir=None, dataset='cifar10', batch_size=256, eval_batch_size=1024, num_epochs=200, learning_rate=0.1, aug_imagenet_apply_colour_jitter=False, aug_imagenet_greyscale_prob=0.0, sgd_momentum=0.9, sgd_nesterov=True, lr_schedule='stepped', lr_sched_steps=[[60, 0.2], [120, 0.04], [160, 0.008]], lr_sched_halfcoslength=400.0, lr_sched_warmup=5.0, l2_reg=0.0005, weight_decay=0.0, architecture='wrn22_10', n_val=5000, n_sup=1000, teacher_alpha=0.999, anneal_teacher_alpha=False, unsupervised_regularizer='none', cons_weight=1.0, conf_thresh=0.97, conf_avg=False, cut_backg_noise=1.0, cut_prob=1.0, box_reg_scale_mode='fixed', box_reg_scale=0.25, box_reg_random_aspect_ratio=False, cow_sigma_range=(4.0, 8.0), cow_prop_range=(0.25, 1.0), mix_regularizer='none', mix_aug_separately=False, mix_logits=True, mix_weight=1.0, mix_conf_thresh=0.97, mix_conf_avg=True, mix_conf_mode='mix_prob', ict_alpha=0.1, mix_box_reg_scale_mode='fixed', mix_box_reg_scale=0.25, mix_box_reg_random_aspect_ratio=False, mix_cow_sigma_range=(4.0, 8.0), mix_cow_prop_range=(0.0, 1.0), subset_seed=12345, val_seed=131, run_seed=None, log_fn=print, checkpoints='on', on_epoch_finished_fn=None, debug=False): """Run experiment.""" if checkpoints not in {'none', 'on', 'retain'}: raise ValueError('checkpoints should be one of (none|on|retain)') if checkpoints != 'none': checkpoint_path = os.path.join(model_dir, 'checkpoint.pkl') checkpoint_new_path = os.path.join(model_dir, 'checkpoint.pkl.new') else: checkpoint_path = None checkpoint_new_path = None if dataset not in {'svhn', 'cifar10', 'cifar100', 'imagenet'}: raise ValueError('Unknown dataset \'{}\''.format(dataset)) if architecture not in { 'wrn20_10', 'wrn26_10', 'wrn26_2', 'wrn20_6_shakeshake', 'wrn26_6_shakeshake', 'wrn26_2_shakeshake', 'pyramid', 'resnet50', 'resnet101', 'resnet152', 'resnet50x2', 'resnet101x2', 'resnet152x2', 'resnet50x4', 'resnet101x4', 'resnet152x4', 'resnext50_32x4d', 'resnext101_32x8d', 'resnext152_32x4d' }: raise ValueError('Unknown architecture \'{}\''.format(architecture)) if lr_schedule not in {'constant', 'stepped', 'cosine'}: raise ValueError('Unknown LR schedule \'{}\''.format(lr_schedule)) if mix_conf_mode not in {'mix_prob', 'mix_conf'}: raise ValueError('Unknown mix_conf_mode \'{}\''.format(mix_conf_mode)) if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(model_dir) else: summary_writer = None unsup_reg, augment_twice = build_pert_reg( unsupervised_regularizer, cut_backg_noise=cut_backg_noise, cut_prob=cut_prob, box_reg_scale_mode=box_reg_scale_mode, box_reg_scale=box_reg_scale, box_reg_random_aspect_ratio=box_reg_random_aspect_ratio, cow_sigma_range=cow_sigma_range, cow_prop_range=cow_prop_range) mix_reg = build_mix_reg( mix_regularizer, ict_alpha=ict_alpha, box_reg_scale_mode=mix_box_reg_scale_mode, box_reg_scale=mix_box_reg_scale, box_reg_random_aspect_ratio=mix_box_reg_random_aspect_ratio, cow_sigma_range=mix_cow_sigma_range, cow_prop_range=mix_cow_prop_range) if run_seed is None: run_seed = subset_seed << 32 | n_val train_rng = jax.random.PRNGKey(run_seed) init_rng, train_rng = jax.random.split(train_rng) if batch_size % jax.device_count() > 0: raise ValueError('Train batch size must be divisible by the number of ' 'devices') if eval_batch_size % jax.device_count() > 0: raise ValueError('Eval batch size must be divisible by the number of ' 'devices') local_batch_size = batch_size // jax.host_count() local_eval_batch_size = eval_batch_size // jax.host_count() device_batch_size = batch_size // jax.device_count() if dataset == 'svhn': image_size = 32 top5_err_required = False data_source = small_image_data_source.SVHNDataSource( n_val=n_val, n_sup=n_sup, train_batch_size=local_batch_size, eval_batch_size=local_eval_batch_size, augment_twice=augment_twice, subset_seed=subset_seed, val_seed=val_seed) elif dataset == 'cifar10': image_size = 32 top5_err_required = False data_source = small_image_data_source.CIFAR10DataSource( n_val=n_val, n_sup=n_sup, train_batch_size=local_batch_size, eval_batch_size=local_eval_batch_size, augment_twice=augment_twice, subset_seed=subset_seed, val_seed=val_seed) elif dataset == 'cifar100': image_size = 32 top5_err_required = False data_source = small_image_data_source.CIFAR100DataSource( n_val=n_val, n_sup=n_sup, train_batch_size=local_batch_size, eval_batch_size=local_eval_batch_size, augment_twice=augment_twice, subset_seed=subset_seed, val_seed=val_seed) elif dataset == 'imagenet': image_size = 224 top5_err_required = True if imagenet_subset_dir is None: raise ValueError( 'Please provide a directory to the imagenet_subset_dir ' 'command line arg to specify where the ImageNet ' 'subsets are stored') data_source = imagenet_data_source.ImageNetDataSource( imagenet_subset_dir, n_val, n_sup, local_batch_size, local_eval_batch_size, augment_twice, apply_colour_jitter=aug_imagenet_apply_colour_jitter, greyscale_prob=aug_imagenet_greyscale_prob, load_test_set=(n_val == 0), image_size=image_size, subset_seed=subset_seed, val_seed=val_seed) else: raise RuntimeError n_train = data_source.n_train train_ds = data_source.train_semisup_ds if n_val == 0: eval_ds = data_source.test_ds n_eval = data_source.n_test else: eval_ds = data_source.val_ds n_eval = data_source.n_val log_fn( 'DATA: |train|={}, |sup|={}, |eval|={}, (|val|={}, |test|={})'.format( data_source.n_train, data_source.n_sup, n_eval, data_source.n_val, data_source.n_test)) log_fn('Loaded dataset') steps_per_epoch = n_train // batch_size steps_per_eval = n_eval // eval_batch_size if n_eval % eval_batch_size > 0: steps_per_eval += 1 num_steps = steps_per_epoch * num_epochs # Create model model_stu, state_stu = create_model(init_rng, architecture, device_batch_size, image_size, data_source.n_classes) state_stu = jax_utils.replicate(state_stu) log_fn('Built model') # Create optimizer optimizer_def = optim.Momentum(learning_rate=learning_rate, beta=sgd_momentum, nesterov=sgd_nesterov) optimizer_stu = optimizer_def.create(model_stu) optimizer_stu = optimizer_stu.replicate() del model_stu # don't keep a copy of the initial model # Create learning rate function base_learning_rate = learning_rate * batch_size / 256. if lr_schedule == 'constant': learning_rate_fn = create_constant_learning_rate_fn(base_learning_rate) elif lr_schedule == 'stepped': learning_rate_fn = create_stepped_learning_rate_fn( base_learning_rate, steps_per_epoch, lr_sched_steps=lr_sched_steps, warmup_length=lr_sched_warmup) elif lr_schedule == 'cosine': learning_rate_fn = create_cosine_learning_rate_fn( base_learning_rate, steps_per_epoch, halfcoslength_epochs=lr_sched_halfcoslength, warmup_length=lr_sched_warmup) else: raise RuntimeError if anneal_teacher_alpha: if lr_schedule == 'constant': one_minus_alpha_fn = create_constant_learning_rate_fn( 1.0 - teacher_alpha) elif lr_schedule == 'stepped': one_minus_alpha_fn = create_stepped_learning_rate_fn( 1.0 - teacher_alpha, steps_per_epoch, lr_sched_steps=lr_sched_steps) elif lr_schedule == 'cosine': one_minus_alpha_fn = create_cosine_learning_rate_fn( 1.0 - teacher_alpha, steps_per_epoch, halfcoslength_epochs=lr_sched_halfcoslength) else: raise RuntimeError teacher_alpha_fn = lambda step: 1.0 - one_minus_alpha_fn(step) else: teacher_alpha_fn = lambda step: teacher_alpha log_fn('Built optimizer') # Teacher model is just the student as we duplicate it when we modify it model_tea = optimizer_stu.target # Replicate batch stats state_tea = jax.tree_map(lambda x: x, state_stu) # Set up epoch and step counter # Load existing checkpoint if available epoch = 1 step = 0 if checkpoints != 'none': if tf.io.gfile.exists(checkpoint_path): with tf.io.gfile.GFile(checkpoint_path, 'rb') as f_in: check = pickle.load(f_in) # Student optimizer and batch stats optimizer_stu = util.restore_state_list( optimizer_stu, check['optimizer_stu']) state_stu = util.restore_state_list(state_stu, check['state_stu']) # Teacher model and batch stats model_tea = util.restore_state_list(model_tea, check['model_tea']) state_tea = util.restore_state_list(state_tea, check['state_tea']) epoch = check['epoch'] step = check['step'] log_fn('Loaded checkpoint from {}'.format(checkpoint_path)) # # Training and evaluation step functions # p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn, l2_reg=l2_reg, weight_decay=weight_decay, teacher_alpha_fn=teacher_alpha_fn, unsup_reg=unsup_reg, cons_weight=cons_weight, conf_thresh=conf_thresh, conf_avg=conf_avg, mix_reg=mix_reg, mix_aug_separately=mix_aug_separately, mix_logits=mix_logits, mix_weight=mix_weight, mix_conf_thresh=mix_conf_thresh, mix_conf_avg=mix_conf_avg, mix_conf_mode=mix_conf_mode), axis_name='batch') p_eval_step = jax.pmap(functools.partial(eval_step, eval_top_5=top5_err_required), axis_name='batch') # Create dataset batch iterators train_iter = iter(train_ds) eval_iter = iter(eval_ds) # # Training loop # log_fn('Training...') epoch_metrics_stu = [] t1 = time.time() while step < num_steps: train_rng, iter_rng = jax.random.split(train_rng) batch = next(train_iter) batch = jax.tree_map(lambda x: x._numpy(), batch) # pylint: disable=protected-access batch = shard(batch, iter_rng) optimizer_stu, state_stu, metrics_stu, model_tea, state_tea = p_train_step( optimizer_stu, state_stu, model_tea, state_tea, batch) if debug: log_fn('Step {} time {}'.format(step, time.time() - t1)) epoch_metrics_stu.append(metrics_stu) if (step + 1) % steps_per_epoch == 0: epoch_metrics_stu = util.get_metrics(epoch_metrics_stu) train_epoch_metrics = jax.tree_map(lambda x: x.mean(), epoch_metrics_stu) if summary_writer is not None: for key, vals in epoch_metrics_stu.items(): tag = 'train_%s' % key for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) epoch_metrics_stu = [] eval_stu_metrics = [] eval_tea_metrics = [] for _ in range(steps_per_eval): eval_batch = next(eval_iter) # TF to NumPy eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access # Pad short batches eval_batch = util.pad_classification_batch( eval_batch, local_eval_batch_size) # Shard across local devices eval_batch = shard(eval_batch) metrics_stu = p_eval_step(optimizer_stu.target, state_stu, eval_batch) metrics_tea = p_eval_step(model_tea, state_tea, eval_batch) eval_stu_metrics.append(metrics_stu) eval_tea_metrics.append(metrics_tea) eval_stu_metrics = util.get_metrics(eval_stu_metrics) eval_tea_metrics = util.get_metrics(eval_tea_metrics) eval_stu_epoch_metrics = jax.tree_map(lambda x: x.sum(), eval_stu_metrics) eval_tea_epoch_metrics = jax.tree_map(lambda x: x.sum(), eval_tea_metrics) eval_stu_epoch_metrics = avg_eval_metrics(eval_stu_epoch_metrics) eval_tea_epoch_metrics = avg_eval_metrics(eval_tea_epoch_metrics) t2 = time.time() if top5_err_required: log_fn( 'EPOCH {} (took {:.3f}s): Train loss={:.6f}, err={:.3%}, ' 'cons loss={:.6f}, conf rate={:.3%}, mix loss={:.6f}, ' 'mix conf rate={:.3%}; STU Eval loss={:.6f}, err={:.3%}, ' 'top-5-err={:.3%}, TEA Eval loss={:.6f}, err={:.3%}, ' 'top-5-err={:.3%}'.format( epoch, t2 - t1, train_epoch_metrics['loss'], train_epoch_metrics['error_rate'], train_epoch_metrics['cons_loss'], train_epoch_metrics['conf_rate'], train_epoch_metrics['mix_loss'], train_epoch_metrics['mix_conf_rate'], eval_stu_epoch_metrics['loss'], eval_stu_epoch_metrics['error_rate'], eval_stu_epoch_metrics['top5_error_rate'], eval_tea_epoch_metrics['loss'], eval_tea_epoch_metrics['error_rate'], eval_tea_epoch_metrics['top5_error_rate'], )) else: log_fn( 'EPOCH {} (took {:.3f}s): Train loss={:.6f}, err={:.3%}, ' 'cons loss={:.6f}, conf rate={:.3%}, mix loss={:.6f}, ' 'mix conf rate={:.3%}; STU Eval loss={:.6f}, err={:.3%}, ' 'TEA Eval loss={:.6f}, err={:.3%}'.format( epoch, t2 - t1, train_epoch_metrics['loss'], train_epoch_metrics['error_rate'], train_epoch_metrics['cons_loss'], train_epoch_metrics['conf_rate'], train_epoch_metrics['mix_loss'], train_epoch_metrics['mix_conf_rate'], eval_stu_epoch_metrics['loss'], eval_stu_epoch_metrics['error_rate'], eval_tea_epoch_metrics['loss'], eval_tea_epoch_metrics['error_rate'], )) if on_epoch_finished_fn is not None: if top5_err_required: on_epoch_finished_fn( epoch, eval_stu_err=eval_stu_epoch_metrics['error_rate'], eval_tea_err=eval_tea_epoch_metrics['error_rate'], eval_stu_top5_err=eval_stu_epoch_metrics[ 'top5_error_rate'], eval_tea_top5_err=eval_tea_epoch_metrics[ 'top5_error_rate'], ) else: on_epoch_finished_fn( epoch, eval_stu_err=eval_stu_epoch_metrics['error_rate'], eval_tea_err=eval_tea_epoch_metrics['error_rate'], ) t1 = t2 if summary_writer is not None: summary_writer.scalar('eval_stu_loss', eval_stu_epoch_metrics['loss'], epoch) summary_writer.scalar('eval_stu_error_rate', eval_stu_epoch_metrics['error_rate'], epoch) summary_writer.scalar('eval_tea_loss', eval_tea_epoch_metrics['loss'], epoch) summary_writer.scalar('eval_tea_error_rate', eval_tea_epoch_metrics['error_rate'], epoch) if top5_err_required: summary_writer.scalar( 'eval_stu_top5_error_rate', eval_stu_epoch_metrics['top5_error_rate'], epoch) summary_writer.scalar( 'eval_tea_top5_error_rate', eval_tea_epoch_metrics['top5_error_rate'], epoch) summary_writer.flush() epoch += 1 if checkpoints != 'none': if jax.host_id() == 0: # Write to new checkpoint file so that we don't immediately # overwrite the old one with tf.io.gfile.GFile(checkpoint_new_path, 'wb') as f_out: check = dict( optimizer_stu=util.to_state_list( optimizer_stu), state_stu=util.to_state_list(state_stu), model_tea=util.to_state_list(model_tea), state_tea=util.to_state_list(state_tea), epoch=epoch, step=step + 1, ) pickle.dump(check, f_out) del check # Remove old checkpoint and rename if tf.io.gfile.exists(checkpoint_path): tf.io.gfile.remove(checkpoint_path) tf.io.gfile.rename(checkpoint_new_path, checkpoint_path) step += 1 if checkpoints == 'on': if jax.host_id() == 0: if tf.io.gfile.exists(checkpoint_path): tf.io.gfile.remove(checkpoint_path)
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) -> TrainState: """Execute model training and evaluation loop. Args: config: Hyperparameter configuration for training and evaluation. workdir: Directory where the tensorboard summaries are written to. Returns: Final TrainState. """ writer = metric_writers.create_default_writer( logdir=workdir, just_logging=jax.host_id() != 0) rng = random.PRNGKey(0) image_size = 224 if config.batch_size % jax.device_count() > 0: raise ValueError( 'Batch size must be divisible by the number of devices') local_batch_size = config.batch_size // jax.process_count() platform = jax.local_devices()[0].platform if config.half_precision: if platform == 'tpu': input_dtype = tf.bfloat16 else: input_dtype = tf.float16 else: input_dtype = tf.float32 dataset_builder = tfds.builder(config.dataset) train_iter = create_input_iter(dataset_builder, local_batch_size, image_size, input_dtype, train=True, cache=config.cache) eval_iter = create_input_iter(dataset_builder, local_batch_size, image_size, input_dtype, train=False, cache=config.cache) steps_per_epoch = (dataset_builder.info.splits['train'].num_examples // config.batch_size) if config.num_train_steps == -1: num_steps = int(steps_per_epoch * config.num_epochs) else: num_steps = config.num_train_steps if config.steps_per_eval == -1: num_validation_examples = dataset_builder.info.splits[ 'validation'].num_examples steps_per_eval = num_validation_examples // config.batch_size else: steps_per_eval = config.steps_per_eval steps_per_checkpoint = steps_per_epoch * 10 base_learning_rate = config.learning_rate * config.batch_size / 256. model_cls = getattr(models, config.model) model = create_model(model_cls=model_cls, half_precision=config.half_precision) learning_rate_fn = create_learning_rate_fn(config, base_learning_rate, steps_per_epoch) state = create_train_state(rng, config, model, image_size, learning_rate_fn) state = restore_checkpoint(state, workdir) # step_offset > 0 if restarting from checkpoint step_offset = int(state.step) state = jax_utils.replicate(state) p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn), axis_name='batch') p_eval_step = jax.pmap(eval_step, axis_name='batch') train_metrics = [] hooks = [] if jax.process_index() == 0: hooks += [ periodic_actions.Profile(num_profile_steps=5, logdir=workdir) ] train_metrics_last_t = time.time() logging.info('Initial compilation, this might take some minutes...') for step, batch in zip(range(step_offset, num_steps), train_iter): state, metrics = p_train_step(state, batch) for h in hooks: h(step) if step == step_offset: logging.info('Initial compilation completed.') if config.get('log_every_steps'): train_metrics.append(metrics) if (step + 1) % config.log_every_steps == 0: train_metrics = common_utils.get_metrics(train_metrics) summary = { f'train_{k}': v for k, v in jax.tree_map(lambda x: x.mean(), train_metrics).items() } summary['steps_per_second'] = config.log_every_steps / ( time.time() - train_metrics_last_t) writer.write_scalars(step + 1, summary) train_metrics = [] train_metrics_last_t = time.time() if (step + 1) % steps_per_epoch == 0: epoch = step // steps_per_epoch eval_metrics = [] # sync batch statistics across replicas state = sync_batch_stats(state) for _ in range(steps_per_eval): eval_batch = next(eval_iter) metrics = p_eval_step(state, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) summary = jax.tree_map(lambda x: x.mean(), eval_metrics) logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) writer.write_scalars( step + 1, {f'eval_{key}': val for key, val in summary.items()}) writer.flush() if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps: state = sync_batch_stats(state) save_checkpoint(state, workdir) # Wait until computations are done before exiting jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready() return state
accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask # summarize metrics metrics = { "loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum() } metrics = jax.lax.psum(metrics, axis_name="batch") return metrics p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0, )) # Replicate the train state on each device state = jax_utils.replicate(state) train_time = 0 epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) for epoch in epochs: # ======================== Training ================================ train_start = time.time() train_metrics = [] # Create sampling rng rng, input_rng = jax.random.split(rng) # Generate an epoch by shuffling sampling indices from the train dataset num_train_samples = len(tokenized_datasets["train"])
def main(_): tf.enable_v2_behavior() tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) if not gfile.isdir(FLAGS.save_dir): gfile.makedirs(FLAGS.save_dir) hparam_str_dict = json.loads(FLAGS.xm_parameters) hparam_str = ','.join([ '%s=%s' % (shorten(k), str(hparam_str_dict[k])) for k in hparam_str_dict.keys() ]) # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) batch_size = FLAGS.per_device_batch_size * n_devices io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.max_characters) predict_io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.predict_max_characters) target_shape = (FLAGS.per_device_batch_size, FLAGS.max_target_length) # Setup DSL # --------------------------------------------------------------------------- # Build token tables. if FLAGS.dataset_type in ['robust_fill', 'robust_fill_base']: spec_vocab = robust_fill_dsl.CHARACTER + input_pipeline.SEPARATOR_TOKEN spec_id_token_table = { i + 3: token for i, token in enumerate(spec_vocab) } bos_id = 1 eos_id = 2 spec_id_token_table[bos_id] = robust_fill_dsl.BOS spec_id_token_table[eos_id] = robust_fill_dsl.EOS spec_token_id_table = { token: id for id, token in spec_id_token_table.items() } spec_vocab_size = len(spec_token_id_table) + 1 # For padding. program_id_token_table, _ = dsl_tokens.build_token_tables() program_vocab_size = len(program_id_token_table) + 1 elif FLAGS.dataset_type == 'scan': # TODO(jxihong): Scan is not handled yet. raise ValueError('Unhandled dataset_type: {}'.format( FLAGS.dataset_type)) else: raise ValueError('Unhandled dataset_type: {}'.format( FLAGS.dataset_type)) # Parse io and program token sequences (for eval). def decode_io(inputs, outputs): """Convert from int tensors to strings.""" if FLAGS.dataset_type == 'robust_fill': def decode_str(s): """Decode string tokens.""" return ''.join( [spec_id_token_table[t_id] for t_id in s if t_id > 0]) inps, outs = [], [] for inp, out in zip(inputs, outputs): inps.append(decode_str(inp)) outs.append(decode_str(out)) return inps, outs elif FLAGS.dataset_type == 'scan': def decode_str(s): """Decode string tokens.""" return ' '.join( [spec_id_token_table[t_id] for t_id in s if t_id > 0]) inps = [decode_str(inp) for inp in inputs] dummy_outs = [''] * len(inps) return inps, dummy_outs else: raise ValueError('Unhandled dataset_type: {}'.format( FLAGS.dataset_type)) def decode_spec(target): """Convert from int tensor to a string.""" target = target[:np.argmax(target == eos_id)].astype(np.int32) if FLAGS.dataset_type == 'robust_fill': target = target[target != bos_id].tolist() return ''.join( [spec_id_token_table[t_id] for t_id in target if t_id > 0]) elif FLAGS.dataset_type == 'scan': # TODO(jxihong): Scan is not handled yet. raise ValueError('Unhandled dataset_type: {}'.format( FLAGS.dataset_type)) else: raise ValueError('Unhandled dataset_type: {}'.format( FLAGS.dataset_type)) def decode_program(program): """Decode program tokens into a program (program object or string).""" program = program[:np.argmax(program == eos_id) + 1].astype(np.int32) if FLAGS.dataset_type == 'robust_fill': # Returns either a Concat program object, or None. program = program[program != bos_id].tolist() try: return robust_fill_dsl.decode_program(program, program_id_token_table) except: # pylint: disable=bare-except return None # Program does not compile. elif FLAGS.dataset_type == 'scan': # Returns a string. program = program[jnp.logical_and(program != bos_id, program != eos_id)].tolist() return ' '.join(scan_vocab.decode(program, program_id_token_table)) else: raise ValueError('Unhandled dataset_type: {}'.format( FLAGS.dataset_type)) def decode_program_str(program): # pylint: disable=unused-variable """Decode program tokens into a string.""" decoded = decode_program(program) if FLAGS.dataset_type == 'robust_fill': try: return decoded.to_string() # pytype: disable=attribute-error except: # pylint: disable=bare-except return 'did not compile' else: assert isinstance(decoded, str), '{} should be string'.format(decoded) return decoded # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') if not FLAGS.dataset_filepattern: raise ValueError('Must specify filepattern to dataset.') # Training dataset. logging.info('Loading dataset from %s', FLAGS.dataset_filepattern) padded_shapes = { 'inputs': io_shape[1:], 'outputs': io_shape[1:], 'target': target_shape[1:], } logging.info('padded_shapes: %s', padded_shapes) if FLAGS.dataset_type == 'robust_fill': if FLAGS.model_type == 'spec_decomposer_model': create_dataset_fn = input_pipeline.create_robust_fill_dataset_for_spec_decomposer_model elif FLAGS.model_type == 'synthesizer_model': create_dataset_fn = input_pipeline.create_robust_fill_dataset_for_synthesizer_model else: raise ValueError(f'Unhandled model_type: {FLAGS.model_type}') elif FLAGS.dataset_type == 'scan': raise NotImplementedError() # TODO(kshi): Implement. # create_dataset_fn = input_pipeline.create_scan_dataset_from_tf_record else: raise ValueError('Unhandled dataset_type: {}'.format( FLAGS.dataset_type)) dataset = create_dataset_fn(FLAGS.dataset_filepattern, spec_token_id_table, FLAGS.num_strings_per_task) dataset = dataset.padded_batch(batch_size, padded_shapes=padded_shapes, drop_remainder=True) # Split evaluation and training. eval_ds = dataset.take(FLAGS.num_eval_steps) # Decrease batch of predict dataset to handle beam search. predict_padded_shapes = padded_shapes.copy() predict_padded_shapes['inputs'] = predict_io_shape[1:] predict_padded_shapes['outputs'] = predict_io_shape[1:] logging.info('predict_padded_shapes: %s', predict_padded_shapes) predict_ds = eval_ds.unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=predict_padded_shapes) train_ds = dataset.skip(FLAGS.num_eval_steps) if FLAGS.train_set_batches > 0: train_ds = train_ds.take(FLAGS.train_set_batches) train_ds = train_ds.repeat() test_dataset = create_dataset_fn(FLAGS.test_dataset_filepattern, spec_token_id_table, FLAGS.num_strings_per_task) test_dataset = test_dataset.padded_batch( batch_size, padded_shapes=predict_padded_shapes, drop_remainder=False) quick_test_dataset = (test_dataset.take( FLAGS.num_quick_test_steps).unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=predict_padded_shapes)) final_test_dataset = (test_dataset.take( FLAGS.num_final_test_steps).unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=predict_padded_shapes)) # Build Model and Optimizer # --------------------------------------------------------------------------- if FLAGS.model_type == 'spec_decomposer_model': output_vocab_size = spec_vocab_size elif FLAGS.model_type == 'synthesizer_model': output_vocab_size = program_vocab_size else: raise ValueError(f'Unhandled model_type: {FLAGS.model_type}') base_config = base_models.TransformerConfig( vocab_size=spec_vocab_size, output_vocab_size=output_vocab_size, shift=True, emb_dim=FLAGS.embedding_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.embedding_dim, mlp_dim=FLAGS.hidden_dim, max_len=max(FLAGS.max_characters, FLAGS.max_target_length), dropout_rate=FLAGS.dropout_rate, attention_dropout_rate=FLAGS.attention_dropout_rate, use_relative_attention=FLAGS.use_relative_attention, deterministic=False, decode=False, bos_token=bos_id, num_input_relative_position_buckets=FLAGS.num_position_buckets, max_input_distance=FLAGS.max_distance, num_output_relative_position_buckets=FLAGS.num_position_buckets, max_output_distance=FLAGS.max_distance, num_input_cross_output_relative_position_buckets=( FLAGS.num_position_buckets), max_input_cross_output_distance=FLAGS.max_distance, num_program_relative_position_buckets=FLAGS.num_position_buckets, max_program_distance=FLAGS.max_distance, num_program_cross_embed_relative_position_buckets=( FLAGS.num_position_buckets), max_program_cross_embed_distance=FLAGS. max_program_cross_embed_distance, num_flat_encoding_relative_position_buckets=( FLAGS.num_position_buckets), max_flat_encoding_distance=FLAGS.max_distance) train_config = models.DecomposeAttentionTransformerConfig( base_config=base_config, dataset_type=FLAGS.dataset_type, flat_encoded_self_attention=FLAGS.flat_encoded_self_attention) eval_config = train_config.replace(base_config=base_config.replace( deterministic=True)) predict_config = train_config.replace(base_config=base_config.replace( shift=False, deterministic=True, decode=not FLAGS.slow_decode, max_len=max(FLAGS.predict_max_characters, FLAGS.max_target_length))) rng = jax.random.PRNGKey(FLAGS.seed) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = jax.random.split(rng) dropout_rng = jax.random.split(rng, jax.local_device_count()) del rng m = models.DecomposeAttentionTransformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) optimizer_def = optim.Adam(FLAGS.lr, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(initial_variables['params']) del initial_variables # Don't keep a copy of the initial model. start_step = 0 if FLAGS.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer) # Grab last step. start_step = int(optimizer.state.step) logging.info('Found model checkpointed at step %d.', start_step) if FLAGS.finetune_start_step > 0: logging.info( 'Checking that start_step (%s) == finetune_start_step (%s)', start_step, FLAGS.finetune_start_step) assert start_step >= FLAGS.finetune_start_step steps_to_skip = start_step - FLAGS.finetune_start_step else: steps_to_skip = start_step # TODO(kshi): It is likely that this code can lead to the job stalling for # 10+ hours when restarting from a checkpoint that had been trained a long # time, possibly because dataset skipping is slow. logging.info('Skipping %s steps...', steps_to_skip) train_ds = train_ds.skip(steps_to_skip) dummy_p_train_step = jax.pmap( lambda dropout_rng: jax.random.split(dropout_rng)[1]) for _ in range(steps_to_skip): dropout_rng = dummy_p_train_step(dropout_rng) logging.info('Finished skipping steps') logging.info('Host %s has dropout_rng = %s', jax.host_id(), dropout_rng) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) # TODO(jxihong): Implement fast decoding. assert FLAGS.slow_decode, 'Fast decoding is not implemented yet.' if FLAGS.finetune_start_step <= 0: learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=FLAGS.lr) else: # Constant LR for finetuning. learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=FLAGS.lr, factors='constant') p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn, config=train_config), axis_name='batch') p_eval_step = jax.pmap(functools.partial(eval_step, eos_token=eos_id, config=eval_config), axis_name='batch') p_init_cache = jax.pmap(functools.partial( initialize_cache, max_decode_len=FLAGS.max_target_length, config=predict_config), axis_name='batch') p_pred_step = jax.pmap(functools.partial( predict_step, eos_token=eos_id, max_decode_len=FLAGS.max_target_length, config=predict_config, slow_decode=FLAGS.slow_decode), axis_name='batch', static_broadcasted_argnums=(4, )) # Main Train Loop # --------------------------------------------------------------------------- logging.info('Starting training!') metrics_all = [] tick = time.time() train_iter = train_ds.as_numpy_iterator() for step in range(start_step, FLAGS.num_train_steps): inputs, outputs, targets = load_data(next(train_iter)) optimizer, metrics, dropout_rng = p_train_step(optimizer, inputs, outputs, targets, dropout_rng=dropout_rng) metrics_all.append(metrics) is_last_step = step == FLAGS.num_train_steps - 1 # Periodic metric handling. # Training Metrics if (step and step % FLAGS.log_freq == 0) or is_last_step: logging.info('Gathering training metrics.') metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map( lambda x: x / denominator, # pylint: disable=cell-var-from-loop metrics_sums) summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) if jax.host_id() == 0: logging.info('Train in step: %d, loss: %.4f', step, summary['loss']) tock = time.time() steps_per_sec = FLAGS.log_freq / (tock - tick) tick = tock summary_writer.scalar('train/steps per second', steps_per_sec, step) for key, val in summary.items(): summary_writer.scalar('train/' + key, val, step) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] # Evaluation Metrics if (step and step % FLAGS.eval_freq == 0) or is_last_step: logging.info('Gathering evaluation metrics.') t_evaluation_start = time.time() eval_metrics = [] for batches in eval_ds.as_numpy_iterator(): inputs, outputs, targets = load_data(batches) metrics = p_eval_step(optimizer.target, inputs, outputs, targets) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) if jax.host_id() == 0: logging.info('Evaluation time: %.4f s step %d, loss: %.4f.', time.time() - t_evaluation_start, step, eval_summary['loss']) for key, val in eval_summary.items(): summary_writer.scalar('eval/' + key, val, step) summary_writer.flush() # Beam search metrics. if (step and step % FLAGS.predict_freq == 0) or is_last_step: logging.info('Gathering beam search metrics.') test_ds = final_test_dataset if is_last_step else quick_test_dataset for dataset, predict_or_test in [(predict_ds, 'predict'), (test_ds, 'test')]: for beam_size in [1, 10]: t_inference_start = time.time() total_successes = 0 total_denominator = 0 ios, targets_list, predictions, top_of_beams, scores = ([], [], [], [], []) for batches in dataset.as_numpy_iterator(): pred_batch = batches # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch['inputs'].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) # pylint: disable=cell-var-from-loop pred_batch = jax.tree_map( lambda x: pad_examples(x, padded_size), pred_batch) inputs, outputs, targets = load_data(pred_batch) cache = (p_init_cache(inputs, outputs, targets) if not FLAGS.slow_decode else None) predicted = p_pred_step(optimizer.target, inputs, outputs, cache, beam_size) predicted = tohost(predicted) inputs, outputs, targets = map( tohost, (inputs, outputs, targets)) for i, beams in enumerate(predicted): inps, outs = decode_io(inputs[i], outputs[i]) if FLAGS.model_type == 'spec_decomposer_model': ground_truth = decode_spec(targets[i]) best_prediction, score = eval_predicted_spec_decomposer_model( beams, ground_truth, decode_spec) decode_to_str_fn = decode_spec elif FLAGS.model_type == 'synthesizer_model': ground_truth = decode_program_str(targets[i]) best_prediction, score = eval_predicted_synthesizer_model( beams, inps, outs, decode_program) decode_to_str_fn = decode_program_str else: raise ValueError( f'Unknown model type {FLAGS.model_type}') if score > 0: total_successes += 1 total_denominator += 1 beams_target = [ decode_to_str_fn(beam) for beam in beams ] ios.append(' ; '.join(map(str, zip(inps, outs)))) targets_list.append(ground_truth) predictions.append(best_prediction) scores.append(score) logging.info('') logging.info('ios: %s', ios[-1]) logging.info('targets[%s]: %s', i, targets[i]) logging.info('ground_truth: %s', ground_truth) logging.info('predicted beam: %s', '\n'.join(beams_target)) logging.info('best_prediction: %s', best_prediction) logging.info('score: %s', score) logging.info('beams: %s', beams) if not ground_truth: logging.warn('ground_truth is empty!') top_of_beam = [] for index, beam in enumerate(beams[:-5:-1]): top_of_beam.append( 'index: {}, decoded: {}, tokens: {}'. format(index, decode_to_str_fn(beam), beam)) top_of_beams.append('\n\n'.join(top_of_beam)) all_total_successes, all_total_denominator = per_host_sum_pmap( jax.tree_map(np.array, (total_successes, total_denominator))) # Record beam search results as text summaries. message = [] for n in np.random.choice(np.arange(len(predictions)), 8): text = ( f'ios: {ios[n]}\n\ntarget: {targets_list[n]}\n\n' f'predicted: {predictions[n]}\n\n' f'score: {scores[n]}\n\n' f'top of beam:\n\n{top_of_beams[n]}\n\n') message.append(text) # Write to tensorboard. if jax.host_id() == 0: accuracy = 100 * all_total_successes / all_total_denominator logging.info( '%s results, step %d, beam size %d: %s / %s = %.2f%% (%.2f s)', predict_or_test, step, beam_size, all_total_successes, all_total_denominator, accuracy, time.time() - t_inference_start) summary_writer.scalar( '{}/beam-size-{}'.format(predict_or_test, beam_size), accuracy, step) summary_writer.text( '{}-samples-beam-{}'.format( predict_or_test, beam_size), '\n------\n'.join(message), step) summary_writer.flush() # Save a Checkpoint. Do this at the end of the training loop, so that if a # worker is descheduled during a round of prediction (which takes a while), # we will redo prediction upon restarting (to avoid losing data). if (step % FLAGS.checkpoint_freq == 0 and step > 0) or is_last_step: if jax.host_id() == 0: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), jax_utils.unreplicate(optimizer), step)
def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Setup logging, we only want one process per machine to log things on the screen. logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) if jax.process_index() == 0: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_info() else: datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # or just provide the name of one of the public datasets for token classification task available on the hub at https://huggingface.co/datasets/ # (the dataset will be downloaded automatically from the datasets Hub). # # For CSV/JSON files, this script will use the column called 'tokens' or the first column if no column called # 'tokens' is found. You can easily tweak this behavior (see below). # # In distributed training, the load_dataset function guarantee that only one local process can concurrently # download the dataset. if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. raw_datasets = load_dataset( data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir ) else: # Loading the dataset from local csv or json file. data_files = {} if data_args.train_file is not None: data_files["train"] = data_args.train_file if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file extension = (data_args.train_file if data_args.train_file is not None else data_args.valid_file).split(".")[-1] raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) # See more about loading any type of standard or custom dataset at # https://huggingface.co/docs/datasets/loading_datasets.html. if raw_datasets["train"] is not None: column_names = raw_datasets["train"].column_names features = raw_datasets["train"].features else: column_names = raw_datasets["validation"].column_names features = raw_datasets["validation"].features if data_args.text_column_name is not None: text_column_name = data_args.text_column_name elif "tokens" in column_names: text_column_name = "tokens" else: text_column_name = column_names[0] if data_args.label_column_name is not None: label_column_name = data_args.label_column_name elif f"{data_args.task_name}_tags" in column_names: label_column_name = f"{data_args.task_name}_tags" else: label_column_name = column_names[1] # In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the # unique labels. def get_label_list(labels): unique_labels = set() for label in labels: unique_labels = unique_labels | set(label) label_list = list(unique_labels) label_list.sort() return label_list if isinstance(features[label_column_name].feature, ClassLabel): label_list = features[label_column_name].feature.names # No need to convert the labels since they are already ints. label_to_id = {i: i for i in range(len(label_list))} else: label_list = get_label_list(raw_datasets["train"][label_column_name]) label_to_id = {l: i for i, l in enumerate(label_list)} num_labels = len(label_list) # Load pretrained model and tokenizer config = AutoConfig.from_pretrained( model_args.config_name if model_args.config_name else model_args.model_name_or_path, num_labels=num_labels, label2id=label_to_id, id2label={i: l for l, i in label_to_id.items()}, finetuning_task=data_args.task_name, cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, ) tokenizer_name_or_path = model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path if config.model_type in {"gpt2", "roberta"}: tokenizer = AutoTokenizer.from_pretrained( tokenizer_name_or_path, cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, add_prefix_space=True, ) else: tokenizer = AutoTokenizer.from_pretrained( tokenizer_name_or_path, cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, ) model = FlaxAutoModelForTokenClassification.from_pretrained( model_args.model_name_or_path, config=config, cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, ) # Preprocessing the datasets # Tokenize all texts and align the labels with them. def tokenize_and_align_labels(examples): tokenized_inputs = tokenizer( examples[text_column_name], max_length=data_args.max_seq_length, padding="max_length", truncation=True, # We use this argument because the texts in our dataset are lists of words (with a label for each word). is_split_into_words=True, ) labels = [] for i, label in enumerate(examples[label_column_name]): word_ids = tokenized_inputs.word_ids(batch_index=i) previous_word_idx = None label_ids = [] for word_idx in word_ids: # Special tokens have a word id that is None. We set the label to -100 so they are automatically # ignored in the loss function. if word_idx is None: label_ids.append(-100) # We set the label for the first token of each word. elif word_idx != previous_word_idx: label_ids.append(label_to_id[label[word_idx]]) # For the other tokens in a word, we set the label to either the current label or -100, depending on # the label_all_tokens flag. else: label_ids.append(label_to_id[label[word_idx]] if data_args.label_all_tokens else -100) previous_word_idx = word_idx labels.append(label_ids) tokenized_inputs["labels"] = labels return tokenized_inputs processed_raw_datasets = raw_datasets.map( tokenize_and_align_labels, batched=True, num_proc=data_args.preprocessing_num_workers, load_from_cache_file=not data_args.overwrite_cache, remove_columns=raw_datasets["train"].column_names, desc="Running tokenizer on dataset", ) train_dataset = processed_raw_datasets["train"] eval_dataset = processed_raw_datasets["validation"] # Log a few random samples from the training set: for index in random.sample(range(len(train_dataset)), 3): logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") # Define a summary writer summary_writer = tensorboard.SummaryWriter(training_args.output_dir) summary_writer.hparams({**training_args.to_dict(), **vars(model_args), **vars(data_args)}) def write_train_metric(summary_writer, train_metrics, train_time, step): summary_writer.scalar("train_time", train_time, step) train_metrics = get_metrics(train_metrics) for key, vals in train_metrics.items(): tag = f"train_{key}" for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) def write_eval_metric(summary_writer, eval_metrics, step): for metric_name, value in eval_metrics.items(): summary_writer.scalar(f"eval_{metric_name}", value, step) num_epochs = int(training_args.num_train_epochs) rng = jax.random.PRNGKey(training_args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) train_batch_size = training_args.per_device_train_batch_size * jax.local_device_count() eval_batch_size = training_args.per_device_eval_batch_size * jax.local_device_count() learning_rate_fn = create_learning_rate_fn( len(train_dataset), train_batch_size, training_args.num_train_epochs, training_args.warmup_steps, training_args.learning_rate, ) state = create_train_state(model, learning_rate_fn, num_labels=num_labels, training_args=training_args) # define step functions def train_step( state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey ) -> Tuple[train_state.TrainState, float]: """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`.""" dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) targets = batch.pop("labels") def loss_fn(params): logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss = state.loss_fn(logits, targets) return loss grad_fn = jax.value_and_grad(loss_fn) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch") return new_state, metrics, new_dropout_rng p_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,)) def eval_step(state, batch): logits = state.apply_fn(**batch, params=state.params, train=False)[0] return state.logits_fn(logits) p_eval_step = jax.pmap(eval_step, axis_name="batch") metric = load_metric("seqeval") def get_labels(y_pred, y_true): # Transform predictions and references tensos to numpy arrays # Remove ignored index (special tokens) true_predictions = [ [label_list[p] for (p, l) in zip(pred, gold_label) if l != -100] for pred, gold_label in zip(y_pred, y_true) ] true_labels = [ [label_list[l] for (p, l) in zip(pred, gold_label) if l != -100] for pred, gold_label in zip(y_pred, y_true) ] return true_predictions, true_labels def compute_metrics(): results = metric.compute() if data_args.return_entity_level_metrics: # Unpack nested dictionaries final_results = {} for key, value in results.items(): if isinstance(value, dict): for n, v in value.items(): final_results[f"{key}_{n}"] = v else: final_results[key] = value return final_results else: return { "precision": results["overall_precision"], "recall": results["overall_recall"], "f1": results["overall_f1"], "accuracy": results["overall_accuracy"], } logger.info(f"===== Starting training ({num_epochs} epochs) =====") train_time = 0 # make sure weights are replicated on each device state = replicate(state) train_time = 0 step_per_epoch = len(train_dataset) // train_batch_size total_steps = step_per_epoch * num_epochs epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) for epoch in epochs: train_start = time.time() train_metrics = [] # Create sampling rng rng, input_rng = jax.random.split(rng) # train for step, batch in enumerate( tqdm( train_data_collator(input_rng, train_dataset, train_batch_size), total=step_per_epoch, desc="Training...", position=1, ) ): state, train_metric, dropout_rngs = p_train_step(state, batch, dropout_rngs) train_metrics.append(train_metric) cur_step = epoch * step_per_epoch + step if cur_step % training_args.logging_steps == 0 and cur_step > 0: # Save metrics train_metric = unreplicate(train_metric) train_time += time.time() - train_start if jax.process_index() == 0: write_train_metric(summary_writer, train_metrics, train_time, cur_step) epochs.write( f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})" ) train_metrics = [] if cur_step % training_args.eval_steps == 0 and cur_step > 0: eval_metrics = {} # evaluate for batch in tqdm( eval_data_collator(eval_dataset, eval_batch_size), total=len(eval_dataset) // eval_batch_size, desc="Evaluating ...", position=2, ): labels = batch.pop("labels") predictions = p_eval_step(state, batch) predictions = np.array([pred for pred in chain(*predictions)]) labels = np.array([label for label in chain(*labels)]) labels[np.array(chain(*batch["attention_mask"])) == 0] = -100 preds, refs = get_labels(predictions, labels) metric.add_batch( predictions=preds, references=refs, ) # evaluate also on leftover examples (not divisible by batch_size) num_leftover_samples = len(eval_dataset) % eval_batch_size # make sure leftover batch is evaluated on one device if num_leftover_samples > 0 and jax.process_index() == 0: # take leftover samples batch = eval_dataset[-num_leftover_samples:] batch = {k: np.array(v) for k, v in batch.items()} labels = batch.pop("labels") predictions = eval_step(unreplicate(state), batch) labels = np.array(labels) labels[np.array(batch["attention_mask"]) == 0] = -100 preds, refs = get_labels(predictions, labels) metric.add_batch( predictions=preds, references=refs, ) eval_metrics = compute_metrics() if data_args.return_entity_level_metrics: logger.info(f"Step... ({cur_step}/{total_steps} | Validation metrics: {eval_metrics}") else: logger.info( f"Step... ({cur_step}/{total_steps} | Validation f1: {eval_metrics['f1']}, Validation Acc: {eval_metrics['accuracy']})" ) if jax.process_index() == 0: write_eval_metric(summary_writer, eval_metrics, cur_step) if (cur_step % training_args.save_steps == 0 and cur_step > 0) or (cur_step == total_steps): # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(unreplicate(state.params)) model.save_pretrained( training_args.output_dir, params=params, push_to_hub=training_args.push_to_hub, commit_message=f"Saving weights and logs of step {cur_step}", ) epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}"
def train(model_def, model_dir, batch_size, num_epochs, learning_rate, sgd_momentum, make_lr_fun=None, l2_reg=0.0005, run_seed=0): """Train model.""" if jax.host_count() > 1: raise ValueError('CIFAR-10 example should not be run on ' 'more than 1 host (for now)') if make_lr_fun is None: # No learning rate function provided # Default to stepped LR schedule for CIFAR-10 and Wide ResNet def make_lr_fun(base_lr, steps_per_epoch): # pylint: disable=function-redefined return lr_schedule.create_stepped_learning_rate_schedule( base_lr, steps_per_epoch, [[60, 0.2], [120, 0.04], [160, 0.008]]) summary_writer = tensorboard.SummaryWriter(model_dir) rng = random.PRNGKey(run_seed) if batch_size % jax.device_count() > 0: raise ValueError( 'Batch size must be divisible by the number of devices') device_batch_size = batch_size // jax.device_count() # Load dataset data_source = input_pipeline.CIFAR10DataSource(train_batch_size=batch_size, eval_batch_size=batch_size) train_ds = data_source.train_ds eval_ds = data_source.eval_ds # Compute steps per epoch and nb of eval steps steps_per_epoch = data_source.TRAIN_IMAGES // batch_size steps_per_eval = data_source.EVAL_IMAGES // batch_size num_steps = steps_per_epoch * num_epochs base_learning_rate = learning_rate # Create the model image_size = 32 model, state = create_model(rng, device_batch_size, image_size, model_def) state = jax_utils.replicate(state) optimizer = create_optimizer(model, base_learning_rate, sgd_momentum) del model # don't keep a copy of the initial model # Learning rate schedule learning_rate_fn = make_lr_fun(base_learning_rate, steps_per_epoch) # pmap the train and eval functions p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn, l2_reg=l2_reg), axis_name='batch') p_eval_step = jax.pmap(eval_step, axis_name='batch') # Create dataset batch iterators train_iter = iter(train_ds) eval_iter = iter(eval_ds) # Gather metrics train_metrics = [] epoch = 1 for step, batch in zip(range(num_steps), train_iter): # Generate a PRNG key that will be rolled into the batch rng, step_key = jax.random.split(rng) # Load and shard the TF batch batch = load_and_shard_tf_batch(batch) # Shard the step PRNG key sharded_keys = common_utils.shard_prng_key(step_key) # Train step optimizer, state, metrics = p_train_step(optimizer, state, batch, sharded_keys) train_metrics.append(metrics) if (step + 1) % steps_per_epoch == 0: # We've finished an epoch train_metrics = common_utils.get_metrics(train_metrics) # Get training epoch summary for logging train_summary = jax.tree_map(lambda x: x.mean(), train_metrics) # Send stats to Tensorboard for key, vals in train_metrics.items(): tag = 'train_%s' % key for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) # Reset train metrics train_metrics = [] # Evaluation eval_metrics = [] for _ in range(steps_per_eval): eval_batch = next(eval_iter) # Load and shard the TF batch eval_batch = load_and_shard_tf_batch(eval_batch) # Step metrics = p_eval_step(optimizer.target, state, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) # Get eval epoch summary for logging eval_summary = jax.tree_map(lambda x: x.mean(), eval_metrics) # Log epoch summary logging.info( 'Epoch %d: TRAIN loss=%.6f, err=%.2f, EVAL loss=%.6f, err=%.2f', epoch, train_summary['loss'], train_summary['error_rate'] * 100.0, eval_summary['loss'], eval_summary['error_rate'] * 100.0) summary_writer.scalar('eval_loss', eval_summary['loss'], epoch) summary_writer.scalar('eval_error_rate', eval_summary['error_rate'], epoch) summary_writer.flush() epoch += 1
def main(args): logdir = os.path.join(args.logdir, args.name) logger = logging.setup_logger(logdir) logger.info(args) logger.info(f'Available devices: {jax.devices()}') # Setup input pipeline dataset_info = input_pipeline.get_dataset_info(args.dataset, 'train') ds_train = input_pipeline.get_data(dataset=args.dataset, mode='train', repeats=None, mixup_alpha=args.mixup_alpha, batch_size=args.batch, shuffle_buffer=args.shuffle_buffer, tfds_data_dir=args.tfds_data_dir, tfds_manual_dir=args.tfds_manual_dir) batch = next(iter(ds_train)) logger.info(ds_train) ds_test = input_pipeline.get_data(dataset=args.dataset, mode='test', repeats=1, batch_size=args.batch_eval, tfds_data_dir=args.tfds_data_dir, tfds_manual_dir=args.tfds_manual_dir) logger.info(ds_test) # Build VisionTransformer architecture model = models.KNOWN_MODELS[args.model] VisionTransformer = model.partial(num_classes=dataset_info['num_classes']) _, params = VisionTransformer.init_by_shape( jax.random.PRNGKey(0), # Discard the "num_local_devices" dimension for initialization. [(batch['image'].shape[1:], batch['image'].dtype.name)]) pretrained_path = os.path.join(args.vit_pretrained_dir, f'{args.model}.npz') params = checkpoint.load_pretrained( pretrained_path=pretrained_path, init_params=params, model_config=models.CONFIGS[args.model], logger=logger) # pmap replicates the models over all TPUs/GPUs vit_fn_repl = jax.pmap(VisionTransformer.call) update_fn_repl = make_update_fn(VisionTransformer.call, args.accum_steps) # Create optimizer and replicate it over all TPUs/GPUs opt = momentum_clip.Optimizer( dtype=args.optim_dtype, grad_norm_clip=args.grad_norm_clip).create(params) opt_repl = flax_utils.replicate(opt) # Delete referenes to the objects that are not needed anymore del opt del params def copyfiles(paths): """Small helper to copy files to args.copy_to using tf.io.gfile.""" if not args.copy_to: return for path in paths: to_path = os.path.join(args.copy_to, args.name, os.path.basename(path)) tf.io.gfile.makedirs(os.path.dirname(to_path)) tf.io.gfile.copy(path, to_path, overwrite=True) logger.info(f'Copied {path} to {to_path}.') total_steps = args.total_steps or ( input_pipeline.DATASET_PRESETS[args.dataset]['total_steps']) # Prepare the learning-rate and pre-fetch it to device to avoid delays. lr_fn = hyper.create_learning_rate_schedule(total_steps, args.base_lr, args.decay_type, args.warmup_steps) lr_iter = hyper.lr_prefetch_iter(lr_fn, 0, total_steps) update_rngs = jax.random.split(jax.random.PRNGKey(0), jax.local_device_count()) # Run training loop writer = metric_writers.create_default_writer(logdir, asynchronous=False) writer.write_hparams( {k: v for k, v in vars(args).items() if v is not None}) logger.info('Starting training loop; initial compile can take a while...') t0 = time.time() for step, batch, lr_repl in zip( range(1, total_steps + 1), input_pipeline.prefetch(ds_train, args.prefetch), lr_iter): opt_repl, loss_repl, update_rngs = update_fn_repl( opt_repl, lr_repl, batch, update_rngs) if step == 1: logger.info(f'First step took {time.time() - t0:.1f} seconds.') t0 = time.time() if args.progress_every and step % args.progress_every == 0: writer.write_scalars(step, dict(train_loss=float(loss_repl[0]))) done = step / total_steps logger.info(f'Step: {step}/{total_steps} {100*done:.1f}%, ' f'ETA: {(time.time()-t0)/done*(1-done)/3600:.2f}h') copyfiles(glob.glob(f'{logdir}/*')) # Run eval step if ((args.eval_every and step % args.eval_every == 0) or (step == total_steps)): accuracy_test = np.mean([ c for batch in input_pipeline.prefetch(ds_test, args.prefetch) for c in (np.argmax( vit_fn_repl(opt_repl.target, batch['image']), axis=2) == np.argmax(batch['label'], axis=2)).ravel() ]) lr = float(lr_repl[0]) logger.info(f'Step: {step} ' f'Learning rate: {lr:.7f}, ' f'Test accuracy: {accuracy_test:0.5f}') writer.write_scalars(step, dict(accuracy_test=accuracy_test, lr=lr)) copyfiles(glob.glob(f'{logdir}/*')) if args.output: checkpoint.save(flax_utils.unreplicate(opt_repl.target), args.output) logger.info(f'Stored fine tuned checkpoint to {args.output}') copyfiles([args.output])
def train(config, workdir): """Runs a training loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ # Create directories for experimental logs tf.io.gfile.makedirs(workdir) sample_dir = os.path.join(workdir, "samples") tf.io.gfile.makedirs(sample_dir) rng = jax.random.PRNGKey(config.seed) tb_dir = os.path.join(workdir, "tensorboard") tf.io.gfile.makedirs(tb_dir) if jax.host_id() == 0: writer = tensorboard.SummaryWriter(tb_dir) # Initialize model. rng, model_rng = jax.random.split(rng) model_name = config.model.name ncsn_def = mutils.get_model(model_name).partial(config=config) rng, run_rng = jax.random.split(rng) # Whether the generative model is conditioned on class labels class_conditional = "conditional" in config.training.loss.lower() with nn.stateful() as init_model_state: with nn.stochastic(run_rng): input_shape = (jax.local_device_count(), config.data.image_size, config.data.image_size, 3) input_list = [(input_shape, jnp.float32), (input_shape[:1], jnp.int32)] if class_conditional: input_list.append(input_list[-1]) _, initial_params = ncsn_def.init_by_shape(model_rng, input_list, train=True) ncsn = nn.Model(ncsn_def, initial_params) optimizer = losses.get_optimizer(config).create(ncsn) state = mutils.State(step=0, optimizer=optimizer, lr=config.optim.lr, model_state=init_model_state, ema_rate=config.model.ema_rate, params_ema=initial_params, rng=rng) # pytype: disable=wrong-keyword-args del ncsn, init_model_state # Do not keep a copy of the initial model. # Create checkpoints directory and the initial checkpoint checkpoint_dir = os.path.join(workdir, "checkpoints") ckpt = utils.Checkpoint(checkpoint_dir, max_to_keep=None) ckpt.restore_or_initialize(state) # Save intermediate checkpoints to resume training automatically checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta") ckpt_meta = utils.Checkpoint(checkpoint_meta_dir, max_to_keep=1) state = ckpt_meta.restore_or_initialize(state) initial_step = int(state.step) rng = state.rng # Build input pipeline. rng, ds_rng = jax.random.split(rng) train_ds, eval_ds, _ = datasets.get_dataset(ds_rng, config) train_iter = iter(train_ds) # pytype: disable=wrong-arg-types eval_iter = iter(eval_ds) # pytype: disable=wrong-arg-types scaler = datasets.get_data_scaler(config) # data normalizer inverse_scaler = datasets.get_data_inverse_scaler(config) # Distribute training. optimize_fn = losses.optimization_manager(config) if config.training.loss.lower() == "ddpm": # Use score matching loss with DDPM-type perturbation. ddpm_params = mutils.get_ddpm_params() train_step = functools.partial(losses.ddpm_loss, ddpm_params=ddpm_params, train=True, optimize_fn=optimize_fn) eval_step = functools.partial(losses.ddpm_loss, ddpm_params=ddpm_params, train=False) else: # Use score matching loss with NCSN-type perturbation. sigmas = mutils.get_sigmas(config) # Whether to use a continuous distribution of noise levels continuous = "continuous" in config.training.loss.lower() train_step = functools.partial( losses.ncsn_loss, sigmas=sigmas, class_conditional=class_conditional, continuous=continuous, train=True, optimize_fn=optimize_fn, anneal_power=config.training.anneal_power) eval_step = functools.partial( losses.ncsn_loss, sigmas=sigmas, class_conditional=class_conditional, continuous=continuous, train=False, anneal_power=config.training.anneal_power) p_train_step = jax.pmap(train_step, axis_name="batch") p_eval_step = jax.pmap(eval_step, axis_name="batch") state = flax_utils.replicate(state) num_train_steps = config.training.n_iters logging.info("Starting training loop at step %d.", initial_step) rng = jax.random.fold_in(rng, jax.host_id()) for step in range(initial_step, num_train_steps + 1): # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU # devices. # Convert data to JAX arrays. Use ._numpy() to avoid copy. batch = jax.tree_map(lambda x: scaler(x._numpy()), next(train_iter)) # pylint: disable=protected-access rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1) next_rng = jnp.asarray(next_rng) loss, state = p_train_step(next_rng, state, batch) loss = flax.jax_utils.unreplicate(loss) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) if jax.host_id() == 0 and step % 50 == 0: logging.info("step: %d, training_loss: %.5e", step, loss) writer.scalar("training_loss", loss, step) # Save a temporary checkpoint to resume training after pre-emption. if step % config.training.snapshot_freq_for_preemption == 0 and jax.host_id( ) == 0: saved_state = flax_utils.unreplicate(state) saved_state = saved_state.replace(rng=rng) ckpt_meta.save(saved_state) # Report the loss on an evaluation dataset. if step % 100 == 0: rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1) next_rng = jnp.asarray(next_rng) eval_batch = jax.tree_map(lambda x: scaler(x._numpy()), next(eval_iter)) # pylint: disable=protected-access eval_loss, _ = p_eval_step(next_rng, state, eval_batch) eval_loss = flax.jax_utils.unreplicate(eval_loss) if jax.host_id() == 0: logging.info("step: %d, eval_loss: %.5e", step, eval_loss) writer.scalar("eval_loss", eval_loss, step) # Save a checkpoint periodically and generate samples. if (step + 1 ) % config.training.snapshot_freq == 0 or step == num_train_steps: # Save the checkpoint. if jax.host_id() == 0: saved_state = flax_utils.unreplicate(state) saved_state = saved_state.replace(rng=rng) ckpt.save(saved_state) # Generate and save samples if config.training.snapshot_sampling: rng, sample_rng = jax.random.split(rng) init_shape = tuple(train_ds.element_spec["image"].shape) samples = sampling.get_samples( sample_rng, config, flax_utils.unreplicate(state), init_shape, scaler, inverse_scaler, class_conditional=class_conditional) this_sample_dir = os.path.join( sample_dir, "iter_{}_host_{}".format(step, jax.host_id())) tf.io.gfile.makedirs(this_sample_dir) if config.sampling.final_only: # Do not save intermediate samples sample = samples[-1] image_grid = sample.reshape((-1, *sample.shape[2:])) nrow = int(np.sqrt(image_grid.shape[0])) sample = np.clip(sample * 255, 0, 255).astype(np.uint8) with tf.io.gfile.GFile( os.path.join(this_sample_dir, "sample.np"), "wb") as fout: np.save(fout, sample) with tf.io.gfile.GFile( os.path.join(this_sample_dir, "sample.png"), "wb") as fout: utils.save_image(image_grid, fout, nrow=nrow, padding=2) else: # Save all intermediate samples produced during sampling. for i, sample in enumerate(samples): image_grid = sample.reshape((-1, *sample.shape[2:])) nrow = int(np.sqrt(image_grid.shape[0])) sample = np.clip(sample * 255, 0, 255).astype(np.uint8) with tf.io.gfile.GFile( os.path.join(this_sample_dir, "sample_{}.np".format(i)), "wb") as fout: np.save(fout, sample) with tf.io.gfile.GFile( os.path.join(this_sample_dir, "sample_{}.png".format(i)), "wb") as fout: utils.save_image(image_grid, fout, nrow=nrow, padding=2)