Beispiel #1
0
def test_retry(ray_start_2_cpus):
    def train_func():
        ckpt = train.load_checkpoint()
        restored = bool(ckpt)  # Does a previous checkpoint exist?
        itr = 0
        if ckpt:
            itr = ckpt["iter"] + 1

        for i in range(itr, 4):
            if i == 2 and not restored:
                raise Exception("try to fail me")
            train.save_checkpoint(iter=i)
            train.report(test=i, training_iteration=i)

    trainer = Trainer(TestConfig(), num_workers=1)
    TestTrainable = trainer.to_tune_trainable(train_func)

    analysis = tune.run(TestTrainable, max_failures=3)
    last_ckpt = analysis.trials[0].checkpoint.value
    checkpoint_file = os.path.join(last_ckpt, TUNE_CHECKPOINT_FILE_NAME)
    assert os.path.exists(checkpoint_file)
    with open(checkpoint_file, "rb") as f:
        checkpoint = cloudpickle.load(f)
        assert checkpoint["iter"] == 3
    trial_dfs = list(analysis.trial_dataframes.values())
    assert len(trial_dfs[0]["training_iteration"]) == 4
Beispiel #2
0
def test_reuse_checkpoint(ray_start_2_cpus):
    def train_func(config):
        itr = 0
        ckpt = train.load_checkpoint()
        if ckpt is not None:
            itr = ckpt["iter"] + 1

        for i in range(itr, config["max_iter"]):
            train.save_checkpoint(iter=i)
            train.report(test=i, training_iteration=i)

    trainer = Trainer(TestConfig(), num_workers=1)
    TestTrainable = trainer.to_tune_trainable(train_func)

    [trial] = tune.run(TestTrainable, config={"max_iter": 5}).trials
    last_ckpt = trial.checkpoint.value
    checkpoint_file = os.path.join(last_ckpt, TUNE_CHECKPOINT_FILE_NAME)
    assert os.path.exists(checkpoint_file)
    with open(checkpoint_file, "rb") as f:
        checkpoint = cloudpickle.load(f)
        assert checkpoint["iter"] == 4
    analysis = tune.run(
        TestTrainable, config={"max_iter": 10}, restore=last_ckpt)
    trial_dfs = list(analysis.trial_dataframes.values())
    assert len(trial_dfs[0]["training_iteration"]) == 5
Beispiel #3
0
def test_tune_error(ray_start_2_cpus):
    def train_func(config):
        raise RuntimeError("Error in training function!")

    trainer = Trainer(TestConfig(), num_workers=1)
    TestTrainable = trainer.to_tune_trainable(train_func)

    with pytest.raises(TuneError):
        tune.run(TestTrainable)
Beispiel #4
0
def tune_linear(num_workers, num_samples):
    trainer = Trainer("torch", num_workers=num_workers)
    Trainable = trainer.to_tune_trainable(train_func)
    analysis = tune.run(Trainable,
                        num_samples=num_samples,
                        config={
                            "lr": tune.loguniform(1e-4, 1e-1),
                            "batch_size": tune.choice([4, 16, 32]),
                            "epochs": 3
                        })
    results = analysis.get_best_config(metric="loss", mode="min")
    print(results)
    return results
Beispiel #5
0
def test_tune_checkpoint(ray_start_2_cpus):
    def train_func():
        for i in range(10):
            train.report(test=i)
        train.save_checkpoint(hello="world")

    trainer = Trainer(TestConfig(), num_workers=1)
    TestTrainable = trainer.to_tune_trainable(train_func)

    [trial] = tune.run(TestTrainable).trials
    checkpoint_file = os.path.join(trial.checkpoint.value, TUNE_CHECKPOINT_FILE_NAME)
    assert os.path.exists(checkpoint_file)
    with open(checkpoint_file, "rb") as f:
        checkpoint = cloudpickle.load(f)
        assert checkpoint["hello"] == "world"
Beispiel #6
0
def tune_tensorflow_mnist(num_workers, use_gpu, num_samples):
    epochs = 2
    trainer = Trainer("tensorflow", num_workers=num_workers, use_gpu=use_gpu)
    MnistTrainable = trainer.to_tune_trainable(tensorflow_mnist_train_func)

    analysis = tune.run(MnistTrainable,
                        num_samples=num_samples,
                        config={
                            "lr": tune.loguniform(1e-4, 1e-1),
                            "batch_size": tune.choice([32, 64, 128]),
                            "epochs": epochs
                        })

    # Check that loss decreases in each trial.
    for path, df in analysis.trial_dataframes.items():
        assert df.loc[1, "loss"] < df.loc[0, "loss"]
def tune_tensorflow_mnist(num_workers, num_samples):
    trainer = Trainer(backend="tensorflow", num_workers=num_workers)
    Trainable = trainer.to_tune_trainable(train_func)
    analysis = tune.run(
        Trainable,
        num_samples=num_samples,
        config={
            "lr": tune.loguniform(1e-4, 1e-1),
            "batch_size": tune.choice([32, 64, 128]),
            "epochs": 3
        })
    best_loss = analysis.get_best_config(metric="loss", mode="min")
    best_accuracy = analysis.get_best_config(metric="accuracy", mode="max")
    print(f"Best loss config: {best_loss}")
    print(f"Best accuracy config: {best_accuracy}")
    return analysis
Beispiel #8
0
def train_mnist(test_mode=False, num_workers=1, use_gpu=False):
    trainer = Trainer(backend="torch",
                      num_workers=num_workers,
                      use_gpu=use_gpu)
    TorchTrainable = trainer.to_tune_trainable(training_loop)

    return tune.run(
        TorchTrainable,
        num_samples=1,
        config={
            "lr": tune.grid_search([1e-4, 1e-3]),
            "test_mode": test_mode,
            "batch_size": 128,
        },
        stop={"training_iteration": 2},
        verbose=1,
        metric="val_loss",
        mode="min",
        checkpoint_at_end=True,
    )
    )
    parser.add_argument("--use-gpu",
                        action="store_true",
                        default=False,
                        help="Enables GPU training")

    args, _ = parser.parse_known_args()
    if args.smoke_test:
        ray.init(num_cpus=4)
    else:
        ray.init(address=args.address)

    trainer = Trainer("torch",
                      num_workers=args.num_workers,
                      use_gpu=args.use_gpu)
    Trainable = trainer.to_tune_trainable(train_func)
    pbt_scheduler = PopulationBasedTraining(
        time_attr="training_iteration",
        metric="loss",
        mode="min",
        perturbation_interval=1,
        hyperparam_mutations={
            # distribution for resampling
            "lr": lambda: np.random.uniform(0.001, 1),
            # allow perturbations within this set of categorical values
            "momentum": [0.8, 0.9, 0.99],
        },
    )

    reporter = CLIReporter()
    reporter.add_metric_column("loss", "loss")