예제 #1
0
파일: demo_lib.py 프로젝트: skye/flax_bert
def get_learning_rate_fn(config, num_train_steps):
    warmup_steps = int(config.warmup_proportion * num_train_steps)
    cooldown_steps = num_train_steps - warmup_steps
    learning_rate_fn = training.create_learning_rate_scheduler(
        factors='constant * linear_warmup * linear_decay',
        base_learning_rate=config.learning_rate,
        warmup_steps=warmup_steps,
        steps_per_cycle=cooldown_steps,
    )
    return learning_rate_fn
예제 #2
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    config = get_config()

    datasets.logging.set_verbosity_error()
    dataset = datasets.load_dataset(config.dataset_path, config.dataset_name)
    os.environ['TOKENIZERS_PARALLELISM'] = 'true'
    tokenizer = BertTokenizerFast.from_pretrained(config.tokenizer)
    tokenizer.model_max_length = config.max_seq_length
    data_pipeline = data.ClassificationDataPipeline(dataset, tokenizer)

    num_train_examples = len(dataset['train'])
    num_train_steps = int(num_train_examples * config.num_train_epochs //
                          config.train_batch_size)
    warmup_steps = int(config.warmup_proportion * num_train_steps)
    cooldown_steps = num_train_steps - warmup_steps

    is_regression_task = (
        dataset['train'].features['label'].dtype == 'float32')
    if is_regression_task:
        num_classes = 1
        compute_stats = compute_regression_stats
    else:
        num_classes = dataset['train'].features['label'].num_classes
        compute_stats = compute_classification_stats

    model = create_model(config, num_classes=num_classes)
    optimizer = create_optimizer(config, model)
    optimizer = optimizer.replicate()
    del model  # don't keep a copy of the initial model

    learning_rate_fn = training.create_learning_rate_scheduler(
        factors='constant * linear_warmup * linear_decay',
        base_learning_rate=config.learning_rate,
        warmup_steps=warmup_steps,
        steps_per_cycle=cooldown_steps,
    )

    output_dir = get_output_dir(config)
    gfile.makedirs(output_dir)

    train_history = training.TrainStateHistory(learning_rate_fn)
    train_state = train_history.initial_state()

    if config.do_train:
        train_step_fn = training.create_train_step(compute_loss_and_metrics,
                                                   clip_grad_norm=1.0)
        train_iter = data_pipeline.get_inputs(
            split='train', batch_size=config.train_batch_size, training=True)

        for step, batch in zip(range(0, num_train_steps), train_iter):
            optimizer, train_state = train_step_fn(optimizer, batch,
                                                   train_state)

    if config.do_eval:
        eval_step = training.create_eval_fn(compute_stats)
        eval_results = []

        if config.dataset_path == 'glue' and config.dataset_name == 'mnli':
            validation_splits = ['validation_matched', 'validation_mismatched']
        else:
            validation_splits = ['validation']

        for split in validation_splits:
            eval_iter = data_pipeline.get_inputs(
                split=split, batch_size=config.eval_batch_size, training=False)
            eval_stats = eval_step(optimizer, eval_iter)
            eval_metric = datasets.load_metric(config.dataset_path,
                                               config.dataset_name)
            eval_metric.add_batch(predictions=eval_stats['prediction'],
                                  references=eval_stats['label'])
            eval_metrics = eval_metric.compute()
            prefix = 'eval_mismatched' if split == 'validation_mismatched' else 'eval'
            for name, val in sorted(eval_metrics.items()):
                line = f'{prefix}_{name} = {val:.06f}'
                print(line, flush=True)
                eval_results.append(line)

        eval_results_path = os.path.join(output_dir, 'eval_results.txt')
        with gfile.GFile(eval_results_path, 'w') as f:
            for line in eval_results:
                f.write(line + '\n')

    if config.do_predict:
        predict_step = training.create_eval_fn(compute_stats)
        predict_results = []

        path_map = {
            ('glue', 'cola', 'test'): 'CoLA.tsv',
            ('glue', 'mrpc', 'test'): 'MRPC.tsv',
            ('glue', 'qqp', 'test'): 'QQP.tsv',
            ('glue', 'sst2', 'test'): 'SST-2.tsv',
            ('glue', 'stsb', 'test'): 'STS-B.tsv',
            ('glue', 'mnli', 'test_matched'): 'MNLI-m.tsv',
            ('glue', 'mnli', 'test_mismatched'): 'MNLI-mm.tsv',
            ('glue', 'qnli', 'test'): 'QNLI.tsv',
            ('glue', 'rte', 'test'): 'RTE.tsv',
            # No eval on WNLI for now. BERT accuracy on WNLI is below baseline,
            # unless a special training recipe is used.
            # ('glue/wnli', 'test'): 'WNLI.tsv',
        }
        label_sets = {
            ('glue', 'cola'): ['0', '1'],
            ('glue', 'mrpc'): ['0', '1'],
            ('glue', 'qqp'): ['0', '1'],
            ('glue', 'sst2'): ['0', '1'],
            ('glue', 'mnli'): ['entailment', 'neutral', 'contradiction'],
            ('glue', 'qnli'): ['entailment', 'not_entailment'],
            ('glue', 'rte'): ['entailment', 'not_entailment'],
        }

        for path_map_key in path_map:
            candidate_dataset_path, candidate_dataset_name, split = path_map_key
            if (candidate_dataset_path != config.dataset_path
                    or candidate_dataset_name != config.dataset_name):
                continue

            predict_iter = data_pipeline.get_inputs(
                split=split, batch_size=config.eval_batch_size, training=False)
            predict_stats = predict_step(optimizer, predict_iter)
            idxs = predict_stats['idx']
            predictions = predict_stats['prediction']

            tsv_path = os.path.join(
                output_dir, path_map[config.dataset_path, config.dataset_name,
                                     split])
            with gfile.GFile(tsv_path, 'w') as f:
                f.write('index\tprediction\n')
                if is_regression_task:
                    for idx, val in zip(idxs, predictions):
                        f.write(f'{idx}\t{val:.06f}\n')
                else:
                    label_set = label_sets[config.dataset_path,
                                           config.dataset_name]
                    for idx, val in zip(idxs, predictions):
                        f.write(f'{idx}\t{label_set[val]}\n')
            print('Wrote', tsv_path)
예제 #3
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError("Too many command-line arguments.")

    config = get_config()

    datasets.logging.set_verbosity_error()
    # Workaround for https://github.com/huggingface/datasets/issues/812
    logging.getLogger("filelock").setLevel(logging.ERROR)
    dataset = datasets.load_dataset(config.dataset_path, config.dataset_name)
    data_pipeline = data.ClassificationDataPipeline(
        dataset, config.init_checkpoint, max_seq_length=config.max_seq_length)

    num_train_examples = len(dataset["train"])
    num_train_steps = int(num_train_examples * config.num_train_epochs //
                          config.train_batch_size)
    warmup_steps = int(config.warmup_proportion * num_train_steps)
    cooldown_steps = num_train_steps - warmup_steps

    is_regression_task = dataset["train"].features["label"].dtype == "float32"
    if is_regression_task:
        num_classes = 1
        compute_stats = compute_regression_stats
    else:
        num_classes = dataset["train"].features["label"].num_classes
        compute_stats = compute_classification_stats

    model = modeling.BertForSequenceClassification(config=config.model,
                                                   n_classes=num_classes)
    initial_params = get_initial_params(model,
                                        init_checkpoint=config.init_checkpoint)
    optimizer = create_optimizer(config, initial_params)
    del initial_params  # the optimizer takes ownership of all params
    optimizer = optimizer.replicate()
    optimizer = training.harmonize_across_hosts(optimizer)

    learning_rate_fn = training.create_learning_rate_scheduler(
        factors="constant * linear_warmup * linear_decay",
        base_learning_rate=config.learning_rate,
        warmup_steps=warmup_steps,
        steps_per_cycle=cooldown_steps,
    )

    output_dir = get_output_dir(config)
    gfile.makedirs(output_dir)

    train_history = training.TrainStateHistory(learning_rate_fn)
    train_state = train_history.initial_state()

    if config.do_train:
        train_step_fn = training.create_train_step(
            model,
            compute_loss_and_metrics,
            max_grad_norm=config.max_grad_norm)
        train_iter = data_pipeline.get_inputs(
            split="train", batch_size=config.train_batch_size, training=True)

        for step, batch in zip(range(0, num_train_steps), train_iter):
            optimizer, train_state = train_step_fn(optimizer, batch,
                                                   train_state)

    if config.do_eval:
        eval_step = training.create_eval_fn(model, compute_stats)
        eval_results = []

        if config.dataset_path == "glue" and config.dataset_name == "mnli":
            validation_splits = ["validation_matched", "validation_mismatched"]
        else:
            validation_splits = ["validation"]

        for split in validation_splits:
            eval_iter = data_pipeline.get_inputs(
                split=split, batch_size=config.eval_batch_size, training=False)
            eval_stats = eval_step(optimizer, eval_iter)
            eval_metric = datasets.load_metric(config.dataset_path,
                                               config.dataset_name)
            eval_metric.add_batch(predictions=eval_stats["prediction"],
                                  references=eval_stats["label"])
            eval_metrics = eval_metric.compute()
            prefix = "eval_mismatched" if split == "validation_mismatched" else "eval"
            for name, val in sorted(eval_metrics.items()):
                line = f"{prefix}_{name} = {val:.06f}"
                print(line, flush=True)
                eval_results.append(line)

        eval_results_path = os.path.join(output_dir, "eval_results.txt")
        with gfile.GFile(eval_results_path, "w") as f:
            for line in eval_results:
                f.write(line + "\n")

    if config.do_predict:
        predict_step = training.create_eval_fn(model, compute_stats)
        predict_results = []

        path_map = {
            ("glue", "cola", "test"): "CoLA.tsv",
            ("glue", "mrpc", "test"): "MRPC.tsv",
            ("glue", "qqp", "test"): "QQP.tsv",
            ("glue", "sst2", "test"): "SST-2.tsv",
            ("glue", "stsb", "test"): "STS-B.tsv",
            ("glue", "mnli", "test_matched"): "MNLI-m.tsv",
            ("glue", "mnli", "test_mismatched"): "MNLI-mm.tsv",
            ("glue", "qnli", "test"): "QNLI.tsv",
            ("glue", "rte", "test"): "RTE.tsv",
            # No eval on WNLI for now. BERT accuracy on WNLI is below baseline,
            # unless a special training recipe is used.
            # ('glue/wnli', 'test'): 'WNLI.tsv',
        }
        label_sets = {
            ("glue", "cola"): ["0", "1"],
            ("glue", "mrpc"): ["0", "1"],
            ("glue", "qqp"): ["0", "1"],
            ("glue", "sst2"): ["0", "1"],
            ("glue", "mnli"): ["entailment", "neutral", "contradiction"],
            ("glue", "qnli"): ["entailment", "not_entailment"],
            ("glue", "rte"): ["entailment", "not_entailment"],
        }

        for path_map_key in path_map:
            candidate_dataset_path, candidate_dataset_name, split = path_map_key
            if (candidate_dataset_path != config.dataset_path
                    or candidate_dataset_name != config.dataset_name):
                continue

            predict_iter = data_pipeline.get_inputs(
                split=split, batch_size=config.eval_batch_size, training=False)
            predict_stats = predict_step(optimizer, predict_iter)
            idxs = predict_stats["idx"]
            predictions = predict_stats["prediction"]

            tsv_path = os.path.join(
                output_dir, path_map[config.dataset_path, config.dataset_name,
                                     split])
            with gfile.GFile(tsv_path, "w") as f:
                f.write("index\tprediction\n")
                if is_regression_task:
                    for idx, val in zip(idxs, predictions):
                        f.write(f"{idx}\t{val:.06f}\n")
                else:
                    label_set = label_sets[config.dataset_path,
                                           config.dataset_name]
                    for idx, val in zip(idxs, predictions):
                        f.write(f"{idx}\t{label_set[val]}\n")
            print("Wrote", tsv_path)
예제 #4
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    config = FLAGS.config

    model = create_model(config)
    optimizer = create_optimizer(config, model)
    del model  # don't keep a copy of the initial model

    output_dir = get_output_dir(config)
    gfile.makedirs(output_dir)

    # Restore from a local checkpoint, if one exists.
    optimizer = checkpoints.restore_checkpoint(output_dir, optimizer)
    start_step = int(optimizer.state[0].step)

    optimizer = optimizer.replicate()

    os.environ['TOKENIZERS_PARALLELISM'] = 'true'
    tokenizer = BertTokenizerFast.from_pretrained(config.tokenizer)
    tokenizer.model_max_length = config.max_seq_length

    # The commented lines below correspond to a data pipeline that uses publicly
    # available data, in the form of English Wikipedia as processed and hosted by
    # the HuggingFace datasets library. The pipeline works, and downstream task
    # performance shows a benefit to pre-training, but I (Nikita) have yet to
    # confirm that final model quality is on par with the original BERT.
    #
    # dataset = datasets.load_dataset('wikipedia', '20200501.en')['train']
    # data_pipeline = data.PretrainingDataPipelineV1(
    #   dataset, tokenizer,
    #   max_predictions_per_seq=config.max_predictions_per_seq)

    # The data pipeline below relies on having text files of Wikipedia + Books in
    # the same format as the original BERT data. That original data is not
    # publicly available, so you will need to provide your own. I (Nikita) have
    # had success using data from Gong et al. "Efficient Training of BERT by
    # Progressively Stacking", but this data was also obtained through private
    # correspondence and may not be generally available.
    # The data_files argument may be a list, if data is split across multiple
    # input files.
    dataset = datasets.load_dataset(
        'bert_data.py',
        data_files=os.path.expanduser('~/data/bert/corpus.train.tok'))['train']
    data_pipeline = data.PretrainingDataPipeline(
        dataset,
        tokenizer,
        max_predictions_per_seq=config.max_predictions_per_seq)

    datasets.logging.set_verbosity_error()

    learning_rate_fn = training.create_learning_rate_scheduler(
        factors='constant * linear_warmup * linear_decay',
        base_learning_rate=config.learning_rate,
        warmup_steps=config.num_warmup_steps,
        steps_per_cycle=config.num_train_steps - config.num_warmup_steps,
    )

    train_history = training.TrainStateHistory(learning_rate_fn)
    train_state = train_history.initial_state()

    if config.do_train:
        train_iter = data_pipeline.get_inputs(
            batch_size=config.train_batch_size, training=True)
        train_step_fn = training.create_train_step(
            compute_pretraining_loss_and_metrics, clip_grad_norm=1.0)

        for step, batch in zip(range(start_step, config.num_train_steps),
                               train_iter):
            optimizer, train_state = train_step_fn(optimizer, batch,
                                                   train_state)
            if jax.host_id() == 0 and (step % config.save_checkpoints_steps
                                       == 0
                                       or step == config.num_train_steps - 1):
                checkpoints.save_checkpoint(output_dir,
                                            optimizer.unreplicate(), step)
                config_path = os.path.join(output_dir, 'config.json')
                if not os.path.exists(config_path):
                    with open(config_path, 'w') as f:
                        json.dump({'model_type': 'bert', **config.model}, f)

    if config.do_eval:
        eval_iter = data_pipeline.get_inputs(batch_size=config.eval_batch_size)
        eval_iter = itertools.islice(eval_iter, config.max_eval_steps)
        eval_fn = training.create_eval_fn(compute_pretraining_stats,
                                          sample_feature_name='input_ids')
        eval_stats = eval_fn(optimizer, eval_iter)

        eval_metrics = {
            'loss':
            jnp.mean(eval_stats['loss']),
            'masked_lm_loss':
            jnp.mean(eval_stats['masked_lm_loss']),
            'next_sentence_loss':
            jnp.mean(eval_stats['next_sentence_loss']),
            'masked_lm_accuracy':
            jnp.sum(eval_stats['masked_lm_correct']) /
            jnp.sum(eval_stats['masked_lm_total']),
            'next_sentence_accuracy':
            jnp.sum(eval_stats['next_sentence_correct']) /
            jnp.sum(eval_stats['next_sentence_total']),
        }

        eval_results = []
        for name, val in sorted(eval_metrics.items()):
            line = f'{name} = {val:.06f}'
            print(line, flush=True)
            eval_results.append(line)

        eval_results_path = os.path.join(output_dir, 'eval_results.txt')
        with gfile.GFile(eval_results_path, 'w') as f:
            for line in eval_results:
                f.write(line + '\n')
예제 #5
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    config = FLAGS.config

    model = create_model(config)
    optimizer = create_optimizer(config, model)
    del model  # don't keep a copy of the initial model

    output_dir = get_output_dir(config)
    gfile.makedirs(output_dir)

    # Restore from a local checkpoint, if one exists.
    optimizer = checkpoints.restore_checkpoint(output_dir, optimizer)
    start_step = int(optimizer.state[0].step)

    optimizer = optimizer.replicate()
    optimizer = training.harmonize_across_hosts(optimizer)

    os.environ['TOKENIZERS_PARALLELISM'] = 'true'
    tokenizer = BertTokenizerFast.from_pretrained(config.tokenizer)
    tokenizer.model_max_length = config.max_seq_length

    data_pipeline = data.PretrainingDataPipeline(
        glob.glob('cache/pretrain.*_of_*.tfrecord'),
        tokenizer,
        max_predictions_per_seq=config.max_predictions_per_seq)

    learning_rate_fn = training.create_learning_rate_scheduler(
        factors='constant * linear_warmup * linear_decay',
        base_learning_rate=config.learning_rate,
        warmup_steps=config.num_warmup_steps,
        steps_per_cycle=config.num_train_steps - config.num_warmup_steps,
    )

    train_history = training.TrainStateHistory(learning_rate_fn)
    train_state = train_history.initial_state()

    if config.do_train:
        train_iter = data_pipeline.get_inputs(
            batch_size=config.train_batch_size, training=True)
        train_step_fn = training.create_train_step(
            compute_pretraining_loss_and_metrics, clip_grad_norm=1.0)

        for step, batch in zip(range(start_step, config.num_train_steps),
                               train_iter):
            optimizer, train_state = train_step_fn(optimizer, batch,
                                                   train_state)
            if jax.host_id() == 0 and (step % config.save_checkpoints_steps
                                       == 0
                                       or step == config.num_train_steps - 1):
                checkpoints.save_checkpoint(output_dir,
                                            optimizer.unreplicate(), step)
                config_path = os.path.join(output_dir, 'config.json')
                if not os.path.exists(config_path):
                    with open(config_path, 'w') as f:
                        json.dump({'model_type': 'bert', **config.model}, f)

    if config.do_eval:
        eval_iter = data_pipeline.get_inputs(batch_size=config.eval_batch_size)
        eval_iter = itertools.islice(eval_iter, config.max_eval_steps)
        eval_fn = training.create_eval_fn(compute_pretraining_stats,
                                          sample_feature_name='input_ids')
        eval_stats = eval_fn(optimizer, eval_iter)

        eval_metrics = {
            'loss':
            jnp.mean(eval_stats['loss']),
            'masked_lm_loss':
            jnp.mean(eval_stats['masked_lm_loss']),
            'next_sentence_loss':
            jnp.mean(eval_stats['next_sentence_loss']),
            'masked_lm_accuracy':
            jnp.sum(eval_stats['masked_lm_correct']) /
            jnp.sum(eval_stats['masked_lm_total']),
            'next_sentence_accuracy':
            jnp.sum(eval_stats['next_sentence_correct']) /
            jnp.sum(eval_stats['next_sentence_total']),
        }

        eval_results = []
        for name, val in sorted(eval_metrics.items()):
            line = f'{name} = {val:.06f}'
            print(line, flush=True)
            eval_results.append(line)

        eval_results_path = os.path.join(output_dir, 'eval_results.txt')
        with gfile.GFile(eval_results_path, 'w') as f:
            for line in eval_results:
                f.write(line + '\n')
예제 #6
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError("Too many command-line arguments.")

    config = FLAGS.config

    input_files = sum([glob.glob(pattern) for pattern in config.input_files],
                      [])
    assert input_files, "No input files!"
    print(f"Training with {len(input_files)} input files, including:")
    print(f" - {input_files[0]}")

    model = modeling.BertForPreTraining(config=config.model)
    initial_params = get_initial_params(model,
                                        init_checkpoint=config.init_checkpoint)
    optimizer = create_optimizer(config, initial_params)
    del initial_params  # the optimizer takes ownership of all params

    output_dir = get_output_dir(config)
    gfile.makedirs(output_dir)

    # Restore from a local checkpoint, if one exists.
    optimizer = checkpoints.restore_checkpoint(output_dir, optimizer)
    if isinstance(optimizer.state, (list, tuple)):
        start_step = int(optimizer.state[0].step)
    else:
        start_step = int(optimizer.state.step)

    optimizer = optimizer.replicate()
    optimizer = training.harmonize_across_hosts(optimizer)

    data_pipeline = data.PretrainingDataPipeline(
        sum([glob.glob(pattern) for pattern in config.input_files], []),
        config.tokenizer,
        max_seq_length=config.max_seq_length,
        max_predictions_per_seq=config.max_predictions_per_seq,
    )

    learning_rate_fn = training.create_learning_rate_scheduler(
        factors="constant * linear_warmup * linear_decay",
        base_learning_rate=config.learning_rate,
        warmup_steps=config.num_warmup_steps,
        steps_per_cycle=config.num_train_steps - config.num_warmup_steps,
    )

    train_history = training.TrainStateHistory(learning_rate_fn)
    train_state = train_history.initial_state()

    if config.do_train:
        train_batch_size = config.train_batch_size
        if jax.host_count() > 1:
            assert (train_batch_size % jax.host_count() == 0
                    ), "train_batch_size must be divisible by number of hosts"
            train_batch_size = train_batch_size // jax.host_count()
        train_iter = data_pipeline.get_inputs(batch_size=train_batch_size,
                                              training=True)
        train_step_fn = training.create_train_step(
            model,
            compute_pretraining_loss_and_metrics,
            max_grad_norm=config.max_grad_norm,
        )

        for step, batch in zip(range(start_step, config.num_train_steps),
                               train_iter):
            optimizer, train_state = train_step_fn(optimizer, batch,
                                                   train_state)
            if jax.host_id() == 0 and (step % config.save_checkpoints_steps
                                       == 0
                                       or step == config.num_train_steps - 1):
                checkpoints.save_checkpoint(output_dir,
                                            optimizer.unreplicate(), step)
                config_path = os.path.join(output_dir, "config.json")
                if not os.path.exists(config_path):
                    with open(config_path, "w") as f:
                        json.dump({"model_type": "bert", **config.model}, f)
                tokenizer_path = os.path.join(output_dir,
                                              "sentencepiece.model")
                if not os.path.exists(tokenizer_path):
                    shutil.copy(config.tokenizer, tokenizer_path)

        # With the current Rust data pipeline code, running more than one pipeline
        # at a time will lead to a hang. A simple workaround is to fully delete the
        # training pipeline before potentially starting another for evaluation.
        del train_iter

    if config.do_eval:
        eval_iter = data_pipeline.get_inputs(batch_size=config.eval_batch_size)
        eval_iter = itertools.islice(eval_iter, config.max_eval_steps)
        eval_fn = training.create_eval_fn(model,
                                          compute_pretraining_stats,
                                          sample_feature_name="input_ids")
        eval_stats = eval_fn(optimizer, eval_iter)

        eval_metrics = {
            "loss":
            jnp.mean(eval_stats["loss"]),
            "masked_lm_loss":
            jnp.mean(eval_stats["masked_lm_loss"]),
            "next_sentence_loss":
            jnp.mean(eval_stats["next_sentence_loss"]),
            "masked_lm_accuracy":
            jnp.sum(eval_stats["masked_lm_correct"]) /
            jnp.sum(eval_stats["masked_lm_total"]),
            "next_sentence_accuracy":
            jnp.sum(eval_stats["next_sentence_correct"]) /
            jnp.sum(eval_stats["next_sentence_total"]),
        }

        eval_results = []
        for name, val in sorted(eval_metrics.items()):
            line = f"{name} = {val:.06f}"
            print(line, flush=True)
            eval_results.append(line)

        eval_results_path = os.path.join(output_dir, "eval_results.txt")
        with gfile.GFile(eval_results_path, "w") as f:
            for line in eval_results:
                f.write(line + "\n")