コード例 #1
0
ファイル: loggingfuncs.py プロジェクト: jfpettit/flare
 def dump_tabular(self):
     """
     Write all of the diagnostics from the current iteration.
     Writes both to stdout, and to the output file.
     """
     if proc_id() == 0:
         vals = []
         key_lens = [len(key) for key in self.log_headers]
         max_key_len = max(15, max(key_lens))
         keystr = "%" + "%d" % max_key_len
         fmt = "| " + keystr + "s | %15s |"
         n_slashes = 22 + max_key_len
         print("-" * n_slashes)
         for key in self.log_headers:
             val = self.log_current_row.get(key, "")
             valstr = "%8.3g" % val if hasattr(val, "__float__") else val
             print(fmt % (key, valstr))
             vals.append(val)
         print("-" * n_slashes, flush=True)
         if self.output_file is not None:
             if self.first_row:
                 self.output_file.write("\t".join(self.log_headers) + "\n")
             self.output_file.write("\t".join(map(str, vals)) + "\n")
             self.output_file.flush()
     self.log_current_row.clear()
     self.first_row = False
コード例 #2
0
ファイル: tblog.py プロジェクト: jfpettit/flare
    def add_hist(
        self,
        key: str,
        val: Union[torch.Tensor, list, np.array, tuple],
        step: int,
        bins: Optional[str] = "tensorflow",
    ):
        """
        Helper function to add a vector value to a tensorboard histogram. Called for you in :func:`~add_vals`

        Args:
           key (str): Key to dictionary value to write with.
           val (torch.Tensor or np.array or list or tuple): Value to log to tensorboard.
           step (int): Training step.  
           bins (str): binning style for histogram.
        """
        if proc_id() == 0:
            if isinstance(val, torch.Tensor):
                val = val.item()
            val = np.array(val)
            if len(val) > 0:
                self.writer.add_histogram(key,
                                          val,
                                          global_step=step,
                                          bins=bins)
コード例 #3
0
ファイル: loggingfuncs.py プロジェクト: jfpettit/flare
 def __init__(self, output_dir=None, output_fname="progress.txt", exp_name=None):
     """
     Initialize a Logger.
     Args:
         output_dir (string): A directory for saving results to. If 
             ``None``, defaults to a temp directory of the form
             ``/tmp/experiments/somerandomnumber``.
         output_fname (string): Name for the tab-separated-value file 
             containing metrics logged throughout a training run. 
             Defaults to ``progress.txt``. 
         exp_name (string): Experiment name. If you run multiple training
             runs and give them all the same ``exp_name``, the plotter
             will know to group them. (Use case: if you run the same
             hyperparameter configuration with multiple random seeds, you
             should give them all the same ``exp_name``.)
     """
     if proc_id() == 0:
         self.output_dir = output_dir or "/tmp/experiments/%i" % int(time.time())
         if not osp.exists(self.output_dir):
             os.makedirs(self.output_dir)
         self.output_file = open(osp.join(self.output_dir, output_fname), "w")
         atexit.register(self.output_file.close)
         print(
             colorize(
                 "Logging data to %s" % self.output_file.name, "green", bold=True
             )
         )
     else:
         self.output_dir = None
         self.output_file = None
     self.first_row = True
     self.log_headers = []
     self.log_current_row = {}
     self.exp_name = exp_name
コード例 #4
0
ファイル: loggingfuncs.py プロジェクト: jfpettit/flare
 def save_state(self, state_dict, itr=None):
     """
     Saves the state of an experiment.
     To be clear: this is about saving *state*, not logging diagnostics.
     All diagnostic logging is separate from this function. This function
     will save whatever is in ``state_dict``---usually just a copy of the
     environment---and the most recent parameters for the model you 
     previously set up saving for with ``setup_tf_saver``. 
     Call with any frequency you prefer. If you only want to maintain a
     single state and overwrite it at each call with the most recent 
     version, leave ``itr=None``. If you want to keep all of the states you
     save, provide unique (increasing) values for 'itr'.
     Args:
         state_dict (dict): Dictionary containing essential elements to
             describe the current state of training.
         itr: An int, or None. Current iteration of training.
     """
     if proc_id() == 0:
         fname = "vars.pkl" if itr is None else "vars%d.pkl" % itr
         try:
             joblib.dump(state_dict, osp.join(self.output_dir, fname))
         except:
             self.log("Warning: could not pickle state_dict.", color="red")
         if hasattr(self, "tf_saver_elements"):
             self._tf_simple_save(itr)
         if hasattr(self, "pytorch_saver_elements"):
             self._pytorch_simple_save(itr)
コード例 #5
0
 def save(self):
     """Write save dictionary to .pkl file."""
     if proc_id() == 0:
         ct = time.time()
         if len(self.saver_dict) > 0:
             pkl.dump(
                 self.saver_dict,
                 open(
                     self.out_path /
                     f"env_states_and_screens_saved_on{ct}.pkl", "wb"),
             )
コード例 #6
0
ファイル: loggingfuncs.py プロジェクト: jfpettit/flare
 def _tf_simple_save(self, itr=None):
     """
     Uses simple_save to save a trained model, plus info to make it easy
     to associated tensors to variables after restore. 
     """
     if proc_id() == 0:
         assert hasattr(
             self, "tf_saver_elements"
         ), "First have to setup saving with self.setup_tf_saver"
         fpath = "tf1_save" + ("%d" % itr if itr is not None else "")
         fpath = osp.join(self.output_dir, fpath)
         if osp.exists(fpath):
             # simple_save refuses to be useful if fpath already exists,
             # so just delete fpath if it's there.
             shutil.rmtree(fpath)
         tf.saved_model.simple_save(export_dir=fpath, **self.tf_saver_elements)
         joblib.dump(self.tf_saver_info, osp.join(fpath, "model_info.pkl"))
コード例 #7
0
ファイル: tblog.py プロジェクト: jfpettit/flare
    def add_plot(self, key: str, val: Union[torch.Tensor, np.array, list,
                                            float, int], step: int):
        """
        Helper function to add a scalar value to a tensorboard plot. Called for you in :func:`~add_vals`

        Args:
            key (str): Key to dictionary value to write with.
            val (torch.Tensor or np.array or list or float or int): Value to log to tensorboard.
            step (int): Training step. 
        """
        if proc_id() == 0:
            if isinstance(val, torch.Tensor):
                val = val.item()
                self.writer.add_scalar(key, val, global_step=step)
            elif isinstance(val, list):
                if len(val) == 1:
                    val = np.float(val)
                    self.writer.add_scalar(key, val, global_step=step)
            else:
                self.writer.add_scalar(key, val, global_step=step)
コード例 #8
0
ファイル: tblog.py プロジェクト: jfpettit/flare
    def __init__(self, fpath: Optional[str] = "flare_runs"):
        if fpath is None:
            self.fpath = "flare_runs/run_at_time_" + str(time.time())
        else:
            self.fpath = fpath

        if os.path.exists(self.fpath):
            print(
                utils.colorize(
                    f"Warning path at {self.fpath} already exists, storing info there anyway.",
                    "yellow",
                ))

        self.full_logdir = self.fpath

        if proc_id() == 0:
            os.makedirs(self.fpath, exist_ok=True)
            self.writer = SummaryWriter(log_dir=self.fpath, flush_secs=30)
            print(
                utils.colorize(f"TensorBoard Logdir: {self.full_logdir}",
                               "green"))
コード例 #9
0
ファイル: loggingfuncs.py プロジェクト: jfpettit/flare
 def save_config(self, config):
     """
     Log an experiment configuration.
     Call this once at the top of your experiment, passing in all important
     config vars as a dict. This will serialize the config to JSON, while
     handling anything which can't be serialized in a graceful way (writing
     as informative a string as possible). 
     Example use:
     .. code-block:: python
         logger = EpochLogger(**logger_kwargs)
         logger.save_config(locals())
     """
     config_json = convert_json(config)
     if self.exp_name is not None:
         config_json["exp_name"] = self.exp_name
     if proc_id() == 0:
         output = json.dumps(
             config_json, separators=(",", ":\t"), indent=4, sort_keys=True
         )
         print(colorize("Saving config:\n", color="cyan", bold=True))
         print(output)
         with open(osp.join(self.output_dir, "config.json"), "w") as out:
             out.write(output)
コード例 #10
0
ファイル: loggingfuncs.py プロジェクト: jfpettit/flare
 def _pytorch_simple_save(self, itr=None):
     """
     Saves the PyTorch model (or models).
     """
     if proc_id() == 0:
         assert hasattr(
             self, "pytorch_saver_elements"
         ), "First have to setup saving with self.setup_pytorch_saver"
         fpath = "pyt_save"
         fpath = osp.join(self.output_dir, fpath)
         fname = "model" + ("%d" % itr if itr is not None else "") + ".pt"
         fname = osp.join(fpath, fname)
         os.makedirs(fpath, exist_ok=True)
         with warnings.catch_warnings():
             warnings.simplefilter("ignore")
             # We are using a non-recommended way of saving PyTorch models,
             # by pickling whole objects (which are dependent on the exact
             # directory structure at the time of saving) as opposed to
             # just saving network weights. This works sufficiently well
             # for the purposes of Spinning Up, but you may want to do
             # something different for your personal PyTorch project.
             # We use a catch_warnings() context to avoid the warnings about
             # not being able to save the source code.
             torch.save(self.pytorch_saver_elements, fname)
コード例 #11
0
ファイル: tblog.py プロジェクト: jfpettit/flare
 def end(self):
     """flush tensorboard writer."""
     if proc_id() == 0:
         self.writer.flush()
コード例 #12
0
ファイル: loggingfuncs.py プロジェクト: jfpettit/flare
 def log(self, msg, color="green"):
     """Print a colorized message to stdout."""
     if proc_id() == 0:
         print(colorize(msg, color, bold=True))