def main(_): tf.enable_v2_behavior() tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) if not gfile.isdir(FLAGS.save_dir): gfile.makedirs(FLAGS.save_dir) hparam_str_dict = json.loads(FLAGS.xm_parameters) hparam_str = ','.join(['%s=%s' % (shorten(k), str(hparam_str_dict[k])) for k in hparam_str_dict.keys()]) # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) batch_size = FLAGS.per_device_batch_size * n_devices io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.max_characters) predict_io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.predict_max_characters) target_shape = (FLAGS.per_device_batch_size, FLAGS.max_target_length) indices_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task) program_shape = (FLAGS.per_device_batch_size, FLAGS.max_program_length) # pylint: disable=unused-variable # Setup DSL # --------------------------------------------------------------------------- # Build token tables. if FLAGS.dataset_type in ['robust_fill', 'robust_fill_base']: spec_vocab = robust_fill_dsl.CHARACTER + '|' spec_id_token_table = {i+3: token for i, token in enumerate(spec_vocab)} bos_id = 1 eos_id = 2 spec_id_token_table[bos_id] = robust_fill_dsl.BOS spec_id_token_table[eos_id] = robust_fill_dsl.EOS spec_token_id_table = {token: id for id, token in spec_id_token_table.items()} spec_vocab_size = len(spec_token_id_table) + 1 # For padding. program_id_token_table, program_token_id_table = ( dsl_tokens.build_token_tables()) program_vocab_size = len(program_id_token_table) + 1 # pylint: disable=unused-variable elif FLAGS.dataset_type == 'scan': # TODO(jxihong): Scan is not handled yet. raise ValueError('Unhandled dataset_type: {}'.format(FLAGS.dataset_type)) else: raise ValueError('Unhandled dataset_type: {}'.format(FLAGS.dataset_type)) # Parse io and program token sequences (for eval). def decode_io(inputs, outputs): """Decode io examples tokens.""" if FLAGS.dataset_type == 'robust_fill': def decode_str(s): """Decode string tokens.""" return ''.join([spec_id_token_table[t_id] for t_id in s if t_id > 0]) inps, outs = [], [] for inp, out in zip(inputs, outputs): inps.append(decode_str(inp)) outs.append(decode_str(out)) return inps, outs elif FLAGS.dataset_type == 'scan': def decode_str(s): """Decode string tokens.""" return ' '.join([spec_id_token_table[t_id] for t_id in s if t_id > 0]) inps = [decode_str(inp) for inp in inputs] dummy_outs = [''] * len(inps) return inps, dummy_outs else: raise ValueError('Unhandled dataset_type: {}'.format(FLAGS.dataset_type)) def decode_target(target): """Decode program tokens.""" target = target[:np.argmax(target == eos_id)].astype(np.int32) if FLAGS.dataset_type == 'robust_fill': target = target[target != bos_id].tolist() return ''.join([spec_id_token_table[t_id] for t_id in target if t_id > 0]) elif FLAGS.dataset_type == 'scan': # TODO(jxihong): Scan is not handled yet. raise ValueError('Unhandled dataset_type: {}'.format(FLAGS.dataset_type)) else: raise ValueError('Unhandled dataset_type: {}'.format(FLAGS.dataset_type)) def decode_program(program): """Decode program tokens.""" program = program[:np.argmax(program == eos_id) + 1].astype(np.int32) if FLAGS.dataset_type == 'robust_fill': # Returns either a Concat program object, or None. program = program[program != bos_id].tolist() try: return robust_fill_dsl.decode_program(program, program_id_token_table) except: # pylint: disable=bare-except return None # Program does not compile. elif FLAGS.dataset_type == 'scan': # Returns a string. program = program[jnp.logical_and(program != bos_id, program != eos_id)].tolist() return ' '.join(scan_vocab.decode(program, program_id_token_table)) else: raise ValueError('Unhandled dataset_type: {}'.format(FLAGS.dataset_type)) def decode_program_str(program): # pylint: disable=unused-variable """Decode program tokens into a string.""" decoded = decode_program(program) if FLAGS.dataset_type == 'robust_fill': try: return decoded.to_string() except: # pylint: disable=bare-except return 'did not compile' else: assert isinstance(decoded, str), '{} should be string'.format(decoded) return decoded # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') if not FLAGS.dataset_filepattern: raise ValueError('Must specify filepattern to dataset.') # Training dataset. logging.info('Loading dataset from %s', FLAGS.dataset_filepattern) padded_shapes = { 'inputs': io_shape[1:], 'outputs': io_shape[1:], 'spec_parts': target_shape[1:], 'start_index': indices_shape[1:], 'end_index': indices_shape[1:], # TODO(kshi): include programs. # 'program': program_shape[1:], } logging.info('padded_shapes: %s', padded_shapes) if FLAGS.dataset_type == 'robust_fill': create_dataset_fn = input_pipeline.create_robust_fill_dataset_from_tf_record elif FLAGS.dataset_type == 'scan': raise NotImplementedError() # TODO(kshi): Implement. # create_dataset_fn = input_pipeline.create_scan_dataset_from_tf_record else: raise ValueError('Unhandled dataset_type: {}'.format(FLAGS.dataset_type)) dataset = create_dataset_fn( FLAGS.dataset_filepattern, program_token_id_table, spec_token_id_table, max_target_length=FLAGS.max_target_length) dataset = dataset.padded_batch( batch_size, padded_shapes=padded_shapes, drop_remainder=True) # Split evaluation and training. eval_ds = dataset.take(FLAGS.num_eval_steps) # Decrease batch of predict dataset to handle beam search. predict_padded_shapes = padded_shapes.copy() predict_padded_shapes['inputs'] = predict_io_shape[1:] predict_padded_shapes['outputs'] = predict_io_shape[1:] logging.info('predict_padded_shapes: %s', predict_padded_shapes) predict_ds = eval_ds.unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=predict_padded_shapes) train_ds = dataset.skip(FLAGS.num_eval_steps) if FLAGS.train_set_batches > 0: train_ds = train_ds.take(FLAGS.train_set_batches) train_ds = train_ds.repeat() test_dataset = create_dataset_fn( FLAGS.test_dataset_filepattern, program_token_id_table, spec_token_id_table, max_target_length=FLAGS.max_target_length) test_dataset = test_dataset.padded_batch( batch_size, padded_shapes=predict_padded_shapes, drop_remainder=False) quick_test_dataset = (test_dataset .take(FLAGS.num_quick_test_steps) .unbatch() .padded_batch(int(np.ceil(batch_size / 10)), padded_shapes=predict_padded_shapes)) final_test_dataset = (test_dataset .take(FLAGS.num_final_test_steps) .unbatch() .padded_batch(int(np.ceil(batch_size / 10)), padded_shapes=predict_padded_shapes)) # Build Model and Optimizer # --------------------------------------------------------------------------- base_config = base_models.TransformerConfig( vocab_size=spec_vocab_size, output_vocab_size=spec_vocab_size, shift=True, emb_dim=FLAGS.embedding_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.embedding_dim, mlp_dim=FLAGS.hidden_dim, max_len=max(FLAGS.max_characters, FLAGS.max_program_length), dropout_rate=FLAGS.dropout_rate, attention_dropout_rate=FLAGS.attention_dropout_rate, use_relative_attention=FLAGS.use_relative_attention, deterministic=False, decode=False, bos_token=bos_id, num_input_relative_position_buckets=FLAGS.num_position_buckets, max_input_distance=FLAGS.max_distance, num_output_relative_position_buckets=FLAGS.num_position_buckets, max_output_distance=FLAGS.max_distance, num_input_cross_output_relative_position_buckets=( FLAGS.num_position_buckets), max_input_cross_output_distance=FLAGS.max_distance, num_program_relative_position_buckets=FLAGS.num_position_buckets, max_program_distance=FLAGS.max_distance, num_program_cross_embed_relative_position_buckets=( FLAGS.num_position_buckets), max_program_cross_embed_distance=FLAGS.max_distance, bidirectional_program_attention=FLAGS.bidirectional_program_attention) train_config = models.DecomposeAttentionTransformerConfig( base_config=base_config, dataset_type=FLAGS.dataset_type) eval_config = models.DecomposeAttentionTransformerConfig( base_config=base_config.replace(deterministic=True), dataset_type=FLAGS.dataset_type) predict_config = models.DecomposeAttentionTransformerConfig( base_config=base_config.replace( shift=False, deterministic=True, decode=not FLAGS.slow_decode, max_len=max(FLAGS.predict_max_characters, FLAGS.max_target_length)), dataset_type=FLAGS.dataset_type) rng = jax.random.PRNGKey(FLAGS.seed) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = jax.random.split(rng) dropout_rng = jax.random.split(rng, jax.local_device_count()) del rng m = models.DecomposeAttentionTransformer(eval_config) initial_variables = jax.jit(m.init)( init_rng, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) optimizer_def = optim.Adam( FLAGS.lr, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(initial_variables['params']) del initial_variables # Don't keep a copy of the initial model. start_step = 0 if FLAGS.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer) # Grab last step. start_step = int(optimizer.state.step) logging.info('Found model checkpointed at step %d.', start_step) if FLAGS.finetune_start_step > 0: logging.info('Checking that start_step (%s) == finetune_start_step (%s)', start_step, FLAGS.finetune_start_step) assert start_step >= FLAGS.finetune_start_step steps_to_skip = start_step - FLAGS.finetune_start_step else: steps_to_skip = start_step # TODO(kshi): It is likely that this code can lead to the job stalling for # 10+ hours when restarting from a checkpoint that had been trained a long # time, possibly because dataset skipping is slow. logging.info('Skipping %s steps...', steps_to_skip) train_ds = train_ds.skip(steps_to_skip) dummy_p_train_step = jax.pmap( lambda dropout_rng: jax.random.split(dropout_rng)[1]) for _ in range(steps_to_skip): dropout_rng = dummy_p_train_step(dropout_rng) logging.info('Finished skipping steps') logging.info('Host %s has dropout_rng = %s', jax.host_id(), dropout_rng) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) # TODO(jxihong): Implement fast decoding. assert FLAGS.slow_decode, 'Fast decoding is not implemented yet.' if FLAGS.finetune_start_step <= 0: learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=FLAGS.lr) else: # Constant LR for finetuning. learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=FLAGS.lr, factors='constant') p_train_step = jax.pmap( functools.partial( train_step, learning_rate_fn=learning_rate_fn, config=train_config), axis_name='batch') p_eval_step = jax.pmap( functools.partial(eval_step, eos_token=eos_id, config=eval_config), axis_name='batch') p_init_cache = jax.pmap( functools.partial( initialize_cache, max_decode_len=FLAGS.max_target_length, config=predict_config), axis_name='batch') p_pred_step = jax.pmap( functools.partial( predict_step, eos_token=eos_id, max_decode_len=FLAGS.max_target_length, config=predict_config, slow_decode=FLAGS.slow_decode), axis_name='batch', static_broadcasted_argnums=(4,)) # Main Train Loop # --------------------------------------------------------------------------- logging.info('Starting training!') metrics_all = [] tick = time.time() train_iter = train_ds.as_numpy_iterator() for step in range(start_step, FLAGS.num_train_steps): inputs, outputs, targets = load_data(next(train_iter)) optimizer, metrics, dropout_rng = p_train_step( optimizer, inputs, outputs, targets, dropout_rng=dropout_rng) metrics_all.append(metrics) is_last_step = step == FLAGS.num_train_steps - 1 # Periodic metric handling. # Training Metrics if (step and step % FLAGS.log_freq == 0) or is_last_step: logging.info('Gathering training metrics.') metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map( lambda x: x / denominator, # pylint: disable=cell-var-from-loop metrics_sums) summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) if jax.host_id() == 0: logging.info('Train in step: %d, loss: %.4f', step, summary['loss']) tock = time.time() steps_per_sec = FLAGS.log_freq / (tock - tick) tick = tock summary_writer.scalar('train/steps per second', steps_per_sec, step) for key, val in summary.items(): summary_writer.scalar('train/' + key, val, step) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] # Evaluation Metrics if (step and step % FLAGS.eval_freq == 0) or is_last_step: logging.info('Gathering evaluation metrics.') t_evaluation_start = time.time() eval_metrics = [] for batches in eval_ds.as_numpy_iterator(): inputs, outputs, targets = load_data(batches) metrics = p_eval_step(optimizer.target, inputs, outputs, targets) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) if jax.host_id() == 0: logging.info('Evaluation time: %.4f s step %d, loss: %.4f.', time.time()-t_evaluation_start, step, eval_summary['loss']) for key, val in eval_summary.items(): summary_writer.scalar('eval/' + key, val, step) summary_writer.flush() # Beam search metrics. if (step and step % FLAGS.predict_freq == 0) or is_last_step: logging.info('Gathering beam search metrics.') test_ds = final_test_dataset if is_last_step else quick_test_dataset for dataset, predict_or_test in [(predict_ds, 'predict'), (test_ds, 'test')]: for beam_size in [1, 10]: t_inference_start = time.time() total_successes = 0 total_denominator = 0 ios, targets_list, predictions, top_of_beams = [], [], [], [] for batches in dataset.as_numpy_iterator(): pred_batch = batches # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch['inputs'].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) # pylint: disable=cell-var-from-loop pred_batch = jax.tree_map( lambda x: pad_examples(x, padded_size), pred_batch) inputs, outputs, targets = load_data(pred_batch) cache = (p_init_cache(inputs, outputs, targets) if not FLAGS.slow_decode else None) predicted = p_pred_step(optimizer.target, inputs, outputs, cache, beam_size) predicted = tohost(predicted) inputs, outputs, targets = map(tohost, (inputs, outputs, targets)) for i, beams in enumerate(predicted): inps, outs = decode_io(inputs[i], outputs[i]) ground_truth = decode_target(targets[i]) beams_target = [decode_target(beam) for beam in beams] predicted_target = beams_target[0] for beam_target in beams_target: if beam_target == ground_truth: predicted_target = beam_target total_successes += 1 break total_denominator += 1 ios.append(' ; '.join(map(str, zip(inps, outs)))) targets_list.append(ground_truth) predictions.append(predicted_target) logging.info('') logging.info('ios: %s', ios[-1]) logging.info('targets[%s]: %s', i, targets[i]) logging.info('ground_truth: %s', ground_truth) logging.info('predicted beam: %s', '\n'.join(beams_target)) logging.info('predicted_target: %s', predicted_target) logging.info('beams: %s', beams) if not ground_truth: logging.warn('ground_truth is empty!') top_of_beam = [] for index, beam in enumerate(beams[:-5:-1]): top_of_beam.append('index: {}, decoded: {}, tokens: {}'.format( index, decode_target(beam), beam)) top_of_beams.append('\n\n'.join(top_of_beam)) all_total_successes, all_total_denominator = per_host_sum_pmap( jax.tree_map(np.array, (total_successes, total_denominator))) # Record beam search results as text summaries. message = [] for n in np.random.choice(np.arange(len(predictions)), 8): text = (f'ios: {ios[n]}\n\ntarget: {targets_list[n]}\n\n' f'predicted: {predictions[n]}\n\n' f'top of beam:\n\n{top_of_beams[n]}\n\n') message.append(text) # Write to tensorboard. if jax.host_id() == 0: accuracy = 100 * all_total_successes / all_total_denominator logging.info( '%s results, step %d, beam size %d: %s / %s = %.2f%% (%.2f s)', predict_or_test, step, beam_size, all_total_successes, all_total_denominator, accuracy, time.time() - t_inference_start) summary_writer.scalar( '{}/beam-size-{}'.format(predict_or_test, beam_size), accuracy, step) summary_writer.text('{}-samples-beam-{}'.format(predict_or_test, beam_size), '\n------\n'.join(message), step) summary_writer.flush() # Save a Checkpoint. Do this at the end of the training loop, so that if a # worker is descheduled during a round of prediction (which takes a while), # we will redo prediction upon restarting (to avoid losing data). if (step % FLAGS.checkpoint_freq == 0 and step > 0) or is_last_step: if jax.host_id() == 0: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), jax_utils.unreplicate(optimizer), step)
def evaluate(config, workdir, eval_folder="eval"): """Evaluate trained models. Args: config: Configuration to use. workdir: Working directory for checkpoints. eval_folder: The subfolder for storing evaluation results. Default to "eval". """ # Create eval_dir eval_dir = os.path.join(workdir, eval_folder) tf.io.gfile.makedirs(eval_dir) rng = jax.random.PRNGKey(config.seed + 1) # Build input pipeline. rng, ds_rng = jax.random.split(rng) _, eval_ds, _ = datasets.get_dataset(ds_rng, config, evaluation=True) scaler = datasets.get_data_scaler(config) inverse_scaler = datasets.get_data_inverse_scaler(config) # Initialize model. rng, model_rng = jax.random.split(rng) model_name = config.model.name ncsn_def = mutils.get_model(model_name).partial(config=config) rng, run_rng = jax.random.split(rng) class_conditional = "conditional" in config.training.loss.lower() with nn.stateful() as init_model_state: with nn.stochastic(run_rng): input_shape = tuple(eval_ds.element_spec["image"].shape[1:]) input_list = [(input_shape, jnp.float32), (input_shape[:1], jnp.int32)] if class_conditional: input_list.append(input_list[-1]) _, initial_params = ncsn_def.init_by_shape(model_rng, input_list, train=True) ncsn = nn.Model(ncsn_def, initial_params) optimizer = losses.get_optimizer(config).create(ncsn) state = mutils.State(step=0, optimizer=optimizer, lr=config.optim.lr, model_state=init_model_state, ema_rate=config.model.ema_rate, params_ema=initial_params, rng=rng) # pytype: disable=wrong-keyword-args del ncsn, init_model_state # Do not keep a copy of the initial model. checkpoint_dir = os.path.join(workdir, "checkpoints") if config.training.loss.lower() == "ddpm": # Use the score matching loss with DDPM-type perturbation. ddpm_params = mutils.get_ddpm_params() eval_step = functools.partial(losses.ddpm_loss, ddpm_params=ddpm_params, train=False) else: # Use the score matching loss with NCSN-type perturbation. sigmas = mutils.get_sigmas(config) continuous = "continuous" in config.training.loss.lower() eval_step = functools.partial( losses.ncsn_loss, sigmas=sigmas, continuous=continuous, class_conditional=class_conditional, train=False, anneal_power=config.training.anneal_power) p_eval_step = jax.pmap(eval_step, axis_name="batch") rng = jax.random.fold_in(rng, jax.host_id()) # A data class for checkpointing. @flax.struct.dataclass class EvalMeta: ckpt_id: int round_id: int rng: Any # Add one additional round to get the exact number of samples as required. num_rounds = config.eval.num_samples // config.eval.batch_size + 1 eval_meta = EvalMeta(ckpt_id=config.eval.begin_ckpt, round_id=-1, rng=rng) eval_meta = checkpoints.restore_checkpoint(eval_dir, eval_meta, step=None, prefix=f"meta_{jax.host_id()}_") if eval_meta.round_id < num_rounds - 1: begin_ckpt = eval_meta.ckpt_id begin_round = eval_meta.round_id + 1 else: begin_ckpt = eval_meta.ckpt_id + 1 begin_round = 0 rng = eval_meta.rng # Use inceptionV3 for images with higher resolution inceptionv3 = config.data.image_size >= 256 inception_model = evaluation.get_inception_model(inceptionv3=inceptionv3) logging.info("begin checkpoint: %d", begin_ckpt) for ckpt in range(begin_ckpt, config.eval.end_ckpt + 1): ckpt_filename = os.path.join(checkpoint_dir, "ckpt-{}.flax".format(ckpt)) # Wait if the target checkpoint hasn't been produced yet. waiting_message_printed = False while not tf.io.gfile.exists(ckpt_filename): if not waiting_message_printed and jax.host_id() == 0: logging.warn("Waiting for the arrival of ckpt-%d.flax", ckpt) waiting_message_printed = True time.sleep(10) # In case the file was just written and not ready to read from yet. try: state = utils.load_state_dict(ckpt_filename, state) except: time.sleep(60) try: state = utils.load_state_dict(ckpt_filename, state) except: time.sleep(120) state = utils.load_state_dict(ckpt_filename, state) pstate = flax.jax_utils.replicate(state) eval_iter = iter(eval_ds) # pytype: disable=wrong-arg-types # Compute the loss function on the full evaluation dataset. all_losses = [] for i, batch in enumerate(eval_iter): rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1) next_rng = jnp.asarray(next_rng) eval_batch = jax.tree_map(lambda x: scaler(x._numpy()), batch) # pylint: disable=protected-access eval_loss, _ = p_eval_step(next_rng, pstate, eval_batch) eval_loss = flax.jax_utils.unreplicate(eval_loss) all_losses.append(eval_loss) if (i + 1) % 1000 == 0 and jax.host_id() == 0: logging.info("Finished %dth step loss evaluation", i + 1) all_losses = jnp.asarray(all_losses) state = jax.device_put(state) # Sampling and computing statistics for Inception scores, FIDs, and KIDs. # Designed to be pre-emption safe. Automatically resumes when interrupted. for r in range(begin_round, num_rounds): if jax.host_id() == 0: logging.info("sampling -- ckpt: %d, round: %d", ckpt, r) rng, sample_rng = jax.random.split(rng) init_shape = tuple(eval_ds.element_spec["image"].shape) this_sample_dir = os.path.join( eval_dir, f"ckpt_{ckpt}_host_{jax.host_id()}") tf.io.gfile.makedirs(this_sample_dir) samples = sampling.get_samples(sample_rng, config, state, init_shape, scaler, inverse_scaler, class_conditional=class_conditional) samples = samples[-1] samples = np.clip(samples * 255., 0, 255).astype(np.uint8) samples = samples.reshape( (-1, config.data.image_size, config.data.image_size, 3)) with tf.io.gfile.GFile( os.path.join(this_sample_dir, f"samples_{r}.npz"), "wb") as fout: io_buffer = io.BytesIO() np.savez_compressed(io_buffer, samples=samples) fout.write(io_buffer.getvalue()) gc.collect() latents = evaluation.run_inception_distributed( samples, inception_model, inceptionv3=inceptionv3) gc.collect() with tf.io.gfile.GFile( os.path.join(this_sample_dir, f"statistics_{r}.npz"), "wb") as fout: io_buffer = io.BytesIO() np.savez_compressed(io_buffer, pool_3=latents["pool_3"], logits=latents["logits"]) fout.write(io_buffer.getvalue()) eval_meta = eval_meta.replace(ckpt_id=ckpt, round_id=r, rng=rng) # Save an intermediate checkpoint directly if not the last round. # Otherwise save eval_meta after computing the Inception scores and FIDs if r < num_rounds - 1: checkpoints.save_checkpoint(eval_dir, eval_meta, step=ckpt * num_rounds + r, keep=1, prefix=f"meta_{jax.host_id()}_") # Compute inception scores, FIDs and KIDs. if jax.host_id() == 0: # Load all statistics that have been previously computed and saved. all_logits = [] all_pools = [] for host in range(jax.host_count()): this_sample_dir = os.path.join(eval_dir, f"ckpt_{ckpt}_host_{host}") stats = tf.io.gfile.glob( os.path.join(this_sample_dir, "statistics_*.npz")) wait_message = False while len(stats) < num_rounds: if not wait_message: logging.warn("Waiting for statistics on host %d", host) wait_message = True stats = tf.io.gfile.glob( os.path.join(this_sample_dir, "statistics_*.npz")) time.sleep(1) for stat_file in stats: with tf.io.gfile.GFile(stat_file, "rb") as fin: stat = np.load(fin) if not inceptionv3: all_logits.append(stat["logits"]) all_pools.append(stat["pool_3"]) if not inceptionv3: all_logits = np.concatenate(all_logits, axis=0)[:config.eval.num_samples] all_pools = np.concatenate(all_pools, axis=0)[:config.eval.num_samples] # Load pre-computed dataset statistics. data_stats = evaluation.load_dataset_stats(config) data_pools = data_stats["pool_3"] if hasattr(config.eval, "num_partitions"): # Divide samples into several partitions and compute FID/KID/IS on them. assert not inceptionv3 fids = [] kids = [] inception_scores = [] partition_size = config.eval.num_samples // config.eval.num_partitions tf_data_pools = tf.convert_to_tensor(data_pools) for i in range(config.eval.num_partitions): this_pools = all_pools[i * partition_size:(i + 1) * partition_size] this_logits = all_logits[i * partition_size:(i + 1) * partition_size] inception_scores.append( tfgan.eval.classifier_score_from_logits(this_logits)) fids.append( tfgan.eval. frechet_classifier_distance_from_activations( data_pools, this_pools)) this_pools = tf.convert_to_tensor(this_pools) kids.append( tfgan.eval.kernel_classifier_distance_from_activations( tf_data_pools, this_pools).numpy()) fids = np.asarray(fids) inception_scores = np.asarray(inception_scores) kids = np.asarray(kids) with tf.io.gfile.GFile( os.path.join(eval_dir, f"report_all_{ckpt}.npz"), "wb") as f: io_buffer = io.BytesIO() np.savez_compressed(io_buffer, all_losses=all_losses, mean_loss=all_losses.mean(), ISs=inception_scores, fids=fids, kids=kids) f.write(io_buffer.getvalue()) else: # Compute FID/KID/IS on all samples together. if not inceptionv3: inception_score = tfgan.eval.classifier_score_from_logits( all_logits) else: inception_score = -1 fid = tfgan.eval.frechet_classifier_distance_from_activations( data_pools, all_pools) # Hack to get tfgan KID work for eager execution. tf_data_pools = tf.convert_to_tensor(data_pools) tf_all_pools = tf.convert_to_tensor(all_pools) kid = tfgan.eval.kernel_classifier_distance_from_activations( tf_data_pools, tf_all_pools).numpy() del tf_data_pools, tf_all_pools logging.info( "ckpt-%d --- loss: %.6e, inception_score: %.6e, FID: %.6e, KID: %.6e", ckpt, all_losses.mean(), inception_score, fid, kid) with tf.io.gfile.GFile( os.path.join(eval_dir, f"report_{ckpt}.npz"), "wb") as f: io_buffer = io.BytesIO() np.savez_compressed(io_buffer, all_losses=all_losses, mean_loss=all_losses.mean(), IS=inception_score, fid=fid, kid=kid) f.write(io_buffer.getvalue()) else: # For host_id() != 0. # Use file existence to emulate synchronization across hosts. if hasattr(config.eval, "num_partitions"): assert not inceptionv3 while not tf.io.gfile.exists( os.path.join(eval_dir, f"report_all_{ckpt}.npz")): time.sleep(1.) else: while not tf.io.gfile.exists( os.path.join(eval_dir, f"report_{ckpt}.npz")): time.sleep(1.) # Save eval_meta after computing IS/KID/FID to mark the end of evaluation # for this checkpoint. checkpoints.save_checkpoint(eval_dir, eval_meta, step=ckpt * num_rounds + r, keep=1, prefix=f"meta_{jax.host_id()}_") begin_round = 0 # Remove all meta files after finishing evaluation. meta_files = tf.io.gfile.glob( os.path.join(eval_dir, f"meta_{jax.host_id()}_*")) for file in meta_files: tf.io.gfile.remove(file)
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ tf.io.gfile.makedirs(workdir) vocab_path = config.vocab_path if vocab_path is None: vocab_path = os.path.join(workdir, "sentencepiece_model") config.vocab_path = vocab_path tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) # Load Dataset # --------------------------------------------------------------------------- logging.info("Initializing dataset.") train_ds, eval_ds, _, encoder = input_pipeline.get_datasets( n_devices=jax.local_device_count(), config=config, vocab_path=vocab_path) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_id = temperature_sampler.EOS_ID # Default Sentencepiece EOS token. def decode_tokens(toks): valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode("utf-8") def encode_strings(strs, max_len): tokenized_batch = np.zeros((len(strs), max_len), np.int32) for i, s in enumerate(strs): toks = encoder.tokenize(s).numpy() # Remove EOS token in prompt. tokenized_batch[i, :toks.shape[0] - 1] = toks[:-1] return tokenized_batch tokenized_prompts = encode_strings([config.prompts], config.max_predict_length) logging.info("Initializing model, optimizer, and step functions.") # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = models.TransformerConfig( vocab_size=vocab_size, output_vocab_size=vocab_size, logits_via_embedding=config.logits_via_embedding, dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, emb_dim=config.emb_dim, num_heads=config.num_heads, num_layers=config.num_layers, qkv_dim=config.qkv_dim, mlp_dim=config.mlp_dim, max_len=max(config.max_target_length, config.max_eval_target_length), dropout_rate=config.dropout_rate, attention_dropout_rate=config.attention_dropout_rate, deterministic=False, decode=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(deterministic=True, decode=True) start_step = 0 rng = jax.random.PRNGKey(config.seed) rng, init_rng = jax.random.split(rng) rng, inference_rng = random.split(rng) input_shape = (config.per_device_batch_size, config.max_target_length) m = models.TransformerLM(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(input_shape, jnp.float32)) # apply an optimizer to this tree optimizer_def = optim.Adam(config.learning_rate, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=config.weight_decay) optimizer = optimizer_def.create(initial_variables["params"]) # We access model params only from optimizer below via optimizer.target. del initial_variables if config.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint(workdir, optimizer) # Grab last step. start_step = int(optimizer.state.step) writer = metric_writers.create_default_writer( workdir, just_logging=jax.host_id() > 0) if start_step == 0: writer.write_hparams(dict(config)) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps) # compile multidevice versions of train/eval/predict step fn. p_train_step = jax.pmap(functools.partial( train_step, config=train_config, learning_rate_fn=learning_rate_fn), axis_name="batch", donate_argnums=(0, )) # pytype: disable=wrong-arg-types p_eval_step = jax.pmap(functools.partial(eval_step, config=eval_config), axis_name="batch") p_pred_step = jax.pmap( functools.partial(predict_step, config=predict_config, temperature=config.sampling_temperature, top_k=config.sampling_top_k), axis_name="batch", static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap"d training update for performance. dropout_rngs = jax.random.split(rng, jax.local_device_count()) del rng logging.info("Starting training loop.") hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) if jax.host_id() == 0: hooks += [ report_progress, periodic_actions.Profile(logdir=workdir, num_profile_steps=5) ] train_metrics = [] with metric_writers.ensure_flushes(writer): for step in range(start_step, config.num_train_steps): is_last_step = step == config.num_train_steps - 1 # Shard data to devices and do a training step. with jax.profiler.StepTraceContext("train", step_num=step): batch = common_utils.shard( jax.tree_map(np.asarray, next(train_iter))) optimizer, metrics = p_train_step(optimizer, batch, dropout_rng=dropout_rngs) train_metrics.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) for h in hooks: h(step) # Periodic metric handling. if step % config.eval_every_steps == 0 or is_last_step: with report_progress.timed("training_metrics"): logging.info("Gathering training metrics.") train_metrics = common_utils.get_metrics(train_metrics) lr = train_metrics.pop("learning_rate").mean() metrics_sums = jax.tree_map(jnp.sum, train_metrics) denominator = metrics_sums.pop("denominator") summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary["learning_rate"] = lr summary["perplexity"] = jnp.clip(jnp.exp(summary["loss"]), a_max=1.0e4) summary = {"train_" + k: v for k, v in summary.items()} writer.write_scalars(step, summary) train_metrics = [] with report_progress.timed("eval"): eval_results = evaluate( p_eval_step=p_eval_step, target=optimizer.target, eval_ds=eval_ds, num_eval_steps=config.num_eval_steps) # (clipped) perplexity after averaging log-perplexitie eval_results["perplexity"] = jnp.clip(jnp.exp( eval_results["loss"]), a_max=1.0e4) writer.write_scalars( step, {"eval_" + k: v for k, v in eval_results.items()}) with report_progress.timed("generate_text"): exemplars = generate_prediction( p_pred_step=p_pred_step, target=optimizer.target, tokenized_prompts=tokenized_prompts, eos_id=eos_id, inference_rng=inference_rng, decode_tokens=decode_tokens, max_predict_length=config.max_predict_length) writer.write_texts(step, {"samples": exemplars}) # Save a checkpoint on one host after every checkpoint_freq steps. save_checkpoint = (step % config.checkpoint_every_steps == 0 or is_last_step) if config.save_checkpoints and save_checkpoint and jax.host_id( ) == 0: with report_progress.timed("checkpoint"): checkpoints.save_checkpoint( workdir, jax_utils.unreplicate(optimizer), step)
def save_checkpoint(state): if jax.host_id() == 0: # get train state from the first replica state = jax.device_get(jax.tree_map(lambda x: x[0], state)) step = int(state.step) checkpoints.save_checkpoint(FLAGS.model_dir, state, step, keep=3)
def main(_): tf.enable_v2_behavior() tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) if not gfile.isdir(FLAGS.save_dir): gfile.mkdir(FLAGS.save_dir) hparam_str_dict = dict(seed=FLAGS.seed, lr=FLAGS.lr) # Get hyperparmaters if FLAGS.xm_parameters: for key, value in json.loads(FLAGS.xm_parameters).items(): if key not in hparam_str_dict: hparam_str_dict[key] = value hparam_str = ','.join([ '%s=%s' % (k, str(hparam_str_dict[k])) for k in sorted(hparam_str_dict.keys()) ]) # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) batch_size = FLAGS.per_device_batch_size * n_devices io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.max_characters) program_shape = (FLAGS.per_device_batch_size, FLAGS.max_program_length) # Setup DSL # --------------------------------------------------------------------------- # Build token tables. id_char_table = {i + 1: char for (i, char) in enumerate(dsl.CHARACTER)} char_id_table = {char: id for id, char in id_char_table.items()} id_token_table, token_id_table = dsl_tokens.build_token_tables() io_vocab_size = len(char_id_table) + 1 # For padding. program_vocab_size = len(token_id_table) + 1 bos_token = token_id_table[dsl.BOS] eos_token = token_id_table[dsl.EOS] def decode_io(inputs, outputs): """Decode io examples tokens.""" def decode_str(s): """Decode string tokens.""" return ''.join([id_char_table[c_id] for c_id in s if c_id > 0]) io_string = '' inps, outs = [], [] for inp, out in zip(inputs, outputs): inps.append(decode_str(inp)) outs.append(decode_str(out)) io_string += inps[-1] + ' < ' + outs[-1] + ' > ' return inps, outs, io_string[:-3] # Remove last separator. def decode_program(program): """Decode program tokens.""" program = program[:np.argmax(program == eos_token) + 1].astype( np.int32) try: p = dsl.decode_program(program, id_token_table) return p, p.to_string() except: # pylint: disable=bare-except return None, '' # Program does not compile. # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') if not FLAGS.dataset_filepattern: raise ValueError('Must specify filepattern to dataset.') # Training dataset. dataset = input_pipeline.create_dataset_from_tf_record( FLAGS.dataset_filepattern, token_id_table, char_id_table) dataset = dataset.padded_batch(batch_size, padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:]), drop_remainder=True) # Split evaluation and training. eval_ds = dataset.take(FLAGS.num_eval_steps) # Decrease batch of predict dataset to handle beam search. predict_ds = eval_ds.unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:])) train_ds = dataset.skip(FLAGS.num_eval_steps).repeat() train_iter = train_ds.as_numpy_iterator() # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = models.TransformerConfig( vocab_size=io_vocab_size, output_vocab_size=program_vocab_size, shift=True, emb_dim=FLAGS.embedding_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.embedding_dim, mlp_dim=FLAGS.hidden_dim, max_len=max(FLAGS.max_characters, FLAGS.max_program_length), use_relative_attention=FLAGS.use_relative_attention, num_relative_position_buckets=FLAGS.num_relative_position_buckets, deterministic=False, decode=False, bos_token=bos_token) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(shift=False, deterministic=True, decode=True) rng = jax.random.PRNGKey(FLAGS.seed) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = jax.random.split(rng) m = models.ProgramTransformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(program_shape, jnp.float32)) optimizer_def = optim.Adam(FLAGS.lr, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(initial_variables['params']) del initial_variables # Don't keep a copy of the initial model. start_step = 0 if FLAGS.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer) # Grab last step. start_step = int(optimizer.state.step) logging.info('Found model checkpointed at step %d.', start_step) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = train_lib.create_learning_rate_scheduler( base_learning_rate=FLAGS.lr) p_train_step = jax.pmap(functools.partial( train_lib.train_step, learning_rate_fn=learning_rate_fn, config=train_config), axis_name='batch') p_eval_step = jax.pmap(functools.partial(train_lib.eval_step, config=eval_config), axis_name='batch') p_init_cache = jax.pmap(functools.partial( train_lib.initialize_cache, max_decode_len=FLAGS.max_program_length, config=predict_config), axis_name='batch') p_pred_step = jax.pmap(functools.partial(train_lib.predict_step, config=predict_config), axis_name='batch', static_broadcasted_argnums=(4, 5, 6)) # Main Train Loop # --------------------------------------------------------------------------- train_rngs = jax.random.split(rng, jax.local_device_count()) del rng metrics_all = [] tick = time.time() for step in range(start_step, FLAGS.num_train_steps): inputs, outputs, programs = common_utils.shard(next(train_iter)) optimizer, metrics, train_rngs = p_train_step(optimizer, inputs, outputs, programs, train_rng=train_rngs) metrics_all.append(metrics) # Save a Checkpoint if ((step % FLAGS.checkpoint_freq == 0 and step > 0) or step == FLAGS.num_train_steps - 1): if jax.host_id() == 0: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if not step or step % FLAGS.log_freq != 0: continue logging.info('Gathering training metrics.') # Training Metrics metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map( lambda x: x / denominator, # pylint: disable=cell-var-from-loop metrics_sums) summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) if jax.host_id() == 0: logging.info('Train in step: %d, loss: %.4f', step, summary['loss']) tock = time.time() steps_per_sec = FLAGS.log_freq / (tock - tick) tick = tock summary_writer.scalar('train/steps per second', steps_per_sec, step) for key, val in summary.items(): summary_writer.scalar('train/' + key, val, step) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] # Evaluation Metrics logging.info('Gathering evaluation metrics.') t_evaluation_start = time.time() eval_metrics = [] for batches in eval_ds.as_numpy_iterator(): inputs, outputs, programs = common_utils.shard(batches) metrics = p_eval_step(optimizer.target, inputs, outputs, programs) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) if jax.host_id() == 0: logging.info('Evaluation time: %.4f s step %d, loss: %.4f.', time.time() - t_evaluation_start, step, eval_summary['loss']) for key, val in eval_summary.items(): summary_writer.scalar('eval/' + key, val, step) summary_writer.flush() # Beam search metrics. logging.info('Gathering beam search metrics.') for beam_size in [10, 100]: t_inference_start = time.time() pred_acc = 0 pred_denominator = 0 ios, targets, predictions = [], [], [] for batches in predict_ds.as_numpy_iterator(): pred_batch = batches # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch[0].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) # pylint: disable=cell-var-from-loop pred_batch = jax.tree_map( lambda x: train_lib.pad_examples(x, padded_size), pred_batch) inputs, outputs, programs = common_utils.shard(pred_batch) cache = p_init_cache(inputs, outputs, programs) predicted = p_pred_step(optimizer.target, inputs, outputs, cache, eos_token, programs.shape[-1], beam_size) predicted = train_lib.tohost(predicted) inputs, outputs, programs = map(train_lib.tohost, (inputs, outputs, programs)) pred_denominator += programs.shape[0] for i, beams in enumerate(predicted): inps, outs, io_string = decode_io(inputs[i], outputs[i]) p, p_score = train_lib.eval_predicted( beams, inps, outs, parse_beam_fn=lambda x: decode_program(x)[0]) if p_score >= len(inps): pred_acc += 1 ios.append(io_string) targets.append(decode_program(programs[i])[1]) predictions.append(p.to_string() if p else '') all_pred_acc, all_pred_denominator = train_lib.per_host_sum_pmap( jax.tree_map(np.array, (pred_acc, pred_denominator))) # Record beam search results as text summaries. message = [] for n in np.random.choice(np.arange(len(predictions)), 8): text = (f'ios: {ios[n]}\n\ntarget: {targets[n]}\n\n' f'predicted: {predictions[n]}\n\n') message.append(text) # Write to tensorboard. if jax.host_id() == 0: logging.info( 'Prediction time (beam %d): %.4f s step %d, score %.4f.', beam_size, time.time() - t_inference_start, step, all_pred_acc / all_pred_denominator) summary_writer.scalar('predict/score-{}'.format(beam_size), all_pred_acc / all_pred_denominator, step) summary_writer.text('samples-{}'.format(beam_size), '\n------\n'.join(message), step) summary_writer.flush()
def main(unused_argv): rng = random.PRNGKey(20200823) # Shift the numpy random seed by host_id() to shuffle data loaded by different # hosts. np.random.seed(20201473 + jax.host_id()) if FLAGS.config is not None: utils.update_flags(FLAGS) if FLAGS.batch_size % jax.device_count() != 0: raise ValueError( "Batch size must be divisible by the number of devices.") if FLAGS.train_dir is None: raise ValueError("train_dir must be set. None set now.") if FLAGS.data_dir is None: raise ValueError("data_dir must be set. None set now.") dataset = datasets.get_dataset("train", FLAGS) test_dataset = datasets.get_dataset("test", FLAGS) rng, key = random.split(rng) model, variables = models.get_model(key, dataset.peek(), FLAGS) optimizer = flax.optim.Adam(FLAGS.lr_init).create(variables) state = utils.TrainState(optimizer=optimizer) del optimizer, variables learning_rate_fn = functools.partial(utils.learning_rate_decay, lr_init=FLAGS.lr_init, lr_final=FLAGS.lr_final, max_steps=FLAGS.max_steps, lr_delay_steps=FLAGS.lr_delay_steps, lr_delay_mult=FLAGS.lr_delay_mult) train_pstep = jax.pmap(functools.partial(train_step, model), axis_name="batch", in_axes=(0, 0, 0, None), donate_argnums=(2, )) def render_fn(variables, key_0, key_1, rays): return jax.lax.all_gather(model.apply(variables, key_0, key_1, rays, FLAGS.randomized), axis_name="batch") render_pfn = jax.pmap( render_fn, in_axes=(None, None, None, 0), # Only distribute the data input. donate_argnums=(3, ), axis_name="batch", ) # Compiling to the CPU because it's faster and more accurate. ssim_fn = jax.jit(functools.partial(utils.compute_ssim, max_val=1.), backend="cpu") if not utils.isdir(FLAGS.train_dir): utils.makedirs(FLAGS.train_dir) state = checkpoints.restore_checkpoint(FLAGS.train_dir, state) # Resume training a the step of the last checkpoint. init_step = state.optimizer.state.step + 1 state = flax.jax_utils.replicate(state) if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(FLAGS.train_dir) # Prefetch_buffer_size = 3 x batch_size pdataset = flax.jax_utils.prefetch_to_device(dataset, 3) n_local_devices = jax.local_device_count() rng = rng + jax.host_id() # Make random seed separate across hosts. keys = random.split(rng, n_local_devices) # For pmapping RNG keys. gc.disable() # Disable automatic garbage collection for efficiency. stats_trace = [] reset_timer = True for step, batch in zip(range(init_step, FLAGS.max_steps + 1), pdataset): if reset_timer: t_loop_start = time.time() reset_timer = False lr = learning_rate_fn(step) state, stats, keys = train_pstep(keys, state, batch, lr) if jax.host_id() == 0: stats_trace.append(stats) if step % FLAGS.gc_every == 0: gc.collect() # Log training summaries. This is put behind a host_id check because in # multi-host evaluation, all hosts need to run inference even though we # only use host 0 to record results. if jax.host_id() == 0: if step % FLAGS.print_every == 0: summary_writer.scalar("train_loss", stats.loss[0], step) summary_writer.scalar("train_psnr", stats.psnr[0], step) summary_writer.scalar("train_loss_coarse", stats.loss_c[0], step) summary_writer.scalar("train_psnr_coarse", stats.psnr_c[0], step) summary_writer.scalar("weight_l2", stats.weight_l2[0], step) avg_loss = np.mean( np.concatenate([s.loss for s in stats_trace])) avg_psnr = np.mean( np.concatenate([s.psnr for s in stats_trace])) stats_trace = [] summary_writer.scalar("train_avg_loss", avg_loss, step) summary_writer.scalar("train_avg_psnr", avg_psnr, step) summary_writer.scalar("learning_rate", lr, step) steps_per_sec = FLAGS.print_every / (time.time() - t_loop_start) reset_timer = True rays_per_sec = FLAGS.batch_size * steps_per_sec summary_writer.scalar("train_steps_per_sec", steps_per_sec, step) summary_writer.scalar("train_rays_per_sec", rays_per_sec, step) precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1 print(("{:" + "{:d}".format(precision) + "d}").format(step) + f"/{FLAGS.max_steps:d}: " + f"i_loss={stats.loss[0]:0.4f}, " + f"avg_loss={avg_loss:0.4f}, " + f"weight_l2={stats.weight_l2[0]:0.2e}, " + f"lr={lr:0.2e}, " + f"{rays_per_sec:0.0f} rays/sec") if step % FLAGS.save_every == 0: state_to_save = jax.device_get( jax.tree_map(lambda x: x[0], state)) checkpoints.save_checkpoint(FLAGS.train_dir, state_to_save, int(step), keep=100) # Test-set evaluation. if FLAGS.render_every > 0 and step % FLAGS.render_every == 0: # We reuse the same random number generator from the optimization step # here on purpose so that the visualization matches what happened in # training. t_eval_start = time.time() eval_variables = jax.device_get(jax.tree_map( lambda x: x[0], state)).optimizer.target test_case = next(test_dataset) pred_color, pred_disp, pred_acc = utils.render_image( functools.partial(render_pfn, eval_variables), test_case["rays"], keys[0], FLAGS.dataset == "llff", chunk=FLAGS.chunk) # Log eval summaries on host 0. if jax.host_id() == 0: psnr = utils.compute_psnr( ((pred_color - test_case["pixels"])**2).mean()) ssim = ssim_fn(pred_color, test_case["pixels"]) eval_time = time.time() - t_eval_start num_rays = jnp.prod( jnp.array(test_case["rays"].directions.shape[:-1])) rays_per_sec = num_rays / eval_time summary_writer.scalar("test_rays_per_sec", rays_per_sec, step) print( f"Eval {step}: {eval_time:0.3f}s., {rays_per_sec:0.0f} rays/sec" ) summary_writer.scalar("test_psnr", psnr, step) summary_writer.scalar("test_ssim", ssim, step) summary_writer.image("test_pred_color", pred_color, step) summary_writer.image("test_pred_disp", pred_disp, step) summary_writer.image("test_pred_acc", pred_acc, step) summary_writer.image("test_target", test_case["pixels"], step) if FLAGS.max_steps % FLAGS.save_every != 0: state = jax.device_get(jax.tree_map(lambda x: x[0], state)) checkpoints.save_checkpoint(FLAGS.train_dir, state, int(FLAGS.max_steps), keep=100)
def main(unused_argv): rng = random.PRNGKey(20200823) # Shift the numpy random seed by host_id() to shuffle data loaded by different # hosts. np.random.seed(20201473 + jax.host_id()) if FLAGS.config is not None: utils.update_flags(FLAGS) if FLAGS.batch_size % jax.device_count() != 0: raise ValueError( "Batch size must be divisible by the number of devices.") if FLAGS.train_dir is None: raise ValueError("train_dir must be set. None set now.") if FLAGS.data_dir is None: raise ValueError("data_dir must be set. None set now.") dataset = datasets.get_dataset("train", FLAGS) test_dataset = datasets.get_dataset("test", FLAGS) test_render_fn = jax.pmap( # Note rng_keys are useless in eval mode since there's no randomness. # pylint: disable=g-long-lambda lambda key_0, key_1, model, rays: jax.lax.all_gather( model(key_0, key_1, *rays), axis_name="batch"), in_axes=(None, None, None, 0), # Only distribute the data input. donate_argnums=3, axis_name="batch", ) rng, key = random.split(rng) init_model, init_state = models.get_model(key, dataset.peek(), FLAGS) optimizer_def = optim.Adam(FLAGS.lr) optimizer = optimizer_def.create(init_model) state = model_utils.TrainState(step=0, optimizer=optimizer, model_state=init_state) if not utils.isdir(FLAGS.train_dir): utils.makedirs(FLAGS.train_dir) state = checkpoints.restore_checkpoint(FLAGS.train_dir, state) offset = state.step + 1 state = jax_utils.replicate(state) del init_model, init_state if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(FLAGS.train_dir) t_loop_start = time.time() learning_rate_fn = functools.partial(utils.learning_rate_decay, init_lr=FLAGS.lr, decay_steps=FLAGS.lr_decay * 1000, decay_rate=0.1) ptrain_step = jax.pmap(train_step, axis_name="batch", in_axes=(0, 0, 0, None), donate_argnums=2) # Prefetch_buffer_size = 3 x batch_size pdataset = jax_utils.prefetch_to_device(dataset, 3) n_local_deices = jax.local_device_count() rng = rng + jax.host_id() # Make random seed separate across hosts. keys = random.split(rng, n_local_deices) # For pmapping RNG keys. gc.disable() # Disable automatic garbage collection for efficiency. stats_trace = [] for step, batch in zip(range(offset, FLAGS.max_steps + 1), pdataset): lr = learning_rate_fn(step) state, stats, keys = ptrain_step(keys, state, batch, lr) if jax.host_id() == 0: stats_trace.append(stats[0]) if step % FLAGS.gc_every == 0: gc.collect() # --- Train logs start --- # Put the training time visualization before the host_id check as in # multi-host evaluation, all hosts need to run inference even though we # only use host 0 to record results. if FLAGS.render_every > 0 and step % FLAGS.render_every == 0: # We reuse the same random number generator from the optimization step # here on purpose so that the visualization matches what happened in # training. state_to_eval = jax.device_get(jax.tree_map(lambda x: x[0], state)) test_case = next(test_dataset) pred_color, pred_disp, pred_acc = utils.render_image( state_to_eval, test_case["rays"], test_render_fn, keys[0], FLAGS.dataset == "llff", chunk=FLAGS.chunk) if jax.host_id() == 0: psnr = utils.compute_psnr( ((pred_color - test_case["pixels"])**2).mean()) summary_writer.scalar("test_psnr", psnr, step) summary_writer.image("test_pred_color", pred_color, step) summary_writer.image("test_pred_disp", pred_disp, step) summary_writer.image("test_pred_acc", pred_acc, step) summary_writer.image("test_target", test_case["pixels"], step) if jax.host_id() != 0: # Only log via host 0. continue if step % FLAGS.print_every == 0: summary_writer.scalar("train_loss", stats[0].loss[0], step) summary_writer.scalar("train_psnr", stats[0].psnr[0], step) if len(stats) > 1: summary_writer.scalar("train_loss_coarse", stats[1].loss[0], step) summary_writer.scalar("train_psnr_coarse", stats[1].psnr[0], step) avg_loss = np.mean(np.concatenate([s.loss for s in stats_trace])) avg_psnr = np.mean(np.concatenate([s.psnr for s in stats_trace])) stats_trace = [] summary_writer.scalar("train_avg_loss", avg_loss, step) summary_writer.scalar("train_avg_psnr", avg_psnr, step) summary_writer.scalar("learning_rate", lr, step) steps_per_sec = FLAGS.print_every / (time.time() - t_loop_start) t_loop_start = time.time() rays_per_sec = FLAGS.batch_size * steps_per_sec summary_writer.scalar("steps_per_sec", steps_per_sec, step) summary_writer.scalar("rays_per_sec", rays_per_sec, step) precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1 print(("{:" + "{:d}".format(precision) + "d}").format(step) + f"/{FLAGS.max_steps:d}: " + f"i_loss={stats[0].loss[0]:0.5f} | " + f"avg_loss={avg_loss:0.5f}, " + f"{rays_per_sec:0.3f} rays/sec") if step % FLAGS.save_every == 0: state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state)) checkpoints.save_checkpoint(FLAGS.train_dir, state_to_save, state_to_save.step, keep=100) # --- Train logs end --- if FLAGS.max_steps % FLAGS.save_every != 0: state = jax.device_get(jax.tree_map(lambda x: x[0], state)) checkpoints.save_checkpoint(FLAGS.train_dir, state, int(state.step), keep=100)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') if FLAGS.jax_backend_target: jax.config.FLAGS.jax_xla_backend = 'tpu_driver' jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target # This seems to be necessary even when importing TF2? tf.enable_v2_behavior() # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: tf.io.gfile.makedirs(FLAGS.model_dir) train_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'train')) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'eval')) if FLAGS.batch_size % n_devices: raise ValueError( 'Batch size must be divisible by the number of devices') vocab_path = FLAGS.vocab_path if vocab_path is None: vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model') tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) # Load Dataset logging.info('Initializing dataset.') train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets( n_devices=n_devices, dataset_name=FLAGS.dataset_name, eval_dataset_name=FLAGS.eval_dataset_name, shard_idx=jax.host_id(), shard_count=jax.host_count(), data_dir=FLAGS.data_dir, vocab_path=vocab_path, target_vocab_size=FLAGS.vocab_size, batch_size=FLAGS.batch_size, max_length=FLAGS.max_target_length, max_eval_length=FLAGS.max_eval_target_length) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_token = 2 # Default Sentencepiece EOS token. def decode_tokens(toks): valid_toks = toks[:np.argmax(toks == eos_token) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode('utf-8') logging.info('Initializing model, optimizer, and step functions.') # Build Model and Optimizer transformer_kwargs = { 'vocab_size': vocab_size, 'output_vocab_size': vocab_size, 'emb_dim': FLAGS.emb_dim, 'num_heads': FLAGS.num_heads, 'num_layers': FLAGS.num_layers, 'qkv_dim': FLAGS.qkv_dim, 'mlp_dim': FLAGS.mlp_dim, 'max_len': max(FLAGS.max_target_length, FLAGS.max_eval_target_length), 'share_embeddings': FLAGS.share_embeddings, 'logits_via_embedding': FLAGS.logits_via_embedding, } start_step = 0 rng = random.PRNGKey(FLAGS.random_seed) rng, init_rng = random.split(rng) input_shape = (FLAGS.batch_size, FLAGS.max_target_length) target_shape = (FLAGS.batch_size, FLAGS.max_target_length) model, cache_def = create_model(init_rng, input_shape, target_shape, transformer_kwargs) optimizer = create_optimizer(model, FLAGS.learning_rate, FLAGS.weight_decay) # We access model only from optimizer below via optimizer.target. del model if FLAGS.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer) # Grab last step. start_step = int(optimizer.state.step) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=FLAGS.learning_rate, warmup_steps=FLAGS.warmup_steps) p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn, label_smoothing=FLAGS.label_smoothing, use_bfloat16=FLAGS.use_bfloat16), axis_name='batch') p_eval_step = jax.pmap(functools.partial( eval_step, label_smoothing=FLAGS.label_smoothing, use_bfloat16=FLAGS.use_bfloat16), axis_name='batch') p_pred_step = jax.pmap( functools.partial(predict_step, use_bfloat16=FLAGS.use_bfloat16, beam_size=FLAGS.beam_size), axis_name='batch', static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. dropout_rngs = random.split(rng, n_devices) logging.info('Starting training loop.') metrics_all = [] t_loop_start = time.time() for step, batch in zip(range(start_step, FLAGS.num_train_steps), train_iter): # Shard data to devices and do a training step. batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access optimizer, metrics, dropout_rngs = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) metrics_all.append(metrics) # Save a checkpoint on one host after every checkpoint_freq steps. if (FLAGS.save_checkpoints and step % FLAGS.checkpoint_freq == 0 and step > 0 and jax.host_id() == 0): checkpoints.save_checkpoint(FLAGS.model_dir, jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if step % FLAGS.eval_frequency != 0 and step > 0: continue logging.info('Gathering training metrics.') # Training Metrics metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr steps_per_eval = FLAGS.eval_frequency if step != 0 else 1 steps_per_sec = steps_per_eval / (time.time() - t_loop_start) t_loop_start = time.time() if jax.host_id() == 0: train_summary_writer.scalar('steps per second', steps_per_sec, step) for key, val in summary.items(): train_summary_writer.scalar(key, val, step) train_summary_writer.flush() metrics_all = [] logging.info('train in step: %d, loss: %.4f', step, summary['loss']) # Eval Metrics logging.info('Gathering evaluation metrics.') t_eval_start = time.time() eval_metrics = [] eval_iter = iter(eval_ds) for _, eval_batch in zip(range(FLAGS.num_eval_steps), eval_iter): eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access eval_batch = common_utils.shard(eval_batch) metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) if jax.host_id() == 0: for key, val in eval_summary.items(): eval_summary_writer.scalar(key, val, step) eval_summary_writer.flush() logging.info('eval in step: %d, loss: %.4f', step, eval_summary['loss']) logging.info('eval time: %.4f s step %d', time.time() - t_eval_start, step) # Translation and BLEU Score. logging.info('Translating evaluation dataset.') t_inference_start = time.time() predict_iter = iter(predict_ds) sources, references, predictions = [], [], [] for _, pred_batch in enumerate(predict_iter): pred_batch = jax.tree_map(lambda x: x._numpy(), pred_batch) # pylint: disable=protected-access # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch['inputs'].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) pred_batch = jax.tree_map( lambda x: pad_examples(x, padded_size), pred_batch) # pylint: disable=cell-var-from-loop pred_batch = common_utils.shard(pred_batch) per_device_batchsize = pred_batch['inputs'].shape[1] cache_dtype = jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32 cache = jax_utils.replicate( cache_def.initialize_cache( (per_device_batchsize, FLAGS.max_predict_length), dtype=cache_dtype)) predicted = p_pred_step(pred_batch['inputs'], optimizer.target, cache, eos_token, FLAGS.max_predict_length) predicted = tohost(predicted) inputs = tohost(pred_batch['inputs']) targets = tohost(pred_batch['targets']) # Iterate through non-padding examples of batch. for i, s in enumerate(predicted[:cur_pred_batch_size]): sources.append(decode_tokens(inputs[i])) references.append(decode_tokens(targets[i])) predictions.append(decode_tokens(s)) logging.info('Translation: %d predictions %d references %d sources.', len(predictions), len(references), len(sources)) logging.info('Translation time: %.4f s step %d.', time.time() - t_inference_start, step) # Calculate BLEU score for translated eval corpus against reference. bleu_matches = bleu.bleu_partial(references, predictions) all_bleu_matches = per_host_sum_pmap(bleu_matches) bleu_score = bleu.complete_bleu(*all_bleu_matches) # Save translation samples for tensorboard. exemplars = '' for n in np.random.choice(np.arange(len(predictions)), 8): exemplars += f'{sources[n]}\n\n{references[n]}\n\n{predictions[n]}\n\n' if jax.host_id() == 0: eval_summary_writer.scalar('bleu', bleu_score, step) eval_summary_writer.text('samples', exemplars, step) eval_summary_writer.flush() logging.info('Translation BLEU Score %.4f', bleu_score)
def main(argv): global CFG CFG = FLAGS.config if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Guarantee that the JAX bfloat16 extension is used rather than TF bfloat16. _ = np.array(jnp.array([1.0], dtype=jnp.bfloat16)) # Use hardware RNG for bernoulli randoms in dropout mask creation. if CFG.hardware_rng: models.set_hardware_bernoulli() if 'module_import' in CFG and CFG.module_import: for module in CFG.module_import: importlib.import_module(module) if 'additional_task_cache_dirs' in CFG and CFG.additional_task_cache_dirs: t5.data.add_global_cache_dirs(CFG.additional_task_cache_dirs) num_partitions = CFG.num_partitions topology = train_lib.compute_multihost_topology(num_partitions) batch_size = CFG.batch_size eval_batch_size = CFG.eval_batch_size per_replica_set_eval_batch_size = eval_batch_size // topology.num_replica_sets if batch_size % topology.num_replicas: raise ValueError('Batch size must be divisible by the number of replicas.') steps_per_epoch = CFG.steps_per_epoch logging.info('steps per epoch: %d', steps_per_epoch) broadcast = functools.partial( train_lib.broadcast, num_replicas=topology.per_replica_set_num_replicas, num_partitions=topology.per_host_num_partitions, devices=topology.this_host_device_assignment) if jax.host_id() == 0: tf.io.gfile.makedirs(FLAGS.model_dir) tf.io.gfile.copy(FLAGS['config'].config_filename, os.path.join(FLAGS.model_dir, 'config.py'), overwrite=True) train_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'train')) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'eval')) else: train_summary_writer = None eval_summary_writer = None # Write summaries in background thread to avoid blocking on device sync if CFG.infeed: # Infeed is currently synchronous, so do it in a background thread too infeed_pool = thread.ThreadPoolExecutor(jax.local_device_count(), 'infeed') (train_ds, eval_ds), eval_cache = input_pipeline.get_datasets_and_cache( CFG, topology.num_replica_sets, topology.replica_set_id, topology.per_replica_set_host_id) vocab = input_pipeline.get_vocabulary(CFG.mixture_or_task_name) encoder = vocab.tf_tokenizer eos_id = vocab.tokenizer.eos_id() def decode_tokens(toks, eos_id = eos_id, max_id = 32000): """Decode tokens back to unicode.""" del eos_id # TODO(levskaya): T5 doesn't seem to emit EOS tokens? double check this # is the best decoding function or just switch to using tf_decode. # valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32) valid_toks = toks.astype(np.int32) valid_toks[valid_toks >= max_id] = 3 return encoder.detokenize(valid_toks).numpy().decode('utf-8') logging.info('Initializing model, optimizer, and step functions.') train_config, eval_config, predict_config = get_configs(CFG) rng = random.PRNGKey(CFG.random_seed) rng, init_rng = random.split(rng) # This is used for infeed conversion from feature dict <--> tuple train_keys = [ 'inputs', 'targets', 'inputs_position', 'targets_position', 'inputs_segmentation', 'targets_segmentation' ] device_train_input_shape = tuple([ (batch_size // topology.num_replicas, CFG.max_input_length if 'inputs' in k else CFG.max_target_length) for k in train_keys ]) learning_rate_fn = train_lib.create_learning_rate_scheduler( factors=CFG.schedule, base_learning_rate=CFG.learning_rate, warmup_steps=CFG.warmup_steps) # First, we only abstractly initialize the optimizer and model parameters, # since the parameters may not even fit in device memory! # TODO(jekbradbury): make optimizer_defs compare by value so it can be created # in get_initial_params without causing pytree incompatibility optimizer_def = optim.Adafactor( CFG.learning_rate, decay_rate=0.8, step_offset=CFG.step_offset) initialize_params_fn = functools.partial( get_initial_params, config=CFG, transformer_config=eval_config, optimizer_def=optimizer_def) optimizer = jax.eval_shape(initialize_params_fn, init_rng) # tuple-like pytree leaves for global_arg_shapes optimizer_shapes = jax.tree_map(lambda x: partitions.Spec(*x.shape), optimizer) # Build parameter partition annotations for preserving partitions from train # to eval. if num_partitions > 1: optimizer_partitions = optimizer.restore_state( partitions.set_partitions(num_partitions, optimizer.state_dict())) per_host_optimizer_partitions = optimizer.restore_state( partitions.set_partitions(topology.per_host_num_partitions, optimizer.state_dict())) # Restore unreplicated optimizer + model state from last checkpoint. # TODO(jekbradbury,levskaya): implement sharded native checkpoint/restore existing_checkpoint_found = False if CFG.restore_checkpoints: existing_checkpoint_found = train_lib.checkpoint_exists(FLAGS.model_dir) optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer) # Import a pretrained-T5 checkpoint only if we didn't import a local # "native" checkpoint (e.g. due to resuming a pre-empted finetuning run.) # TODO(jekbradbury,levskaya): implement sharded T5 checkpoint/restore if CFG.restore_t5_checkpoint and not existing_checkpoint_found: optimizer = checkpoint_importer.restore_from_t5_checkpoint( optimizer, CFG.restore_t5_checkpoint) if CFG.restore_t5_checkpoint or existing_checkpoint_found: if num_partitions > 1: # Until checkpoint/restore is sharded, the restored checkpoint is global # and we need to slice each sharded parameter into the chunk containing # only the partitions that are present on this host. def per_host_chunk(x, spec): if spec is None or spec is x: # unsharded or not a parameter return x if spec[0] == 1: dim_size = x.shape[1] elif spec[1] == 1: dim_size = x.shape[0] else: raise NotImplementedError() chunk_size = ( dim_size * topology.per_host_num_partitions // num_partitions) lower = topology.per_replica_set_host_id * chunk_size upper = (topology.per_replica_set_host_id + 1) * chunk_size if spec[0] == 1: return x[:, lower:upper] else: return x[lower:upper] optimizer = jax.tree_multimap(per_host_chunk, optimizer, optimizer_partitions) else: # If pretraining and no checkpoint imported, we jit the (sharded-) init # function to minimize fragmentation. We use the same pmap(sharded_jit) # setup as the training step/loop to initialize everything "in-place" and # avoid communication or OOM. if num_partitions > 1: initialize_params_fn = sharded_jit( initialize_params_fn, in_parts=None, local_in_parts=None, out_parts=optimizer_partitions, local_out_parts=per_host_optimizer_partitions, # devices=one_replica_device_assignment, ) initialize_params_fn = jax.pmap( initialize_params_fn, 'batch', in_axes=0, axis_size=topology.num_replicas, devices=topology.device_assignment) init_rng = broadcast(init_rng) optimizer = initialize_params_fn(init_rng) # We maintain the optimizer in unbroadcasted form (i.e. with no leading # replica axis). This is equivalent to the as-yet-nonexistent pmap kwarg # out_axes=None. optimizer = train_lib.unbroadcast(optimizer) else: optimizer = jax.jit(initialize_params_fn)(init_rng) # --------------------------------------------------------------------------- # Compile multidevice versions of train/eval/predict step and cache init fn. # --------------------------------------------------------------------------- # We can use either a single train-step for a host training loop: # train_step(optimizer, batch, prev_metrics, dropout_rng, **kwargs) # --> new_optimizer, metrics, new_dropout_rng def p_train_step(optimizer, batch, prev_metrics, dropout_rng): return train_lib.train_step( optimizer, batch, prev_metrics, dropout_rng, config=train_config, learning_rate_fn=learning_rate_fn, num_microbatches=CFG.microbatches, label_smoothing=CFG.label_smoothing, z_loss=CFG.z_loss, use_bfloat16=CFG.use_bfloat16) if num_partitions > 1: p_train_step = sharded_jit( p_train_step, in_parts=(optimizer_partitions, None, None, None), local_in_parts=(per_host_optimizer_partitions, None, None, None), out_parts=(optimizer_partitions, None, None), local_out_parts=(per_host_optimizer_partitions, None, None)) # TODO(levskaya): the in_axes spec below might be wrong, double-check. p_train_step = jax.pmap( p_train_step, axis_name='batch', in_axes=(None, 0, 0, 0), donate_argnums=(0,), global_arg_shapes=(optimizer_shapes, None, None, None), axis_size=topology.num_replicas, devices=topology.device_assignment) # pytype: disable=wrong-arg-types # OR, we use an on-device loop that feeds the training step via infeed queue. def device_train_loop_cond( args ): """Stopping criterion for on-device loop.""" _, _, _, _, step, epoch = args return step // steps_per_epoch == epoch def device_train_loop_body( args ): """On-device loop body.""" optimizer, dropout_rngs, metrics, token, step, epoch = args # Ordering input data from infeed requires threading a symbolic token # through the computation. input_data, token = lax.infeed( token, shape=tuple( [jax.ShapedArray(s, jnp.int32) for s in device_train_input_shape])) # Rebuild input dict from infeed data tuple. batch = {k: v for k, v in zip(train_keys, input_data)} # Run the train_step function and return the loop state. optimizer, metrics, dropout_rngs = train_lib.train_step( optimizer, batch, metrics, dropout_rngs, train_config, learning_rate_fn, num_microbatches=CFG.microbatches, label_smoothing=CFG.label_smoothing, z_loss=CFG.z_loss) step += 1 return optimizer, dropout_rngs, metrics, token, step, epoch def device_train_loop(optimizer, dropout_rngs, metrics, step, epoch): # Create symbolic token for threading infeed data. token = lax.create_token(step) # Run on-device loop. optimizer, dropout_rngs, metrics, _, step, _ = lax.while_loop( device_train_loop_cond, device_train_loop_body, (optimizer, dropout_rngs, metrics, token, step, epoch)) return optimizer, dropout_rngs, metrics, step if num_partitions > 1: device_train_loop = sharded_jit( device_train_loop, in_parts=(optimizer_partitions, None, None, None, None), local_in_parts=(per_host_optimizer_partitions, None, None, None, None), out_parts=(optimizer_partitions, None, None, None), local_out_parts=(per_host_optimizer_partitions, None, None, None)) p_train_epoch = jax.pmap( device_train_loop, axis_name='batch', in_axes=(None, 0, 0, None, None), donate_argnums=(0,), global_arg_shapes=(optimizer_shapes, None, None, None, None), axis_size=topology.num_replicas, devices=topology.device_assignment) # pytype: disable=wrong-arg-types # Reduction psum for metric data. def p_allreduce_metrics(x): return lax.psum(x, axis_name='batch') if num_partitions > 1: p_allreduce_metrics = sharded_jit( p_allreduce_metrics, in_parts=None, local_in_parts=None, out_parts=None, local_out_parts=None, num_partitions=num_partitions, local_num_partitions=topology.per_host_num_partitions) p_allreduce_metrics = jax.pmap( p_allreduce_metrics, axis_name='batch', global_arg_shapes=None, axis_size=topology.num_replicas, devices=topology.device_assignment) # Training evaluation computation. # eval_step(params, batch, config, label_smoothing=0.0) --> metrics def p_eval_step(params, batch): return train_lib.eval_step( params, batch, config=eval_config, label_smoothing=CFG.label_smoothing) if num_partitions > 1: p_eval_step = sharded_jit( p_eval_step, in_parts=(optimizer_partitions.target, None), local_in_parts=(per_host_optimizer_partitions.target, None), out_parts=None, local_out_parts=None) p_eval_step = jax.pmap( p_eval_step, axis_name='batch', in_axes=(None, 0), global_arg_shapes=(optimizer_shapes.target, None), axis_size=topology.num_replicas, devices=topology.device_assignment) # pytype: disable=wrong-arg-types # Fast autoregressive decoding loop. # For inference and model evaluation. # predict_step(inputs, params, # eos_id, max_decode_len, config, beam_size=4) --> beam_seqs def p_pred_step(inputs, params): return train_lib.predict_step(inputs, params, eos_id, CFG.max_eval_target_length, predict_config, CFG.beam_size) if num_partitions > 1: p_pred_step = sharded_jit( p_pred_step, in_parts=(None, optimizer_partitions.target), local_in_parts=(None, per_host_optimizer_partitions.target), out_parts=None, local_out_parts=None) p_pred_step = jax.pmap( p_pred_step, axis_name='batch', in_axes=(0, None), global_arg_shapes=(None, optimizer_shapes.target), axis_size=topology.num_replicas, devices=topology.device_assignment) # pytype: disable=wrong-arg-types # --------------------------------------------------------------------------- # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. # There should be a unique dropout key for each replica represented on this # host, but the key should be the same for the same replica on other hosts. # Again, this is what the replica set abstraction is for. dropout_rngs = random.split( random.fold_in(rng, topology.replica_set_id), topology.per_replica_set_num_replicas) # restore step from last checkpoint host_step = int(optimizer.state.step) empty_metrics = broadcast({ 'loss': 0.0, 'accuracy': 0.0, 'learning_rate': 0.0, 'denominator': 0.0 }) if CFG.infeed: # TODO(jekbradbury): support something like this for the Python-loop case logging.info('Precompiling training loop and moving optimizer to device.') optimizer, _, metrics, _ = p_train_epoch(optimizer, dropout_rngs, empty_metrics, jnp.array(0, dtype=jnp.int32), 1) optimizer = train_lib.unbroadcast(optimizer) metrics['loss'].block_until_ready() logging.info('Starting training loop.') local_devices = jax.local_devices() device_step = broadcast(host_step) first_epoch = host_step // steps_per_epoch # Main Loop over "epochs". train_iter = train_ds.as_numpy_iterator() for epoch in range(first_epoch, first_epoch + CFG.num_epochs): metrics = empty_metrics # NOTE: 'optimizer' is unbroadcast by construction at initialization or # when loading a checkpoint. It is maintained in 'unbroadcast' state to # enable the XLA cross-replica sharding optimization. The broadcasting is # handled automatically by the pmap'd functions that use it. # Gather all task evaluation metrics. logging.info('Evaluating tasks.') if epoch == first_epoch + 1: train_lib.sync_devices() for task in eval_cache.tasks: logging.info('Evaluating task %s', task.name) all_predicted, all_bs = [], [] for pred_batch in eval_cache.preprocessed_examples[task.name]: # Handle final odd-sized batch by padding instead of dropping it. input_batch, unpadded_batch_size = train_lib.pad_batch_to_size( pred_batch['inputs'], per_replica_set_eval_batch_size) all_bs.append(unpadded_batch_size) # Split batch dimensions for pmap. input_batch = jax.tree_map( lambda x: x.reshape( (topology.per_replica_set_num_replicas, -1) + x.shape[1:]), input_batch) # Run fast inference on batch. all_predicted.append(p_pred_step(input_batch, optimizer.target)) # Pad out the number of batches so each host has the same number. max_host_batch_number = np.max( eval_cache.preprocessed_batch_sizes[task.name]) batch_shortfall = max_host_batch_number - len(all_predicted) if batch_shortfall > 0: # TODO(levskaya): Fix for case of entirely empty all_predicted. # To make sure the cross-host barriers work, we run the program the same # number of times on all hosts. The results of this call is ignored, and # the predictions are populated with zeros instead. p_pred_step(input_batch, optimizer.target) # Dummy call. all_predicted.extend([jnp.zeros_like(all_predicted[0])] * batch_shortfall) all_bs.extend([0] * batch_shortfall) all_predicted = jnp.concatenate(all_predicted) all_bs = jnp.array(all_bs) # Collect all batches from across hosts and reverse sharding. all_predicted = train_lib.host_allgather( all_predicted, topology.num_replica_sets, topology.replica_set_id, topology.per_replica_set_host_id == 0) seqlength = all_predicted.shape[-1] total_examples = np.sum( train_lib.host_allgather(all_bs, topology.num_replica_sets, topology.replica_set_id, topology.per_replica_set_host_id == 0)) del all_bs assert total_examples == len(eval_cache.examples[task.name]), ( 'Total number of batches incorrect for task %s.' % task.name) # De-shard the collected predicted tokens and remove padding. all_predicted = np.transpose(all_predicted, (1, 2, 0, 3)).reshape( -1, seqlength)[:total_examples] # We now run the post-processing and metric-fns on a single host. if jax.host_id() == 0: assert eval_summary_writer raw_predictions = [] for tokens in all_predicted: raw_predictions.append(decode_tokens(tokens)) # post-process predictions for metric fns predictions = [ task.postprocess_fn(p, example=ex) for p, ex in zip(raw_predictions, eval_cache.examples[task.name]) ] for metric_fn in task.metric_fns: scores = metric_fn(eval_cache.targets[task.name], predictions) for metric_name, metric_value in scores.items(): tag = f'eval/{task.name}/{metric_name}' eval_summary_writer.scalar(tag, metric_value, host_step) logging.info('EVAL %s at step %d: %.3f', tag, host_step, metric_value) eval_summary_writer.flush() # Save text samples for tensorboard. exemplars = '' for n in np.random.choice(np.arange(len(predictions)), 8): tgt_txt = tf.compat.as_text( eval_cache.examples[task.name][n]['targets_plaintext']) pred_txt = raw_predictions[n] exemplars += (f'{eval_cache.inputs[task.name][n]}\n\n' f'target: {tgt_txt}\n\n' f'prediction: {pred_txt}\n\n') eval_summary_writer.text(f'{task.name} samples', exemplars, host_step) eval_summary_writer.flush() # Take an Xprof trace after the first loop has compiled everything. if epoch == first_epoch + 1: train_lib.sync_devices() # For on-device loop, we launch the computation before feeding data. logging.info('BEGIN Train loop.') if CFG.infeed: optimizer, dropout_rngs, metrics, device_step = p_train_epoch( optimizer, dropout_rngs, metrics, train_lib.unbroadcast(device_step), epoch) optimizer = train_lib.unbroadcast(optimizer) # Epoch loop. while int(host_step // steps_per_epoch) == epoch: batch = next(train_iter) batch = jax.tree_map( lambda x: x.reshape( (topology.per_replica_set_num_replicas, -1) + x.shape[1:]), batch) # Feed the on-device training loop. if CFG.infeed: for i, device in enumerate(local_devices): # When using infeed to provide data to the computation, we're on our # own for feeding the right values to the right devices. Each device # should get the minibatch corresponding to its replica, a slice of # the larger batch corresponding to the host's replica set. if device.platform == 'tpu': device_coords = (*device.coords, device.id % 2) else: device_coords = (device.host_id, i) per_replica_set_device_coords = tuple( dc % prsm for dc, prsm in zip(device_coords, topology.per_replica_set_mesh)) per_replica_set_replica_coords = tuple( prsdc // prm for prsdc, prm in zip(per_replica_set_device_coords, topology.per_replica_mesh)) per_replica_set_replica_id = 0 for prsm, prm, prsrc in zip(topology.per_replica_set_mesh, topology.per_replica_mesh, per_replica_set_replica_coords): per_replica_set_replica_id = ( per_replica_set_replica_id * prsm // prm + prsrc) input_tuple = tuple( [batch[k][per_replica_set_replica_id] for k in train_keys]) # Safety check: infeed does not check shape or types but requires # them to agree with on-device spec, otherwise the queue and program # stalls. tuple_shapes = jax.tree_map(jnp.shape, input_tuple) tuple_dtypes = jax.tree_map(lambda x: x.dtype, input_tuple) assert tuple_shapes == device_train_input_shape, ( 'infeed shape error %s != %s' % (tuple_shapes, device_train_input_shape)) assert tuple(set(tuple_dtypes)) == (jnp.int32,), \ ('infeed dtype error %s not all of type %s' % ( tuple_dtypes, jnp.int32)) infeed_pool.submit( functools.partial(device.transfer_to_infeed, input_tuple)) # Host training loop. else: optimizer, metrics, dropout_rngs = p_train_step(optimizer, batch, metrics, dropout_rngs) optimizer = train_lib.unbroadcast(optimizer) host_step += 1 logging.info('END Train loop.') # Maybe save a checkpoint on one host. if (CFG.save_checkpoints and epoch % CFG.checkpoint_freq == CFG.checkpoint_freq - 1 and jax.host_id() == 0): checkpoints.save_checkpoint(FLAGS.model_dir, optimizer, host_step) # Gather training metrics. metrics = p_allreduce_metrics(metrics) metrics = jax.tree_map(lambda x: jax.device_get(x[0]), metrics) denominator = metrics.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics) # pylint: disable=cell-var-from-loop logging.info('train in step: %s, %s', host_step, summary) if jax.host_id() == 0: assert train_summary_writer for key, val in summary.items(): train_summary_writer.scalar(key, val, host_step) train_summary_writer.flush() # Gather training evaluation metrics. logging.info('Gathering training evaluation metrics.') eval_metrics = [] eval_iter = eval_ds.as_numpy_iterator() for _, eval_batch in zip(range(CFG.num_eval_steps), eval_iter): eval_batch = jax.tree_map( lambda x: x.reshape( (topology.per_replica_set_num_replicas, -1) + x.shape[1:]), eval_batch) metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) # average metrics across devices eval_metrics = p_allreduce_metrics(eval_metrics) eval_metrics = common_utils.get_metrics(eval_metrics) # average metrics across steps eval_metrics = jax.tree_map(np.sum, eval_metrics) eval_denominator = eval_metrics.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics) logging.info('eval in step: %s, %s', host_step, eval_summary) if jax.host_id() == 0: assert eval_summary_writer for key, val in eval_summary.items(): eval_summary_writer.scalar(key, val, host_step) eval_summary_writer.flush() # Wait until computations are done before exiting logging.info('Finished.') train_lib.sync_devices() # Shut down the infeed threadpool. if CFG.infeed: infeed_pool.shutdown()
def test_save_restore_checkpoints(self): tmp_dir = self.create_tempdir().full_path test_object0 = { 'a': np.array([0, 0, 0], np.int32), 'b': np.array([0, 0, 0], np.int32) } test_object1 = { 'a': np.array([1, 2, 3], np.int32), 'b': np.array([1, 1, 1], np.int32) } test_object2 = { 'a': np.array([4, 5, 6], np.int32), 'b': np.array([2, 2, 2], np.int32) } new_object = checkpoints.restore_checkpoint(tmp_dir, test_object0, prefix='test_') jtu.check_eq(new_object, test_object0) # Create leftover temporary checkpoint, which should be ignored. gfile.GFile(os.path.join(tmp_dir, 'test_tmp'), 'w') checkpoints.save_checkpoint(tmp_dir, test_object1, 0, prefix='test_', keep=1) self.assertIn('test_0', os.listdir(tmp_dir)) new_object = checkpoints.restore_checkpoint(tmp_dir, test_object0, prefix='test_') jtu.check_eq(new_object, test_object1) checkpoints.save_checkpoint(tmp_dir, test_object1, 1, prefix='test_', keep=1) checkpoints.save_checkpoint(tmp_dir, test_object2, 2, prefix='test_', keep=1) new_object = checkpoints.restore_checkpoint(tmp_dir, test_object0, prefix='test_') jtu.check_eq(new_object, test_object2) checkpoints.save_checkpoint(tmp_dir, test_object2, 3, prefix='test_', keep=2) checkpoints.save_checkpoint(tmp_dir, test_object1, 4, prefix='test_', keep=2) new_object = checkpoints.restore_checkpoint(tmp_dir, test_object0, prefix='test_') jtu.check_eq(new_object, test_object1) new_object = checkpoints.restore_checkpoint(tmp_dir, test_object0, step=3, prefix='test_') jtu.check_eq(new_object, test_object2) with self.assertRaises(ValueError): checkpoints.restore_checkpoint(tmp_dir, test_object0, step=5, prefix='test_')
def test_save_restore_checkpoints(self): tmp_dir = self.create_tempdir().full_path test_object0 = { 'a': np.array([0, 0, 0], np.int32), 'b': np.array([0, 0, 0], np.int32) } test_object1 = { 'a': np.array([1, 2, 3], np.int32), 'b': np.array([1, 1, 1], np.int32) } test_object2 = { 'a': np.array([4, 5, 6], np.int32), 'b': np.array([2, 2, 2], np.int32) } new_object = checkpoints.restore_checkpoint(tmp_dir, test_object0, prefix='test_') jtu.check_eq(new_object, test_object0) # Create leftover temporary checkpoint, which should be ignored. gfile.GFile(os.path.join(tmp_dir, 'test_tmp'), 'w') checkpoints.save_checkpoint(tmp_dir, test_object1, 0, prefix='test_', keep=1) self.assertIn('test_0', os.listdir(tmp_dir)) new_object = checkpoints.restore_checkpoint(tmp_dir, test_object0, prefix='test_') jtu.check_eq(new_object, test_object1) checkpoints.save_checkpoint(tmp_dir, test_object1, 1, prefix='test_', keep=1) checkpoints.save_checkpoint(tmp_dir, test_object2, 2, prefix='test_', keep=1) new_object = checkpoints.restore_checkpoint(tmp_dir, test_object0, prefix='test_') jtu.check_eq(new_object, test_object2) checkpoints.save_checkpoint(tmp_dir, test_object2, 3, prefix='test_', keep=2) checkpoints.save_checkpoint(tmp_dir, test_object1, 4, prefix='test_', keep=2) new_object = checkpoints.restore_checkpoint(tmp_dir, test_object0, prefix='test_') jtu.check_eq(new_object, test_object1) new_object = checkpoints.restore_checkpoint(tmp_dir, test_object0, step=3, prefix='test_') jtu.check_eq(new_object, test_object2) # Restore a specific path. new_object = checkpoints.restore_checkpoint( os.path.join(tmp_dir, 'test_3'), test_object0) jtu.check_eq(new_object, test_object2) # If a specific path is specified, but it does not exist, the same behavior # as when a directory is empty should apply: the target is returned # unchanged. new_object = checkpoints.restore_checkpoint( os.path.join(tmp_dir, 'test_not_there'), test_object0) jtu.check_eq(new_object, test_object0) with self.assertRaises(ValueError): checkpoints.restore_checkpoint(tmp_dir, test_object0, step=5, prefix='test_')
def save_checkpoint(optimizer, ema): # get train state from the first replica optimizer, ema = jax.device_get( jax.tree_map(lambda x: x[0], (optimizer, ema))) step = int(optimizer.state.step) checkpoints.save_checkpoint(FLAGS.model_dir, (optimizer, ema), step, keep=3)
def train( state: TrainState, *, train_iter: T.Iterator[PT.MiniBatch], train_step: TrainStep, n_steps: int, report_freq: int, reporter: Reporter, val_iter: T.Optional[T.Iterator[PT.MiniBatch]] = None, val_step: T.Optional[ValStep] = None, val_freq: T.Optional[int] = None, val_metrics: T.Optional[utils.Metrics] = None, distributed: bool = False, save_ckpts: bool = True, ckpt_metric: str = 'loss', ckpt_name: str = 'model', extra_report_fn: T.Optional[T.Callable[[TrainState, PT.MiniBatch, int], None]] = None, start_step: int = 0, ) -> TrainState: assert 'time/step' in reporter.train_names if val_step is not None: assert val_iter is not None assert val_metrics is not None assert val_freq is not None and val_freq % report_freq == 0 assert 'time' in reporter.val_names iter_slice = itertools.islice(train_iter, 0, n_steps - start_step) train_iter = iter( tqdm(iter_slice, total=n_steps - start_step, desc='Training')) if distributed: state = flax.jax_utils.replicate(state) if hasattr(state, 'rngs'): pfold_in = partial(jax.pmap(jax.random.fold_in), data=jnp.arange(jax.device_count())) state = state.replace(rngs=jax.tree_map(pfold_in, state.rngs)) with reporter as rep, redirect_print_to_tqdm(): cur_best = -1 start_time = time.perf_counter() for i, batch in enumerate(train_iter, start=start_step): state = train_step(state, batch) if i % report_freq == 0 or i == n_steps - 1: time_per_step = (time.perf_counter() - start_time) / report_freq train_dict = {'time/step': time_per_step} train_dict.update(state.metrics.items()) state = state.replace(metrics=state.metrics.reset()) val_dict = {} if val_freq is not None and (i % val_freq == 0 or i == n_steps - 1): assert val_step is not None assert val_metrics is not None assert val_iter is not None start_time = time.perf_counter() val_state = (state if not distributed else flax.jax_utils.unreplicate(state)) variables = val_state.variables val_metrics = T.cast(utils.Metrics, val_metrics.reset()) for val_batch in val_iter: val_metrics = val_step(val_batch, variables, val_metrics) val_dict = dict(val_metrics.items()) val_dict['time'] = time.perf_counter() - start_time ckpt_metric_val = val_dict[ckpt_metric] if save_ckpts and cur_best < ckpt_metric_val: # TODO: Add comparison option (i.e. less or more is better) cur_best = ckpt_metric_val checkpoints.save_checkpoint(f'ckpts_{ckpt_name}', jax.device_get(val_state), i, keep=5) if 'ckpt' in reporter.val_names: val_dict['ckpt'] = 'Saved.' rep.report(i, train_dict, val_dict) if extra_report_fn is not None: if distributed: rstate, rbatch = flax.jax_utils.unreplicate( (state, batch)) else: rstate, rbatch = state, batch extra_report_fn(rstate, rbatch, i) start_time = time.perf_counter() return state if not distributed else flax.jax_utils.unreplicate(state)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Make sure tf does not allocate gpu memory. tf.config.experimental.set_visible_devices([], 'GPU') if FLAGS.jax_backend_target: jax.config.FLAGS.jax_xla_backend = 'tpu_driver' jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target # Number of local devices for this host. n_devices = jax.local_device_count() if jax.process_index() == 0: tf.io.gfile.makedirs(FLAGS.model_dir) if FLAGS.batch_size % n_devices: raise ValueError( 'Batch size must be divisible by the number of devices') vocab_path = FLAGS.vocab_path if vocab_path is None: vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model') tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') if FLAGS.dynamic: train_ds_mgr, eval_ds, predict_ds, encoder = input_pipeline.get_dynamic_datasets( dataset_name=FLAGS.dataset_name, eval_dataset_name=FLAGS.eval_dataset_name, shard_idx=jax.process_index(), shard_count=jax.process_count(), data_dir=FLAGS.data_dir, vocab_path=FLAGS.vocab_path, target_vocab_size=FLAGS.vocab_size, batch_size=FLAGS.batch_size, max_length=FLAGS.max_target_length, max_eval_length=FLAGS.max_eval_target_length, paracrawl_size=FLAGS.paracrawl_size, is_scores_path=FLAGS.is_scores_path, num_buckets=FLAGS.num_data_buckets) if FLAGS.static: weights = np.array([float(w) for w in FLAGS.static.split(',')]) assert len(weights) == FLAGS.num_data_buckets train_ds = train_ds_mgr.sampled_dataset(weights) FLAGS.dynamic = False else: init_dist = np.zeros(FLAGS.num_data_buckets) if FLAGS.data_selection_size < FLAGS.num_data_buckets: init_dist[range(FLAGS.data_selection_size)] = 1.0 train_ds = train_ds_mgr.sampled_dataset(init_dist) else: train_ds = build_split(train_ds_mgr, 1.0) else: train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets( dataset_name=FLAGS.dataset_name, eval_dataset_name=FLAGS.eval_dataset_name, shard_idx=jax.process_index(), shard_count=jax.process_count(), data_dir=FLAGS.data_dir, vocab_path=vocab_path, target_vocab_size=FLAGS.vocab_size, batch_size=FLAGS.batch_size, max_length=FLAGS.max_target_length, max_eval_length=FLAGS.max_eval_target_length, paracrawl_size=FLAGS.paracrawl_size, is_scores_path=FLAGS.is_scores_path, num_to_keep=FLAGS.data_selection_size, pseudo_path=FLAGS.pseudo_path, repeat_count=FLAGS.repeat_count, newscommentary_size=FLAGS.newscommentary_size, split_tokenizer=FLAGS.split_tokenizer, sample_size=FLAGS.sample_size, newscomment_sample_ratio=FLAGS.newscomment_sample_ratio) if FLAGS.aux_eval_dataset: aux_datasets = [] aux_names = FLAGS.aux_eval_dataset.split(',') for name in aux_names: _, aux_eval_ds, _, _ = input_pipeline.get_wmt_datasets( dataset_name=name, eval_dataset_name=None, shard_idx=jax.process_index(), shard_count=jax.process_count(), data_dir=FLAGS.data_dir, vocab_path=vocab_path, target_vocab_size=FLAGS.vocab_size, batch_size=FLAGS.batch_size, max_length=FLAGS.max_target_length, max_eval_length=FLAGS.max_eval_target_length, paracrawl_size=FLAGS.paracrawl_size, is_scores_path=FLAGS.is_scores_path, num_to_keep=FLAGS.data_selection_size, pseudo_path=FLAGS.pseudo_path, repeat_count=FLAGS.repeat_count, newscommentary_size=FLAGS.newscommentary_size) aux_datasets.append(aux_eval_ds) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_id = decode.EOS_ID # Default Sentencepiece EOS token. def decode_tokens(toks): valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode('utf-8') logging.info('Initializing model, optimizer, and step functions.') # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = models.TransformerConfig( vocab_size=vocab_size, output_vocab_size=vocab_size, share_embeddings=FLAGS.share_embeddings, logits_via_embedding=FLAGS.logits_via_embedding, dtype=jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32, emb_dim=FLAGS.emb_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.qkv_dim, mlp_dim=FLAGS.mlp_dim, max_len=max(FLAGS.max_target_length, FLAGS.max_eval_target_length), dropout_rate=FLAGS.dropout_rate, attention_dropout_rate=FLAGS.attention_dropout_rate, deterministic=False, decode=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(deterministic=True, decode=True) start_step = 0 rng = jax.random.PRNGKey(FLAGS.random_seed) rng, init_rng = jax.random.split(rng) # It's possible that is supposed to be per device batch size input_shape = (FLAGS.batch_size, FLAGS.max_target_length) target_shape = (FLAGS.batch_size, FLAGS.max_target_length) m = models.Transformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) # apply an optimizer to this tree optimizer_def = optim.Adam(FLAGS.learning_rate, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(initial_variables['params']) # We access model params only from optimizer below via optimizer.target. del initial_variables if FLAGS.restore_checkpoints: logging.info('Restoring checkpoint.') # If we have a pretrained model, use that. Else, just continue where leftoff model_path = FLAGS.pretrained_model_dir if FLAGS.pretrained_model_dir else FLAGS.model_dir optimizer = checkpoints.restore_checkpoint(model_path, optimizer) # Grab last step. start_step = int(optimizer.state.step) if FLAGS.adapter != train_util.NONE: adapter = optim.ModelParamTraversal( lambda path, _: FLAGS.adapter in path) optimizer = optimizer_def.create(optimizer.target, focus=adapter) writer = metric_writers.create_default_writer( FLAGS.model_dir, just_logging=jax.process_index() > 0) flag_key = [ k for k in FLAGS.flags_by_module_dict().keys() if 'wmt.par' in k ] if flag_key: flag_key = flag_key[0] local_flags = { f.name: f.value for f in FLAGS.flags_by_module_dict()[flag_key] } writer.write_hparams(local_flags) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) if FLAGS.adapter != train_util.NONE: learning_rate_fn = common.create_learning_rate_scheduler( factors='constant', base_learning_rate=FLAGS.learning_rate, warmup_steps=FLAGS.warmup_steps) else: learning_rate_fn = common.create_learning_rate_scheduler( base_learning_rate=FLAGS.learning_rate, warmup_steps=FLAGS.warmup_steps, steps_per_cycle=FLAGS.steps_per_cycle, init_step=start_step, finetune_lr=FLAGS.finetune_lr) # compile multidevice versions of train/eval/predict step and cache init fn. p_train_step = jax.pmap(functools.partial( train_util.train_step, config=train_config, learning_rate_fn=learning_rate_fn, label_smoothing=FLAGS.label_smoothing), axis_name='batch', donate_argnums=(0, )) # pytype: disable=wrong-arg-types p_eval_step = jax.pmap(functools.partial(train_util.eval_step, config=eval_config), axis_name='batch') p_init_cache = jax.pmap(functools.partial( train_util.initialize_cache, max_decode_len=FLAGS.max_predict_length, config=predict_config), axis_name='batch') p_pred_step = jax.pmap( functools.partial(train_util.predict_step, config=predict_config, beam_size=FLAGS.beam_size), axis_name='batch', static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant p_get_diag_grads = jax.pmap(functools.partial(train_util.get_diag_grads, config=eval_config), axis_name='batch') p_get_bucket_score = jax.pmap(functools.partial(get_diag_score, strategy=FLAGS.strategy), axis_name='batch') # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap"d training update for performance. dropout_rngs = jax.random.split(rng, jax.local_device_count()) del rng logging.info('Starting training loop.') hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=FLAGS.num_train_steps, writer=writer) if jax.process_index() == 0: hooks += [ report_progress, periodic_actions.Profile(logdir=FLAGS.model_dir, num_profile_steps=5) ] train_metrics = [] total_steps = start_step + FLAGS.num_train_steps if FLAGS.eval_only: total_steps = start_step + 1 best_eval_loss = 1000 curr_eval_loss = 1000 with metric_writers.ensure_flushes(writer): for step in range(start_step, total_steps): is_last_step = step == total_steps - 1 if FLAGS.dynamic and ((step - start_step) % FLAGS.resample_freq == 0): # Dynamic macro: use gradient alignment to score different ratios # of top k vs bottom N-k bins if FLAGS.macro: train_iter = get_macro_distribution( p_get_diag_grads, p_get_bucket_score, aux_eval_ds, train_ds_mgr, optimizer, eval_ds) else: # Use gradient alignment to score bins # take the top k bins and sample uniformly from them. raw_distribution = get_new_distribution( p_get_diag_grads, p_get_bucket_score, aux_eval_ds, train_ds_mgr, optimizer, eval_ds) logging.info(raw_distribution) selected = np.argsort( raw_distribution)[::-1][:FLAGS.data_selection_size] new_distribution = np.zeros(100) new_distribution[selected] = 1.0 logging.info(new_distribution) train_ds = train_ds_mgr.sampled_dataset(new_distribution) train_iter = iter(train_ds) # Shard data to devices and do a training step. if not FLAGS.eval_only: logging.info('Doing Training.') with jax.profiler.StepTraceAnnotation('train', step_num=step): try: batch = common_utils.shard( jax.tree_map(np.asarray, next(train_iter))) optimizer, metrics = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) train_metrics.append(metrics) except StopIteration: is_last_step = True # Quick indication that training is happening. logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step) for h in hooks: h(step) # Periodic metric handling. if (step - start_step) % FLAGS.eval_frequency == 0 or is_last_step: if not FLAGS.eval_only: with report_progress.timed('training_metrics'): logging.info('Gathering training metrics.') train_metrics = common_utils.get_metrics(train_metrics) lr = train_metrics.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, train_metrics) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr summary = {'train_' + k: v for k, v in summary.items()} writer.write_scalars(step, summary) train_metrics = [] if FLAGS.eval_only: p_eval_per_pos_step = jax.pmap(functools.partial( train_util.eval_per_pos_step, config=eval_config), axis_name='batch') # Get per example loss loss_filename = FLAGS.model_dir + '/test_losses.csv' train_util.write_per_example_losses( p_eval_step=p_eval_per_pos_step, target=optimizer.target, eval_ds=eval_ds, num_eval_steps=FLAGS.num_eval_steps, loss_filename=loss_filename) else: with report_progress.timed('eval'): eval_results = train_util.evaluate( p_eval_step=p_eval_step, target=optimizer.target, eval_ds=eval_ds, num_eval_steps=FLAGS.num_eval_steps) curr_eval_loss = eval_results['loss'] writer.write_scalars( step, {'eval_' + k: v for k, v in eval_results.items()}) if FLAGS.aux_eval_dataset: for aux_i, aux_eval_ds in enumerate(aux_datasets): with report_progress.timed('aux_eval'): eval_results = train_util.evaluate( p_eval_step=p_eval_step, target=optimizer.target, eval_ds=aux_eval_ds, num_eval_steps=FLAGS.num_eval_steps) writer.write_scalars( step, { 'aux' + str(aux_i) + '_eval_' + k: v for k, v in eval_results.items() }) if FLAGS.compute_bleu: with report_progress.timed('translate_and_bleu'): decode_file = FLAGS.model_dir + '/decodes.csv' exemplars, bleu_score = train_util.translate_and_calculate_bleu( p_pred_step=p_pred_step, p_init_cache=p_init_cache, target=optimizer.target, predict_ds=predict_ds, decode_tokens=decode_tokens, max_predict_length=FLAGS.max_predict_length, num_eval_steps=FLAGS.num_eval_steps, decode_file=decode_file if FLAGS.eval_only else '') writer.write_scalars(step, {'bleu': bleu_score}) writer.write_texts(step, {'samples': exemplars}) # Save a checkpoint on one host after every checkpoint_freq steps. save_checkpoint = ((step - start_step) % FLAGS.checkpoint_freq == 0 or is_last_step) if FLAGS.save_checkpoints and save_checkpoint and jax.process_index( ) == 0: if curr_eval_loss < best_eval_loss: # only save better checkpoints best_eval_loss = curr_eval_loss with report_progress.timed('checkpoint'): checkpoints.save_checkpoint( FLAGS.model_dir, jax_utils.unreplicate(optimizer), step, keep=FLAGS.chkpts_to_keep, overwrite=True) if is_last_step: break
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): """Runs training interleaved with evaluation.""" # Setup input pipeline dataset_info = input_pipeline.get_dataset_info(config.dataset, 'train') ds_train, ds_test = input_pipeline.get_datasets(config) batch = next(iter(ds_train)) logging.info(ds_train) logging.info(ds_test) # Build VisionTransformer architecture model_cls = {'ViT': models.VisionTransformer, 'Mixer': models.MlpMixer}[config.get('model_type', 'ViT')] model = model_cls(num_classes=dataset_info['num_classes'], **config.model) def init_model(): return model.init( jax.random.PRNGKey(0), # Discard the "num_local_devices" dimension for initialization. jnp.ones(batch['image'].shape[1:], batch['image'].dtype.name), train=False) # Use JIT to make sure params reside in CPU memory. variables = jax.jit(init_model, backend='cpu')() model_or_filename = config.get('model_or_filename') if model_or_filename: # Loading model from repo published with "How to train your ViT? Data, # Augmentation, and Regularization in Vision Transformers" paper. if '-' in model_or_filename: filename = model_or_filename else: # Select best checkpoint from i21k pretraining by final upstream # validation accuracy. df = checkpoint.get_augreg_df(directory=config.pretrained_dir) sel = df.filename.apply( lambda filename: filename.split('-')[0] == model_or_filename) best = df.loc[sel].query('ds=="i21k"').sort_values('final_val').iloc[-1] filename = best.filename logging.info('Selected fillename="%s" for "%s" with final_val=%.3f', filename, model_or_filename, best.final_val) pretrained_path = os.path.join(config.pretrained_dir, f'{config.model.name}.npz') else: # ViT / Mixer papers filename = config.model.name pretrained_path = os.path.join(config.pretrained_dir, f'{filename}.npz') if not tf.io.gfile.exists(pretrained_path): raise ValueError( f'Could not find "{pretrained_path}" - you can download models from ' '"gs://vit_models/imagenet21k" or directly set ' '--config.pretrained_dir="gs://vit_models/imagenet21k".') params = checkpoint.load_pretrained( pretrained_path=pretrained_path, init_params=variables['params'], model_config=config.model) total_steps = config.total_steps lr_fn = utils.create_learning_rate_schedule(total_steps, config.base_lr, config.decay_type, config.warmup_steps) update_fn_repl = make_update_fn( apply_fn=model.apply, accum_steps=config.accum_steps, lr_fn=lr_fn) infer_fn_repl = jax.pmap(functools.partial(model.apply, train=False)) # Create optimizer and replicate it over all TPUs/GPUs opt = momentum_clip.Optimizer( dtype=config.optim_dtype, grad_norm_clip=config.grad_norm_clip).create(params) initial_step = 1 opt, initial_step = flax_checkpoints.restore_checkpoint( workdir, (opt, initial_step)) logging.info('Will start/continue training at initial_step=%d', initial_step) opt_repl = flax.jax_utils.replicate(opt) # Delete references to the objects that are not needed anymore del opt del params # Prepare the learning-rate and pre-fetch it to device to avoid delays. update_rng_repl = flax.jax_utils.replicate(jax.random.PRNGKey(0)) # Setup metric writer & hooks. writer = metric_writers.create_default_writer(workdir, asynchronous=False) writer.write_hparams(config.to_dict()) hooks = [ periodic_actions.Profile(logdir=workdir), periodic_actions.ReportProgress( num_train_steps=total_steps, writer=writer), ] # Run training loop logging.info('Starting training loop; initial compile can take a while...') t0 = lt0 = time.time() lstep = initial_step for step, batch in zip( range(initial_step, total_steps + 1), input_pipeline.prefetch(ds_train, config.prefetch)): with jax.profiler.StepTraceContext('train', step_num=step): opt_repl, loss_repl, update_rng_repl = update_fn_repl( opt_repl, flax.jax_utils.replicate(step), batch, update_rng_repl) for hook in hooks: hook(step) if step == initial_step: logging.info('First step took %.1f seconds.', time.time() - t0) t0 = time.time() lt0, lstep = time.time(), step # Report training metrics if config.progress_every and step % config.progress_every == 0: img_sec_core_train = (config.batch * (step - lstep) / (time.time() - lt0)) / jax.device_count() lt0, lstep = time.time(), step writer.write_scalars( step, dict( train_loss=float(flax.jax_utils.unreplicate(loss_repl)), img_sec_core_train=img_sec_core_train)) done = step / total_steps logging.info(f'Step: {step}/{total_steps} {100*done:.1f}%, ' # pylint: disable=logging-format-interpolation f'img/sec/core: {img_sec_core_train:.1f}, ' f'ETA: {(time.time()-t0)/done*(1-done)/3600:.2f}h') # Run evaluation if ((config.eval_every and step % config.eval_every == 0) or (step == total_steps)): accuracies = [] lt0 = time.time() for test_batch in input_pipeline.prefetch(ds_test, config.prefetch): logits = infer_fn_repl( dict(params=opt_repl.target), test_batch['image']) accuracies.append( (np.argmax(logits, axis=-1) == np.argmax(test_batch['label'], axis=-1)).mean()) accuracy_test = np.mean(accuracies) img_sec_core_test = ( config.batch_eval * ds_test.cardinality().numpy() / (time.time() - lt0) / jax.device_count()) lt0 = time.time() lr = float(lr_fn(step)) logging.info(f'Step: {step} ' # pylint: disable=logging-format-interpolation f'Learning rate: {lr:.7f}, ' f'Test accuracy: {accuracy_test:0.5f}, ' f'img/sec/core: {img_sec_core_test:.1f}') writer.write_scalars( step, dict( accuracy_test=accuracy_test, lr=lr, img_sec_core_test=img_sec_core_test)) # Store checkpoint. if ((config.checkpoint_every and step % config.eval_every == 0) or step == total_steps): checkpoint_path = flax_checkpoints.save_checkpoint( workdir, (flax.jax_utils.unreplicate(opt_repl), step), step) logging.info('Stored checkpoint at step %d to "%s"', step, checkpoint_path) return flax.jax_utils.unreplicate(opt_repl)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') tf.enable_v2_behavior() config = FLAGS.config logging.info('===========Config Dict============') logging.info(config) batch_size = config.batch_size learning_rate = config.learning_rate num_train_steps = config.num_train_steps num_eval_steps = config.num_eval_steps eval_freq = config.eval_frequency random_seed = config.random_seed model_type = config.model_type max_length = config.max_length if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'summary')) if batch_size % jax.device_count() > 0: raise ValueError( 'Batch size must be divisible by the number of devices') train_ds, eval_ds, test_ds, encoder = input_pipeline.get_tc_datasets( n_devices=jax.local_device_count(), task_name=FLAGS.task_name, data_dir=FLAGS.data_dir, batch_size=batch_size, fixed_vocab=None, max_length=max_length) vocab_size = encoder.vocab_size logging.info('Vocab Size: %d', vocab_size) train_ds = train_ds.repeat() train_iter = iter(train_ds) input_shape = (batch_size, max_length) model_kwargs = { 'vocab_size': vocab_size, 'emb_dim': config.emb_dim, 'num_heads': config.num_heads, 'num_layers': config.num_layers, 'qkv_dim': config.qkv_dim, 'mlp_dim': config.mlp_dim, 'max_len': max_length, 'classifier': True, 'num_classes': CLASS_MAP[FLAGS.task_name], 'classifier_pool': config.classifier_pool } rng = random.PRNGKey(random_seed) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = random.split(rng) # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. dropout_rngs = random.split(rng, jax.local_device_count()) if model_type == 'transformer': model = create_model(init_rng, transformer.TransformerEncoder, input_shape, model_kwargs) else: raise ValueError('Model type not supported') optimizer = create_optimizer(model, learning_rate, weight_decay=FLAGS.config.weight_decay) del model # Don't keep a copy of the initial model. start_step = 0 if config.restore_checkpoints or FLAGS.test_only: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer) # Grab last step. start_step = int(optimizer.state.step) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = train_utils.create_learning_rate_scheduler( factors=config.factors, base_learning_rate=learning_rate, warmup_steps=config.warmup) p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn), axis_name='batch') p_eval_step = jax.pmap(eval_step, axis_name='batch') # p_pred_step = jax.pmap(predict_step, axis_name='batch') def run_eval(eval_ds, num_eval_steps=-1): eval_metrics = [] eval_iter = iter(eval_ds) if num_eval_steps == -1: num_iter = itertools.count() else: num_iter = range(num_eval_steps) for _, eval_batch in zip(num_iter, eval_iter): # pylint: disable=protected-access eval_batch = common_utils.shard( jax.tree_map(lambda x: x._numpy(), eval_batch)) # pylint: enable=protected-access metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) # Calculate (clipped) perplexity after averaging log-perplexities: eval_summary['perplexity'] = jnp.clip(jnp.exp(eval_summary['loss']), a_max=1.0e4) return eval_summary if FLAGS.test_only: with tf.io.gfile.GFile(os.path.join(FLAGS.model_dir, 'results.json'), 'w') as f: test_summary = run_eval(test_ds) json.dump(jax.tree_map(lambda x: x.tolist(), test_summary), f) return metrics_all = [] tick = time.time() logging.info('Starting training') logging.info('====================') for step, batch in zip(range(start_step, num_train_steps), train_iter): batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access optimizer, metrics, dropout_rngs = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) metrics_all.append(metrics) logging.info('train in step: %d', step) # Save a Checkpoint if ((step % config.checkpoint_freq == 0 and step > 0) or step == num_train_steps - 1): if jax.host_id() == 0 and config.save_checkpoints: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint(FLAGS.model_dir, jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if step % eval_freq == 0 and step > 0: metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) logging.info('train in step: %d, loss: %.4f, acc: %.4f', step, summary['loss'], summary['accuracy']) if jax.host_id() == 0: tock = time.time() steps_per_sec = eval_freq / (tock - tick) tick = tock summary_writer.scalar('steps per second', steps_per_sec, step) for key, val in summary.items(): summary_writer.scalar(f'train_{key}', val, step) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] # Eval Metrics eval_summary = run_eval(eval_ds, num_eval_steps) logging.info('eval in step: %d, loss: %.4f, acc: %.4f', step, eval_summary['loss'], eval_summary['accuracy']) if jax.host_id() == 0: for key, val in eval_summary.items(): summary_writer.scalar(f'eval_{key}', val, step) summary_writer.flush()