def get_learning_rate_fn(config, num_train_steps): warmup_steps = int(config.warmup_proportion * num_train_steps) cooldown_steps = num_train_steps - warmup_steps learning_rate_fn = training.create_learning_rate_scheduler( factors='constant * linear_warmup * linear_decay', base_learning_rate=config.learning_rate, warmup_steps=warmup_steps, steps_per_cycle=cooldown_steps, ) return learning_rate_fn
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') config = get_config() datasets.logging.set_verbosity_error() dataset = datasets.load_dataset(config.dataset_path, config.dataset_name) os.environ['TOKENIZERS_PARALLELISM'] = 'true' tokenizer = BertTokenizerFast.from_pretrained(config.tokenizer) tokenizer.model_max_length = config.max_seq_length data_pipeline = data.ClassificationDataPipeline(dataset, tokenizer) num_train_examples = len(dataset['train']) num_train_steps = int(num_train_examples * config.num_train_epochs // config.train_batch_size) warmup_steps = int(config.warmup_proportion * num_train_steps) cooldown_steps = num_train_steps - warmup_steps is_regression_task = ( dataset['train'].features['label'].dtype == 'float32') if is_regression_task: num_classes = 1 compute_stats = compute_regression_stats else: num_classes = dataset['train'].features['label'].num_classes compute_stats = compute_classification_stats model = create_model(config, num_classes=num_classes) optimizer = create_optimizer(config, model) optimizer = optimizer.replicate() del model # don't keep a copy of the initial model learning_rate_fn = training.create_learning_rate_scheduler( factors='constant * linear_warmup * linear_decay', base_learning_rate=config.learning_rate, warmup_steps=warmup_steps, steps_per_cycle=cooldown_steps, ) output_dir = get_output_dir(config) gfile.makedirs(output_dir) train_history = training.TrainStateHistory(learning_rate_fn) train_state = train_history.initial_state() if config.do_train: train_step_fn = training.create_train_step(compute_loss_and_metrics, clip_grad_norm=1.0) train_iter = data_pipeline.get_inputs( split='train', batch_size=config.train_batch_size, training=True) for step, batch in zip(range(0, num_train_steps), train_iter): optimizer, train_state = train_step_fn(optimizer, batch, train_state) if config.do_eval: eval_step = training.create_eval_fn(compute_stats) eval_results = [] if config.dataset_path == 'glue' and config.dataset_name == 'mnli': validation_splits = ['validation_matched', 'validation_mismatched'] else: validation_splits = ['validation'] for split in validation_splits: eval_iter = data_pipeline.get_inputs( split=split, batch_size=config.eval_batch_size, training=False) eval_stats = eval_step(optimizer, eval_iter) eval_metric = datasets.load_metric(config.dataset_path, config.dataset_name) eval_metric.add_batch(predictions=eval_stats['prediction'], references=eval_stats['label']) eval_metrics = eval_metric.compute() prefix = 'eval_mismatched' if split == 'validation_mismatched' else 'eval' for name, val in sorted(eval_metrics.items()): line = f'{prefix}_{name} = {val:.06f}' print(line, flush=True) eval_results.append(line) eval_results_path = os.path.join(output_dir, 'eval_results.txt') with gfile.GFile(eval_results_path, 'w') as f: for line in eval_results: f.write(line + '\n') if config.do_predict: predict_step = training.create_eval_fn(compute_stats) predict_results = [] path_map = { ('glue', 'cola', 'test'): 'CoLA.tsv', ('glue', 'mrpc', 'test'): 'MRPC.tsv', ('glue', 'qqp', 'test'): 'QQP.tsv', ('glue', 'sst2', 'test'): 'SST-2.tsv', ('glue', 'stsb', 'test'): 'STS-B.tsv', ('glue', 'mnli', 'test_matched'): 'MNLI-m.tsv', ('glue', 'mnli', 'test_mismatched'): 'MNLI-mm.tsv', ('glue', 'qnli', 'test'): 'QNLI.tsv', ('glue', 'rte', 'test'): 'RTE.tsv', # No eval on WNLI for now. BERT accuracy on WNLI is below baseline, # unless a special training recipe is used. # ('glue/wnli', 'test'): 'WNLI.tsv', } label_sets = { ('glue', 'cola'): ['0', '1'], ('glue', 'mrpc'): ['0', '1'], ('glue', 'qqp'): ['0', '1'], ('glue', 'sst2'): ['0', '1'], ('glue', 'mnli'): ['entailment', 'neutral', 'contradiction'], ('glue', 'qnli'): ['entailment', 'not_entailment'], ('glue', 'rte'): ['entailment', 'not_entailment'], } for path_map_key in path_map: candidate_dataset_path, candidate_dataset_name, split = path_map_key if (candidate_dataset_path != config.dataset_path or candidate_dataset_name != config.dataset_name): continue predict_iter = data_pipeline.get_inputs( split=split, batch_size=config.eval_batch_size, training=False) predict_stats = predict_step(optimizer, predict_iter) idxs = predict_stats['idx'] predictions = predict_stats['prediction'] tsv_path = os.path.join( output_dir, path_map[config.dataset_path, config.dataset_name, split]) with gfile.GFile(tsv_path, 'w') as f: f.write('index\tprediction\n') if is_regression_task: for idx, val in zip(idxs, predictions): f.write(f'{idx}\t{val:.06f}\n') else: label_set = label_sets[config.dataset_path, config.dataset_name] for idx, val in zip(idxs, predictions): f.write(f'{idx}\t{label_set[val]}\n') print('Wrote', tsv_path)
def main(argv): if len(argv) > 1: raise app.UsageError("Too many command-line arguments.") config = get_config() datasets.logging.set_verbosity_error() # Workaround for https://github.com/huggingface/datasets/issues/812 logging.getLogger("filelock").setLevel(logging.ERROR) dataset = datasets.load_dataset(config.dataset_path, config.dataset_name) data_pipeline = data.ClassificationDataPipeline( dataset, config.init_checkpoint, max_seq_length=config.max_seq_length) num_train_examples = len(dataset["train"]) num_train_steps = int(num_train_examples * config.num_train_epochs // config.train_batch_size) warmup_steps = int(config.warmup_proportion * num_train_steps) cooldown_steps = num_train_steps - warmup_steps is_regression_task = dataset["train"].features["label"].dtype == "float32" if is_regression_task: num_classes = 1 compute_stats = compute_regression_stats else: num_classes = dataset["train"].features["label"].num_classes compute_stats = compute_classification_stats model = modeling.BertForSequenceClassification(config=config.model, n_classes=num_classes) initial_params = get_initial_params(model, init_checkpoint=config.init_checkpoint) optimizer = create_optimizer(config, initial_params) del initial_params # the optimizer takes ownership of all params optimizer = optimizer.replicate() optimizer = training.harmonize_across_hosts(optimizer) learning_rate_fn = training.create_learning_rate_scheduler( factors="constant * linear_warmup * linear_decay", base_learning_rate=config.learning_rate, warmup_steps=warmup_steps, steps_per_cycle=cooldown_steps, ) output_dir = get_output_dir(config) gfile.makedirs(output_dir) train_history = training.TrainStateHistory(learning_rate_fn) train_state = train_history.initial_state() if config.do_train: train_step_fn = training.create_train_step( model, compute_loss_and_metrics, max_grad_norm=config.max_grad_norm) train_iter = data_pipeline.get_inputs( split="train", batch_size=config.train_batch_size, training=True) for step, batch in zip(range(0, num_train_steps), train_iter): optimizer, train_state = train_step_fn(optimizer, batch, train_state) if config.do_eval: eval_step = training.create_eval_fn(model, compute_stats) eval_results = [] if config.dataset_path == "glue" and config.dataset_name == "mnli": validation_splits = ["validation_matched", "validation_mismatched"] else: validation_splits = ["validation"] for split in validation_splits: eval_iter = data_pipeline.get_inputs( split=split, batch_size=config.eval_batch_size, training=False) eval_stats = eval_step(optimizer, eval_iter) eval_metric = datasets.load_metric(config.dataset_path, config.dataset_name) eval_metric.add_batch(predictions=eval_stats["prediction"], references=eval_stats["label"]) eval_metrics = eval_metric.compute() prefix = "eval_mismatched" if split == "validation_mismatched" else "eval" for name, val in sorted(eval_metrics.items()): line = f"{prefix}_{name} = {val:.06f}" print(line, flush=True) eval_results.append(line) eval_results_path = os.path.join(output_dir, "eval_results.txt") with gfile.GFile(eval_results_path, "w") as f: for line in eval_results: f.write(line + "\n") if config.do_predict: predict_step = training.create_eval_fn(model, compute_stats) predict_results = [] path_map = { ("glue", "cola", "test"): "CoLA.tsv", ("glue", "mrpc", "test"): "MRPC.tsv", ("glue", "qqp", "test"): "QQP.tsv", ("glue", "sst2", "test"): "SST-2.tsv", ("glue", "stsb", "test"): "STS-B.tsv", ("glue", "mnli", "test_matched"): "MNLI-m.tsv", ("glue", "mnli", "test_mismatched"): "MNLI-mm.tsv", ("glue", "qnli", "test"): "QNLI.tsv", ("glue", "rte", "test"): "RTE.tsv", # No eval on WNLI for now. BERT accuracy on WNLI is below baseline, # unless a special training recipe is used. # ('glue/wnli', 'test'): 'WNLI.tsv', } label_sets = { ("glue", "cola"): ["0", "1"], ("glue", "mrpc"): ["0", "1"], ("glue", "qqp"): ["0", "1"], ("glue", "sst2"): ["0", "1"], ("glue", "mnli"): ["entailment", "neutral", "contradiction"], ("glue", "qnli"): ["entailment", "not_entailment"], ("glue", "rte"): ["entailment", "not_entailment"], } for path_map_key in path_map: candidate_dataset_path, candidate_dataset_name, split = path_map_key if (candidate_dataset_path != config.dataset_path or candidate_dataset_name != config.dataset_name): continue predict_iter = data_pipeline.get_inputs( split=split, batch_size=config.eval_batch_size, training=False) predict_stats = predict_step(optimizer, predict_iter) idxs = predict_stats["idx"] predictions = predict_stats["prediction"] tsv_path = os.path.join( output_dir, path_map[config.dataset_path, config.dataset_name, split]) with gfile.GFile(tsv_path, "w") as f: f.write("index\tprediction\n") if is_regression_task: for idx, val in zip(idxs, predictions): f.write(f"{idx}\t{val:.06f}\n") else: label_set = label_sets[config.dataset_path, config.dataset_name] for idx, val in zip(idxs, predictions): f.write(f"{idx}\t{label_set[val]}\n") print("Wrote", tsv_path)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') config = FLAGS.config model = create_model(config) optimizer = create_optimizer(config, model) del model # don't keep a copy of the initial model output_dir = get_output_dir(config) gfile.makedirs(output_dir) # Restore from a local checkpoint, if one exists. optimizer = checkpoints.restore_checkpoint(output_dir, optimizer) start_step = int(optimizer.state[0].step) optimizer = optimizer.replicate() os.environ['TOKENIZERS_PARALLELISM'] = 'true' tokenizer = BertTokenizerFast.from_pretrained(config.tokenizer) tokenizer.model_max_length = config.max_seq_length # The commented lines below correspond to a data pipeline that uses publicly # available data, in the form of English Wikipedia as processed and hosted by # the HuggingFace datasets library. The pipeline works, and downstream task # performance shows a benefit to pre-training, but I (Nikita) have yet to # confirm that final model quality is on par with the original BERT. # # dataset = datasets.load_dataset('wikipedia', '20200501.en')['train'] # data_pipeline = data.PretrainingDataPipelineV1( # dataset, tokenizer, # max_predictions_per_seq=config.max_predictions_per_seq) # The data pipeline below relies on having text files of Wikipedia + Books in # the same format as the original BERT data. That original data is not # publicly available, so you will need to provide your own. I (Nikita) have # had success using data from Gong et al. "Efficient Training of BERT by # Progressively Stacking", but this data was also obtained through private # correspondence and may not be generally available. # The data_files argument may be a list, if data is split across multiple # input files. dataset = datasets.load_dataset( 'bert_data.py', data_files=os.path.expanduser('~/data/bert/corpus.train.tok'))['train'] data_pipeline = data.PretrainingDataPipeline( dataset, tokenizer, max_predictions_per_seq=config.max_predictions_per_seq) datasets.logging.set_verbosity_error() learning_rate_fn = training.create_learning_rate_scheduler( factors='constant * linear_warmup * linear_decay', base_learning_rate=config.learning_rate, warmup_steps=config.num_warmup_steps, steps_per_cycle=config.num_train_steps - config.num_warmup_steps, ) train_history = training.TrainStateHistory(learning_rate_fn) train_state = train_history.initial_state() if config.do_train: train_iter = data_pipeline.get_inputs( batch_size=config.train_batch_size, training=True) train_step_fn = training.create_train_step( compute_pretraining_loss_and_metrics, clip_grad_norm=1.0) for step, batch in zip(range(start_step, config.num_train_steps), train_iter): optimizer, train_state = train_step_fn(optimizer, batch, train_state) if jax.host_id() == 0 and (step % config.save_checkpoints_steps == 0 or step == config.num_train_steps - 1): checkpoints.save_checkpoint(output_dir, optimizer.unreplicate(), step) config_path = os.path.join(output_dir, 'config.json') if not os.path.exists(config_path): with open(config_path, 'w') as f: json.dump({'model_type': 'bert', **config.model}, f) if config.do_eval: eval_iter = data_pipeline.get_inputs(batch_size=config.eval_batch_size) eval_iter = itertools.islice(eval_iter, config.max_eval_steps) eval_fn = training.create_eval_fn(compute_pretraining_stats, sample_feature_name='input_ids') eval_stats = eval_fn(optimizer, eval_iter) eval_metrics = { 'loss': jnp.mean(eval_stats['loss']), 'masked_lm_loss': jnp.mean(eval_stats['masked_lm_loss']), 'next_sentence_loss': jnp.mean(eval_stats['next_sentence_loss']), 'masked_lm_accuracy': jnp.sum(eval_stats['masked_lm_correct']) / jnp.sum(eval_stats['masked_lm_total']), 'next_sentence_accuracy': jnp.sum(eval_stats['next_sentence_correct']) / jnp.sum(eval_stats['next_sentence_total']), } eval_results = [] for name, val in sorted(eval_metrics.items()): line = f'{name} = {val:.06f}' print(line, flush=True) eval_results.append(line) eval_results_path = os.path.join(output_dir, 'eval_results.txt') with gfile.GFile(eval_results_path, 'w') as f: for line in eval_results: f.write(line + '\n')
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') config = FLAGS.config model = create_model(config) optimizer = create_optimizer(config, model) del model # don't keep a copy of the initial model output_dir = get_output_dir(config) gfile.makedirs(output_dir) # Restore from a local checkpoint, if one exists. optimizer = checkpoints.restore_checkpoint(output_dir, optimizer) start_step = int(optimizer.state[0].step) optimizer = optimizer.replicate() optimizer = training.harmonize_across_hosts(optimizer) os.environ['TOKENIZERS_PARALLELISM'] = 'true' tokenizer = BertTokenizerFast.from_pretrained(config.tokenizer) tokenizer.model_max_length = config.max_seq_length data_pipeline = data.PretrainingDataPipeline( glob.glob('cache/pretrain.*_of_*.tfrecord'), tokenizer, max_predictions_per_seq=config.max_predictions_per_seq) learning_rate_fn = training.create_learning_rate_scheduler( factors='constant * linear_warmup * linear_decay', base_learning_rate=config.learning_rate, warmup_steps=config.num_warmup_steps, steps_per_cycle=config.num_train_steps - config.num_warmup_steps, ) train_history = training.TrainStateHistory(learning_rate_fn) train_state = train_history.initial_state() if config.do_train: train_iter = data_pipeline.get_inputs( batch_size=config.train_batch_size, training=True) train_step_fn = training.create_train_step( compute_pretraining_loss_and_metrics, clip_grad_norm=1.0) for step, batch in zip(range(start_step, config.num_train_steps), train_iter): optimizer, train_state = train_step_fn(optimizer, batch, train_state) if jax.host_id() == 0 and (step % config.save_checkpoints_steps == 0 or step == config.num_train_steps - 1): checkpoints.save_checkpoint(output_dir, optimizer.unreplicate(), step) config_path = os.path.join(output_dir, 'config.json') if not os.path.exists(config_path): with open(config_path, 'w') as f: json.dump({'model_type': 'bert', **config.model}, f) if config.do_eval: eval_iter = data_pipeline.get_inputs(batch_size=config.eval_batch_size) eval_iter = itertools.islice(eval_iter, config.max_eval_steps) eval_fn = training.create_eval_fn(compute_pretraining_stats, sample_feature_name='input_ids') eval_stats = eval_fn(optimizer, eval_iter) eval_metrics = { 'loss': jnp.mean(eval_stats['loss']), 'masked_lm_loss': jnp.mean(eval_stats['masked_lm_loss']), 'next_sentence_loss': jnp.mean(eval_stats['next_sentence_loss']), 'masked_lm_accuracy': jnp.sum(eval_stats['masked_lm_correct']) / jnp.sum(eval_stats['masked_lm_total']), 'next_sentence_accuracy': jnp.sum(eval_stats['next_sentence_correct']) / jnp.sum(eval_stats['next_sentence_total']), } eval_results = [] for name, val in sorted(eval_metrics.items()): line = f'{name} = {val:.06f}' print(line, flush=True) eval_results.append(line) eval_results_path = os.path.join(output_dir, 'eval_results.txt') with gfile.GFile(eval_results_path, 'w') as f: for line in eval_results: f.write(line + '\n')
def main(argv): if len(argv) > 1: raise app.UsageError("Too many command-line arguments.") config = FLAGS.config input_files = sum([glob.glob(pattern) for pattern in config.input_files], []) assert input_files, "No input files!" print(f"Training with {len(input_files)} input files, including:") print(f" - {input_files[0]}") model = modeling.BertForPreTraining(config=config.model) initial_params = get_initial_params(model, init_checkpoint=config.init_checkpoint) optimizer = create_optimizer(config, initial_params) del initial_params # the optimizer takes ownership of all params output_dir = get_output_dir(config) gfile.makedirs(output_dir) # Restore from a local checkpoint, if one exists. optimizer = checkpoints.restore_checkpoint(output_dir, optimizer) if isinstance(optimizer.state, (list, tuple)): start_step = int(optimizer.state[0].step) else: start_step = int(optimizer.state.step) optimizer = optimizer.replicate() optimizer = training.harmonize_across_hosts(optimizer) data_pipeline = data.PretrainingDataPipeline( sum([glob.glob(pattern) for pattern in config.input_files], []), config.tokenizer, max_seq_length=config.max_seq_length, max_predictions_per_seq=config.max_predictions_per_seq, ) learning_rate_fn = training.create_learning_rate_scheduler( factors="constant * linear_warmup * linear_decay", base_learning_rate=config.learning_rate, warmup_steps=config.num_warmup_steps, steps_per_cycle=config.num_train_steps - config.num_warmup_steps, ) train_history = training.TrainStateHistory(learning_rate_fn) train_state = train_history.initial_state() if config.do_train: train_batch_size = config.train_batch_size if jax.host_count() > 1: assert (train_batch_size % jax.host_count() == 0 ), "train_batch_size must be divisible by number of hosts" train_batch_size = train_batch_size // jax.host_count() train_iter = data_pipeline.get_inputs(batch_size=train_batch_size, training=True) train_step_fn = training.create_train_step( model, compute_pretraining_loss_and_metrics, max_grad_norm=config.max_grad_norm, ) for step, batch in zip(range(start_step, config.num_train_steps), train_iter): optimizer, train_state = train_step_fn(optimizer, batch, train_state) if jax.host_id() == 0 and (step % config.save_checkpoints_steps == 0 or step == config.num_train_steps - 1): checkpoints.save_checkpoint(output_dir, optimizer.unreplicate(), step) config_path = os.path.join(output_dir, "config.json") if not os.path.exists(config_path): with open(config_path, "w") as f: json.dump({"model_type": "bert", **config.model}, f) tokenizer_path = os.path.join(output_dir, "sentencepiece.model") if not os.path.exists(tokenizer_path): shutil.copy(config.tokenizer, tokenizer_path) # With the current Rust data pipeline code, running more than one pipeline # at a time will lead to a hang. A simple workaround is to fully delete the # training pipeline before potentially starting another for evaluation. del train_iter if config.do_eval: eval_iter = data_pipeline.get_inputs(batch_size=config.eval_batch_size) eval_iter = itertools.islice(eval_iter, config.max_eval_steps) eval_fn = training.create_eval_fn(model, compute_pretraining_stats, sample_feature_name="input_ids") eval_stats = eval_fn(optimizer, eval_iter) eval_metrics = { "loss": jnp.mean(eval_stats["loss"]), "masked_lm_loss": jnp.mean(eval_stats["masked_lm_loss"]), "next_sentence_loss": jnp.mean(eval_stats["next_sentence_loss"]), "masked_lm_accuracy": jnp.sum(eval_stats["masked_lm_correct"]) / jnp.sum(eval_stats["masked_lm_total"]), "next_sentence_accuracy": jnp.sum(eval_stats["next_sentence_correct"]) / jnp.sum(eval_stats["next_sentence_total"]), } eval_results = [] for name, val in sorted(eval_metrics.items()): line = f"{name} = {val:.06f}" print(line, flush=True) eval_results.append(line) eval_results_path = os.path.join(output_dir, "eval_results.txt") with gfile.GFile(eval_results_path, "w") as f: for line in eval_results: f.write(line + "\n")