Ejemplo n.º 1
0
def main():
    eta, sd = com.init_ray()

    eta = 3 if eta == 1 else eta

    MyTrainable = TorchTrainer.as_trainable(
        data_creator=workload.data_creator,
        model_creator=workload.model_creator,
        loss_creator=workload.loss_creator,
        optimizer_creator=workload.optimizer_creator,
        training_operator_cls=workload.WLMOperator,
        config={
            "seed": sd,
            "extra_fluid_trial_resources": {}
        },
    )
    params = {
        **com.run_options(__file__),
        "stop": workload.create_stopper(),
        **setup_tune_scheduler(),
    }

    analysis = tune.run(MyTrainable, **params)

    dfs = analysis.trial_dataframes
    for logdir, df in dfs.items():
        ld = Path(logdir)
        df.to_csv(ld / "trail_dataframe.csv")
Ejemplo n.º 2
0
def main():
    num_worker, sd = com.init_ray()

    MyTrainable_SyncBOHB = TorchTrainer.as_trainable(
        data_creator=workload.data_creator,
        model_creator=workload.model_creator,
        loss_creator=workload.loss_creator,
        optimizer_creator=workload.optimizer_creator,
        config={
            "seed": sd,
            BATCH_SIZE: 64,
            "extra_fluid_trial_resources": {}
        },
    )

    params = {
        **com.run_options(__file__),
        "stop": workload.create_stopper(),
        **setup_tune_scheduler(num_worker),
    }

    analysis = tune.run(MyTrainable_SyncBOHB, **params)

    dfs = analysis.trial_dataframes
    for logdir, df in dfs.items():
        ld = Path(logdir)
        df.to_csv(ld / "trail_dataframe.csv")
Ejemplo n.º 3
0
def main():
    _, sd = com.init_ray()
    workload.init_dcgan()

    MyTrainable = TorchTrainer.as_trainable(
        data_creator=workload.data_creator,
        model_creator=workload.model_creator,
        loss_creator=workload.loss_creator,
        optimizer_creator=workload.optimizer_creator,
        training_operator_cls=workload.GANOperator,
        config={
            "seed": sd,
            **workload.static_config(),
        },
    )

    params = {
        **com.run_options(__file__),
        "stop": workload.create_stopper(),
        **setup_tune_scheduler(),
    }

    analysis = tune.run(MyTrainable, **params)

    dfs = analysis.trial_dataframes
    for logdir, df in dfs.items():
        ld = Path(logdir)
        df.to_csv(ld / "trail_dataframe.csv")