def write_per_example_losses(*, p_eval_step, target, eval_ds, num_eval_steps, loss_filename): """Evaluate the target an return a dictionary with the metrics.""" logging.info('Gathering evaluation metrics.') losses = [] lengths = [] eval_iter = iter(eval_ds) # pytype: disable=wrong-arg-types for _, eval_batch in zip(range(num_eval_steps), eval_iter): eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access eval_batch = common_utils.shard(eval_batch) loss, length = p_eval_step(target, eval_batch) losses.append(common.tohost(loss)) lengths.append(common.tohost(length)) # Write losses and lengths if jax.process_index() == 0: with tf.io.gfile.GFile(loss_filename, 'w') as f: writer = csv.writer(f) for pos_losses in losses: for val in pos_losses: writer.writerow(list(val)) with tf.io.gfile.GFile(loss_filename.replace('.csv', '_length.csv'), 'w') as f: writer = csv.writer(f) for val in lengths: writer.writerow([int(v) for v in list(val)]) return
def get_losses(ds_iter, optimizer, p_eval_step, model_id, test_cluster_id): """Given optimizer and dataset, compute losses and write to file.""" logging.info('Start scoring loop.') n_devices = jax.local_device_count() t_loop_start = time.time() filename = '/losses_testcluster{test_cluster_id}_ftid{model_id}.csv' save_file = filename.format(test_cluster_id=test_cluster_id, model_id=model_id) save_file = FLAGS.save_path + save_file with tf.io.gfile.GFile(save_file, 'w') as fp: writer = csv.writer(fp) for batch_idx, eval_batch in enumerate(ds_iter): eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access cur_pred_batch_size = eval_batch['inputs'].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) eval_batch = jax.tree_map( lambda x: common.pad_examples(x, padded_size), eval_batch) # pylint: disable=cell-var-from-loop eval_batch = common_utils.shard(eval_batch) losses, lengths = p_eval_step(optimizer.target, eval_batch) if jax.process_index() == 0: losses = common.tohost(losses) lengths = common.tohost(lengths) if cur_pred_batch_size % n_devices: writer.writerow(losses[:cur_pred_batch_size]) else: writer.writerow(losses) if batch_idx % 500 == 0: print('Batch', batch_idx) print(time.time() - t_loop_start) if batch_idx >= FLAGS.limit: break
def get_decodes(ds_iter, optimizer, p_init_cache, p_pred_step, model_id, test_cluster_id, decode_tokens): """Given optimizer and dataset, compute losses and write to file.""" logging.info('Start scoring loop.') n_devices = jax.local_device_count() predictions = [] max_predict_length = 256 filename = '/decodes_testcluster{test_cluster_id}_ftid{model_id}.csv' save_file = filename.format(test_cluster_id=test_cluster_id, model_id=model_id) save_file = FLAGS.save_path + save_file with tf.io.gfile.GFile(save_file, 'w') as fp: writer = csv.writer(fp) for batch_idx, pred_batch in enumerate(ds_iter): pred_batch = jax.tree_map(lambda x: x._numpy(), pred_batch) # pylint: disable=protected-access # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch['inputs'].shape[0] if cur_pred_batch_size % n_devices: padded_size = int(np.ceil(cur_pred_batch_size / n_devices) * n_devices) pred_batch = jax.tree_map( lambda x: pad_examples(x, padded_size), # pylint: disable=cell-var-from-loop pred_batch) pred_batch = common_utils.shard(pred_batch) cache = p_init_cache(pred_batch['inputs']) predicted, _ = p_pred_step(pred_batch['inputs'], optimizer.target, cache, decode.EOS_ID, max_predict_length) if jax.process_index() == 0: predicted = common.tohost(predicted) # Iterate through non-padding examples of batch. for s in predicted[:cur_pred_batch_size]: predictions.append(decode_tokens(s)) if batch_idx >= FLAGS.limit: break writer.writerow(predictions)
def compute_is_scores(filename): """Compute IS scores for training data.""" # Make sure tf does not allocate gpu memory. tf.config.experimental.set_visible_devices([], 'GPU') if FLAGS.jax_backend_target: jax.config.FLAGS.jax_xla_backend = 'tpu_driver' jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: tf.io.gfile.makedirs(FLAGS.model_dir) if FLAGS.batch_size % n_devices: raise ValueError( 'Batch size must be divisible by the number of devices') vocab_path = FLAGS.vocab_path if vocab_path is None: vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model') tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) # Load Dataset print('Loading data') logging.info('Initializing dataset.') train_ds, encoder = input_pipeline.get_wmt_is_datasets( n_devices=n_devices, dataset_name=FLAGS.dataset_name, shard_idx=jax.host_id(), shard_count=jax.host_count(), data_dir=FLAGS.data_dir, vocab_path=vocab_path, target_vocab_size=FLAGS.vocab_size, batch_size=FLAGS.batch_size, max_length=FLAGS.max_target_length, paracrawl_size=FLAGS.paracrawl_size) print('Datasets created') train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_id = decode.EOS_ID # Default Sentencepiece EOS token. print('data iterators created') logging.info('Initializing model, optimizer, and step functions.') # Build Model and Optimizer # --------------------------------------------------------------------------- eval_config = models.TransformerConfig( vocab_size=vocab_size, output_vocab_size=vocab_size, share_embeddings=FLAGS.share_embeddings, logits_via_embedding=FLAGS.logits_via_embedding, dtype=jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32, emb_dim=FLAGS.emb_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.qkv_dim, mlp_dim=FLAGS.mlp_dim, max_len=max(FLAGS.max_target_length, FLAGS.max_eval_target_length), dropout_rate=FLAGS.dropout_rate, attention_dropout_rate=FLAGS.attention_dropout_rate, deterministic=True, decode=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) start_step = 0 rng = jax.random.PRNGKey(FLAGS.random_seed) rng, init_rng = jax.random.split(rng) # It's possible that is supposed to be per device batch size input_shape = (FLAGS.batch_size, FLAGS.max_target_length) target_shape = (FLAGS.batch_size, FLAGS.max_target_length) m = models.Transformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) # apply an optimizer to this tree optimizer_def = optim.Adam(FLAGS.learning_rate, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(initial_variables['params']) # We access model params only from optimizer below via optimizer.target. del initial_variables if FLAGS.restore_checkpoints: logging.info('Restoring checkpoint.') # If we have a pretrained model, use that. Else, just continue where leftoff model_path = FLAGS.pretrained_model_dir if FLAGS.pretrained_model_dir else FLAGS.model_dir # When loading a checkpoint trained with adapters (ie. frozen weights) # restoring from the base optimizer fails. We catch this error and create # the optimizer with frozen weights. try: optimizer = checkpoints.restore_checkpoint(model_path, optimizer) # Grab last step. start_step = int(optimizer.state.step) except ValueError: adapter = optim.ModelParamTraversal( lambda path, _: FLAGS.adapter in path) optimizer = optimizer_def.create(optimizer.target, focus=adapter) optimizer = checkpoints.restore_checkpoint(model_path, optimizer) start_step = optimizer.state[0].step else: raise RuntimeError('Must restore checkpoint for IS') if FLAGS.adapter != NONE and not isinstance(optimizer, optim.MultiOptimizer): adapter = optim.ModelParamTraversal( lambda path, _: FLAGS.adapter in path) optimizer = optimizer_def.create(optimizer.target, focus=adapter) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) p_eval_step = jax.pmap(functools.partial(eval_for_is_step, config=eval_config), axis_name='batch') logging.info('Start scoring loop.') metrics_all = [] t_loop_start = time.time() # Eval Metrics logging.info('Gathering evaluation metrics.') t_eval_start = time.time() save_file = FLAGS.is_save_path + '/' + filename + '-lengths.txt' length_fp = tf.io.gfile.GFile(save_file, 'w') lengths_writer = csv.writer(length_fp) save_file = FLAGS.is_save_path + '/' + filename + '.txt' with tf.io.gfile.GFile(save_file, 'w') as fp: writer = csv.writer(fp) for batch_idx, eval_batch in enumerate(train_iter): eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access cur_pred_batch_size = eval_batch['inputs'].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) eval_batch = jax.tree_map( lambda x: common.pad_examples(x, padded_size), eval_batch) # pylint: disable=cell-var-from-loop eval_batch = common_utils.shard(eval_batch) losses, lengths = p_eval_step(optimizer.target, eval_batch) if jax.host_id() == 0: losses = common.tohost(losses) lengths = common.tohost(lengths) if cur_pred_batch_size % n_devices: writer.writerow(losses[:cur_pred_batch_size]) lengths_writer.writerow(lengths[:cur_pred_batch_size]) else: writer.writerow(losses) lengths_writer.writerow(lengths) if batch_idx % 500 == 0: print('Batch', batch_idx) print(time.time() - t_loop_start) length_fp.close()