Exemplo n.º 1
0
def tune_transformer(num_samples=8, gpus_per_trial=0, smoke_test=False):
    data_dir_name = "./data" if not smoke_test else "./test_data"
    data_dir = os.path.abspath(os.path.join(os.getcwd(), data_dir_name))
    if not os.path.exists(data_dir):
        os.mkdir(data_dir, 0o755)

    # Change these as needed.
    model_name = "bert-base-uncased" if not smoke_test \
        else "sshleifer/tiny-distilroberta-base"
    task_name = "rte"

    task_data_dir = os.path.join(data_dir, task_name.upper())

    num_labels = glue_tasks_num_labels[task_name]

    config = AutoConfig.from_pretrained(model_name,
                                        num_labels=num_labels,
                                        finetuning_task=task_name)

    # Download and cache tokenizer, model, and features
    print("Downloading and caching Tokenizer")
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Triggers tokenizer download to cache
    print("Downloading and caching pre-trained model")
    AutoModelForSequenceClassification.from_pretrained(
        model_name,
        config=config,
    )

    def get_model():
        return AutoModelForSequenceClassification.from_pretrained(
            model_name,
            config=config,
        )

    # Download data.
    download_data(task_name, data_dir)

    data_args = GlueDataTrainingArguments(task_name=task_name,
                                          data_dir=task_data_dir)

    train_dataset = GlueDataset(data_args,
                                tokenizer=tokenizer,
                                mode="train",
                                cache_dir=task_data_dir)
    eval_dataset = GlueDataset(data_args,
                               tokenizer=tokenizer,
                               mode="dev",
                               cache_dir=task_data_dir)

    training_args = TrainingArguments(
        output_dir=".",
        learning_rate=1e-5,  # config
        do_train=True,
        do_eval=True,
        no_cuda=gpus_per_trial <= 0,
        evaluation_strategy="epoch",
        load_best_model_at_end=True,
        num_train_epochs=2,  # config
        max_steps=-1,
        per_device_train_batch_size=16,  # config
        per_device_eval_batch_size=16,  # config
        warmup_steps=0,
        weight_decay=0.1,  # config
        logging_dir="./logs",
        skip_memory_metrics=True,
        report_to="none")

    trainer = Trainer(model_init=get_model,
                      args=training_args,
                      train_dataset=train_dataset,
                      eval_dataset=eval_dataset,
                      compute_metrics=build_compute_metrics_fn(task_name))

    tune_config = {
        "per_device_train_batch_size": 32,
        "per_device_eval_batch_size": 32,
        "num_train_epochs": tune.choice([2, 3, 4, 5]),
        "max_steps": 1 if smoke_test else -1,  # Used for smoke test.
    }

    scheduler = PopulationBasedTraining(time_attr="training_iteration",
                                        metric="eval_acc",
                                        mode="max",
                                        perturbation_interval=1,
                                        hyperparam_mutations={
                                            "weight_decay":
                                            tune.uniform(0.0, 0.3),
                                            "learning_rate":
                                            tune.uniform(1e-5, 5e-5),
                                            "per_device_train_batch_size":
                                            [16, 32, 64],
                                        })

    reporter = CLIReporter(parameter_columns={
        "weight_decay": "w_decay",
        "learning_rate": "lr",
        "per_device_train_batch_size": "train_bs/gpu",
        "num_train_epochs": "num_epochs"
    },
                           metric_columns=[
                               "eval_acc", "eval_loss", "epoch",
                               "training_iteration"
                           ])

    trainer.hyperparameter_search(
        hp_space=lambda _: tune_config,
        backend="ray",
        n_trials=num_samples,
        resources_per_trial={
            "cpu": 1,
            "gpu": gpus_per_trial
        },
        scheduler=scheduler,
        keep_checkpoints_num=1,
        checkpoint_score_attr="training_iteration",
        stop={"training_iteration": 1} if smoke_test else None,
        progress_reporter=reporter,
        local_dir="~/ray_results/",
        name="tune_transformer_pbt",
        log_to_file=True)
Exemplo n.º 2
0
def tune_transformer(num_samples=8,
                     gpus_per_trial=0,
                     smoke_test=False,
                     ray_address=None):
    ray.init(ray_address, log_to_driver=False)
    data_dir_name = "./data" if not smoke_test else "./test_data"
    data_dir = os.path.abspath(os.path.join(os.getcwd(), data_dir_name))
    if not os.path.exists(data_dir):
        os.mkdir(data_dir, 0o755)

    # Change these as needed.
    model_name = "bert-base-uncased" if not smoke_test \
        else "sshleifer/tiny-distilroberta-base"
    task_name = "rte"

    task_data_dir = os.path.join(data_dir, task_name.upper())

    # Download and cache tokenizer, model, and features
    print("Downloading and caching Tokenizer")

    # Triggers tokenizer download to cache
    AutoTokenizer.from_pretrained(model_name)
    print("Downloading and caching pre-trained model")

    # Triggers model download to cache
    AutoModelForSequenceClassification.from_pretrained(model_name)

    # Download data.
    download_data(task_name, data_dir)

    config = {
        "model_name": model_name,
        "task_name": task_name,
        "data_dir": task_data_dir,
        "per_gpu_val_batch_size": 32,
        "per_gpu_train_batch_size": tune.choice([16, 32, 64]),
        "learning_rate": tune.uniform(1e-5, 5e-5),
        "weight_decay": tune.uniform(0.0, 0.3),
        "num_epochs": tune.choice([2, 3, 4, 5]),
        "max_steps": 1 if smoke_test else -1,  # Used for smoke test.
        "wandb": {
            "project": "pbt_transformers",
            "reinit": True,
            "allow_val_change": True
        }
    }

    scheduler = PopulationBasedTraining(time_attr="training_iteration",
                                        metric="eval_acc",
                                        mode="max",
                                        perturbation_interval=1,
                                        hyperparam_mutations={
                                            "weight_decay":
                                            tune.uniform(0.0, 0.3),
                                            "learning_rate":
                                            tune.uniform(1e-5, 5e-5),
                                            "per_gpu_train_batch_size":
                                            [16, 32, 64],
                                        })

    reporter = CLIReporter(parameter_columns={
        "weight_decay": "w_decay",
        "learning_rate": "lr",
        "per_gpu_train_batch_size": "train_bs/gpu",
        "num_epochs": "num_epochs"
    },
                           metric_columns=[
                               "eval_acc", "eval_loss", "epoch",
                               "training_iteration"
                           ])

    analysis = tune.run(train_transformer,
                        resources_per_trial={
                            "cpu": 1,
                            "gpu": gpus_per_trial
                        },
                        config=config,
                        num_samples=num_samples,
                        scheduler=scheduler,
                        keep_checkpoints_num=3,
                        checkpoint_score_attr="training_iteration",
                        stop={"training_iteration": 1} if smoke_test else None,
                        progress_reporter=reporter,
                        local_dir="~/ray_results/",
                        name="tune_transformer_pbt")

    if not smoke_test:
        test_best_model(analysis, config["model_name"], config["task_name"],
                        config["data_dir"])