Esempio n. 1
0
from fireup.utils.run_utils import ExperimentGrid
from fireup import ppo
import torch

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--cpu', type=int, default=4)
    parser.add_argument('--num_runs', type=int, default=3)
    args = parser.parse_args()

    eg = ExperimentGrid(name='ppo-bench')
    eg.add('env_name', 'CartPole-v0', '', True)
    eg.add('seed', [10*i for i in range(args.num_runs)])
    eg.add('epochs', 10)
    eg.add('steps_per_epoch', 4000)
    eg.add('ac_kwargs:hidden_sizes', [(32,), (64,64)], 'hid')
    eg.add('ac_kwargs:activation', [torch.tanh, torch.relu], '')
    eg.run(ppo, num_cpu=args.cpu)
Esempio n. 2
0
def parse_and_execute_grid_search(cmd, args):
    """Interprets algorithm name and cmd line args into an ExperimentGrid."""

    # Parse which algorithm to execute
    algo = eval("fireup." + cmd)

    # Before all else, check to see if any of the flags is 'help'.
    valid_help = ["--help", "-h", "help"]
    if any([arg in valid_help for arg in args]):
        print("\n\nShowing docstring for fireup." + cmd + ":\n")
        print(algo.__doc__)
        sys.exit()

    def process(arg):
        # Process an arg by eval-ing it, so users can specify more
        # than just strings at the command line (eg allows for
        # users to give functions as args).
        try:
            return eval(arg)
        except:
            return arg

    # Make first pass through args to build base arg_dict. Anything
    # with a '--' in front of it is an argument flag and everything after,
    # until the next flag, is a possible value.
    arg_dict = dict()
    for i, arg in enumerate(args):
        assert i > 0 or "--" in arg, friendly_err(
            "You didn't specify a first flag.")
        if "--" in arg:
            arg_key = arg.lstrip("-")
            arg_dict[arg_key] = []
        else:
            arg_dict[arg_key].append(process(arg))

    # Make second pass through, to catch flags that have no vals.
    # Assume such flags indicate that a boolean parameter should have
    # value True.
    for k, v in arg_dict.items():
        if len(v) == 0:
            v.append(True)

    # Third pass: check for user-supplied shorthands, where a key has
    # the form --keyname[kn]. The thing in brackets, 'kn', is the
    # shorthand. NOTE: modifying a dict while looping through its
    # contents is dangerous, and breaks in 3.6+. We loop over a fixed list
    # of keys to avoid this issue.
    given_shorthands = dict()
    fixed_keys = list(arg_dict.keys())
    for k in fixed_keys:
        p1, p2 = k.find("["), k.find("]")
        if p1 >= 0 and p2 >= 0:
            # Both '[' and ']' found, so shorthand has been given
            k_new = k[:p1]
            shorthand = k[p1 + 1:p2]
            given_shorthands[k_new] = shorthand
            arg_dict[k_new] = arg_dict[k]
            del arg_dict[k]

    # Penultimate pass: sugar. Allow some special shortcuts in arg naming,
    # eg treat "env" the same as "env_name". This is super specific
    # to Fired Up implementations, and may be hard to maintain.
    # These special shortcuts are described by SUBSTITUTIONS.
    for special_name, true_name in SUBSTITUTIONS.items():
        if special_name in arg_dict:
            # swap it in arg dict
            arg_dict[true_name] = arg_dict[special_name]
            del arg_dict[special_name]

        if special_name in given_shorthands:
            # point the shortcut to the right name
            given_shorthands[true_name] = given_shorthands[special_name]
            del given_shorthands[special_name]

    # Final pass: check for the special args that go to the 'run' command
    # for an experiment grid, separate them from the arg dict, and make sure
    # that they have unique values. The special args are given by RUN_KEYS.
    run_kwargs = dict()
    for k in RUN_KEYS:
        if k in arg_dict:
            val = arg_dict[k]
            assert len(val) == 1, friendly_err(
                "You can only provide one value for %s." % k)
            run_kwargs[k] = val[0]
            del arg_dict[k]

    # Determine experiment name. If not given by user, will be determined
    # by the algorithm name.
    if "exp_name" in arg_dict:
        assert len(arg_dict["exp_name"]) == 1, friendly_err(
            "You can only provide one value for exp_name.")
        exp_name = arg_dict["exp_name"][0]
        del arg_dict["exp_name"]
    else:
        exp_name = "cmd_" + cmd

    # Make sure that if num_cpu > 1, the algorithm being used is compatible
    # with MPI.
    if "num_cpu" in run_kwargs and not (run_kwargs["num_cpu"] == 1):
        assert cmd in MPI_COMPATIBLE_ALGOS, friendly_err(
            "This algorithm can't be run with num_cpu > 1.")

    # Special handling for environment: make sure that env_name is a real,
    # registered gym environment.
    valid_envs = [e.id for e in list(gym.envs.registry.all())]
    assert "env_name" in arg_dict, friendly_err(
        "You did not give a value for --env_name! Add one and try again.")
    for env_name in arg_dict["env_name"]:
        err_msg = dedent("""

            %s is not registered with Gym.

            Recommendations:

                * Check for a typo (did you include the version tag?)

                * View the complete list of valid Gym environments at

                    https://gym.openai.com/envs/

            """ % env_name)
        assert env_name in valid_envs, err_msg

    # Construct and execute the experiment grid.
    eg = ExperimentGrid(name=exp_name)
    for k, v in arg_dict.items():
        eg.add(k, v, shorthand=given_shorthands.get(k))
    eg.run(algo, **run_kwargs)
Esempio n. 3
0
from fireup.utils.run_utils import ExperimentGrid
from fireup import ppo
import torch

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--cpu", type=int, default=4)
    parser.add_argument("--num_runs", type=int, default=3)
    args = parser.parse_args()

    eg = ExperimentGrid(name="ppo-bench")
    eg.add("env_name", "CartPole-v0", "", True)
    eg.add("seed", [10 * i for i in range(args.num_runs)])
    eg.add("epochs", 10)
    eg.add("steps_per_epoch", 4000)
    eg.add("ac_kwargs:hidden_sizes", [(32, ), (64, 64)], "hid")
    eg.add("ac_kwargs:activation", [torch.tanh, torch.relu], "")
    eg.run(ppo, num_cpu=args.cpu)