Esempio n. 1
0
def test_meta_train_splitting_v2_multi_gpu_recurrent() -> None:
    """
    Runs meta training with meta splitting networks v2 and compares reward curves
    against saved baseline for a MetaWorld ML10 benchmark, running multiple processes
    and on the GPU with a recurrent policy.
    """

    # Load default training config.
    with open(META_SPLITTING_V2_CONFIG_PATH, "r") as config_file:
        config = json.load(config_file)

    # Modify default training config.
    config["meta_train_config"]["num_updates"] //= MP_FACTOR
    config["meta_train_config"]["num_processes"] *= MP_FACTOR
    config["meta_test_config"]["num_updates"] //= MP_FACTOR
    config["meta_test_config"]["num_processes"] *= MP_FACTOR
    config["cuda"] = True
    config["meta_train_config"]["architecture_config"]["recurrent"] = True
    config["meta_train_config"]["architecture_config"][
        "recurrent_hidden_size"] = 64
    config["meta_train_config"][
        "baseline_metrics_filename"] = "meta_splitting_v2_multi_gpu_recurrent"

    # Run training.
    meta_train(config)
Esempio n. 2
0
def test_meta_train_splitting_v2() -> None:
    """
    Runs meta training with meta splitting networks v2 and compares reward curves
    against saved baseline for a MetaWorld ML10 benchmark.
    """

    # Load default training config.
    with open(META_SPLITTING_V2_CONFIG_PATH, "r") as config_file:
        config = json.load(config_file)

    # Modify default training config.
    config["meta_train_config"][
        "baseline_metrics_filename"] = "meta_splitting_v2"

    # Run training.
    meta_train(config)
Esempio n. 3
0
def test_meta_train_splitting_v1_recurrent() -> None:
    """
    Runs meta training with meta splitting networks v1 and compares reward curves
    against saved baseline for a MetaWorld ML10 benchmark, with a recurrent policy.
    """

    # Load default training config.
    with open(META_SPLITTING_V1_CONFIG_PATH, "r") as config_file:
        config = json.load(config_file)

    # Modify default training config.
    config["meta_train_config"]["architecture_config"]["recurrent"] = True
    config["meta_train_config"]["architecture_config"][
        "recurrent_hidden_size"] = 64
    config["meta_train_config"][
        "baseline_metrics_filename"] = "meta_splitting_v1_recurrent"

    # Run training.
    meta_train(config)
Esempio n. 4
0
def test_meta_train_splitting_v1_multi() -> None:
    """
    Runs meta training with meta splitting networks v1 and compares reward curves
    against saved baseline for a MetaWorld ML10 benchmark, running multiple processes.
    """

    # Load default training config.
    with open(META_SPLITTING_V1_CONFIG_PATH, "r") as config_file:
        config = json.load(config_file)

    # Modify default training config.
    config["meta_train_config"]["num_updates"] //= MP_FACTOR
    config["meta_train_config"]["num_processes"] *= MP_FACTOR
    config["meta_test_config"]["num_updates"] //= MP_FACTOR
    config["meta_test_config"]["num_processes"] *= MP_FACTOR
    config["meta_train_config"][
        "baseline_metrics_filename"] = "meta_splitting_v1_multi"

    # Run training.
    meta_train(config)
Esempio n. 5
0
if __name__ == "__main__":

    # Parse config filename from command line arguments.
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "command",
        type=str,
        help="Command to run. Either 'train' or 'tune'.",
    )
    parser.add_argument(
        "config_filename",
        type=str,
        help="Name of config file to load from.",
    )
    args = parser.parse_args()

    # Load config file.
    with open(args.config_filename, "r") as config_file:
        config = json.load(config_file)

    # Run specified command.
    if args.command == "train":
        train(config)
    elif args.command == "tune":
        tune(config)
    elif args.command == "meta_train":
        meta_train(config)
    else:
        raise ValueError("Unsupported command: '%s'" % args.command)