def save_config(self, config): """ Log an experiment configuration. Call this once at the top of your experiment, passing in all important config vars as a dict. This will serialize the config to JSON, while handling anything which can't be serialized in a graceful way (writing as informative a string as possible). Example use: .. code-block:: python logger = EpochLogger(**logger_kwargs) logger.save_config(locals()) """ config_json = convert_json(config) if self.exp_name is not None: config_json["exp_name"] = self.exp_name if proc_id() == 0: output = json.dumps(config_json, separators=(",", ":\t"), indent=4, sort_keys=True) print(colorize("Saving config:\n", color="cyan", bold=True)) print(output) with open(osp.join(self.output_dir, "config.json"), "w") as out: out.write(output)
def print(self): """Print a helpful report about the experiment grid.""" print("=" * DIV_LINE_WIDTH) # Prepare announcement at top of printing. If the ExperimentGrid has a # short name, write this as one line. If the name is long, break the # announcement over two lines. base_msg = "ExperimentGrid %s runs over parameters:\n" name_insert = "[" + self._name + "]" if len(base_msg % name_insert) <= 80: msg = base_msg % name_insert else: msg = base_msg % (name_insert + "\n") print(colorize(msg, color="green", bold=True)) # List off parameters, shorthands, and possible values. for k, v, sh in zip(self.keys, self.vals, self.shs): color_k = colorize(k.ljust(40), color="cyan", bold=True) print("", color_k, "[" + sh + "]" if sh is not None else "", "\n") for i, val in enumerate(v): print("\t" + str(convert_json(val))) print() # Count up the number of variants. The number counting seeds # is the total number of experiments that will run; the number not # counting seeds is the total number of otherwise-unique configs # being investigated. nvars_total = int(np.prod([len(v) for v in self.vals])) if "seed" in self.keys: num_seeds = len(self.vals[self.keys.index("seed")]) nvars_seedless = int(nvars_total / num_seeds) else: nvars_seedless = nvars_total print(" Variants, counting seeds: ".ljust(40), nvars_total) print(" Variants, not counting seeds: ".ljust(40), nvars_seedless) print() print("=" * DIV_LINE_WIDTH)
def call_experiment(exp_name, thunk, seed=0, num_cpu=1, data_dir=None, datestamp=False, **kwargs): """ Run a function (thunk) with hyperparameters (kwargs), plus configuration. This wraps a few pieces of functionality which are useful when you want to run many experiments in sequence, including logger configuration and splitting into multiple processes for MPI. There's also a Fired Up specific convenience added into executing the thunk: if ``env_name`` is one of the kwargs passed to call_experiment, it's assumed that the thunk accepts an argument called ``env_fn``, and that the ``env_fn`` should make a gym environment with the given ``env_name``. The way the experiment is actually executed is slightly complicated: the function is serialized to a string, and then ``run_entrypoint.py`` is executed in a subprocess call with the serialized string as an argument. ``run_entrypoint.py`` unserializes the function call and executes it. We choose to do it this way---instead of just calling the function directly here---to avoid leaking state between successive experiments. Args: exp_name (string): Name for experiment. thunk (callable): A python function. seed (int): Seed for random number generators. num_cpu (int): Number of MPI processes to split into. Also accepts 'auto', which will set up as many procs as there are cpus on the machine. data_dir (string): Used in configuring the logger, to decide where to store experiment results. Note: if left as None, data_dir will default to ``DEFAULT_DATA_DIR`` from ``fireup/user_config.py``. **kwargs: All kwargs to pass to thunk. """ # Determine number of CPU cores to run on num_cpu = psutil.cpu_count(logical=False) if num_cpu == "auto" else num_cpu # Send random seed to thunk kwargs["seed"] = seed # Be friendly and print out your kwargs, so we all know what's up print(colorize("Running experiment:\n", color="cyan", bold=True)) print(exp_name + "\n") print(colorize("with kwargs:\n", color="cyan", bold=True)) kwargs_json = convert_json(kwargs) print( json.dumps(kwargs_json, separators=(",", ":\t"), indent=4, sort_keys=True)) print("\n") # Set up logger output directory if "logger_kwargs" not in kwargs: kwargs["logger_kwargs"] = setup_logger_kwargs(exp_name, seed, data_dir, datestamp) else: print("Note: Call experiment is not handling logger_kwargs.\n") def thunk_plus(): # Make 'env_fn' from 'env_name' if "env_name" in kwargs: import gym env_name = kwargs["env_name"] kwargs["env_fn"] = lambda: gym.make(env_name) del kwargs["env_name"] # Fork into multiple processes mpi_fork(num_cpu) # Run thunk thunk(**kwargs) # Prepare to launch a script to run the experiment pickled_thunk = cloudpickle.dumps(thunk_plus) encoded_thunk = base64.b64encode( zlib.compress(pickled_thunk)).decode("utf-8") entrypoint = osp.join(osp.abspath(osp.dirname(__file__)), "run_entrypoint.py") cmd = [ sys.executable if sys.executable else "python", entrypoint, encoded_thunk ] try: subprocess.check_call(cmd, env=os.environ) except CalledProcessError: err_msg = ("\n" * 3 + "=" * DIV_LINE_WIDTH + "\n" + dedent(""" There appears to have been an error in your experiment. Check the traceback above to see what actually went wrong. The traceback below, included for completeness (but probably not useful for diagnosing the error), shows the stack leading up to the experiment launch. """) + "=" * DIV_LINE_WIDTH + "\n" * 3) print(err_msg) raise # Tell the user about where results are, and how to check them logger_kwargs = kwargs["logger_kwargs"] plot_cmd = "python3 -m fireup.run plot " + logger_kwargs["output_dir"] plot_cmd = colorize(plot_cmd, "green") test_cmd = "python3 -m fireup.run test_policy " + logger_kwargs[ "output_dir"] test_cmd = colorize(test_cmd, "green") output_msg = ("\n" * 5 + "=" * DIV_LINE_WIDTH + "\n" + dedent("""\ End of experiment. Plot results from this run with: %s Watch the trained agent with: %s """ % (plot_cmd, test_cmd)) + "=" * DIV_LINE_WIDTH + "\n" * 5) print(output_msg)