def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') config = FLAGS.config model = create_model(config) optimizer = create_optimizer(config, model) del model # don't keep a copy of the initial model output_dir = get_output_dir(config) gfile.makedirs(output_dir) # Restore from a local checkpoint, if one exists. optimizer = checkpoints.restore_checkpoint(output_dir, optimizer) start_step = int(optimizer.state[0].step) optimizer = optimizer.replicate() os.environ['TOKENIZERS_PARALLELISM'] = 'true' tokenizer = BertTokenizerFast.from_pretrained(config.tokenizer) tokenizer.model_max_length = config.max_seq_length # The commented lines below correspond to a data pipeline that uses publicly # available data, in the form of English Wikipedia as processed and hosted by # the HuggingFace datasets library. The pipeline works, and downstream task # performance shows a benefit to pre-training, but I (Nikita) have yet to # confirm that final model quality is on par with the original BERT. # # dataset = datasets.load_dataset('wikipedia', '20200501.en')['train'] # data_pipeline = data.PretrainingDataPipelineV1( # dataset, tokenizer, # max_predictions_per_seq=config.max_predictions_per_seq) # The data pipeline below relies on having text files of Wikipedia + Books in # the same format as the original BERT data. That original data is not # publicly available, so you will need to provide your own. I (Nikita) have # had success using data from Gong et al. "Efficient Training of BERT by # Progressively Stacking", but this data was also obtained through private # correspondence and may not be generally available. # The data_files argument may be a list, if data is split across multiple # input files. dataset = datasets.load_dataset( 'bert_data.py', data_files=os.path.expanduser('~/data/bert/corpus.train.tok'))['train'] data_pipeline = data.PretrainingDataPipeline( dataset, tokenizer, max_predictions_per_seq=config.max_predictions_per_seq) datasets.logging.set_verbosity_error() learning_rate_fn = training.create_learning_rate_scheduler( factors='constant * linear_warmup * linear_decay', base_learning_rate=config.learning_rate, warmup_steps=config.num_warmup_steps, steps_per_cycle=config.num_train_steps - config.num_warmup_steps, ) train_history = training.TrainStateHistory(learning_rate_fn) train_state = train_history.initial_state() if config.do_train: train_iter = data_pipeline.get_inputs( batch_size=config.train_batch_size, training=True) train_step_fn = training.create_train_step( compute_pretraining_loss_and_metrics, clip_grad_norm=1.0) for step, batch in zip(range(start_step, config.num_train_steps), train_iter): optimizer, train_state = train_step_fn(optimizer, batch, train_state) if jax.host_id() == 0 and (step % config.save_checkpoints_steps == 0 or step == config.num_train_steps - 1): checkpoints.save_checkpoint(output_dir, optimizer.unreplicate(), step) config_path = os.path.join(output_dir, 'config.json') if not os.path.exists(config_path): with open(config_path, 'w') as f: json.dump({'model_type': 'bert', **config.model}, f) if config.do_eval: eval_iter = data_pipeline.get_inputs(batch_size=config.eval_batch_size) eval_iter = itertools.islice(eval_iter, config.max_eval_steps) eval_fn = training.create_eval_fn(compute_pretraining_stats, sample_feature_name='input_ids') eval_stats = eval_fn(optimizer, eval_iter) eval_metrics = { 'loss': jnp.mean(eval_stats['loss']), 'masked_lm_loss': jnp.mean(eval_stats['masked_lm_loss']), 'next_sentence_loss': jnp.mean(eval_stats['next_sentence_loss']), 'masked_lm_accuracy': jnp.sum(eval_stats['masked_lm_correct']) / jnp.sum(eval_stats['masked_lm_total']), 'next_sentence_accuracy': jnp.sum(eval_stats['next_sentence_correct']) / jnp.sum(eval_stats['next_sentence_total']), } eval_results = [] for name, val in sorted(eval_metrics.items()): line = f'{name} = {val:.06f}' print(line, flush=True) eval_results.append(line) eval_results_path = os.path.join(output_dir, 'eval_results.txt') with gfile.GFile(eval_results_path, 'w') as f: for line in eval_results: f.write(line + '\n')
def train_and_evaluate( random_seed, batch_size, learning_rate, num_train_steps, num_eval_steps, eval_freq, max_target_length, max_eval_target_length, weight_decay, data_dir, model_dir, restore_checkpoints, save_checkpoints, checkpoint_freq, max_predict_token_length, sampling_temperature, sampling_top_k, prompt_str): """Executes model training and evaluation loop. Args: random_seed: Seed for initializing PRNG random seed. batch_size: Batch size for training. learning_rate: Learning rate for the Adam optimizer. num_train_steps: Number of training steps. num_eval_steps: Number of evaluation steps. eval_freq: Frequency of evaluation during training. max_target_length: Maximum length of training examples. max_eval_target_length: Maximum length of eval examples. weight_decay: Decay factor for AdamW-style weight decay. data_dir: Directory containing TFDS lm1b/subwords32k datasets. model_dir: Directory where to store model data. restore_checkpoints: Whether to restore from existing model checkpoints. save_checkpoints: Whether to save model checkpoints. checkpoint_freq: Save a checkpoint every these number of steps. max_predict_token_length: Maximum example text inference token length. sampling_temperature: Sampling temperature for language model inference. sampling_top_k: Top k cutoff for logit sampling. prompt_str: Prompt for language model sampling. """ if jax.host_id() == 0: train_summary_writer = tensorboard.SummaryWriter( os.path.join(model_dir, 'train')) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(model_dir, 'eval')) if batch_size % jax.device_count() > 0: raise ValueError('Batch size must be divisible by the number of devices') train_ds, eval_ds, info_ds = input_pipeline.get_lm1b_datasets( n_devices=jax.local_device_count(), data_dir=data_dir, batch_size=batch_size, dynamic_batching=True, max_target_length=max_target_length, max_eval_target_length=max_eval_target_length) vocab_size = info_ds['text'].encoder.vocab_size encoder = info_ds['text'].encoder train_iter = iter(train_ds) input_shape = (batch_size, max_target_length) transformer_lm_kwargs = { 'vocab_size': vocab_size, 'emb_dim': 512, 'num_heads': 8, 'num_layers': 6, 'qkv_dim': 512, 'mlp_dim': 2048, 'max_len': max(max_target_length, max_eval_target_length) } rng = random.PRNGKey(random_seed) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = random.split(rng) # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. dropout_rngs = random.split(rng, jax.local_device_count()) model, cache_def = create_model(init_rng, input_shape, transformer_lm_kwargs) optimizer = create_optimizer(model, learning_rate, weight_decay) del model # Don't keep a copy of the initial model. start_step = 0 if restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint(model_dir, optimizer) # Grab last step. start_step = int(optimizer.state.step) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=learning_rate) p_train_step = jax.pmap( functools.partial(train_step, learning_rate_fn=learning_rate_fn), axis_name='batch') p_eval_step = jax.pmap(eval_step, axis_name='batch') p_pred_step = jax.pmap(predict_step, axis_name='batch') metrics_all = [] tick = time.time() for step, batch in zip(range(start_step, num_train_steps), train_iter): batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access optimizer, metrics, dropout_rngs = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) metrics_all.append(metrics) # Save a Checkpoint if ((step % checkpoint_freq == 0 and step > 0) or step == num_train_steps - 1): if jax.host_id() == 0 and save_checkpoints: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint( model_dir, jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if step % eval_freq == 0 and step > 0: metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) logging.info('train in step: %d, loss: %.4f', step, summary['loss']) if jax.host_id() == 0: tock = time.time() steps_per_sec = eval_freq / (tock - tick) tick = tock train_summary_writer.scalar('steps per second', steps_per_sec, step) for key, val in summary.items(): train_summary_writer.scalar(key, val, step) train_summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] # Eval Metrics eval_metrics = [] eval_iter = iter(eval_ds) if num_eval_steps == -1: num_iter = itertools.repeat(1) else: num_iter = range(num_eval_steps) for _, eval_batch in zip(num_iter, eval_iter): # pylint: disable=protected-access eval_batch = common_utils.shard( jax.tree_map(lambda x: x._numpy(), eval_batch)) # pylint: enable=protected-access metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) # Calculate (clipped) perplexity after averaging log-perplexities: eval_summary['perplexity'] = jnp.clip( jnp.exp(eval_summary['loss']), a_max=1.0e4) logging.info('eval in step: %d, loss: %.4f', step, eval_summary['loss']) if jax.host_id() == 0: for key, val in eval_summary.items(): eval_summary_writer.scalar(key, val, step) eval_summary_writer.flush() # Fast inference of prompt extension using trained LM. rng, subrng = jax.random.split(rng) pred_rngs = random.split(subrng, jax.local_device_count()) prompt = jnp.array(encoder.encode(prompt_str)) prompt = jax_utils.replicate(prompt) prompt = jnp.reshape(prompt, (prompt.shape[0], 1, prompt.shape[1])) cache = jax_utils.replicate( cache_def.initialize_cache((1, max_predict_token_length))) predicted = p_pred_step( prompt, optimizer.target, cache, pred_rngs, max_predict_token_length, sampling_temperature, sampling_top_k) predicted = tohost(predicted) exemplars = '' for n in range(predicted.shape[0]): exemplars += encoder.decode(predicted[n]) + '\n\n' if jax.host_id() == 0: eval_summary_writer.text('samples', exemplars, step) eval_summary_writer.flush()
def main(_): tf.enable_v2_behavior() tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) if not gfile.isdir(FLAGS.save_dir): gfile.mkdir(FLAGS.save_dir) hparam_str_dict = dict(seed=FLAGS.seed, lr=FLAGS.lr) # Get hyperparmaters if FLAGS.xm_parameters: for key, value in json.loads(FLAGS.xm_parameters).items(): if key not in hparam_str_dict: hparam_str_dict[key] = value hparam_str = ','.join(['%s=%s' % (shorten(k), str(hparam_str_dict[k])) for k in sorted(hparam_str_dict.keys())]) # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) batch_size = FLAGS.per_device_batch_size * n_devices io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.max_characters) program_shape = (FLAGS.per_device_batch_size, FLAGS.num_partial_programs, FLAGS.max_program_length) split_io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.num_partial_programs, FLAGS.max_characters) # Setup DSL # --------------------------------------------------------------------------- # Build token tables. id_char_table = {i+1: char for (i, char) in enumerate(dsl.CHARACTER)} char_id_table = {char: id for id, char in id_char_table.items()} id_token_table, token_id_table = dsl_tokens.build_token_tables() io_vocab_size = len(char_id_table) + 1 # For padding. program_vocab_size = len(token_id_table) + 1 bos_token = token_id_table[dsl.BOS] eos_token = token_id_table[dsl.EOS] # Parse io and program token sequences (for eval). def decode_io(inputs, outputs): """Decode io examples tokens.""" def decode_str(s): """Decode string tokens.""" return ''.join([id_char_table[c_id] for c_id in s if c_id > 0]) inps, outs = [], [] for inp, out in zip(inputs, outputs): inps.append(decode_str(inp)) outs.append(decode_str(out)) return inps, outs def decode_program(program): """Decode program tokens.""" # Concatenate all partial programs. full_program = [] for p in program: full_program.extend(p[:np.argmax(p == eos_token)].astype(np.int32)) full_program = np.concatenate([full_program, [eos_token]], axis=0) try: return dsl.decode_program(full_program, id_token_table) except: # pylint: disable=bare-except return None # Program does not compile. # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') if not FLAGS.dataset_filepattern: raise ValueError('Must specify filepattern to dataset.') # Training dataset. dataset = input_pipeline.create_dataset_from_tf_record( FLAGS.dataset_filepattern, token_id_table, char_id_table, num_partial_programs=FLAGS.num_partial_programs) dataset = dataset.padded_batch( batch_size, padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:], split_io_shape[1:]), drop_remainder=True) # Split evaluation and training. eval_ds = dataset.take(FLAGS.num_eval_steps) # Decrease batch of predict dataset to handle beam search. predict_ds = eval_ds.unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:], split_io_shape[1:])) train_ds = dataset.skip(FLAGS.num_eval_steps).repeat().prefetch(5) train_iter = train_ds.as_numpy_iterator() # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = base_models.TransformerConfig( vocab_size=io_vocab_size, output_vocab_size=program_vocab_size, shift=True, emb_dim=FLAGS.embedding_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.embedding_dim, mlp_dim=FLAGS.hidden_dim, max_len=max(FLAGS.max_characters, FLAGS.max_program_length), deterministic=False, decode=False, bos_token=bos_token) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace( shift=False, deterministic=True, decode=not FLAGS.slow_decode) rng = jax.random.PRNGKey(FLAGS.seed) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = jax.random.split(rng) m = models.DecomposeExpandingLayerTransformer( config=eval_config, num_partial_programs=FLAGS.num_partial_programs, use_expanding_layer=FLAGS.use_expanding_layer) initial_variables = jax.jit(m.init)( init_rng, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(program_shape, jnp.float32)) adam_opt_def = optim.Adam( FLAGS.lr, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = adam_opt_def.create(initial_variables['params']) del initial_variables # Don't keep a copy of the initial model. start_step = 0 if FLAGS.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer) # Grab last step. start_step = int(optimizer.state.step) logging.info('Found model checkpointed at step %d.', start_step) if start_step > 0: start_step += 1 # Build Pretraining Model and Optimizer (if specified) # --------------------------------------------------------------------------- pretrain_optimizer = None # Optimizer used for pretrainined split_target = None # Split pretrained model on partial programs. if start_step < FLAGS.num_pretrain_steps: # Load in pretraining optimizer. def filter_fn(path, value): del value if FLAGS.freeze_encoder and path.startswith('/encoder'): return False if FLAGS.freeze_decoder and path.startswith('/decoder'): return False return True trainable_weights = optim.ModelParamTraversal(filter_fn) pretrain_opt_def = optim.MultiOptimizer((trainable_weights, adam_opt_def)) pretrain_optimizer = pretrain_opt_def.create(optimizer.target) if FLAGS.pretrain_checkpoint_format: pretrain_exprs = FLAGS.max_expressions // FLAGS.num_partial_programs checkpoint_dir = FLAGS.pretrain_checkpoint_format.format(pretrain_exprs) if gfile.isdir(checkpoint_dir): # Use the pretrained parameters if no training has occurred yet. if start_step == 0: restore_paths = [] if FLAGS.restore_encoder: restore_paths.append('target/encoder') if FLAGS.restore_decoder: restore_paths.append('target/decoder') pretrain_optimizer = restore_selected_paths( pretrain_optimizer, checkpoint_dir=checkpoint_dir, restore_paths=restore_paths) logging.info('Found model pretrained at %s.', checkpoint_dir) if FLAGS.match_split_encoding: split_model = models.DecomposeExpandingLayerTransformer( config=eval_config, num_partial_programs=1, use_expanding_layer=False) split_program_shape = (FLAGS.per_device_batch_size, 1, FLAGS.max_program_length) split_initial_variables = jax.jit(split_model.init)( init_rng, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(split_program_shape, jnp.float32)) split_optimizer = adam_opt_def.create( split_initial_variables['params']) split_optimizer = checkpoints.restore_checkpoint( checkpoint_dir, split_optimizer) split_target = split_optimizer.target else: logging.warn('Could not find model at %s.', checkpoint_dir) if FLAGS.match_split_encoding and (split_target is None): raise RuntimeError('We could not load the pretrained checkpoint, ' 'which is needed to match split embeddings.') learning_rate_fn = create_learning_rate_scheduler(base_learning_rate=FLAGS.lr) p_pretrain_step = jax.pmap( functools.partial( pretrain_step, num_partial_programs=FLAGS.num_partial_programs, learning_rate_fn=learning_rate_fn, config=train_config, use_expanding_layer=FLAGS.use_expanding_layer, split_params=split_target), axis_name='batch') p_train_step = jax.pmap( functools.partial( train_step, num_partial_programs=FLAGS.num_partial_programs, learning_rate_fn=learning_rate_fn, config=train_config, use_expanding_layer=FLAGS.use_expanding_layer), axis_name='batch') p_eval_step = jax.pmap( functools.partial( eval_step, num_partial_programs=FLAGS.num_partial_programs, eos_token=eos_token, config=eval_config, use_expanding_layer=FLAGS.use_expanding_layer), axis_name='batch') p_init_cache = jax.pmap( functools.partial( initialize_cache, num_partial_programs=FLAGS.num_partial_programs, max_decode_len=FLAGS.max_program_length, config=predict_config, use_expanding_layer=FLAGS.use_expanding_layer), axis_name='batch') p_pred_step = jax.pmap( functools.partial( predict_step, num_partial_programs=FLAGS.num_partial_programs, max_decode_len=FLAGS.max_program_length, eos_token=eos_token, config=predict_config, slow_decode=FLAGS.slow_decode, use_expanding_layer=FLAGS.use_expanding_layer), axis_name='batch', static_broadcasted_argnums=(4,)) p_split_pred_step = jax.pmap( functools.partial( predict_step, num_partial_programs=FLAGS.num_partial_programs, max_decode_len=FLAGS.max_program_length, eos_token=eos_token, config=predict_config, slow_decode=FLAGS.slow_decode, use_expanding_layer=False, use_split_encoding=True, split_params=split_target), axis_name='batch', static_broadcasted_argnums=(4,)) # Main Train Loop # --------------------------------------------------------------------------- train_rngs = jax.random.split(rng, jax.local_device_count()) del rng # Replicate optimizer. if pretrain_optimizer: pretrain_optimizer = jax_utils.replicate(pretrain_optimizer) optimizer = jax_utils.replicate(optimizer) metrics_all = [] tick = time.time() for step in range(start_step, FLAGS.num_train_steps): inputs, outputs, programs, split_outputs = ( common_utils.shard(next(train_iter))) if step < FLAGS.num_pretrain_steps: pretrain_optimizer, metrics, train_rngs = p_pretrain_step( pretrain_optimizer, inputs, outputs, programs, split_outputs=split_outputs, pretrain_rng=train_rngs) else: optimizer, metrics, train_rngs = p_train_step( optimizer, inputs, outputs, programs, train_rng=train_rngs) metrics_all.append(metrics) is_last_pretrain_step = step == FLAGS.num_pretrain_steps - 1 is_last_step = step == FLAGS.num_train_steps - 1 if is_last_pretrain_step: optimizer = maybe_copy_model_from_pretraining( optimizer, pretrain_optimizer, step, adam_opt_def) # Save a Checkpoint if (step % FLAGS.checkpoint_freq == 0 and step > 0) or is_last_step: optimizer = maybe_copy_model_from_pretraining( optimizer, pretrain_optimizer, step, adam_opt_def) if jax.host_id() == 0: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if not step or (step % FLAGS.log_freq != 0 and not is_last_step and not is_last_pretrain_step): continue optimizer = maybe_copy_model_from_pretraining( optimizer, pretrain_optimizer, step, adam_opt_def) logging.info('Gathering training metrics.') # Training Metrics metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map( lambda x: x / denominator, # pylint: disable=cell-var-from-loop metrics_sums) summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) if jax.host_id() == 0: logging.info('Train in step: %d, loss: %.4f', step, summary['loss']) tock = time.time() steps_per_sec = FLAGS.log_freq / (tock - tick) tick = tock summary_writer.scalar('train/steps per second', steps_per_sec, step) for key, val in summary.items(): summary_writer.scalar('train/' + key, val, step) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] # Evaluation Metrics logging.info('Gathering evaluation metrics.') t_evaluation_start = time.time() eval_summary = evaluate( p_eval_step=p_eval_step, target=optimizer.target, eval_ds=eval_ds) if jax.host_id() == 0: logging.info('Evaluation time: %.4f s step %d, loss: %.4f.', time.time()-t_evaluation_start, step, eval_summary['loss']) for key, val in eval_summary.items(): summary_writer.scalar('eval/' + key, val, step) summary_writer.flush() # Beam search metrics. logging.info('Gathering beam search metrics.') for beam_size in [1, 10, 12, 24, 48, 96]: t_inference_start = time.time() pred_acc, message = predict_and_compute_score( p_pred_step=p_pred_step, p_init_cache=p_init_cache, target=optimizer.target, predict_ds=predict_ds, decode_io=decode_io, decode_program=decode_program, beam_size=beam_size, num_partial_programs=FLAGS.num_partial_programs, use_best_first_search=FLAGS.best_first_search, slow_decode=FLAGS.slow_decode) # Write to tensorboard. if jax.host_id() == 0: slow_or_fast = 'slow' if FLAGS.slow_decode else 'fast' logging.info( 'Prediction time, %s (beam %d): %.4f s, step %d, score %.4f', slow_or_fast, beam_size, time.time() - t_inference_start, step, pred_acc) beam_search_or_bfs = 'bfs' if FLAGS.best_first_search else 'beam-search' summary_writer.scalar( 'predict-{}/score-{}-{}'.format(slow_or_fast, beam_search_or_bfs, beam_size), pred_acc, step) summary_writer.text('samples-{}'.format(beam_size), '\n------\n'.join(message), step) summary_writer.flush() if step < FLAGS.num_pretrain_steps and FLAGS.match_split_encoding: pred_acc, message = predict_and_compute_score( p_pred_step=p_split_pred_step, p_init_cache=p_init_cache, target=optimizer.target, predict_ds=predict_ds, decode_io=decode_io, decode_program=decode_program, beam_size=beam_size, num_partial_programs=FLAGS.num_partial_programs, use_best_first_search=FLAGS.best_first_search, slow_decode=FLAGS.slow_decode) # Write to tensorboard. if jax.host_id() == 0: slow_or_fast = 'slow' if FLAGS.slow_decode else 'fast' beam_search_or_bfs = ('bfs' if FLAGS.best_first_search else 'beam-search') summary_writer.scalar( 'predict-split-{}/score-{}-{}'.format(slow_or_fast, beam_search_or_bfs, beam_size), pred_acc, step) summary_writer.text('samples-split-{}'.format(beam_size), '\n------\n'.join(message), step) summary_writer.flush()
def main(argv): global CFG CFG = FLAGS.config if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Guarantee that the JAX bfloat16 extension is used rather than TF bfloat16. _ = np.array(jnp.array([1.0], dtype=jnp.bfloat16)) # Use hardware RNG for bernoulli randoms in dropout mask creation. if CFG.hardware_rng: models.set_hardware_bernoulli() if 'module_import' in CFG and CFG.module_import: for module in CFG.module_import: importlib.import_module(module) if 'additional_task_cache_dirs' in CFG and CFG.additional_task_cache_dirs: t5.data.add_global_cache_dirs(CFG.additional_task_cache_dirs) num_partitions = CFG.num_partitions topology = train_lib.compute_multihost_topology(num_partitions) batch_size = CFG.batch_size eval_batch_size = CFG.eval_batch_size per_replica_set_eval_batch_size = eval_batch_size // topology.num_replica_sets if batch_size % topology.num_replicas: raise ValueError( 'Batch size must be divisible by the number of replicas.') steps_per_epoch = CFG.steps_per_epoch logging.info('steps per epoch: %d', steps_per_epoch) broadcast = functools.partial( train_lib.broadcast, num_replicas=topology.per_replica_set_num_replicas, num_partitions=topology.per_host_num_partitions, devices=topology.this_host_device_assignment) if jax.host_id() == 0: tf.io.gfile.makedirs(FLAGS.model_dir) tf.io.gfile.copy(FLAGS['config'].config_filename, os.path.join(FLAGS.model_dir, 'config.py'), overwrite=True) train_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'train')) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'eval')) else: train_summary_writer = None eval_summary_writer = None # Write summaries in background thread to avoid blocking on device sync if CFG.infeed: # Infeed is currently synchronous, so do it in a background thread too infeed_pool = thread.ThreadPoolExecutor(jax.local_device_count(), 'infeed') (train_ds, eval_ds), eval_cache = input_pipeline.get_datasets_and_cache( CFG, topology.num_replica_sets, topology.replica_set_id, topology.per_replica_set_host_id) vocab = input_pipeline.get_vocabulary(CFG.mixture_or_task_name) encoder = vocab.tf_tokenizer eos_id = vocab.tokenizer.eos_id() def decode_tokens(toks, eos_id=eos_id, max_id=32000): """Decode tokens back to unicode.""" del eos_id # TODO(levskaya): T5 doesn't seem to emit EOS tokens? double check this # is the best decoding function or just switch to using tf_decode. # valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32) valid_toks = toks.astype(np.int32) valid_toks[valid_toks >= max_id] = 3 return encoder.detokenize(valid_toks).numpy().decode('utf-8') logging.info('Initializing model, optimizer, and step functions.') train_config, eval_config, predict_config = get_configs(CFG) rng = random.PRNGKey(CFG.random_seed) rng, init_rng = random.split(rng) # This is used for infeed conversion from feature dict <--> tuple train_keys = [ 'inputs', 'targets', 'inputs_position', 'targets_position', 'inputs_segmentation', 'targets_segmentation' ] device_train_input_shape = tuple([ (batch_size // topology.num_replicas, CFG.max_input_length if 'inputs' in k else CFG.max_target_length) for k in train_keys ]) learning_rate_fn = train_lib.create_learning_rate_scheduler( factors=CFG.schedule, base_learning_rate=CFG.learning_rate, warmup_steps=CFG.warmup_steps) # First, we only abstractly initialize the optimizer and model parameters, # since the parameters may not even fit in device memory! # TODO(jekbradbury): make optimizer_defs compare by value so it can be created # in get_initial_params without causing pytree incompatibility optimizer_def = optim.Adafactor(CFG.learning_rate, decay_rate=0.8, step_offset=CFG.step_offset) initialize_params_fn = functools.partial(get_initial_params, config=CFG, transformer_config=eval_config, optimizer_def=optimizer_def) optimizer = jax.eval_shape(initialize_params_fn, init_rng) # tuple-like pytree leaves for global_arg_shapes optimizer_shapes = jax.tree_map(lambda x: partitions.Spec(*x.shape), optimizer) # Build parameter partition annotations for preserving partitions from train # to eval. if num_partitions > 1: optimizer_partitions = optimizer.restore_state( partitions.set_partitions(num_partitions, optimizer.state_dict())) per_host_optimizer_partitions = optimizer.restore_state( partitions.set_partitions(topology.per_host_num_partitions, optimizer.state_dict())) # Restore unreplicated optimizer + model state from last checkpoint. # TODO(jekbradbury,levskaya): implement sharded native checkpoint/restore existing_checkpoint_found = False if CFG.restore_checkpoints: existing_checkpoint_found = train_lib.checkpoint_exists( FLAGS.model_dir) optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer) # Import a pretrained-T5 checkpoint only if we didn't import a local # "native" checkpoint (e.g. due to resuming a pre-empted finetuning run.) # TODO(jekbradbury,levskaya): implement sharded T5 checkpoint/restore if CFG.restore_t5_checkpoint and not existing_checkpoint_found: optimizer = checkpoint_importer.restore_from_t5_checkpoint( optimizer, CFG.restore_t5_checkpoint) if CFG.restore_t5_checkpoint or existing_checkpoint_found: if num_partitions > 1: # Until checkpoint/restore is sharded, the restored checkpoint is global # and we need to slice each sharded parameter into the chunk containing # only the partitions that are present on this host. def per_host_chunk(x, spec): if spec is None or spec is x: # unsharded or not a parameter return x if spec[0] == 1: dim_size = x.shape[1] elif spec[1] == 1: dim_size = x.shape[0] else: raise NotImplementedError() chunk_size = (dim_size * topology.per_host_num_partitions // num_partitions) lower = topology.per_replica_set_host_id * chunk_size upper = (topology.per_replica_set_host_id + 1) * chunk_size if spec[0] == 1: return x[:, lower:upper] else: return x[lower:upper] optimizer = jax.tree_multimap(per_host_chunk, optimizer, optimizer_partitions) else: # If pretraining and no checkpoint imported, we jit the (sharded-) init # function to minimize fragmentation. We use the same pmap(sharded_jit) # setup as the training step/loop to initialize everything "in-place" and # avoid communication or OOM. if num_partitions > 1: initialize_params_fn = sharded_jit( initialize_params_fn, in_parts=None, local_in_parts=None, out_parts=optimizer_partitions, local_out_parts=per_host_optimizer_partitions, # devices=one_replica_device_assignment, ) initialize_params_fn = jax.pmap(initialize_params_fn, 'batch', in_axes=0, axis_size=topology.num_replicas, devices=topology.device_assignment) init_rng = broadcast(init_rng) optimizer = initialize_params_fn(init_rng) # We maintain the optimizer in unbroadcasted form (i.e. with no leading # replica axis). This is equivalent to the as-yet-nonexistent pmap kwarg # out_axes=None. optimizer = train_lib.unbroadcast(optimizer) else: optimizer = jax.jit(initialize_params_fn)(init_rng) # --------------------------------------------------------------------------- # Compile multidevice versions of train/eval/predict step and cache init fn. # --------------------------------------------------------------------------- # We can use either a single train-step for a host training loop: # train_step(optimizer, batch, prev_metrics, dropout_rng, **kwargs) # --> new_optimizer, metrics, new_dropout_rng def p_train_step(optimizer, batch, prev_metrics, dropout_rng): return train_lib.train_step(optimizer, batch, prev_metrics, dropout_rng, config=train_config, learning_rate_fn=learning_rate_fn, num_microbatches=CFG.microbatches, label_smoothing=CFG.label_smoothing, z_loss=CFG.z_loss, use_bfloat16=CFG.use_bfloat16) if num_partitions > 1: p_train_step = sharded_jit( p_train_step, in_parts=(optimizer_partitions, None, None, None), local_in_parts=(per_host_optimizer_partitions, None, None, None), out_parts=(optimizer_partitions, None, None), local_out_parts=(per_host_optimizer_partitions, None, None)) # TODO(levskaya): the in_axes spec below might be wrong, double-check. p_train_step = jax.pmap(p_train_step, axis_name='batch', in_axes=(None, 0, 0, 0), donate_argnums=(0, ), global_arg_shapes=(optimizer_shapes, None, None, None), axis_size=topology.num_replicas, devices=topology.device_assignment) # pytype: disable=wrong-arg-types # OR, we use an on-device loop that feeds the training step via infeed queue. def device_train_loop_cond(args): """Stopping criterion for on-device loop.""" _, _, _, _, step, epoch = args return step // steps_per_epoch == epoch def device_train_loop_body(args): """On-device loop body.""" optimizer, dropout_rngs, metrics, token, step, epoch = args # Ordering input data from infeed requires threading a symbolic token # through the computation. input_data, token = lax.infeed(token, shape=tuple([ jax.ShapedArray(s, jnp.int32) for s in device_train_input_shape ])) # Rebuild input dict from infeed data tuple. batch = {k: v for k, v in zip(train_keys, input_data)} # Run the train_step function and return the loop state. optimizer, metrics, dropout_rngs = train_lib.train_step( optimizer, batch, metrics, dropout_rngs, train_config, learning_rate_fn, num_microbatches=CFG.microbatches, label_smoothing=CFG.label_smoothing, z_loss=CFG.z_loss) step += 1 return optimizer, dropout_rngs, metrics, token, step, epoch def device_train_loop(optimizer, dropout_rngs, metrics, step, epoch): # Create symbolic token for threading infeed data. token = lax.create_token(step) # Run on-device loop. optimizer, dropout_rngs, metrics, _, step, _ = lax.while_loop( device_train_loop_cond, device_train_loop_body, (optimizer, dropout_rngs, metrics, token, step, epoch)) return optimizer, dropout_rngs, metrics, step if num_partitions > 1: device_train_loop = sharded_jit( device_train_loop, in_parts=(optimizer_partitions, None, None, None, None), local_in_parts=(per_host_optimizer_partitions, None, None, None, None), out_parts=(optimizer_partitions, None, None, None), local_out_parts=(per_host_optimizer_partitions, None, None, None)) p_train_epoch = jax.pmap(device_train_loop, axis_name='batch', in_axes=(None, 0, 0, None, None), donate_argnums=(0, ), global_arg_shapes=(optimizer_shapes, None, None, None, None), axis_size=topology.num_replicas, devices=topology.device_assignment) # pytype: disable=wrong-arg-types # Reduction psum for metric data. def p_allreduce_metrics(x): return lax.psum(x, axis_name='batch') if num_partitions > 1: p_allreduce_metrics = sharded_jit( p_allreduce_metrics, in_parts=None, local_in_parts=None, out_parts=None, local_out_parts=None, num_partitions=num_partitions, local_num_partitions=topology.per_host_num_partitions) p_allreduce_metrics = jax.pmap(p_allreduce_metrics, axis_name='batch', global_arg_shapes=None, axis_size=topology.num_replicas, devices=topology.device_assignment) # Training evaluation computation. # eval_step(params, batch, config, label_smoothing=0.0) --> metrics def p_eval_step(params, batch): return train_lib.eval_step(params, batch, config=eval_config, label_smoothing=CFG.label_smoothing) if num_partitions > 1: p_eval_step = sharded_jit( p_eval_step, in_parts=(optimizer_partitions.target, None), local_in_parts=(per_host_optimizer_partitions.target, None), out_parts=None, local_out_parts=None) p_eval_step = jax.pmap(p_eval_step, axis_name='batch', in_axes=(None, 0), global_arg_shapes=(optimizer_shapes.target, None), axis_size=topology.num_replicas, devices=topology.device_assignment) # pytype: disable=wrong-arg-types # Fast autoregressive decoding loop. # For inference and model evaluation. # predict_step(inputs, params, # eos_id, max_decode_len, config, beam_size=4) --> beam_seqs def p_pred_step(inputs, params): return train_lib.predict_step(inputs, params, eos_id, CFG.max_eval_target_length, predict_config, CFG.beam_size) if num_partitions > 1: p_pred_step = sharded_jit( p_pred_step, in_parts=(None, optimizer_partitions.target), local_in_parts=(None, per_host_optimizer_partitions.target), out_parts=None, local_out_parts=None) p_pred_step = jax.pmap(p_pred_step, axis_name='batch', in_axes=(0, None), global_arg_shapes=(None, optimizer_shapes.target), axis_size=topology.num_replicas, devices=topology.device_assignment) # pytype: disable=wrong-arg-types # --------------------------------------------------------------------------- # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. # There should be a unique dropout key for each replica represented on this # host, but the key should be the same for the same replica on other hosts. # Again, this is what the replica set abstraction is for. dropout_rngs = random.split(random.fold_in(rng, topology.replica_set_id), topology.per_replica_set_num_replicas) # restore step from last checkpoint host_step = int(optimizer.state.step) empty_metrics = broadcast({ 'loss': 0.0, 'accuracy': 0.0, 'learning_rate': 0.0, 'denominator': 0.0 }) if CFG.infeed: # TODO(jekbradbury): support something like this for the Python-loop case logging.info( 'Precompiling training loop and moving optimizer to device.') optimizer, _, metrics, _ = p_train_epoch(optimizer, dropout_rngs, empty_metrics, jnp.array(0, dtype=jnp.int32), 1) optimizer = train_lib.unbroadcast(optimizer) metrics['loss'].block_until_ready() logging.info('Starting training loop.') local_devices = jax.local_devices() device_step = broadcast(host_step) first_epoch = host_step // steps_per_epoch # Main Loop over "epochs". train_iter = train_ds.as_numpy_iterator() for epoch in range(first_epoch, first_epoch + CFG.num_epochs): metrics = empty_metrics # NOTE: 'optimizer' is unbroadcast by construction at initialization or # when loading a checkpoint. It is maintained in 'unbroadcast' state to # enable the XLA cross-replica sharding optimization. The broadcasting is # handled automatically by the pmap'd functions that use it. # Gather all task evaluation metrics. logging.info('Evaluating tasks.') if epoch == first_epoch + 1: train_lib.sync_devices() for task in eval_cache.tasks: logging.info('Evaluating task %s', task.name) all_predicted, all_bs = [], [] for pred_batch in eval_cache.preprocessed_examples[task.name]: # Handle final odd-sized batch by padding instead of dropping it. input_batch, unpadded_batch_size = train_lib.pad_batch_to_size( pred_batch['inputs'], per_replica_set_eval_batch_size) all_bs.append(unpadded_batch_size) # Split batch dimensions for pmap. input_batch = jax.tree_map( lambda x: x.reshape((topology.per_replica_set_num_replicas, -1) + x.shape[1:]), input_batch) # Run fast inference on batch. all_predicted.append(p_pred_step(input_batch, optimizer.target)) # Pad out the number of batches so each host has the same number. max_host_batch_number = np.max( eval_cache.preprocessed_batch_sizes[task.name]) batch_shortfall = max_host_batch_number - len(all_predicted) if batch_shortfall > 0: # TODO(levskaya): Fix for case of entirely empty all_predicted. # To make sure the cross-host barriers work, we run the program the same # number of times on all hosts. The results of this call is ignored, and # the predictions are populated with zeros instead. p_pred_step(input_batch, optimizer.target) # Dummy call. all_predicted.extend([jnp.zeros_like(all_predicted[0])] * batch_shortfall) all_bs.extend([0] * batch_shortfall) all_predicted = jnp.concatenate(all_predicted) all_bs = jnp.array(all_bs) # Collect all batches from across hosts and reverse sharding. all_predicted = train_lib.host_allgather( all_predicted, topology.num_replica_sets, topology.replica_set_id, topology.per_replica_set_host_id == 0) seqlength = all_predicted.shape[-1] total_examples = np.sum( train_lib.host_allgather( all_bs, topology.num_replica_sets, topology.replica_set_id, topology.per_replica_set_host_id == 0)) del all_bs assert total_examples == len(eval_cache.examples[task.name]), ( 'Total number of batches incorrect for task %s.' % task.name) # De-shard the collected predicted tokens and remove padding. all_predicted = np.transpose(all_predicted, (1, 2, 0, 3)).reshape( -1, seqlength)[:total_examples] # We now run the post-processing and metric-fns on a single host. if jax.host_id() == 0: assert eval_summary_writer raw_predictions = [] for tokens in all_predicted: raw_predictions.append(decode_tokens(tokens)) # post-process predictions for metric fns predictions = [ task.postprocess_fn(p, example=ex) for p, ex in zip( raw_predictions, eval_cache.examples[task.name]) ] for metric_fn in task.metric_fns: scores = metric_fn(eval_cache.targets[task.name], predictions) for metric_name, metric_value in scores.items(): tag = f'eval/{task.name}/{metric_name}' eval_summary_writer.scalar(tag, metric_value, host_step) logging.info('EVAL %s at step %d: %.3f', tag, host_step, metric_value) eval_summary_writer.flush() # Save text samples for tensorboard. exemplars = '' for n in np.random.choice(np.arange(len(predictions)), 8): tgt_txt = tf.compat.as_text( eval_cache.examples[task.name][n]['targets_plaintext']) pred_txt = raw_predictions[n] exemplars += (f'{eval_cache.inputs[task.name][n]}\n\n' f'target: {tgt_txt}\n\n' f'prediction: {pred_txt}\n\n') eval_summary_writer.text(f'{task.name} samples', exemplars, host_step) eval_summary_writer.flush() # Take an Xprof trace after the first loop has compiled everything. if epoch == first_epoch + 1: train_lib.sync_devices() # For on-device loop, we launch the computation before feeding data. logging.info('BEGIN Train loop.') if CFG.infeed: optimizer, dropout_rngs, metrics, device_step = p_train_epoch( optimizer, dropout_rngs, metrics, train_lib.unbroadcast(device_step), epoch) optimizer = train_lib.unbroadcast(optimizer) # Epoch loop. while int(host_step // steps_per_epoch) == epoch: batch = next(train_iter) batch = jax.tree_map( lambda x: x.reshape( (topology.per_replica_set_num_replicas, -1) + x.shape[1:]), batch) # Feed the on-device training loop. if CFG.infeed: for i, device in enumerate(local_devices): # When using infeed to provide data to the computation, we're on our # own for feeding the right values to the right devices. Each device # should get the minibatch corresponding to its replica, a slice of # the larger batch corresponding to the host's replica set. if device.platform == 'tpu': device_coords = (*device.coords, device.id % 2) else: device_coords = (device.host_id, i) per_replica_set_device_coords = tuple( dc % prsm for dc, prsm in zip( device_coords, topology.per_replica_set_mesh)) per_replica_set_replica_coords = tuple( prsdc // prm for prsdc, prm in zip(per_replica_set_device_coords, topology.per_replica_mesh)) per_replica_set_replica_id = 0 for prsm, prm, prsrc in zip( topology.per_replica_set_mesh, topology.per_replica_mesh, per_replica_set_replica_coords): per_replica_set_replica_id = ( per_replica_set_replica_id * prsm // prm + prsrc) input_tuple = tuple([ batch[k][per_replica_set_replica_id] for k in train_keys ]) # Safety check: infeed does not check shape or types but requires # them to agree with on-device spec, otherwise the queue and program # stalls. tuple_shapes = jax.tree_map(jnp.shape, input_tuple) tuple_dtypes = jax.tree_map(lambda x: x.dtype, input_tuple) assert tuple_shapes == device_train_input_shape, ( 'infeed shape error %s != %s' % (tuple_shapes, device_train_input_shape)) assert tuple(set(tuple_dtypes)) == (jnp.int32,), \ ('infeed dtype error %s not all of type %s' % ( tuple_dtypes, jnp.int32)) infeed_pool.submit( functools.partial(device.transfer_to_infeed, input_tuple)) # Host training loop. else: optimizer, metrics, dropout_rngs = p_train_step( optimizer, batch, metrics, dropout_rngs) optimizer = train_lib.unbroadcast(optimizer) host_step += 1 logging.info('END Train loop.') # Maybe save a checkpoint on one host. if (CFG.save_checkpoints and epoch % CFG.checkpoint_freq == CFG.checkpoint_freq - 1 and jax.host_id() == 0): checkpoints.save_checkpoint(FLAGS.model_dir, optimizer, host_step) # Gather training metrics. metrics = p_allreduce_metrics(metrics) metrics = jax.tree_map(lambda x: jax.device_get(x[0]), metrics) denominator = metrics.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics) # pylint: disable=cell-var-from-loop logging.info('train in step: %s, %s', host_step, summary) if jax.host_id() == 0: assert train_summary_writer for key, val in summary.items(): train_summary_writer.scalar(key, val, host_step) train_summary_writer.flush() # Gather training evaluation metrics. logging.info('Gathering training evaluation metrics.') eval_metrics = [] eval_iter = eval_ds.as_numpy_iterator() for _, eval_batch in zip(range(CFG.num_eval_steps), eval_iter): eval_batch = jax.tree_map( lambda x: x.reshape( (topology.per_replica_set_num_replicas, -1) + x.shape[1:]), eval_batch) metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) # average metrics across devices eval_metrics = p_allreduce_metrics(eval_metrics) eval_metrics = common_utils.get_metrics(eval_metrics) # average metrics across steps eval_metrics = jax.tree_map(np.sum, eval_metrics) eval_denominator = eval_metrics.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics) logging.info('eval in step: %s, %s', host_step, eval_summary) if jax.host_id() == 0: assert eval_summary_writer for key, val in eval_summary.items(): eval_summary_writer.scalar(key, val, host_step) eval_summary_writer.flush() # Wait until computations are done before exiting logging.info('Finished.') train_lib.sync_devices() # Shut down the infeed threadpool. if CFG.infeed: infeed_pool.shutdown()
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Make sure tf does not allocate gpu memory. tf.config.experimental.set_visible_devices([], 'GPU') if FLAGS.jax_backend_target: jax.config.FLAGS.jax_xla_backend = 'tpu_driver' jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: tf.io.gfile.makedirs(FLAGS.model_dir) train_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'train')) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'eval')) if FLAGS.batch_size % n_devices: raise ValueError( 'Batch size must be divisible by the number of devices') vocab_path = FLAGS.vocab_path if vocab_path is None: vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model') tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) # Load Dataset logging.info('Initializing dataset.') train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets( n_devices=n_devices, dataset_name=FLAGS.dataset_name, eval_dataset_name=FLAGS.eval_dataset_name, shard_idx=jax.host_id(), shard_count=jax.host_count(), data_dir=FLAGS.data_dir, vocab_path=vocab_path, target_vocab_size=FLAGS.vocab_size, batch_size=FLAGS.batch_size, max_length=FLAGS.max_target_length, max_eval_length=FLAGS.max_eval_target_length) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_id = decode.EOS_ID # Default Sentencepiece EOS token. def decode_tokens(toks): valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode('utf-8') logging.info('Initializing model, optimizer, and step functions.') # Build Model and Optimizer transformer_kwargs = { 'vocab_size': vocab_size, 'output_vocab_size': vocab_size, 'emb_dim': FLAGS.emb_dim, 'num_heads': FLAGS.num_heads, 'num_layers': FLAGS.num_layers, 'qkv_dim': FLAGS.qkv_dim, 'mlp_dim': FLAGS.mlp_dim, 'max_len': max(FLAGS.max_target_length, FLAGS.max_eval_target_length), 'share_embeddings': FLAGS.share_embeddings, 'logits_via_embedding': FLAGS.logits_via_embedding, } start_step = 0 rng = random.PRNGKey(FLAGS.random_seed) rng, init_rng = random.split(rng) input_shape = (FLAGS.batch_size, FLAGS.max_target_length) target_shape = (FLAGS.batch_size, FLAGS.max_target_length) model, cache_def = create_model(init_rng, input_shape, target_shape, transformer_kwargs) optimizer = create_optimizer(model, FLAGS.learning_rate, FLAGS.weight_decay) # We access model only from optimizer below via optimizer.target. del model if FLAGS.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer) # Grab last step. start_step = int(optimizer.state.step) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=FLAGS.learning_rate, warmup_steps=FLAGS.warmup_steps) p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn, label_smoothing=FLAGS.label_smoothing, use_bfloat16=FLAGS.use_bfloat16), axis_name='batch') p_eval_step = jax.pmap(functools.partial( eval_step, label_smoothing=FLAGS.label_smoothing, use_bfloat16=FLAGS.use_bfloat16), axis_name='batch') p_pred_step = jax.pmap( functools.partial(predict_step, use_bfloat16=FLAGS.use_bfloat16, beam_size=FLAGS.beam_size), axis_name='batch', static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. dropout_rngs = random.split(rng, n_devices) logging.info('Starting training loop.') metrics_all = [] t_loop_start = time.time() for step, batch in zip(range(start_step, FLAGS.num_train_steps), train_iter): # Shard data to devices and do a training step. batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access optimizer, metrics, dropout_rngs = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) metrics_all.append(metrics) # Save a checkpoint on one host after every checkpoint_freq steps. if (FLAGS.save_checkpoints and step % FLAGS.checkpoint_freq == 0 and step > 0 and jax.host_id() == 0): checkpoints.save_checkpoint(FLAGS.model_dir, jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if step % FLAGS.eval_frequency != 0 and step > 0: continue logging.info('Gathering training metrics.') # Training Metrics metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr steps_per_eval = FLAGS.eval_frequency if step != 0 else 1 steps_per_sec = steps_per_eval / (time.time() - t_loop_start) t_loop_start = time.time() if jax.host_id() == 0: train_summary_writer.scalar('steps per second', steps_per_sec, step) for key, val in summary.items(): train_summary_writer.scalar(key, val, step) train_summary_writer.flush() metrics_all = [] logging.info('train in step: %d, loss: %.4f', step, summary['loss']) # Eval Metrics logging.info('Gathering evaluation metrics.') t_eval_start = time.time() eval_metrics = [] eval_iter = iter(eval_ds) for _, eval_batch in zip(range(FLAGS.num_eval_steps), eval_iter): eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access eval_batch = common_utils.shard(eval_batch) metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) if jax.host_id() == 0: for key, val in eval_summary.items(): eval_summary_writer.scalar(key, val, step) eval_summary_writer.flush() logging.info('eval in step: %d, loss: %.4f', step, eval_summary['loss']) logging.info('eval time: %.4f s step %d', time.time() - t_eval_start, step) # Translation and BLEU Score. logging.info('Translating evaluation dataset.') t_inference_start = time.time() predict_iter = iter(predict_ds) sources, references, predictions = [], [], [] for _, pred_batch in enumerate(predict_iter): pred_batch = jax.tree_map(lambda x: x._numpy(), pred_batch) # pylint: disable=protected-access # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch['inputs'].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) pred_batch = jax.tree_map( lambda x: pad_examples(x, padded_size), pred_batch) # pylint: disable=cell-var-from-loop pred_batch = common_utils.shard(pred_batch) per_device_batchsize = pred_batch['inputs'].shape[1] cache_dtype = jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32 cache = jax_utils.replicate( cache_def.initialize_cache( (per_device_batchsize, FLAGS.max_predict_length), dtype=cache_dtype)) predicted = p_pred_step(pred_batch['inputs'], optimizer.target, cache, eos_id, FLAGS.max_predict_length) predicted = tohost(predicted) inputs = tohost(pred_batch['inputs']) targets = tohost(pred_batch['targets']) # Iterate through non-padding examples of batch. for i, s in enumerate(predicted[:cur_pred_batch_size]): sources.append(decode_tokens(inputs[i])) references.append(decode_tokens(targets[i])) predictions.append(decode_tokens(s)) logging.info('Translation: %d predictions %d references %d sources.', len(predictions), len(references), len(sources)) logging.info('Translation time: %.4f s step %d.', time.time() - t_inference_start, step) # Calculate BLEU score for translated eval corpus against reference. bleu_matches = bleu.bleu_partial(references, predictions) all_bleu_matches = per_host_sum_pmap(bleu_matches) bleu_score = bleu.complete_bleu(*all_bleu_matches) # Save translation samples for tensorboard. exemplars = '' for n in np.random.choice(np.arange(len(predictions)), 8): exemplars += f'{sources[n]}\n\n{references[n]}\n\n{predictions[n]}\n\n' if jax.host_id() == 0: eval_summary_writer.scalar('bleu', bleu_score, step) eval_summary_writer.text('samples', exemplars, step) eval_summary_writer.flush() logging.info('Translation BLEU Score %.4f', bleu_score)
def main(unused_argv): config = utils.load_config() dataset = datasets.get_dataset('test', FLAGS.data_dir, config) model, init_variables = models.construct_mipnerf( random.PRNGKey(20200823), dataset.peek()) optimizer = flax.optim.Adam(config.lr_init).create(init_variables) state = utils.TrainState(optimizer=optimizer) del optimizer, init_variables # Rendering is forced to be deterministic even if training was randomized, as # this eliminates 'speckle' artifacts. def render_eval_fn(variables, _, rays): return jax.lax.all_gather( model.apply( variables, random.PRNGKey(0), # Unused. rays, randomized=False, white_bkgd=config.white_bkgd), axis_name='batch') # pmap over only the data input. render_eval_pfn = jax.pmap( render_eval_fn, in_axes=(None, None, 0), donate_argnums=2, axis_name='batch', ) ssim_fn = jax.jit(functools.partial(math.compute_ssim, max_val=1.)) last_step = 0 out_dir = path.join(FLAGS.train_dir, 'path_renders' if config.render_path else 'test_preds') if not FLAGS.eval_once: summary_writer = tensorboard.SummaryWriter( path.join(FLAGS.train_dir, 'eval')) while True: state = checkpoints.restore_checkpoint(FLAGS.train_dir, state) step = int(state.optimizer.state.step) if step <= last_step: continue if FLAGS.save_output and (not utils.isdir(out_dir)): utils.makedirs(out_dir) psnr_values = [] ssim_values = [] avg_values = [] if not FLAGS.eval_once: showcase_index = random.randint(random.PRNGKey(step), (), 0, dataset.size) for idx in range(dataset.size): print(f'Evaluating {idx+1}/{dataset.size}') batch = next(dataset) pred_color, pred_distance, pred_acc = models.render_image( functools.partial(render_eval_pfn, state.optimizer.target), batch['rays'], None, chunk=FLAGS.chunk) vis_suite = vis.visualize_suite(pred_distance, pred_acc) if jax.host_id() != 0: # Only record via host 0. continue if not FLAGS.eval_once and idx == showcase_index: showcase_color = pred_color showcase_acc = pred_acc showcase_vis_suite = vis_suite if not config.render_path: showcase_gt = batch['pixels'] if not config.render_path: psnr = float( math.mse_to_psnr(((pred_color - batch['pixels'])**2).mean())) ssim = float(ssim_fn(pred_color, batch['pixels'])) print(f'PSNR={psnr:.4f} SSIM={ssim:.4f}') psnr_values.append(psnr) ssim_values.append(ssim) if FLAGS.save_output and (config.test_render_interval > 0): if (idx % config.test_render_interval) == 0: utils.save_img_uint8( pred_color, path.join(out_dir, 'color_{:03d}.png'.format(idx))) utils.save_img_float32( pred_distance, path.join(out_dir, 'distance_{:03d}.tiff'.format(idx))) utils.save_img_float32( pred_acc, path.join(out_dir, 'acc_{:03d}.tiff'.format(idx))) for k, v in vis_suite.items(): utils.save_img_uint8( v, path.join(out_dir, k + '_{:03d}.png'.format(idx))) if (not FLAGS.eval_once) and (jax.host_id() == 0): summary_writer.image('pred_color', showcase_color, step) summary_writer.image('pred_acc', showcase_acc, step) for k, v in showcase_vis_suite.items(): summary_writer.image('pred_' + k, v, step) if not config.render_path: summary_writer.scalar('psnr', np.mean(np.array(psnr_values)), step) summary_writer.scalar('ssim', np.mean(np.array(ssim_values)), step) summary_writer.image('target', showcase_gt, step) if FLAGS.save_output and (not config.render_path) and (jax.host_id() == 0): with utils.open_file(path.join(out_dir, f'psnrs_{step}.txt'), 'w') as f: f.write(' '.join([str(v) for v in psnr_values])) with utils.open_file(path.join(out_dir, f'ssims_{step}.txt'), 'w') as f: f.write(' '.join([str(v) for v in ssim_values])) if FLAGS.eval_once: break if int(step) >= config.max_steps: break last_step = step
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Make sure tf does not allocate gpu memory. tf.config.experimental.set_visible_devices([], 'GPU') if FLAGS.jax_backend_target: jax.config.FLAGS.jax_xla_backend = 'tpu_driver' jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: tf.io.gfile.makedirs(FLAGS.model_dir) if FLAGS.batch_size % n_devices: raise ValueError('Batch size must be divisible by the number of devices') vocab_path = FLAGS.vocab_path if vocab_path is None: vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model') tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') if FLAGS.dynamic: train_ds_mgr, eval_ds, predict_ds, encoder = input_pipeline.get_dynamic_datasets( dataset_name=FLAGS.dataset_name, eval_dataset_name=FLAGS.eval_dataset_name, shard_idx=jax.host_id(), shard_count=jax.host_count(), data_dir=FLAGS.data_dir, vocab_path=FLAGS.vocab_path, target_vocab_size=FLAGS.vocab_size, batch_size=FLAGS.batch_size, max_length=FLAGS.max_target_length, max_eval_length=FLAGS.max_eval_target_length, paracrawl_size=FLAGS.paracrawl_size, is_scores_path=FLAGS.is_scores_path, num_buckets=FLAGS.num_data_buckets) if FLAGS.static: weights = np.array([float(w) for w in FLAGS.static.split(',')]) assert len(weights) == FLAGS.num_data_buckets train_ds = train_ds_mgr.sampled_dataset(weights) FLAGS.dynamic = False else: init_dist = np.zeros(FLAGS.num_data_buckets) if FLAGS.data_selection_size < FLAGS.num_data_buckets: init_dist[range(FLAGS.data_selection_size)] = 1.0 train_ds = train_ds_mgr.sampled_dataset(init_dist) else: train_ds = build_split(train_ds_mgr, 1.0) else: train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets( dataset_name=FLAGS.dataset_name, eval_dataset_name=FLAGS.eval_dataset_name, shard_idx=jax.host_id(), shard_count=jax.host_count(), data_dir=FLAGS.data_dir, vocab_path=vocab_path, target_vocab_size=FLAGS.vocab_size, batch_size=FLAGS.batch_size, max_length=FLAGS.max_target_length, max_eval_length=FLAGS.max_eval_target_length, paracrawl_size=FLAGS.paracrawl_size, is_scores_path=FLAGS.is_scores_path, num_to_keep=FLAGS.data_selection_size, pseudo_path=FLAGS.pseudo_path, repeat_count=FLAGS.repeat_count, newscommentary_size=FLAGS.newscommentary_size) if FLAGS.aux_eval_dataset: aux_datasets = [] aux_names = FLAGS.aux_eval_dataset.split(',') for name in aux_names: _, aux_eval_ds, _, _ = input_pipeline.get_wmt_datasets( dataset_name=name, eval_dataset_name=None, shard_idx=jax.host_id(), shard_count=jax.host_count(), data_dir=FLAGS.data_dir, vocab_path=vocab_path, target_vocab_size=FLAGS.vocab_size, batch_size=FLAGS.batch_size, max_length=FLAGS.max_target_length, max_eval_length=FLAGS.max_eval_target_length, paracrawl_size=FLAGS.paracrawl_size, is_scores_path=FLAGS.is_scores_path, num_to_keep=FLAGS.data_selection_size, pseudo_path=FLAGS.pseudo_path, repeat_count=FLAGS.repeat_count, newscommentary_size=FLAGS.newscommentary_size) aux_datasets.append(aux_eval_ds) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_id = decode.EOS_ID # Default Sentencepiece EOS token. def decode_tokens(toks): valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode('utf-8') logging.info('Initializing model, optimizer, and step functions.') # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = models.TransformerConfig( vocab_size=vocab_size, output_vocab_size=vocab_size, share_embeddings=FLAGS.share_embeddings, logits_via_embedding=FLAGS.logits_via_embedding, dtype=jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32, emb_dim=FLAGS.emb_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.qkv_dim, mlp_dim=FLAGS.mlp_dim, max_len=max(FLAGS.max_target_length, FLAGS.max_eval_target_length), dropout_rate=FLAGS.dropout_rate, attention_dropout_rate=FLAGS.attention_dropout_rate, deterministic=False, decode=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(deterministic=True, decode=True) start_step = 0 rng = jax.random.PRNGKey(FLAGS.random_seed) rng, init_rng = jax.random.split(rng) # It's possible that is supposed to be per device batch size input_shape = (FLAGS.batch_size, FLAGS.max_target_length) target_shape = (FLAGS.batch_size, FLAGS.max_target_length) m = models.Transformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) # apply an optimizer to this tree optimizer_def = optim.Adam( FLAGS.learning_rate, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(initial_variables['params']) # We access model params only from optimizer below via optimizer.target. del initial_variables if FLAGS.restore_checkpoints: logging.info('Restoring checkpoint.') # If we have a pretrained model, use that. Else, just continue where leftoff model_path = FLAGS.pretrained_model_dir if FLAGS.pretrained_model_dir else FLAGS.model_dir optimizer = checkpoints.restore_checkpoint(model_path, optimizer) # Grab last step. start_step = int(optimizer.state.step) if FLAGS.adapter != NONE: adapter = optim.ModelParamTraversal(lambda path, _: FLAGS.adapter in path) optimizer = optimizer_def.create(optimizer.target, focus=adapter) writer = metric_writers.create_default_writer( FLAGS.model_dir, just_logging=jax.process_index() > 0) flag_key = [k for k in FLAGS.flags_by_module_dict().keys() if 'wmt.par' in k ] if flag_key: flag_key = flag_key[0] local_flags = { f.name: f.value for f in FLAGS.flags_by_module_dict()[flag_key] } writer.write_hparams(local_flags) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) if FLAGS.adapter != NONE: learning_rate_fn = common.create_learning_rate_scheduler( factors='constant', base_learning_rate=FLAGS.learning_rate, warmup_steps=FLAGS.warmup_steps) else: learning_rate_fn = common.create_learning_rate_scheduler( base_learning_rate=FLAGS.learning_rate, warmup_steps=FLAGS.warmup_steps, steps_per_cycle=FLAGS.steps_per_cycle, init_step=start_step, finetune_lr=FLAGS.finetune_lr) # compile multidevice versions of train/eval/predict step and cache init fn. p_train_step = jax.pmap( functools.partial( train_step, config=train_config, learning_rate_fn=learning_rate_fn, label_smoothing=FLAGS.label_smoothing), axis_name='batch', donate_argnums=(0,)) # pytype: disable=wrong-arg-types p_eval_step = jax.pmap( functools.partial(eval_step, config=eval_config), axis_name='batch') p_init_cache = jax.pmap( functools.partial( initialize_cache, max_decode_len=FLAGS.max_predict_length, config=predict_config), axis_name='batch') p_pred_step = jax.pmap( functools.partial( predict_step, config=predict_config, beam_size=FLAGS.beam_size), axis_name='batch', static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant p_get_diag_grads = jax.pmap( functools.partial( get_diag_grads, config=eval_config), axis_name='batch') p_get_bucket_score = jax.pmap( functools.partial( get_diag_score, strategy=FLAGS.strategy), axis_name='batch') # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap"d training update for performance. dropout_rngs = jax.random.split(rng, jax.local_device_count()) del rng logging.info('Starting training loop.') hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=FLAGS.num_train_steps, writer=writer) if jax.process_index() == 0: hooks += [ report_progress, periodic_actions.Profile(logdir=FLAGS.model_dir, num_profile_steps=5) ] train_metrics = [] total_steps = start_step + FLAGS.num_train_steps best_eval_loss = 1000 curr_eval_loss = 1000 with metric_writers.ensure_flushes(writer): for step in range(start_step, total_steps): is_last_step = step == total_steps - 1 if FLAGS.dynamic and ((step - start_step) % FLAGS.resample_freq == 0): # Dynamic macro: use gradient alignment to score different ratios # of top k vs bottom N-k bins if FLAGS.macro: train_iter = get_macro_distribution(p_get_diag_grads, p_get_bucket_score, aux_eval_ds, train_ds_mgr, optimizer, eval_ds) else: # Use gradient alignment to score bins # take the top k bins and sample uniformly from them. raw_distribution = get_new_distribution(p_get_diag_grads, p_get_bucket_score, aux_eval_ds, train_ds_mgr, optimizer, eval_ds) logging.info(raw_distribution) selected = np.argsort( raw_distribution)[::-1][:FLAGS.data_selection_size] new_distribution = np.zeros(100) new_distribution[selected] = 1.0 logging.info(new_distribution) train_ds = train_ds_mgr.sampled_dataset(new_distribution) train_iter = iter(train_ds) # Shard data to devices and do a training step. with jax.profiler.StepTraceAnnotation('train', step_num=step): try: batch = common_utils.shard(jax.tree_map(np.asarray, next(train_iter))) optimizer, metrics = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) train_metrics.append(metrics) except StopIteration: is_last_step = True # Quick indication that training is happening. logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step) for h in hooks: h(step) # Periodic metric handling. if (step - start_step) % FLAGS.eval_frequency == 0 or is_last_step: with report_progress.timed('training_metrics'): logging.info('Gathering training metrics.') train_metrics = common_utils.get_metrics(train_metrics) lr = train_metrics.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, train_metrics) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr summary = {'train_' + k: v for k, v in summary.items()} writer.write_scalars(step, summary) train_metrics = [] with report_progress.timed('eval'): eval_results = evaluate( p_eval_step=p_eval_step, target=optimizer.target, eval_ds=eval_ds, num_eval_steps=FLAGS.num_eval_steps) curr_eval_loss = eval_results['loss'] writer.write_scalars( step, {'eval_' + k: v for k, v in eval_results.items()}) if FLAGS.aux_eval_dataset: for aux_i, aux_eval_ds in enumerate(aux_datasets): with report_progress.timed('aux_eval'): eval_results = evaluate( p_eval_step=p_eval_step, target=optimizer.target, eval_ds=aux_eval_ds, num_eval_steps=FLAGS.num_eval_steps) writer.write_scalars( step, { 'aux' + str(aux_i) + '_eval_' + k: v for k, v in eval_results.items() }) if FLAGS.compute_bleu: with report_progress.timed('translate_and_bleu'): exemplars, bleu_score = translate_and_calculate_bleu( p_pred_step=p_pred_step, p_init_cache=p_init_cache, target=optimizer.target, predict_ds=predict_ds, decode_tokens=decode_tokens, max_predict_length=FLAGS.max_predict_length) writer.write_scalars(step, {'bleu': bleu_score}) writer.write_texts(step, {'samples': exemplars}) # Save a checkpoint on one host after every checkpoint_freq steps. save_checkpoint = ((step - start_step) % FLAGS.checkpoint_freq == 0 or is_last_step) if FLAGS.save_checkpoints and save_checkpoint and jax.host_id() == 0: if curr_eval_loss < best_eval_loss: # only save better checkpoints best_eval_loss = curr_eval_loss with report_progress.timed('checkpoint'): checkpoints.save_checkpoint( FLAGS.model_dir, jax_utils.unreplicate(optimizer), step, keep=FLAGS.chkpts_to_keep, overwrite=True) if is_last_step: break
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): """Runs training interleaved with evaluation.""" # Setup input pipeline dataset_info = input_pipeline.get_dataset_info(config.dataset, 'train') ds_train, ds_test = input_pipeline.get_datasets(config) batch = next(iter(ds_train)) logging.info(ds_train) logging.info(ds_test) # Build VisionTransformer architecture model_cls = {'ViT': models.VisionTransformer, 'Mixer': models.MlpMixer}[config.get('model_type', 'ViT')] model = model_cls(num_classes=dataset_info['num_classes'], **config.model) def init_model(): return model.init( jax.random.PRNGKey(0), # Discard the "num_local_devices" dimension for initialization. jnp.ones(batch['image'].shape[1:], batch['image'].dtype.name), train=False) # Use JIT to make sure params reside in CPU memory. variables = jax.jit(init_model, backend='cpu')() model_or_filename = config.get('model_or_filename') if model_or_filename: # Loading model from repo published with "How to train your ViT? Data, # Augmentation, and Regularization in Vision Transformers" paper. # https://arxiv.org/abs/2106.10270 if '-' in model_or_filename: filename = model_or_filename else: # Select best checkpoint from i21k pretraining by final upstream # validation accuracy. df = checkpoint.get_augreg_df(directory=config.pretrained_dir) sel = df.filename.apply( lambda filename: filename.split('-')[0] == model_or_filename) best = df.loc[sel].query('ds=="i21k"').sort_values('final_val').iloc[-1] filename = best.filename logging.info('Selected fillename="%s" for "%s" with final_val=%.3f', filename, model_or_filename, best.final_val) pretrained_path = os.path.join(config.pretrained_dir, f'{config.model.name}.npz') else: # ViT / Mixer papers filename = config.model.name pretrained_path = os.path.join(config.pretrained_dir, f'{filename}.npz') if not tf.io.gfile.exists(pretrained_path): raise ValueError( f'Could not find "{pretrained_path}" - you can download models from ' '"gs://vit_models/imagenet21k" or directly set ' '--config.pretrained_dir="gs://vit_models/imagenet21k".') params = checkpoint.load_pretrained( pretrained_path=pretrained_path, init_params=variables['params'], model_config=config.model) total_steps = config.total_steps lr_fn = utils.create_learning_rate_schedule(total_steps, config.base_lr, config.decay_type, config.warmup_steps) update_fn_repl = make_update_fn( apply_fn=model.apply, accum_steps=config.accum_steps, lr_fn=lr_fn) infer_fn_repl = jax.pmap(functools.partial(model.apply, train=False)) # Create optimizer and replicate it over all TPUs/GPUs opt = momentum_clip.Optimizer( dtype=config.optim_dtype, grad_norm_clip=config.grad_norm_clip).create(params) initial_step = 1 opt, initial_step = flax_checkpoints.restore_checkpoint( workdir, (opt, initial_step)) logging.info('Will start/continue training at initial_step=%d', initial_step) opt_repl = flax.jax_utils.replicate(opt) # Delete references to the objects that are not needed anymore del opt del params # Prepare the learning-rate and pre-fetch it to device to avoid delays. update_rng_repl = flax.jax_utils.replicate(jax.random.PRNGKey(0)) # Setup metric writer & hooks. writer = metric_writers.create_default_writer(workdir, asynchronous=False) writer.write_hparams(config.to_dict()) hooks = [ periodic_actions.Profile(logdir=workdir), periodic_actions.ReportProgress( num_train_steps=total_steps, writer=writer), ] # Run training loop logging.info('Starting training loop; initial compile can take a while...') t0 = lt0 = time.time() lstep = initial_step for step, batch in zip( range(initial_step, total_steps + 1), input_pipeline.prefetch(ds_train, config.prefetch)): with jax.profiler.StepTraceContext('train', step_num=step): opt_repl, loss_repl, update_rng_repl = update_fn_repl( opt_repl, flax.jax_utils.replicate(step), batch, update_rng_repl) for hook in hooks: hook(step) if step == initial_step: logging.info('First step took %.1f seconds.', time.time() - t0) t0 = time.time() lt0, lstep = time.time(), step # Report training metrics if config.progress_every and step % config.progress_every == 0: img_sec_core_train = (config.batch * (step - lstep) / (time.time() - lt0)) / jax.device_count() lt0, lstep = time.time(), step writer.write_scalars( step, dict( train_loss=float(flax.jax_utils.unreplicate(loss_repl)), img_sec_core_train=img_sec_core_train)) done = step / total_steps logging.info(f'Step: {step}/{total_steps} {100*done:.1f}%, ' # pylint: disable=logging-format-interpolation f'img/sec/core: {img_sec_core_train:.1f}, ' f'ETA: {(time.time()-t0)/done*(1-done)/3600:.2f}h') # Run evaluation if ((config.eval_every and step % config.eval_every == 0) or (step == total_steps)): accuracies = [] lt0 = time.time() for test_batch in input_pipeline.prefetch(ds_test, config.prefetch): logits = infer_fn_repl( dict(params=opt_repl.target), test_batch['image']) accuracies.append( (np.argmax(logits, axis=-1) == np.argmax(test_batch['label'], axis=-1)).mean()) accuracy_test = np.mean(accuracies) img_sec_core_test = ( config.batch_eval * ds_test.cardinality().numpy() / (time.time() - lt0) / jax.device_count()) lt0 = time.time() lr = float(lr_fn(step)) logging.info(f'Step: {step} ' # pylint: disable=logging-format-interpolation f'Learning rate: {lr:.7f}, ' f'Test accuracy: {accuracy_test:0.5f}, ' f'img/sec/core: {img_sec_core_test:.1f}') writer.write_scalars( step, dict( accuracy_test=accuracy_test, lr=lr, img_sec_core_test=img_sec_core_test)) # Store checkpoint. if ((config.checkpoint_every and step % config.eval_every == 0) or step == total_steps): checkpoint_path = flax_checkpoints.save_checkpoint( workdir, (flax.jax_utils.unreplicate(opt_repl), step), step) logging.info('Stored checkpoint at step %d to "%s"', step, checkpoint_path) return flax.jax_utils.unreplicate(opt_repl)
def main(_): if FLAGS.use_transformer: assert ( FLAGS.encoder_fn_name == 'transformer' ), 'encoder_fn_name must be transformer if use_transformer is True!' assert (FLAGS.epochs % FLAGS.measurements == 0 ), 'Number of measurements must divide number of epochs!' measurement_epochs = FLAGS.epochs // FLAGS.measurements assert FLAGS.results_save_dir != '', 'Specify results_save_dir!' assert FLAGS.label != '', 'Specify label!' if FLAGS.load_model: assert FLAGS.load_model_dir != '', 'Specify load_model_dir!' assert FLAGS.load_model_step > 0, 'Loaded model must have been trained for more than 0 steps.' if FLAGS.save_model: assert FLAGS.save_model_dir != '', 'Specify save_model_dir!' datum = { 'label': FLAGS.label, 'encoder_fn_name': FLAGS.encoder_fn_name, 'encoder_fn_kwargs_path': FLAGS.encoder_fn_kwargs_path, 'reduce_fn_name': FLAGS.reduce_fn_name, 'reduce_fn_kwargs_path': FLAGS.reduce_fn_kwargs_path, 'epochs': FLAGS.epochs, 'measurements': FLAGS.measurements, 'lens_batch_size': FLAGS.lens_batch_size, 'knn_batch_size': FLAGS.knn_batch_size, 'encoder_lr': FLAGS.encoder_lr, 'lens_lr': FLAGS.lens_lr, 'predictor_lr': FLAGS.predictor_lr, 'encoder_wd': FLAGS.encoder_wd, 'lens_wd': FLAGS.lens_wd, 'predictor_wd': FLAGS.predictor_wd, 'train_families': FLAGS.train_families, 'lens_train_samples': FLAGS.lens_train_samples, 'first_test_family': FLAGS.first_test_family, 'last_test_family': FLAGS.last_test_family, 'lens_shuffle_seed': FLAGS.lens_shuffle_seed, 'lens_sample_random_state': FLAGS.lens_sample_random_state, 'knn_shuffle_seed': FLAGS.knn_shuffle_seed, 'knn_sample_random_state': FLAGS.knn_sample_random_state, 'random_key': FLAGS.random_key, 'use_transformer': FLAGS.use_transformer, 'use_bert': FLAGS.use_bert, 'restore_transformer_dir': FLAGS.restore_transformer_dir, 'gcs_bucket': FLAGS.gcs_bucket, 'data_partitions_dirpath': FLAGS.data_partitions_dirpath, 'results_save_dir': FLAGS.results_save_dir, 'load_model': FLAGS.load_model, 'load_model_dir': FLAGS.load_model_dir, 'load_model_step': FLAGS.load_model_step, 'save_model': FLAGS.save_model, 'save_model_dir': FLAGS.save_model_dir } gcsfs = GCSFS(FLAGS.gcs_bucket) print(datum) df = pd.DataFrame([datum]) with gcsfs.open(os.path.join(FLAGS.results_save_dir, FLAGS.label + '.csv'), 'w') as gcs_file: df.to_csv(gcs_file, index=False) knn_train_samples_ = [1, 5, 10, 50] num_families = len(family_ids) loss_fn_kwargs = {'num_classes': num_families} lens_knn_train_family_accessions = [] for _ in range(1, FLAGS.train_families + 1): family_name = 'PF%05d' % _ lens_knn_train_family_accessions.append(family_name) knn_test_family_accessions = [] for _ in range(FLAGS.first_test_family, FLAGS.last_test_family + 1): family_name = 'PF%05d' % _ knn_test_family_accessions.append(family_name) encoder_fn = encoder_fn_name_to_fn(FLAGS.encoder_fn_name) encoder_fn_kwargs = json.load( open( resource_filename( 'contextual_lenses.resources', os.path.join('encoder_fn_kwargs_resources', FLAGS.encoder_fn_kwargs_path + '.json')))) reduce_fn = reduce_fn_name_to_fn(FLAGS.reduce_fn_name) reduce_fn_kwargs = json.load( open( resource_filename( 'contextual_lenses.resources', os.path.join('reduce_fn_kwargs_resources', FLAGS.reduce_fn_kwargs_path + '.json')))) layers, trainable_encoder = architecture_to_layers(FLAGS.encoder_fn_name, FLAGS.reduce_fn_name) embedding_model = create_model( use_transformer=FLAGS.use_transformer, use_bert=FLAGS.use_bert, restore_transformer_dir=FLAGS.restore_transformer_dir, encoder_fn=encoder_fn, encoder_fn_kwargs=encoder_fn_kwargs, reduce_fn=reduce_fn, reduce_fn_kwargs=reduce_fn_kwargs, layers=layers, output='embedding') datum.update( measure_nearest_neighbor_performance( accuracy_label= 'train_knn_accuracy_untrained_lens_1_knn_train_samples', encoder=embedding_model, family_accessions=lens_knn_train_family_accessions, batch_size=FLAGS.knn_batch_size, train_samples=1, shuffle_seed=FLAGS.knn_shuffle_seed, sample_random_state=FLAGS.knn_sample_random_state)) for knn_train_samples in knn_train_samples_: datum.update( measure_nearest_neighbor_performance( accuracy_label='test_knn_accuracy_untrained_lens_' + str(knn_train_samples) + '_knn_train_samples', encoder=embedding_model, family_accessions=knn_test_family_accessions, batch_size=FLAGS.knn_batch_size, train_samples=knn_train_samples, shuffle_seed=FLAGS.knn_shuffle_seed, sample_random_state=FLAGS.knn_sample_random_state)) encoder_fn_params = None reduce_fn_params = None predict_fn_params = None model = create_model(use_transformer=FLAGS.use_transformer, use_bert=FLAGS.use_bert, restore_transformer_dir=FLAGS.restore_transformer_dir, encoder_fn=encoder_fn, encoder_fn_kwargs=encoder_fn_kwargs, reduce_fn=reduce_fn, reduce_fn_kwargs=reduce_fn_kwargs, layers=layers, output='prediction', encoder_fn_params=encoder_fn_params, reduce_fn_params=reduce_fn_params, predict_fn_params=predict_fn_params) optimizer = create_optimizer( model=model, learning_rate=[FLAGS.encoder_lr, FLAGS.lens_lr, FLAGS.predictor_lr], weight_decay=[FLAGS.encoder_wd, FLAGS.lens_wd, FLAGS.predictor_wd], layers=layers) if FLAGS.load_model: optimizer = checkpoints.restore_checkpoint(ckpt_dir=os.path.join( 'gs://' + FLAGS.gcs_bucket, FLAGS.load_model_dir), target=optimizer, step=FLAGS.load_model_step) trained_params = optimizer.target.params embedding_model = set_model_parameters(model=embedding_model, params=trained_params) if FLAGS.save_model: checkpoints.save_checkpoint(ckpt_dir=os.path.join( 'gs://' + FLAGS.gcs_bucket, FLAGS.save_model_dir), target=optimizer, step=FLAGS.load_model_step) for i in range(FLAGS.measurements): train_batches, train_indexes = create_pfam_batches( family_accessions=lens_knn_train_family_accessions, batch_size=FLAGS.lens_batch_size, samples=FLAGS.lens_train_samples, epochs=measurement_epochs, drop_remainder=True, shuffle_seed=FLAGS.lens_shuffle_seed + i, sample_random_state=FLAGS.lens_sample_random_state) optimizer = train( model=optimizer.target, train_data=train_batches, loss_fn=cross_entropy_loss, loss_fn_kwargs=loss_fn_kwargs, learning_rate=[ FLAGS.encoder_lr, FLAGS.lens_lr, FLAGS.predictor_lr ], weight_decay=[FLAGS.encoder_wd, FLAGS.lens_wd, FLAGS.predictor_wd], layers=layers) results, preds = pfam_evaluate( predict_fn=optimizer.target, test_family_accessions=lens_knn_train_family_accessions, title=None, loss_fn_kwargs=loss_fn_kwargs, batch_size=FLAGS.lens_batch_size, data_partitions_dirpath=FLAGS.data_partitions_dirpath, gcs_bucket=FLAGS.gcs_bucket) lens_accuracy = results['accuracy'] datum['lens_accuracy' + '_measurement_' + str(i)] = lens_accuracy lens_cross_entropy = float(results['cross_entropy']) datum['lens_cross_entropy' + '_measurement_' + str(i)] = lens_cross_entropy trained_params = optimizer.target.params embedding_model = set_model_parameters(model=embedding_model, params=trained_params) datum.update( measure_nearest_neighbor_performance( accuracy_label= 'train_knn_accuracy_trained_lens_1_knn_train_samples' + '_measurement_' + str(i), encoder=embedding_model, family_accessions=lens_knn_train_family_accessions, batch_size=FLAGS.knn_batch_size, train_samples=1, shuffle_seed=FLAGS.knn_shuffle_seed, sample_random_state=FLAGS.knn_sample_random_state)) for knn_train_samples in knn_train_samples_: datum.update( measure_nearest_neighbor_performance( accuracy_label='test_knn_accuracy_trained_lens_' + str(knn_train_samples) + '_knn_train_samples' + '_measurement_' + str(i), encoder=embedding_model, family_accessions=knn_test_family_accessions, batch_size=FLAGS.knn_batch_size, train_samples=knn_train_samples, shuffle_seed=FLAGS.knn_shuffle_seed, sample_random_state=FLAGS.knn_sample_random_state)) print(datum) df = pd.DataFrame([datum]) with gcsfs.open(os.path.join(FLAGS.results_save_dir, FLAGS.label + '.csv'), 'w') as gcs_file: df.to_csv(gcs_file, index=False) if FLAGS.save_model: checkpoints.save_checkpoint(ckpt_dir=os.path.join( 'gs://' + FLAGS.gcs_bucket, FLAGS.save_model_dir), target=optimizer, step=FLAGS.load_model_step + FLAGS.epochs)
def main(_): tf.enable_v2_behavior() tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) if not gfile.isdir(FLAGS.save_dir): gfile.mkdir(FLAGS.save_dir) hparam_str_dict = dict(seed=FLAGS.seed, lr=FLAGS.lr) # Get hyperparmaters if FLAGS.xm_parameters: for key, value in json.loads(FLAGS.xm_parameters).items(): if key not in hparam_str_dict: hparam_str_dict[key] = value hparam_str = ','.join([ '%s=%s' % (k, str(hparam_str_dict[k])) for k in sorted(hparam_str_dict.keys()) ]) # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) batch_size = FLAGS.per_device_batch_size * n_devices io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.max_characters) program_shape = (FLAGS.per_device_batch_size, FLAGS.max_program_length) # Setup DSL # --------------------------------------------------------------------------- # Build token tables. id_char_table = {i + 1: char for (i, char) in enumerate(dsl.CHARACTER)} char_id_table = {char: id for id, char in id_char_table.items()} id_token_table, token_id_table = dsl_tokens.build_token_tables() io_vocab_size = len(char_id_table) + 1 # For padding. program_vocab_size = len(token_id_table) + 1 bos_token = token_id_table[dsl.BOS] eos_token = token_id_table[dsl.EOS] def decode_io(inputs, outputs): """Decode io examples tokens.""" def decode_str(s): """Decode string tokens.""" return ''.join([id_char_table[c_id] for c_id in s if c_id > 0]) io_string = '' inps, outs = [], [] for inp, out in zip(inputs, outputs): inps.append(decode_str(inp)) outs.append(decode_str(out)) io_string += inps[-1] + ' < ' + outs[-1] + ' > ' return inps, outs, io_string[:-3] # Remove last separator. def decode_program(program): """Decode program tokens.""" program = program[:np.argmax(program == eos_token) + 1].astype( np.int32) try: p = dsl.decode_program(program, id_token_table) return p, p.to_string() except: # pylint: disable=bare-except return None, '' # Program does not compile. # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') if not FLAGS.dataset_filepattern: raise ValueError('Must specify filepattern to dataset.') # Training dataset. dataset = input_pipeline.create_dataset_from_tf_record( FLAGS.dataset_filepattern, token_id_table, char_id_table) dataset = dataset.padded_batch(batch_size, padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:]), drop_remainder=True) # Split evaluation and training. eval_ds = dataset.take(FLAGS.num_eval_steps) # Decrease batch of predict dataset to handle beam search. predict_ds = eval_ds.unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:])) train_ds = dataset.skip(FLAGS.num_eval_steps).repeat() train_iter = train_ds.as_numpy_iterator() # Build Model and Optimizer # --------------------------------------------------------------------------- base_train_config = models.TransformerConfig( vocab_size=io_vocab_size, output_vocab_size=program_vocab_size, shift=True, emb_dim=FLAGS.embedding_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.embedding_dim, mlp_dim=FLAGS.hidden_dim, max_len=max(FLAGS.max_characters, FLAGS.max_program_length), deterministic=False, decode=False, bos_token=bos_token) base_eval_config = base_train_config.replace(deterministic=True, train_vq=False) base_predict_config = base_train_config.replace(shift=False, deterministic=True, train_vq=False, decode=True) train_config = models.LatentTransformerConfig( base_cfg=base_train_config, latent_vocab_size=FLAGS.latent_vocab_size, c=FLAGS.c, train_vq=True, commitment_cost_vq=FLAGS.commitment_cost_vq) eval_config = models.LatentTransformerConfig( base_cfg=base_eval_config, latent_vocab_size=FLAGS.latent_vocab_size, c=FLAGS.c, train_vq=True, commitment_cost_vq=FLAGS.commitment_cost_vq) predict_config = models.LatentTransformerConfig( base_cfg=base_predict_config, latent_vocab_size=FLAGS.latent_vocab_size, c=FLAGS.c, train_vq=True, commitment_cost_vq=FLAGS.commitment_cost_vq) # Latent Predictor. lp_train_config = models.TransformerConfig( vocab_size=io_vocab_size, output_vocab_size=FLAGS.latent_vocab_size, shift=True, emb_dim=FLAGS.embedding_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.embedding_dim, mlp_dim=FLAGS.hidden_dim, max_len=max(FLAGS.max_characters, FLAGS.max_program_length), deterministic=False, decode=False, bos_token=bos_token) lp_eval_config = lp_train_config.replace(deterministic=True) lp_predict_config = lp_train_config.replace(shift=False, deterministic=True, decode=True) rng = jax.random.PRNGKey(0) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = jax.random.split(rng) m = models.LatentProgramTransformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(program_shape, jnp.float32)) lp_m = models.ProgramTransformer(lp_eval_config) lp_initial_variables = jax.jit(lp_m.init)(init_rng, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(program_shape, jnp.float32)) optimizer_def = optim.Adam(FLAGS.lr, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(initial_variables['params']) lp_optimizer = optimizer_def.create(lp_initial_variables['params']) state = TrainState(step=0, optimizer=optimizer, model_state=initial_variables['vqvae'], lp_optimizer=lp_optimizer) # Don't keep a copy of the initial model. del initial_variables, lp_initial_variables train_rngs = jax.random.split(rng, jax.local_device_count()) start_step = 0 if FLAGS.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. state = checkpoints.restore_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), state) # Grab last step. start_step = int(state.step) logging.info('Found model checkpointed at step %d.', start_step) state = jax_utils.replicate(state) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=FLAGS.lr) p_train_step = jax.pmap(functools.partial( train_step, bos_token=bos_token, eos_token=eos_token, learning_rate_fn=learning_rate_fn, config=train_config, lp_config=lp_train_config), axis_name='batch', static_broadcasted_argnums=(4, )) p_eval_step = jax.pmap(functools.partial(eval_step, bos_token=bos_token, eos_token=eos_token, config=eval_config, lp_config=lp_eval_config), axis_name='batch') p_init_cache = jax.pmap(functools.partial( initialize_cache, max_decode_len=FLAGS.max_program_length, config=predict_config, lp_config=lp_predict_config), axis_name='batch') p_pred_step = jax.pmap(functools.partial( predict_step, bos_token=bos_token, eos_token=eos_token, max_decode_len=FLAGS.max_program_length, config=predict_config, lp_config=lp_predict_config), axis_name='batch', static_broadcasted_argnums=(5, )) metrics_all = [] latent_metrics_all = [] tick = time.time() for step in range(start_step, FLAGS.num_train_steps): inputs, outputs, programs = common_utils.shard(next(train_iter)) state, metrics, latent_metrics, train_rngs = p_train_step( state, inputs, outputs, programs, step <= FLAGS.num_pretrain_steps, train_rng=train_rngs) metrics, latent_metrics = jax.tree_map(np.array, (metrics, latent_metrics)) metrics_all.append(metrics) latent_metrics_all.append(latent_metrics) # Save a Checkpoint if ((step % FLAGS.checkpoint_freq == 0 and step > 0) or step == FLAGS.num_train_steps - 1): if jax.host_id() == 0: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), jax_utils.unreplicate(state), step) # Periodic metric handling. if not step or step % FLAGS.log_freq != 0: continue logging.info('Gathering training metrics.') # Training Metrics metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map( lambda x: x / denominator, # pylint: disable=cell-var-from-loop metrics_sums) summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) latent_metrics_all = common_utils.get_metrics(latent_metrics_all) metrics_sums = jax.tree_map(jnp.sum, latent_metrics_all) denominator = metrics_sums.pop('denominator') summary.update( jax.tree_map( lambda x: x / denominator, # pylint: disable=cell-var-from-loop metrics_sums)) if jax.host_id() == 0: logging.info('Train in step: %d, loss: %.4f, acc: %.4f', step, summary['loss'], summary['accuracy']) tock = time.time() steps_per_sec = FLAGS.log_freq / (tock - tick) tick = tock summary_writer.scalar('train/steps per second', steps_per_sec, step) for key, val in summary.items(): summary_writer.scalar('train/' + key, val, step) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] latent_metrics_all = [] # Evaluation Metrics logging.info('Gathering evaluation metrics.') t_evaluation_start = time.time() eval_metrics = [] latent_eval_metrics = [] for batches in eval_ds.as_numpy_iterator(): inputs, outputs, programs = common_utils.shard(batches) all_metrics = p_eval_step(state, inputs, outputs, programs) metrics, latent_metrics = jax.tree_map(np.array, all_metrics) eval_metrics.append(metrics) latent_eval_metrics.append(latent_metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) latent_eval_metrics = common_utils.get_metrics(latent_eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, latent_eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary.update( jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums)) if jax.host_id() == 0: logging.info('Evaluation time: %.4f s step %d, loss: %.4f', time.time() - t_evaluation_start, step, eval_summary['loss']) for key, val in eval_summary.items(): summary_writer.scalar('eval/' + key, val, step) summary_writer.flush() # Beam search metrics. logging.info('Gathering beam search metrics.') for beam_size in [10, 50, 100]: t_inference_start = time.time() pred_acc = 0 pred_denominator = 0 ios, targets, predictions, latent_predictions = [], [], [], [] for batches in predict_ds.as_numpy_iterator(): pred_batch = batches # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch[0].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) pred_batch = jax.tree_map( lambda x: pad_examples(x, padded_size), pred_batch) # pylint: disable=cell-var-from-loop inputs, outputs, programs = common_utils.shard(pred_batch) cache, lp_cache = p_init_cache(inputs, outputs, programs) predicted, latent_predicted = p_pred_step( state, inputs, outputs, cache, lp_cache, beam_size) predicted, latent_predicted = map( tohost, (predicted, latent_predicted)) inputs, outputs, programs = map(tohost, (inputs, outputs, programs)) pred_denominator += programs.shape[0] for i, beams in enumerate(predicted): inps, outs, io_string = decode_io(inputs[i], outputs[i]) p, p_idx, p_score = eval_predicted( beams, inps, outs, parse_beam_fn=lambda x: decode_program(x)[0]) if p_score >= len(inps): pred_acc += 1 ios.append(io_string) targets.append(decode_program(programs[i])[1]) predictions.append(p.to_string() if p else '') latent_predictions.append(' '.join( list(np.array(latent_predicted[i, p_idx]).astype(str)))) all_pred_acc, all_pred_denominator = per_host_sum_pmap( jax.tree_map(np.array, (pred_acc, pred_denominator))) # Record beam search results as text summaries. message = [] for n in np.random.choice(np.arange(len(predictions)), 8): text = (f'ios: {ios[n]}\n\ntarget: {targets[n]}\n\n' f'predicted: {predictions[n]}\n\n' f'latent_predicted: {latent_predictions[n]}\n\n') message.append(text) # Write to tensorboard. if jax.host_id() == 0: logging.info( 'Prediction time (beam %d): %.4f s step %d, score %.4f.', beam_size, time.time() - t_inference_start, step, all_pred_acc / all_pred_denominator) summary_writer.scalar('predict/score-{}'.format(beam_size), all_pred_acc / all_pred_denominator, step) summary_writer.text('samples-{}'.format(beam_size), '\n------\n'.join(message), step) summary_writer.flush()
def main(argv): tf.config.experimental.set_visible_devices([], 'GPU') del argv logging.info('*** Starting experiment') gin_configs = FLAGS.gin_configs logging.info('*** Loading Gin configs from: %s', str(gin_configs)) gin.parse_config_files_and_bindings(config_files=gin_configs, bindings=FLAGS.gin_bindings, skip_unknown=True) # Load configurations. exp_config = configs.ExperimentConfig() model_config = configs.ModelConfig() train_config = configs.TrainConfig() # Get directory information. exp_dir = gpath.GPath(FLAGS.base_folder) if exp_config.subname: exp_dir = exp_dir / exp_config.subname summary_dir = exp_dir / 'summaries' / 'train' checkpoint_dir = exp_dir / 'checkpoints' # Log and create directories if this is the main host. if jax.process_index() == 0: logging.info('exp_dir = %s', exp_dir) if not exp_dir.exists(): exp_dir.mkdir(parents=True, exist_ok=True) logging.info('summary_dir = %s', summary_dir) if not summary_dir.exists(): summary_dir.mkdir(parents=True, exist_ok=True) logging.info('checkpoint_dir = %s', checkpoint_dir) if not checkpoint_dir.exists(): checkpoint_dir.mkdir(parents=True, exist_ok=True) config_str = gin.operative_config_str() logging.info('Configuration: \n%s', config_str) with (exp_dir / 'config.gin').open('w') as f: f.write(config_str) logging.info('Starting host %d. There are %d hosts : %s', jax.process_index(), jax.process_count(), str(jax.process_indexs())) logging.info('Found %d accelerator devices: %s.', jax.local_device_count(), str(jax.local_devices())) logging.info('Found %d total devices: %s.', jax.device_count(), str(jax.devices())) rng = random.PRNGKey(exp_config.random_seed) # Shift the numpy random seed by host_id() to shuffle data loaded by different # hosts. np.random.seed(exp_config.random_seed + jax.process_index()) if train_config.batch_size % jax.device_count() != 0: raise ValueError( 'Batch size must be divisible by the number of devices.') devices = jax.local_devices() datasource_spec = exp_config.datasource_spec if datasource_spec is None: datasource_spec = { 'type': exp_config.datasource_type, 'data_dir': FLAGS.data_dir, } logging.info('Creating datasource: %s', datasource_spec) datasource = datasets.from_config( datasource_spec, image_scale=exp_config.image_scale, use_appearance_id=model_config.use_appearance_metadata, use_camera_id=model_config.use_camera_metadata, use_warp_id=model_config.use_warp, use_time=model_config.warp_metadata_encoder_type == 'time', random_seed=exp_config.random_seed, **exp_config.datasource_kwargs) train_iter = datasource.create_iterator( datasource.train_ids, flatten=True, shuffle=True, batch_size=train_config.batch_size, prefetch_size=3, shuffle_buffer_size=train_config.shuffle_buffer_size, devices=devices, ) points_iter = None if train_config.use_background_loss: points = datasource.load_points(shuffle=True) points_batch_size = min( len(points), len(devices) * train_config.background_points_batch_size) points_batch_size -= points_batch_size % len(devices) points_dataset = tf.data.Dataset.from_tensor_slices(points) points_iter = datasets.iterator_from_dataset( points_dataset, batch_size=points_batch_size, prefetch_size=3, devices=devices) learning_rate_sched = schedules.from_config(train_config.lr_schedule) warp_alpha_sched = schedules.from_config(train_config.warp_alpha_schedule) time_alpha_sched = schedules.from_config(train_config.time_alpha_schedule) elastic_loss_weight_sched = schedules.from_config( train_config.elastic_loss_weight_schedule) rng, key = random.split(rng) params = {} model, params['model'] = models.construct_nerf( key, model_config, batch_size=train_config.batch_size, appearance_ids=datasource.appearance_ids, camera_ids=datasource.camera_ids, warp_ids=datasource.warp_ids, near=datasource.near, far=datasource.far, use_warp_jacobian=train_config.use_elastic_loss, use_weights=train_config.use_elastic_loss) optimizer_def = optim.Adam(learning_rate_sched(0)) optimizer = optimizer_def.create(params) state = model_utils.TrainState(optimizer=optimizer, warp_alpha=warp_alpha_sched(0), time_alpha=time_alpha_sched(0)) scalar_params = training.ScalarParams( learning_rate=learning_rate_sched(0), elastic_loss_weight=elastic_loss_weight_sched(0), warp_reg_loss_weight=train_config.warp_reg_loss_weight, warp_reg_loss_alpha=train_config.warp_reg_loss_alpha, warp_reg_loss_scale=train_config.warp_reg_loss_scale, background_loss_weight=train_config.background_loss_weight) state = checkpoints.restore_checkpoint(checkpoint_dir, state) init_step = state.optimizer.state.step + 1 state = jax_utils.replicate(state, devices=devices) del params logging.info('Initializing models') summary_writer = None if jax.process_index() == 0: summary_writer = tensorboard.SummaryWriter(str(summary_dir)) summary_writer.text('gin/train', textdata=gin.config.markdown(config_str), step=0) train_step = functools.partial( training.train_step, model, elastic_reduce_method=train_config.elastic_reduce_method, elastic_loss_type=train_config.elastic_loss_type, use_elastic_loss=train_config.use_elastic_loss, use_background_loss=train_config.use_background_loss, use_warp_reg_loss=train_config.use_warp_reg_loss, ) ptrain_step = jax.pmap( train_step, axis_name='batch', devices=devices, # rng_key, state, batch, scalar_params. in_axes=(0, 0, 0, None), # Treat use_elastic_loss as compile-time static. donate_argnums=(2, ), # Donate the 'batch' argument. ) if devices: n_local_devices = len(devices) else: n_local_devices = jax.local_device_count() logging.info('Starting training') rng = rng + jax.process_index() # Make random seed separate across hosts. keys = random.split(rng, n_local_devices) time_tracker = utils.TimeTracker() time_tracker.tic('data', 'total') for step, batch in zip(range(init_step, train_config.max_steps + 1), train_iter): if points_iter is not None: batch['background_points'] = next(points_iter) time_tracker.toc('data') # pytype: disable=attribute-error scalar_params = scalar_params.replace( learning_rate=learning_rate_sched(step), elastic_loss_weight=elastic_loss_weight_sched(step)) warp_alpha = jax_utils.replicate(warp_alpha_sched(step), devices) time_alpha = jax_utils.replicate(time_alpha_sched(step), devices) state = state.replace(warp_alpha=warp_alpha, time_alpha=time_alpha) with time_tracker.record_time('train_step'): state, stats, keys = ptrain_step(keys, state, batch, scalar_params) time_tracker.toc('total') if step % train_config.print_every == 0 and jax.process_index() == 0: logging.info('step=%d, warp_alpha=%.04f, time_alpha=%.04f, %s', step, warp_alpha_sched(step), time_alpha_sched(step), time_tracker.summary_str('last')) coarse_metrics_str = ', '.join( [f'{k}={v.mean():.04f}' for k, v in stats['coarse'].items()]) fine_metrics_str = ', '.join( [f'{k}={v.mean():.04f}' for k, v in stats['fine'].items()]) logging.info('\tcoarse metrics: %s', coarse_metrics_str) if 'fine' in stats: logging.info('\tfine metrics: %s', fine_metrics_str) if step % train_config.save_every == 0 and jax.process_index() == 0: training.save_checkpoint(checkpoint_dir, state) if step % train_config.log_every == 0 and jax.process_index() == 0: # Only log via host 0. _log_to_tensorboard(summary_writer, jax_utils.unreplicate(state), scalar_params, jax_utils.unreplicate(stats), time_dict=time_tracker.summary('mean')) time_tracker.reset() if step % train_config.histogram_every == 0 and jax.process_index( ) == 0: _log_histograms(summary_writer, model, jax_utils.unreplicate(state)) time_tracker.tic('data', 'total') if train_config.max_steps % train_config.save_every != 0: training.save_checkpoint(checkpoint_dir, state)
def train_and_evaluate(config, workdir): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ is_first_process = jax.process_index() == 0 tf.io.gfile.makedirs(workdir) # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') train_ds, eval_ds, test_ds, encoder = input_pipeline.get_datasets( config) config.seq_length = 250 vocab_size = int(encoder.vocab_size()) config.num_classes = vocab_size config.data_shape = (config.seq_length, 1) logging.info('Training with vocab size %d', vocab_size) def decode_tokens(toks): return encoder.detokenize(toks) start_step = 0 rng = jax.random.PRNGKey(config.seed) rng, init_rng = jax.random.split(rng) config.per_device_batch_size = config.batch_size // jax.process_count() logging.info('Initializing model, optimizer, and step functions.') # Build Model and Optimizer # --------------------------------------------------------------------------- model, initial_variables = model_setup(init_rng, config) # Instead of passing the optimizer fns directly, we use a fn that returns # the optimizer given a learning rate. def tx_fn(lr): return optax.adamw( lr, b1=0.9, b2=0.99, eps=1e-08, eps_root=0.0, weight_decay=config.weight_decay) state = language_train_state.TrainState.create( params=initial_variables['params'], tx_fn=tx_fn) # We access model params only from state below via state.params. del initial_variables if config.restore_checkpoints: # Restore unreplicated model state from last checkpoint. state = checkpoints.restore_checkpoint(workdir, state) # Grab last step. start_step = int(state.step) writer = metric_writers.create_default_writer( workdir, just_logging=not is_first_process) if start_step == 0: config_dict = dict(config) writer.write_hparams(config_dict) if is_first_process and start_step == 0: # Dump config file to work dir for easy model loading. config_path = os.path.join(workdir, 'config') with tf.io.gfile.GFile(config_path, 'wb') as fp: pickle.dump(config, fp) print('Using state', type(state)) # Replicate state. state = jax_utils.replicate(state) learning_rate_fn = create_learning_rate_scheduler( factors=config.lr_factors, base_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps) # Compile multidevice versions of train/eval/predict step fn. p_train_step = jax.pmap( functools.partial( train_step, model=model, learning_rate_fn=learning_rate_fn, clip_grad=config.clip_grad, ema_momentum=config.get('ema_momentum', 0.999)), axis_name='batch', in_axes=(0, 0), donate_argnums=(0,)) p_eval_step = jax.pmap( functools.partial( eval_step, model=model), axis_name='batch') # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of train PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. rng = jax.random.fold_in(rng, jax.process_index()) rng1, rng2, rng3, extensive_eval_rngs, sample_rng = jax.random.split(rng, 5) train_rngs = jax.random.split(rng1, jax.local_device_count()) eval_rngs = jax.random.split(rng2, jax.local_device_count()) test_rngs = jax.random.split(rng3, jax.local_device_count()) del rng, rng1, rng2, rng3 logging.info('Starting training loop.') hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) if is_first_process: hooks += [ report_progress, periodic_actions.Profile(logdir=workdir, num_profile_steps=5) ] train_metrics = [] # Iterator that does epoch-wise indefinite iteration. def iterate_train(train_ds): epoch = 1 while True: msg = f'Starting epoch {epoch}' logging.info(msg) for batch in train_ds: yield batch epoch += 1 train_iter = iterate_train(train_ds) kl_tracker_train = util_fns.KLTracker(num_steps=model.num_steps) kl_history = [] with metric_writers.ensure_flushes(writer): step = start_step for step in range(start_step, config.num_train_steps): is_last_step = step == config.num_train_steps - 1 # Shard data to devices and do a training step. with jax.profiler.StepTraceAnnotation('train', step_num=step): batch = common_utils.shard(jax.tree_map(np.asarray, next(train_iter))) state, metrics = p_train_step( state, batch, rng=train_rngs) train_metrics.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step) for h in hooks: h(step) # Periodic metric handling. if step > 0 and (step % config.eval_every_steps == 0 or is_last_step): with report_progress.timed('training_metrics'): logging.info('Gathering training metrics.') train_metrics = common_utils.get_metrics(train_metrics) # First handle loss terms per step. t_batch = train_metrics.pop('t_batch') nelbo_per_t_batch = train_metrics.pop('nelbo_per_t_batch') kl_tracker_train.update( t_batch.reshape(-1), nelbo_per_t_batch.reshape(-1)) kl_values = kl_tracker_train.get_kl_per_t() kl_history.append(np.array(kl_values)) kl_history = kl_history[-100:] # Keep last 100 items only. # Handle remaining `standard` metrics summary = jax.tree_map(jnp.mean, train_metrics) summary = {'train_' + k: v for k, v in summary.items()} writer.write_scalars(step, summary) train_metrics = [] with report_progress.timed('eval'): eval_results, eval_rngs = evaluate( p_eval_step=p_eval_step, params=state.ema_params, eval_ds=eval_ds, rng=eval_rngs) writer.write_scalars( step, {'eval_' + k: v for k, v in eval_results.items()}) test_results, test_rngs = evaluate( p_eval_step=p_eval_step, params=state.ema_params, eval_ds=test_ds, rng=test_rngs) writer.write_scalars( step, {'test_' + k: v for k, v in test_results.items()}) if step == 1000 or (step > 0 and step % config.detailed_eval_every_steps == 0): if is_first_process: loss_components_path = os.path.join(workdir, 'loss_components') with tf.io.gfile.GFile(loss_components_path, 'wb') as fp: pickle.dump(kl_history[-1], fp) extensive_eval_rngs = extensive_eval( config, extensive_eval_rngs, writer, workdir, model, state, kl_history, test_ds, step, decode_tokens) with report_progress.timed('generate_text'): generate_prediction(sample_rng, config, model, state, writer, decode_tokens, step) # Save a checkpoint on one host after every checkpoint_freq steps. save_checkpoint = ( step > 0 and (step % config.checkpoint_every_steps == 0 or is_last_step)) if config.save_checkpoints and save_checkpoint and is_first_process: with report_progress.timed('checkpoint'): checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(state), step, overwrite=True)
def main(_): assert FLAGS.config.down_factor > 0 and FLAGS.config.render_factor > 0 save_dir = FLAGS.model_dir if FLAGS.save_dir is None else FLAGS.save_dir logging.info("JAX host: %d / %d", jax.process_index(), jax.host_count()) logging.info("JAX local devices: %r", jax.local_devices()) rng = jax.random.PRNGKey(FLAGS.seed) rng, rng_coarse, rng_fine = jax.random.split(rng, 3) ### Load dataset and data values datasets, counts, optics, render_datasets = get_dataset( FLAGS.data_dir, FLAGS.config, num_poses=FLAGS.config.num_poses) train_ds, val_ds, test_ds = datasets train_items, val_items, test_items = counts hwf, r_hwf, near, far = optics render_ds, render_vdirs_ds, num_poses = render_datasets logging.info("Num poses: %d", num_poses) logging.info("Splits: train - %d, val - %d, test - %d", *counts) logging.info("Images: height %d, width %d, focal %.5f", *hwf) logging.info("Render: height %d, width %d, focal %.5f", *r_hwf) ### Init model parameters and optimizer initialized_ = functools.partial(initialized, model_config=FLAGS.config.model) pts_shape = (FLAGS.config.num_rand, FLAGS.config.num_samples, 3) views_shape = (FLAGS.config.num_rand, 3) model_coarse, params_coarse = initialized_(rng_coarse, pts_shape, views_shape) schedule_fn = optax.exponential_decay( init_value=FLAGS.config.learning_rate, transition_steps=FLAGS.config.lr_decay * 1000, decay_rate=FLAGS.config.decay_factor, ) tx = optax.adam(learning_rate=schedule_fn) state = train_state.TrainState.create(apply_fn=(model_coarse.apply, None), params={"coarse": params_coarse}, tx=tx) if FLAGS.config.num_importance > 0: pts_shape = ( FLAGS.config.num_rand, FLAGS.config.num_importance + FLAGS.config.num_samples, 3, ) model_fine, params_fine = initialized_(rng_fine, pts_shape, views_shape) state = train_state.TrainState.create( apply_fn=(model_coarse.apply, model_fine.apply), params={ "coarse": params_coarse, "fine": params_fine }, tx=tx, ) state = checkpoints.restore_checkpoint(FLAGS.model_dir, state) step = int(state.step) state = jax.device_put_replicated(state, jax.local_devices()) # TODO: TPU Colab breaks without message if this is a list # a list is preferred bc tqdm can show an ETA render_dict = { "train": zip(range(train_items), train_ds), "val": zip(range(val_items), val_ds), "test": zip(range(test_items), test_ds), "poses": zip(range(num_poses), render_ds), } render_poses = render_dict[FLAGS.render_video_set] def render_fn(state, rays): step_fn = functools.partial(eval_step, FLAGS.config, near, far, state) return lax.map(step_fn, rays) p_eval_step = jax.pmap( render_fn, axis_name="batch", # in_axes=(0, 0, None), # donate_argnums=(0, 1)) ) if FLAGS.render_video: rgb_list = [] disp_list = [] losses = [] for _, inputs in tqdm(render_poses, desc="Rays render"): rays, padding = prepare_render_data(inputs["rays"].numpy()) preds, *_ = p_eval_step(state, rays) preds = jax.tree_map(lambda x: to_np(x, r_hwf, padding), preds) rgb_list.append(preds["rgb"]) disp_list.append(preds["disp"]) if FLAGS.config.render_factor == 1 and FLAGS.render_video_set != "render": loss = np.mean((preds["rgb"] - inputs["image"])**2.0) losses.append(loss) if FLAGS.config.render_factor == 1 and FLAGS.render_video_set != "render": loss = np.mean(losses) logging.info("Loss %.5f", loss) logging.info("PSNR %.5f", psnr_fn(loss)) gen_video(save_dir, np.stack(rgb_list), "rgb", r_hwf, step) disp = np.stack(disp_list) gen_video(save_dir, disp_post(disp, FLAGS.config), "disp", r_hwf, step, ch=1) if FLAGS.render_testset: test_losses = [] for idx, inputs in tqdm(zip(range(test_items), test_ds), desc="Test render"): rays, padding = prepare_render_data(inputs["rays"].numpy()) preds, *_ = p_eval_step(state, rays) preds = jax.tree_map(lambda x: to_np(x, r_hwf, padding), preds) save_test_imgs(save_dir, preds["rgb"], r_hwf, step, idx) if FLAGS.config.render_factor == 1: loss = np.mean((preds["rgb"] - inputs["image"])**2.0) test_losses.append(loss) if FLAGS.config.render_factor == 1: loss = np.mean(test_losses) logging.info("Loss %.5f", loss) logging.info("PSNR %.5f", psnr_fn(loss))
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') tf.enable_v2_behavior() # make sure tf does not allocate gpu memory tf.config.experimental.set_visible_devices([], 'GPU') if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(FLAGS.model_dir) image_size = 224 batch_size = FLAGS.batch_size if batch_size % jax.device_count() > 0: raise ValueError('Batch size must be divisible by the number of devices') local_batch_size = batch_size // jax.host_count() device_batch_size = batch_size // jax.device_count() platform = jax.local_devices()[0].platform dynamic_scale = None if FLAGS.half_precision: if platform == 'tpu': model_dtype = jnp.bfloat16 input_dtype = tf.bfloat16 else: model_dtype = jnp.float16 input_dtype = tf.float16 dynamic_scale = optim.DynamicScale() else: model_dtype = jnp.float32 input_dtype = tf.float32 train_iter = imagenet_train_utils.create_input_iter( local_batch_size, FLAGS.data_dir, image_size, input_dtype, train=True, cache=FLAGS.cache) eval_iter = imagenet_train_utils.create_input_iter( local_batch_size, FLAGS.data_dir, image_size, input_dtype, train=False, cache=FLAGS.cache) # Create the hyperparameter object if FLAGS.hparams_config_dict: # In this case, there are multiple training configs defined in the config # dict, so we pull out the one this training run should use. if 'configs' in FLAGS.hparams_config_dict: hparams_config_dict = FLAGS.hparams_config_dict.configs[FLAGS.config_idx] else: hparams_config_dict = FLAGS.hparams_config_dict hparams = os_hparams_utils.load_hparams_from_config_dict( hparams_config.TrainingHParams, models.ResNet.HParams, hparams_config_dict) else: raise ValueError('Please provide a base config dict.') os_hparams_utils.write_hparams_to_file_with_host_id_check( hparams, FLAGS.model_dir) # get num_epochs from hparam instead of FLAGS num_epochs = hparams.lr_scheduler.num_epochs steps_per_epoch = input_pipeline.TRAIN_IMAGES // batch_size steps_per_eval = input_pipeline.EVAL_IMAGES // batch_size steps_per_checkpoint = steps_per_epoch * 10 num_steps = steps_per_epoch * num_epochs # Estimate compute / memory costs if jax.host_id() == 0 and FLAGS.estimate_compute_and_memory_cost: estimate_compute_and_memory_cost( image_size=image_size, model_dir=FLAGS.model_dir, hparams=hparams) logging.info('Writing training HLO and estimating compute/memory costs.') rng = random.PRNGKey(hparams.seed) model, variables = imagenet_train_utils.create_model( rng, device_batch_size, image_size, model_dtype, hparams=hparams.model_hparams, train=True, is_teacher=hparams.is_teacher) # pylint: disable=g-long-lambda if hparams.teacher_model == 'resnet50-8bit': teacher_config = w8a8auto_paper_config() teacher_hparams = os_hparams_utils.load_hparams_from_config_dict( hparams_config.TrainingHParams, models.ResNet.HParams, teacher_config) teacher_model, _ = imagenet_train_utils.create_model( rng, device_batch_size, image_size, model_dtype, hparams=teacher_hparams.model_hparams, train=False, is_teacher=True) # teacher model does not need to be trainable # Directory where checkpoints are saved ckpt_model_dir = FLAGS.resnet508b_ckpt_path # will restore to best checkpoint state_load = checkpoints.restore_checkpoint(ckpt_model_dir, None) teacher_variables = {'params': state_load['optimizer']['target']} teacher_variables.update(state_load['model_state']) # create a dictionary for better argument passing teacher = { 'model': lambda var, img, labels: jax.nn.softmax( teacher_model.apply(var, img)), 'variables': teacher_variables, } elif hparams.teacher_model == 'labels': teacher = { 'model': lambda var, img, labels: common_utils.onehot( labels, num_classes=1000), 'variables': {}, # no need of variables in this case } else: raise ValueError('The specified teacher model is not supported.') model_state, params = variables.pop('params') if hparams.optimizer == 'sgd': optimizer = optim.Momentum( beta=hparams.momentum, nesterov=True).create(params) elif hparams.optimizer == 'adam': optimizer = optim.Adam( beta1=hparams.adam.beta1, beta2=hparams.adam.beta2).create(params) else: raise ValueError('Optimizer type is not supported.') state = imagenet_train_utils.TrainState( step=0, optimizer=optimizer, model_state=model_state, dynamic_scale=dynamic_scale) del params, model_state # do not keep a copy of the initial model state = restore_checkpoint(state) step_offset = int(state.step) # step_offset > 0 if restarting from checkpoint state = jax_utils.replicate(state) base_learning_rate = hparams.base_learning_rate * batch_size / 256. learning_rate_fn = create_learning_rate_fn(base_learning_rate, steps_per_epoch, hparams.lr_scheduler, batch_size) p_train_step = jax.pmap( functools.partial( imagenet_train_utils.train_step, model, learning_rate_fn=learning_rate_fn, teacher=teacher), axis_name='batch', static_broadcasted_argnums=(2, 3, 4)) p_eval_step = jax.pmap( functools.partial(imagenet_train_utils.eval_step, model), axis_name='batch', static_broadcasted_argnums=(2,)) epoch_metrics = [] state_dict_summary_all = [] state_dict_keys = _get_state_dict_keys_from_flags() t_loop_start = time.time() last_log_step = 0 for step, batch in zip(range(step_offset, num_steps), train_iter): if hparams.early_stop_steps >= 0 and step > hparams.early_stop_steps * steps_per_epoch: break update_bounds = train_utils.should_update_bounds( hparams.activation_bound_update_freq, hparams.activation_bound_start_step, step) # and pass the result bool value to p_train_step # The function should take hparams.weight_quant_start_step as inputs quantize_weights = train_utils.should_quantize_weights( hparams.weight_quant_start_step, step // steps_per_epoch) state, metrics = p_train_step(state, batch, hparams, update_bounds, quantize_weights) state_dict_summary = summary_utils.get_state_dict_summary( state.model_state, state_dict_keys) state_dict_summary_all.append(state_dict_summary) epoch_metrics.append(metrics) def should_log(step): epoch_no = step // steps_per_epoch step_in_epoch = step - epoch_no * steps_per_epoch do_log = False do_log = do_log or (step + 1 == num_steps) # log at the end end_of_train = step / num_steps > 0.9 do_log = do_log or ((step_in_epoch % (steps_per_epoch // 4) == 0) and not end_of_train) do_log = do_log or ((step_in_epoch % (steps_per_epoch // 16) == 0) and end_of_train) return do_log if should_log(step): epoch = step // steps_per_epoch epoch_metrics = common_utils.get_metrics(epoch_metrics) summary = jax.tree_map(lambda x: x.mean(), epoch_metrics) logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) steps_per_sec = (step - last_log_step) / (time.time() - t_loop_start) last_log_step = step t_loop_start = time.time() # Write to TensorBoard state_dict_summary_all = common_utils.get_metrics(state_dict_summary_all) if jax.host_id() == 0: for key, vals in epoch_metrics.items(): tag = 'train_%s' % key for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) summary_writer.scalar('steps per second', steps_per_sec, step) if FLAGS.write_summary: summary_utils.write_state_dict_summaries_to_tb( state_dict_summary_all, summary_writer, FLAGS.state_dict_summary_freq, step) state_dict_summary_all = [] epoch_metrics = [] eval_metrics = [] # sync batch statistics across replicas state = imagenet_train_utils.sync_batch_stats(state) for _ in range(steps_per_eval): eval_batch = next(eval_iter) metrics = p_eval_step(state, eval_batch, quantize_weights) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) summary = jax.tree_map(lambda x: x.mean(), eval_metrics) logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) if jax.host_id() == 0: for key, val in eval_metrics.items(): tag = 'eval_%s' % key summary_writer.scalar(tag, val.mean(), step) summary_writer.flush() if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps: state = imagenet_train_utils.sync_batch_stats(state) save_checkpoint(state) # Wait until computations are done before exiting jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
def main(unused_argv): rng = random.PRNGKey(20200823) if FLAGS.config is not None: utils.update_flags(FLAGS) if FLAGS.train_dir is None: raise ValueError("train_dir must be set. None set now.") if FLAGS.data_dir is None: raise ValueError("data_dir must be set. None set now.") dataset = datasets.get_dataset("test", FLAGS) rng, key = random.split(rng) model, init_variables = models.get_model(key, dataset.peek(), FLAGS) optimizer = flax.optim.Adam(FLAGS.lr_init).create(init_variables) state = utils.TrainState(optimizer=optimizer) del optimizer, init_variables # Rendering is forced to be deterministic even if training was randomized, as # this eliminates "speckle" artifacts. def render_fn(variables, key_0, key_1, rays): return jax.lax.all_gather( model.apply(variables, key_0, key_1, rays, False), axis_name="batch") # pmap over only the data input. render_pfn = jax.pmap( render_fn, in_axes=(None, None, None, 0), donate_argnums=3, axis_name="batch", ) # Compiling to the CPU because it's faster and more accurate. ssim_fn = jax.jit( functools.partial(utils.compute_ssim, max_val=1.), backend="cpu") last_step = 0 out_dir = path.join(FLAGS.train_dir, "path_renders" if FLAGS.render_path else "test_preds") if not FLAGS.eval_once: summary_writer = tensorboard.SummaryWriter( path.join(FLAGS.train_dir, "eval")) while True: state = checkpoints.restore_checkpoint(FLAGS.train_dir, state) step = int(state.optimizer.state.step) if step <= last_step: continue if FLAGS.save_output and (not utils.isdir(out_dir)): utils.makedirs(out_dir) psnrs = [] ssims = [] if not FLAGS.eval_once: showcase_index = np.random.randint(0, dataset.size) for idx in range(dataset.size): print(f"Evaluating {idx+1}/{dataset.size}") batch = next(dataset) pred_color, pred_disp, pred_acc = utils.render_image( functools.partial(render_pfn, state.optimizer.target), batch["rays"], rng, FLAGS.dataset == "llff", chunk=FLAGS.chunk) if jax.host_id() != 0: # Only record via host 0. continue if not FLAGS.eval_once and idx == showcase_index: showcase_color = pred_color showcase_disp = pred_disp showcase_acc = pred_acc if not FLAGS.render_path: showcase_gt = batch["pixels"] if not FLAGS.render_path: psnr = utils.compute_psnr(((pred_color - batch["pixels"])**2).mean()) ssim = ssim_fn(pred_color, batch["pixels"]) print(f"PSNR = {psnr:.4f}, SSIM = {ssim:.4f}") psnrs.append(float(psnr)) ssims.append(float(ssim)) if FLAGS.save_output: utils.save_img(pred_color, path.join(out_dir, "{:03d}.png".format(idx))) utils.save_img(pred_disp[Ellipsis, 0], path.join(out_dir, "disp_{:03d}.png".format(idx))) if (not FLAGS.eval_once) and (jax.host_id() == 0): summary_writer.image("pred_color", showcase_color, step) summary_writer.image("pred_disp", showcase_disp, step) summary_writer.image("pred_acc", showcase_acc, step) if not FLAGS.render_path: summary_writer.scalar("psnr", np.mean(np.array(psnrs)), step) summary_writer.scalar("ssim", np.mean(np.array(ssims)), step) summary_writer.image("target", showcase_gt, step) if FLAGS.save_output and (not FLAGS.render_path) and (jax.host_id() == 0): with utils.open_file(path.join(out_dir, f"psnrs_{step}.txt"), "w") as f: f.write(" ".join([str(v) for v in psnrs])) with utils.open_file(path.join(out_dir, f"ssims_{step}.txt"), "w") as f: f.write(" ".join([str(v) for v in ssims])) with utils.open_file(path.join(out_dir, "psnr.txt"), "w") as f: f.write("{}".format(np.mean(np.array(psnrs)))) with utils.open_file(path.join(out_dir, "ssim.txt"), "w") as f: f.write("{}".format(np.mean(np.array(ssims)))) if FLAGS.eval_once: break if int(step) >= FLAGS.max_steps: break last_step = step
def train_and_evaluate(config, workdir, vocab_filepath, random_seed = 0): """Runs a training and evaluation loop. Args: config: Model and training configuration. workdir: Working directory for checkpoints and Tensorboard summaries. If this contains a checkpoint, training will be resumed from the latest checkpoint. vocab_filepath: Absolute path to SentencePiece vocab model. random_seed: Random number generator seed. Raises: ValueError: If training or eval batch sizes won't fit number of processes and devices, or config is underspecified. """ n_processes = jax.process_count() # Number of processes n_devices = jax.local_device_count() # Number of local devices per process if config.train_batch_size % (n_processes * n_devices) > 0: raise ValueError( "Training batch size must be divisible by the total number of devices, " "but training batch size = %d, while total number of devices = %d " "(%d processes, each with %d devices)" % (config.train_batch_size, n_processes * n_devices, n_processes, n_devices)) if config.eval_batch_size % (n_processes * n_devices) > 0: raise ValueError( "Eval batch size must be divisible by the total number of devices, " "but eval batch size = %d, while total number of devices = %d " "(%d processes, each with %d devices)" % (config.eval_batch_size, n_processes * n_devices, n_processes, n_devices)) per_process_train_batch_size = config.train_batch_size // n_processes per_process_eval_batch_size = config.eval_batch_size // n_processes if jax.process_index() == 0: train_summary_writer = tensorboard.SummaryWriter( os.path.join(workdir, "train")) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(workdir, "eval")) else: train_summary_writer = None eval_summary_writer = None rng = random.PRNGKey(random_seed) rng, init_rng = random.split(rng) tokenizer = spm.SentencePieceProcessor() tokenizer.Load(vocab_filepath) ds_info = tfds.builder(config.dataset_name).info num_train_examples = ds_info.splits[tfds.Split.TRAIN].num_examples num_train_steps = int(num_train_examples * config.num_train_epochs // config.train_batch_size) num_warmup_steps = int(config.warmup_proportion * num_train_steps) # Round up evaluation frequency to power of 10. eval_frequency = int( math.ceil(config.eval_proportion * num_train_steps / 10)) * 10 is_regression_task = config.dataset_name == "glue/stsb" num_classes = (1 if is_regression_task else ds_info.features["label"].num_classes) with config.unlocked(): config.vocab_size = tokenizer.GetPieceSize() frozen_config = ml_collections.FrozenConfigDict(config) model = models.SequenceClassificationModel( config=frozen_config, n_classes=num_classes) params = _init_params(model, init_rng, config) optimizer = _create_adam_optimizer(config.learning_rate, params) # In case current job restarts, ensure that we continue from where we left # off. optimizer = checkpoints.restore_checkpoint(workdir, optimizer) start_step = int(optimizer.state.step) # Otherwise, try to restore optimizer and model state from config checkpoint. if (start_step == 0 and "init_checkpoint_dir" in config and config.init_checkpoint_dir): optimizer = _restore_pretrained_model(optimizer, params, config) # We access model state only from optimizer via optimizer.target. del params optimizer = jax_utils.replicate(optimizer) if is_regression_task: compute_stats = functools.partial(_compute_regression_stats, model=model) else: compute_stats = functools.partial( _compute_classification_stats, model=model) learning_rate_fn = train_utils.create_learning_rate_scheduler( factors="constant * linear_warmup * linear_decay", base_learning_rate=config.learning_rate, warmup_steps=num_warmup_steps, decay_steps=num_train_steps - num_warmup_steps, ) glue_inputs = functools.partial( input_pipeline.glue_inputs, dataset_name=config.dataset_name, max_seq_length=config.max_seq_length, tokenizer=tokenizer) train_iter = glue_inputs( split=tfds.Split.TRAIN, batch_size=per_process_train_batch_size, training=True) if config.dataset_name == "glue/mnli": # MNLI contains two validation and test datasets. split_suffixes = ["_matched", "_mismatched"] else: split_suffixes = [""] # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. rngs = random.split(rng, n_devices) loss_and_metrics_fn = functools.partial( _compute_loss_and_metrics, model=model) p_train_step = jax.pmap( functools.partial( train_utils.train_step, loss_and_metrics_fn=loss_and_metrics_fn, learning_rate_fn=learning_rate_fn), axis_name="batch") p_eval_step = jax.pmap( functools.partial(train_utils.eval_step, metric_fn=compute_stats), axis_name="batch") eval_metrics_fn = _create_eval_metrics_fn(config.dataset_name, is_regression_task) train_metrics = [] seconds = 0.0 logging.info("Starting training loop.") logging.info("====================") for step, train_batch in zip(range(start_step, num_train_steps), train_iter): train_batch = common_utils.shard(train_batch) curr_time = time.time() optimizer, train_step_metrics, rngs = p_train_step( optimizer, train_batch, rng=rngs) seconds += time.time() - curr_time train_metrics.append(train_step_metrics) if ((step > 0 and config.save_checkpoints_steps and step % config.save_checkpoints_steps == 0) or step == num_train_steps - 1) and jax.process_index() == 0: # Save un-replicated optimizer and model state. checkpoints.save_checkpoint( workdir, jax_utils.unreplicate(optimizer), step, keep=2) # Periodic metric handling. if step % eval_frequency != 0 and step < num_train_steps - 1: continue logging.info("Gathering training metrics at step: %d", step) train_metrics = common_utils.get_metrics(train_metrics) train_summary = { "loss": jnp.sum(train_metrics["loss"]) / jnp.sum(train_metrics["num_labels"]), "learning_rate": learning_rate_fn(step) } if not is_regression_task: train_summary["accuracy"] = jnp.sum( train_metrics["correct_predictions"]) / jnp.sum( train_metrics["num_labels"]) if jax.process_index() == 0: assert train_summary_writer steps_per_sec = (step - start_step + 1) / seconds train_summary_writer.scalar("steps per second", steps_per_sec, step) for key, val in train_summary.items(): train_summary_writer.scalar(key, val, step) train_summary_writer.flush() # Reset metric accumulation for next evaluation cycle. train_metrics = [] logging.info("Gathering validation metrics at step: %d", step) for split_suffix in split_suffixes: eval_iter = glue_inputs( split=tfds.Split.VALIDATION + split_suffix, batch_size=per_process_eval_batch_size, training=False) eval_metrics = [] for _, eval_batch in zip(range(config.max_num_eval_steps), eval_iter): eval_metrics.append( _compute_eval_metrics(p_eval_step, optimizer.target, eval_batch, n_devices)) if eval_metrics: eval_metrics = common_utils.get_metrics(eval_metrics) eval_summary = eval_metrics_fn(eval_metrics) if jax.process_index() == 0: assert eval_summary_writer for key, val in eval_summary.items(): eval_summary_writer.scalar(f"{key}{split_suffix}", val, step) eval_summary_writer.flush() else: logging.warning("Error gathering eval_metrics at step: %d", step)
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ tf.io.gfile.makedirs(workdir) # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: tf.io.gfile.makedirs(workdir) train_summary_writer = tensorboard.SummaryWriter( os.path.join(workdir, "train")) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(workdir, "eval")) if config.batch_size % n_devices: raise ValueError("Batch size must be divisible by the number of devices") vocab_path = config.vocab_path if vocab_path is None: vocab_path = os.path.join(workdir, "sentencepiece_model") tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) # Load Dataset # --------------------------------------------------------------------------- logging.info("Initializing dataset.") train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets( n_devices=n_devices, dataset_name=config.dataset_name, eval_dataset_name=config.eval_dataset_name, shard_idx=jax.host_id(), shard_count=jax.host_count(), vocab_path=vocab_path, target_vocab_size=config.vocab_size, batch_size=config.batch_size, max_corpus_chars=config.max_corpus_chars, max_length=config.max_target_length, max_eval_length=config.max_eval_target_length) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_id = decode.EOS_ID # Default Sentencepiece EOS token. def decode_tokens(toks): valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode("utf-8") if config.num_predict_steps > 0: predict_ds = predict_ds.take(config.num_predict_steps) logging.info("Initializing model, optimizer, and step functions.") # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = models.TransformerConfig( vocab_size=vocab_size, output_vocab_size=vocab_size, share_embeddings=config.share_embeddings, logits_via_embedding=config.logits_via_embedding, dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, emb_dim=config.emb_dim, num_heads=config.num_heads, num_layers=config.num_layers, qkv_dim=config.qkv_dim, mlp_dim=config.mlp_dim, max_len=max(config.max_target_length, config.max_eval_target_length), dropout_rate=config.dropout_rate, attention_dropout_rate=config.attention_dropout_rate, deterministic=False, decode=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(deterministic=True, decode=True) start_step = 0 rng = random.PRNGKey(config.seed) rng, init_rng = random.split(rng) input_shape = (config.batch_size, config.max_target_length) target_shape = (config.batch_size, config.max_target_length) m = models.Transformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) # apply an optimizer to this tree optimizer_def = optim.Adam( config.learning_rate, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=config.weight_decay) optimizer = optimizer_def.create(initial_variables["params"]) # We access model params only from optimizer below via optimizer.target. del initial_variables if config.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint(workdir, optimizer) # Grab last step. start_step = int(optimizer.state.step) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps) # compile multidevice versions of train/eval/predict step and cache init fn. p_train_step = jax.pmap( functools.partial( train_step, config=train_config, learning_rate_fn=learning_rate_fn, label_smoothing=config.label_smoothing), axis_name="batch", donate_argnums=(0,)) # pytype: disable=wrong-arg-types p_eval_step = jax.pmap( functools.partial( eval_step, config=eval_config, label_smoothing=config.label_smoothing), axis_name="batch") p_init_cache = jax.pmap( functools.partial( initialize_cache, max_decode_len=config.max_predict_length, config=predict_config), axis_name="batch") p_pred_step = jax.pmap( functools.partial( predict_step, config=predict_config, beam_size=config.beam_size), axis_name="batch", static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap"d training update for performance. dropout_rngs = random.split(rng, n_devices) logging.info("Starting training loop.") metrics_all = [] t_loop_start = time.time() for step, batch in zip(range(start_step, config.num_train_steps), train_iter): # Shard data to devices and do a training step. batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access optimizer, metrics, dropout_rngs = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) metrics_all.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) # Save a checkpoint on one host after every checkpoint_freq steps. if (config.save_checkpoints and step % config.checkpoint_freq == 0 and step > 0 and jax.host_id() == 0): checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if step % config.eval_frequency != 0 and step > 0: continue # Training Metrics logging.info("Gathering training metrics.") metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop("learning_rate").mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop("denominator") summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary["learning_rate"] = lr steps_per_eval = config.eval_frequency if step != 0 else 1 steps_per_sec = steps_per_eval / (time.time() - t_loop_start) t_loop_start = time.time() if jax.host_id() == 0: train_summary_writer.scalar("steps per second", steps_per_sec, step) for key, val in summary.items(): train_summary_writer.scalar(key, val, step) train_summary_writer.flush() metrics_all = [] logging.info("train in step: %d, loss: %.4f", step, summary["loss"]) # Eval Metrics logging.info("Gathering evaluation metrics.") t_eval_start = time.time() eval_metrics = [] eval_iter = iter(eval_ds) for _, eval_batch in zip(range(config.num_eval_steps), eval_iter): eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access eval_batch = common_utils.shard(eval_batch) metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop("denominator") eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) if jax.host_id() == 0: for key, val in eval_summary.items(): eval_summary_writer.scalar(key, val, step) eval_summary_writer.flush() logging.info("eval in step: %d, loss: %.4f", step, eval_summary["loss"]) logging.info("eval time: %.4f s step %d", time.time() - t_eval_start, step) # Translation and BLEU Score. logging.info("Translating evaluation dataset.") t_inference_start = time.time() sources, references, predictions = [], [], [] for pred_batch in predict_ds: pred_batch = jax.tree_map(lambda x: x._numpy(), pred_batch) # pylint: disable=protected-access # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch["inputs"].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) pred_batch = jax.tree_map( lambda x: pad_examples(x, padded_size), pred_batch) # pylint: disable=cell-var-from-loop pred_batch = common_utils.shard(pred_batch) cache = p_init_cache(pred_batch["inputs"]) predicted = p_pred_step(pred_batch["inputs"], optimizer.target, cache, eos_id, config.max_predict_length) predicted = tohost(predicted) inputs = tohost(pred_batch["inputs"]) targets = tohost(pred_batch["targets"]) # Iterate through non-padding examples of batch. for i, s in enumerate(predicted[:cur_pred_batch_size]): sources.append(decode_tokens(inputs[i])) references.append(decode_tokens(targets[i])) predictions.append(decode_tokens(s)) logging.info("Translation: %d predictions %d references %d sources.", len(predictions), len(references), len(sources)) logging.info("Translation time: %.4f s step %d.", time.time() - t_inference_start, step) # Calculate BLEU score for translated eval corpus against reference. bleu_matches = bleu.bleu_partial(references, predictions) all_bleu_matches = per_host_sum_pmap(bleu_matches) bleu_score = bleu.complete_bleu(*all_bleu_matches) # Save translation samples for tensorboard. exemplars = "" for n in np.random.choice(np.arange(len(predictions)), 8): exemplars += f"{sources[n]}\n\n{references[n]}\n\n{predictions[n]}\n\n" if jax.host_id() == 0: eval_summary_writer.scalar("bleu", bleu_score, step) eval_summary_writer.text("samples", exemplars, step) eval_summary_writer.flush() logging.info("Translation BLEU Score %.4f", bleu_score)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') tf.enable_v2_behavior() config = FLAGS.config logging.info('===========Config Dict============') logging.info(config) batch_size = config.batch_size learning_rate = config.learning_rate num_train_steps = config.num_train_steps num_eval_steps = config.num_eval_steps eval_freq = config.eval_frequency random_seed = config.random_seed model_type = config.model_type max_length = config.max_length if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'summary')) if batch_size % jax.device_count() > 0: raise ValueError( 'Batch size must be divisible by the number of devices') train_ds, eval_ds, test_ds, encoder = input_pipeline.get_tc_datasets( n_devices=jax.local_device_count(), task_name=FLAGS.task_name, data_dir=FLAGS.data_dir, batch_size=batch_size, fixed_vocab=None, max_length=max_length) vocab_size = encoder.vocab_size logging.info('Vocab Size: %d', vocab_size) train_ds = train_ds.repeat() train_iter = iter(train_ds) input_shape = (batch_size, max_length) model_kwargs = { 'vocab_size': vocab_size, 'emb_dim': config.emb_dim, 'num_heads': config.num_heads, 'num_layers': config.num_layers, 'qkv_dim': config.qkv_dim, 'mlp_dim': config.mlp_dim, 'max_len': max_length, 'classifier': True, 'num_classes': CLASS_MAP[FLAGS.task_name], 'classifier_pool': config.classifier_pool } rng = random.PRNGKey(random_seed) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = random.split(rng) # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. dropout_rngs = random.split(rng, jax.local_device_count()) if model_type == 'transformer': model = create_model(init_rng, transformer.TransformerEncoder, input_shape, model_kwargs) else: raise ValueError('Model type not supported') optimizer = create_optimizer(model, learning_rate, weight_decay=FLAGS.config.weight_decay) del model # Don't keep a copy of the initial model. start_step = 0 if config.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer) # Grab last step. start_step = int(optimizer.state.step) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = train_utils.create_learning_rate_scheduler( factors=config.factors, base_learning_rate=learning_rate, warmup_steps=config.warmup) p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn), axis_name='batch') p_eval_step = jax.pmap(eval_step, axis_name='batch') # p_pred_step = jax.pmap(predict_step, axis_name='batch') metrics_all = [] tick = time.time() logging.info('Starting training') logging.info('====================') for step, batch in zip(range(start_step, num_train_steps), train_iter): batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access optimizer, metrics, dropout_rngs = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) metrics_all.append(metrics) logging.info('train in step: %d', step) # Save a Checkpoint if ((step % config.checkpoint_freq == 0 and step > 0) or step == num_train_steps - 1): if jax.host_id() == 0 and config.save_checkpoints: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint(FLAGS.model_dir, jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if step % eval_freq == 0 and step > 0: metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) logging.info('train in step: %d, loss: %.4f, acc: %.4f', step, summary['loss'], summary['accuracy']) if jax.host_id() == 0: tock = time.time() steps_per_sec = eval_freq / (tock - tick) tick = tock summary_writer.scalar('steps per second', steps_per_sec, step) for key, val in summary.items(): summary_writer.scalar(f'train_{key}', val, step) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] # Eval Metrics eval_metrics = [] eval_iter = iter(eval_ds) if num_eval_steps == -1: num_iter = itertools.repeat(1) else: num_iter = range(num_eval_steps) for _, eval_batch in zip(num_iter, eval_iter): # pylint: disable=protected-access eval_batch = common_utils.shard( jax.tree_map(lambda x: x._numpy(), eval_batch)) # pylint: enable=protected-access metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) # Calculate (clipped) perplexity after averaging log-perplexities: eval_summary['perplexity'] = jnp.clip(jnp.exp( eval_summary['loss']), a_max=1.0e4) logging.info('eval in step: %d, loss: %.4f, acc: %.4f', step, eval_summary['loss'], eval_summary['accuracy']) if jax.host_id() == 0: for key, val in eval_summary.items(): summary_writer.scalar(f'eval_{key}', val, step) summary_writer.flush()
def main(unused_argv): rng = random.PRNGKey(20200823) # Shift the numpy random seed by host_id() to shuffle data loaded by different # hosts. np.random.seed(20201473 + jax.host_id()) if FLAGS.config is not None: utils.update_flags(FLAGS) if FLAGS.batch_size % jax.device_count() != 0: raise ValueError( "Batch size must be divisible by the number of devices.") if FLAGS.train_dir is None: raise ValueError("train_dir must be set. None set now.") if FLAGS.data_dir is None: raise ValueError("data_dir must be set. None set now.") dataset = datasets.get_dataset("train", FLAGS) test_dataset = datasets.get_dataset("test", FLAGS) rng, key = random.split(rng) model, variables = models.get_model(key, dataset.peek(), FLAGS) optimizer = flax.optim.Adam(FLAGS.lr_init).create(variables) state = utils.TrainState(optimizer=optimizer) del optimizer, variables learning_rate_fn = functools.partial(utils.learning_rate_decay, lr_init=FLAGS.lr_init, lr_final=FLAGS.lr_final, max_steps=FLAGS.max_steps, lr_delay_steps=FLAGS.lr_delay_steps, lr_delay_mult=FLAGS.lr_delay_mult) train_pstep = jax.pmap(functools.partial(train_step, model), axis_name="batch", in_axes=(0, 0, 0, None), donate_argnums=(2, )) def render_fn(variables, key_0, key_1, rays): return jax.lax.all_gather(model.apply(variables, key_0, key_1, rays, FLAGS.randomized), axis_name="batch") render_pfn = jax.pmap( render_fn, in_axes=(None, None, None, 0), # Only distribute the data input. donate_argnums=(3, ), axis_name="batch", ) # Compiling to the CPU because it's faster and more accurate. ssim_fn = jax.jit(functools.partial(utils.compute_ssim, max_val=1.), backend="cpu") if not utils.isdir(FLAGS.train_dir): utils.makedirs(FLAGS.train_dir) state = checkpoints.restore_checkpoint(FLAGS.train_dir, state) # Resume training a the step of the last checkpoint. init_step = state.optimizer.state.step + 1 state = flax.jax_utils.replicate(state) if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(FLAGS.train_dir) # Prefetch_buffer_size = 3 x batch_size pdataset = flax.jax_utils.prefetch_to_device(dataset, 3) n_local_devices = jax.local_device_count() rng = rng + jax.host_id() # Make random seed separate across hosts. keys = random.split(rng, n_local_devices) # For pmapping RNG keys. gc.disable() # Disable automatic garbage collection for efficiency. stats_trace = [] reset_timer = True for step, batch in zip(range(init_step, FLAGS.max_steps + 1), pdataset): if reset_timer: t_loop_start = time.time() reset_timer = False lr = learning_rate_fn(step) state, stats, keys = train_pstep(keys, state, batch, lr) if jax.host_id() == 0: stats_trace.append(stats) if step % FLAGS.gc_every == 0: gc.collect() # Log training summaries. This is put behind a host_id check because in # multi-host evaluation, all hosts need to run inference even though we # only use host 0 to record results. if jax.host_id() == 0: if step % FLAGS.print_every == 0: summary_writer.scalar("train_loss", stats.loss[0], step) summary_writer.scalar("train_psnr", stats.psnr[0], step) summary_writer.scalar("train_sparsity", stats.sparsity[0], step) summary_writer.scalar("train_loss_coarse", stats.loss_c[0], step) summary_writer.scalar("train_psnr_coarse", stats.psnr_c[0], step) summary_writer.scalar("train_sparsity_coarse", stats.sparsity_c[0], step) summary_writer.scalar("weight_l2", stats.weight_l2[0], step) avg_loss = np.mean( np.concatenate([s.loss for s in stats_trace])) avg_psnr = np.mean( np.concatenate([s.psnr for s in stats_trace])) stats_trace = [] summary_writer.scalar("train_avg_loss", avg_loss, step) summary_writer.scalar("train_avg_psnr", avg_psnr, step) summary_writer.scalar("learning_rate", lr, step) steps_per_sec = FLAGS.print_every / (time.time() - t_loop_start) reset_timer = True rays_per_sec = FLAGS.batch_size * steps_per_sec summary_writer.scalar("train_steps_per_sec", steps_per_sec, step) summary_writer.scalar("train_rays_per_sec", rays_per_sec, step) precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1 print(("{:" + "{:d}".format(precision) + "d}").format(step) + f"/{FLAGS.max_steps:d}: " + f"i_loss={stats.loss[0]:0.4f}, " + f"avg_loss={avg_loss:0.4f}, " + f"weight_l2={stats.weight_l2[0]:0.2e}, " + f"lr={lr:0.2e}, " + f"{rays_per_sec:0.0f} rays/sec") if step % FLAGS.save_every == 0: state_to_save = jax.device_get( jax.tree_map(lambda x: x[0], state)) checkpoints.save_checkpoint(FLAGS.train_dir, state_to_save, int(step), keep=100) # Test-set evaluation. if FLAGS.render_every > 0 and step % FLAGS.render_every == 0: # We reuse the same random number generator from the optimization step # here on purpose so that the visualization matches what happened in # training. t_eval_start = time.time() eval_variables = jax.device_get(jax.tree_map( lambda x: x[0], state)).optimizer.target test_case = next(test_dataset) (pred_color, pred_disp, pred_acc, pred_features, pred_specular) = utils.render_image(functools.partial( render_pfn, eval_variables), test_case["rays"], keys[0], FLAGS.dataset == "llff", chunk=FLAGS.chunk) # Log eval summaries on host 0. if jax.host_id() == 0: psnr = utils.compute_psnr( ((pred_color - test_case["pixels"])**2).mean()) ssim = ssim_fn(pred_color, test_case["pixels"]) eval_time = time.time() - t_eval_start num_rays = jnp.prod( jnp.array(test_case["rays"].directions.shape[:-1])) rays_per_sec = num_rays / eval_time summary_writer.scalar("test_rays_per_sec", rays_per_sec, step) print( f"Eval {step}: {eval_time:0.3f}s., {rays_per_sec:0.0f} rays/sec" ) summary_writer.scalar("test_psnr", psnr, step) summary_writer.scalar("test_ssim", ssim, step) summary_writer.image("test_pred_color", pred_color, step) summary_writer.image("test_pred_disp", pred_disp, step) summary_writer.image("test_pred_acc", pred_acc, step) summary_writer.image("test_pred_features", pred_features, step) summary_writer.image("test_pred_specular", pred_specular, step) summary_writer.image("test_target", test_case["pixels"], step) if FLAGS.max_steps % FLAGS.save_every != 0: state = jax.device_get(jax.tree_map(lambda x: x[0], state)) checkpoints.save_checkpoint(FLAGS.train_dir, state, int(FLAGS.max_steps), keep=100)
def evaluate(config, workdir): """Evalution function.""" # Hide the GPUs and TPUs from TF so it does not reserve memory on them for # LPIPS computation or dataset loading. tf.config.experimental.set_visible_devices([], "GPU") tf.config.experimental.set_visible_devices([], "TPU") rng = jax.random.PRNGKey(config.seed) #---------------------------------------------------------------------------- # Build input pipeline. rng, data_rng = jax.random.split(rng) data_rng = jax.random.fold_in(data_rng, jax.process_index()) config.dataset.data_dir = os.path.join(config.dataset.base_dir, config.dataset.scene) train_ds, test_ds = datasets.create_dataset(config) example_batch = train_ds.peek() rng, key = jax.random.split(rng) #---------------------------------------------------------------------------- # Initialize model. learning_rate_fn = train_utils.create_learning_rate_fn(config) model, state = models.create_train_state(config, key, learning_rate_fn=learning_rate_fn, example_batch=example_batch) #---------------------------------------------------------------------------- # Get the rendering function. Renderig is forced ot be deterministic even if # trainin is randomized render_pfn = render_utils.get_render_function(model, config, randomized=False) last_step = 0 out_dir = os.path.join( workdir, "path_renders" if config.dataset.render_path else "test_preds") if not config.eval.eval_once: # Prepare Metric Writers summary_writer = tensorboard.SummaryWriter( os.path.join(workdir, "eval")) while True: state = checkpoints.restore_checkpoint(workdir, state) step = int(state.step) if step <= last_step: continue if config.eval.save_output and (not file_utils.isdir(out_dir)): file_utils.makedirs(out_dir) psnr_values = [] ssim_values = [] if not config.eval.eval_once: showcase_index = np.random.randint(0, test_ds.size) for idx in range(test_ds.size): logging.info("Evaluating [%d / %d].", idx, test_ds.size) batch = next(test_ds) test_pixels = batch.target_view.rgb if test_pixels is not None: test_pixels = model_utils.uint2float(test_pixels) #----------------------------------------------------------- # Render Image variables = {"params": state.params} pred_color, pred_disp, pred_acc = render_utils.render_image( functools.partial(render_pfn, variables), batch, rng, render_utils.normalize_disp(config.dataset.name), chunk=config.eval.chunk) if jax.process_index() != 0: continue #----------------------------------------------------------- # Get showcase example for logging if not config.eval.eval_once and idx == showcase_index: showcase_color = pred_color showcase_disp = pred_disp showcase_acc = pred_acc if not config.dataset.render_path: showcase_gt = test_pixels #----------------------------------------------------------- # If get pixels available, evaluate if not config.dataset.render_path: psnr = model_utils.compute_psnr( ((pred_color - test_pixels)**2).mean()) ssim = skmetrics.structural_similarity( pred_color.astype(np.float32), test_pixels.astype(np.float32), win_size=11, multichannel=True, gaussian_weights=True) logging.info(f"PSNR = {psnr:.4f}, SSIM = {ssim:.4f}") # pylint: disable=logging-format-interpolation psnr_values.append(float(psnr)) ssim_values.append(float(ssim)) #----------------------------------------------------------- # Save generated image if config.eval.save_output: model_utils.save_img( pred_color, os.path.join(out_dir, "{:03d}.png".format(idx))) if pred_disp is not None: model_utils.save_img( pred_disp[Ellipsis, 0], os.path.join(out_dir, "disp_{:03d}.png".format(idx))) #----------------------------------------------------------- if (not config.eval.eval_once) and (jax.process_index() == 0): summary_writer.image("pred_color", showcase_color, step) if showcase_disp is not None: summary_writer.image("pred_disp", showcase_disp, step) if showcase_acc is not None: summary_writer.image("pred_acc", showcase_acc, step) if not config.dataset.render_path: summary_writer.scalar("eval_metric/psnr", np.mean(np.array(psnr_values)), step) summary_writer.scalar("eval_metric/ssim", np.mean(np.array(ssim_values)), step) summary_writer.image("target", showcase_gt, step) #----------------------------------------------------------- # Save the metric to file if config.eval.save_output and (not config.dataset.render_path) and ( jax.process_index() == 0): with file_utils.open_file( os.path.join(out_dir, f"psnrs_{step}.txt"), "w") as f: f.write(" ".join([str(v) for v in psnr_values])) with file_utils.open_file( os.path.join(out_dir, f"ssims_{step}.txt"), "w") as f: f.write(" ".join([str(v) for v in ssim_values])) with file_utils.open_file(os.path.join(out_dir, "psnr.txt"), "w") as f: f.write("{}".format(np.mean(np.array(psnr_values)))) with file_utils.open_file(os.path.join(out_dir, "ssim.txt"), "w") as f: f.write("{}".format(np.mean(np.array(ssim_values)))) if config.eval.eval_once: break if int(step) >= config.train.max_steps: break last_step = step logging.info("Finishing evaluation at step %d", last_step)
def main(argv): del argv logging.info("*** Starting experiment") gin_configs = FLAGS.gin_configs logging.info("*** Loading Gin configs from: %s", str(gin_configs)) gin.parse_config_files_and_bindings(config_files=gin_configs, bindings=FLAGS.gin_bindings, skip_unknown=True) # Load configurations. exp_config = configs.ExperimentConfig() model_config = configs.ModelConfig(use_stratified_sampling=False) train_config = configs.TrainConfig() eval_config = configs.EvalConfig() # Get directory information. exp_dir = gpath.GPath(FLAGS.exp_dir) if exp_config.subname: exp_dir = exp_dir / exp_config.subname logging.info("\texp_dir = %s", exp_dir) if not exp_dir.exists(): exp_dir.mkdir(parents=True, exist_ok=True) summary_dir = exp_dir / "summaries" / "eval" logging.info("\tsummary_dir = %s", summary_dir) if not summary_dir.exists(): summary_dir.mkdir(parents=True, exist_ok=True) renders_dir = exp_dir / "renders" logging.info("\trenders_dir = %s", renders_dir) if not renders_dir.exists(): renders_dir.mkdir(parents=True, exist_ok=True) checkpoint_dir = exp_dir / "checkpoints" logging.info("\tcheckpoint_dir = %s", checkpoint_dir) rng = random.PRNGKey(20200823) devices_to_use = jax.devices() n_devices = len( devices_to_use) if devices_to_use else jax.local_device_count() datasource_spec = exp_config.datasource_spec if datasource_spec is None: datasource_spec = { "type": exp_config.datasource_type, "data_dir": FLAGS.data_dir, } logging.info("Creating datasource: %s", datasource_spec) datasource = datasets.from_config( datasource_spec, image_scale=exp_config.image_scale, use_appearance_id=model_config.use_appearance_metadata, use_camera_id=model_config.use_camera_metadata, use_warp_id=model_config.use_warp, random_seed=exp_config.random_seed, ) # Get training IDs to evaluate. train_eval_ids = utils.strided_subset(datasource.train_ids, eval_config.num_train_eval) train_eval_iter = datasource.create_iterator(train_eval_ids, batch_size=0) val_eval_ids = utils.strided_subset(datasource.val_ids, eval_config.num_val_eval) val_eval_iter = datasource.create_iterator(val_eval_ids, batch_size=0) test_cameras = datasource.load_test_cameras( count=eval_config.num_test_eval) if test_cameras: test_dataset = datasource.create_cameras_dataset(test_cameras) test_eval_ids = [f"{x:03d}" for x in range(len(test_cameras))] test_eval_iter = datasets.iterator_from_dataset(test_dataset, batch_size=0) else: test_eval_ids = None test_eval_iter = None rng, key = random.split(rng) params = {} model, params["model"] = models.nerf( key, model_config, batch_size=eval_config.chunk, num_appearance_embeddings=len(datasource.appearance_ids), num_camera_embeddings=len(datasource.camera_ids), num_warp_embeddings=len(datasource.warp_ids), near=datasource.near, far=datasource.far, use_warp_jacobian=False, use_weights=False, ) optimizer_def = optim.Adam(0.0) optimizer = optimizer_def.create(params) init_state = model_utils.TrainState(optimizer=optimizer, warp_alpha=0.0) del params def _model_fn(key_0, key_1, params, rays_dict, alpha): out = model.apply( {"params": params}, rays_dict, warp_alpha=alpha, rngs={ "coarse": key_0, "fine": key_1 }, mutable=False, ) return jax.lax.all_gather(out, axis_name="batch") pmodel_fn = jax.pmap( # Note rng_keys are useless in eval mode since there's no randomness. _model_fn, in_axes=(0, 0, 0, 0, None), # Only distribute the data input. devices=devices_to_use, donate_argnums=(3, ), # Donate the 'rays' argument. axis_name="batch", ) render_fn = functools.partial( evaluation.render_image, model_fn=pmodel_fn, device_count=n_devices, chunk=eval_config.chunk, ) last_step = 0 summary_writer = tensorboard.SummaryWriter(str(summary_dir)) while True: if not checkpoint_dir.exists(): logging.info("No checkpoints yet.") time.sleep(10) continue state = checkpoints.restore_checkpoint(checkpoint_dir, init_state) state = jax_utils.replicate(state, devices=devices_to_use) step = int(state.optimizer.state.step[0]) if step <= last_step: logging.info("No new checkpoints (%d <= %d).", step, last_step) time.sleep(10) continue save_dir = renders_dir if eval_config.save_output else None process_iterator( tag="train", item_ids=train_eval_ids, iterator=train_eval_iter, state=state, rng=rng, step=step, render_fn=render_fn, summary_writer=summary_writer, save_dir=save_dir, datasource=datasource, ) process_iterator( tag="val", item_ids=val_eval_ids, iterator=val_eval_iter, state=state, rng=rng, step=step, render_fn=render_fn, summary_writer=summary_writer, save_dir=save_dir, datasource=datasource, ) if test_eval_iter: process_iterator( tag="test", item_ids=test_eval_ids, iterator=test_eval_iter, state=state, rng=rng, step=step, render_fn=render_fn, summary_writer=summary_writer, save_dir=save_dir, datasource=datasource, ) if eval_config.eval_once: break if step >= train_config.max_steps: break last_step = step
def restore_checkpoint(state, workdir): return checkpoints.restore_checkpoint(workdir, state)
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ model = create_model(config) optimizer = create_optimizer(config, model) del model # don't keep a copy of the initial model output_dir = os.path.join(workdir, "pretraining") gfile.makedirs(output_dir) # Restore from a local checkpoint, if one exists. optimizer = checkpoints.restore_checkpoint(output_dir, optimizer) start_step = int(optimizer.state.step) optimizer = optimizer.replicate() tokenizer = spm.SentencePieceProcessor() if "vocab_file" in config: tokenizer.Load(config.vocab_file) tokenizer.SetEncodeExtraOptions("") # Note: [CLS] and [SEP] will be added by the data pipeline, not the tokenizer learning_rate_fn = train_utils.create_learning_rate_scheduler( factors="constant * linear_warmup * cosine_decay", base_learning_rate=config.learning_rate, warmup_steps=config.num_warmup_steps, steps_per_cycle=config.num_train_steps - config.num_warmup_steps, ) train_history = train_utils.TrainStateHistory(learning_rate_fn) train_state = train_history.initial_state() if config.do_train: train_iter = input_pipeline.c4_masked_lm_inputs( config.train_batch_size, tokenizer, config.max_seq_length, config.max_predictions_per_seq) train_step_fn = train_utils.create_train_step( compute_pretraining_loss_and_metrics, clip_grad_norm=1.0) for step in range(start_step, config.num_train_steps): with jax.profiler.StepTraceContext("train", step_num=step): batch = next(train_iter) optimizer, train_state = train_step_fn(optimizer, batch, train_state) if step % config.save_checkpoints_steps == 0 and jax.host_id( ) == 0: checkpoints.save_checkpoint(output_dir, optimizer.unreplicate(), step) if config.do_eval: eval_iter = input_pipeline.c4_masked_lm_inputs( config.eval_batch_size, tokenizer, config.max_seq_length, config.max_predictions_per_seq) eval_iter = itertools.islice(eval_iter, config.max_eval_steps) eval_fn = train_utils.create_eval_fn(compute_pretraining_stats, sample_feature_name="input_ids") eval_stats = eval_fn(optimizer, eval_iter) eval_metrics = { "loss": jnp.mean(eval_stats["loss"]), "masked_lm_loss": jnp.mean(eval_stats["masked_lm_loss"]), "next_sentence_loss": jnp.mean(eval_stats["next_sentence_loss"]), "masked_lm_accuracy": jnp.sum(eval_stats["masked_lm_correct"]) / jnp.sum(eval_stats["masked_lm_total"]), "next_sentence_accuracy": jnp.sum(eval_stats["next_sentence_correct"]) / jnp.sum(eval_stats["next_sentence_total"]), } eval_results = [] for name, val in sorted(eval_metrics.items()): line = f"{name} = {val:.06f}" print(line, flush=True) logging.info(line) eval_results.append(line) eval_results_path = os.path.join(output_dir, "eval_results.txt") with gfile.GFile(eval_results_path, "w") as f: for line in eval_results: f.write(line + "\n")