Example #1
0
    def dump_tabular(self, console=True):
        """
        Write all of the diagnostics from the current iteration.

        Writes both to stdout, and to the output file.
        """
        if proc_id() == 0:
            vals = []
            key_lens = [len(key) for key in self.log_headers]
            max_key_len = max(15, max(key_lens))
            keystr = '%' + '%d' % max_key_len
            fmt = "| " + keystr + "s | %15s |"
            n_slashes = 22 + max_key_len
            if console:
                print("-" * n_slashes)
            for key in self.log_headers:
                val = self.log_current_row.get(key, "")
                valstr = "%8.3g" % val if hasattr(val, "__float__") else val
                if console:
                    print(fmt % (key, valstr))
                vals.append(val)
            if console:
                print("-" * n_slashes)
            if self.output_file is not None:
                if self.first_row:
                    self.output_file.write("\t".join(self.log_headers) + "\n")
                self.output_file.write("\t".join(map(str, vals)) + "\n")
                self.output_file.flush()
        self.log_current_row.clear()
        self.first_row = False
Example #2
0
    def save_state(self, state_dict, itr=None):
        """
        Saves the state of an experiment.

        To be clear: this is about saving *state*, not logging diagnostics.
        All diagnostic logging is separate from this function. This function
        will save whatever is in ``state_dict``---usually just a copy of the
        environment---and the most recent parameters for the model you
        previously set up saving for with ``setup_tf_saver``.

        Call with any frequency you prefer. If you only want to maintain a
        single state and overwrite it at each call with the most recent
        version, leave ``itr=None``. If you want to keep all of the states you
        save, provide unique (increasing) values for 'itr'.

        Args:
            state_dict (dict): Dictionary containing essential elements to
                describe the current state of training.

            itr: An int, or None. Current iteration of training.
        """
        if proc_id() == 0:
            fname = 'vars.pkl' if itr is None else 'vars%d.pkl' % itr
            try:
                joblib.dump(state_dict, osp.join(self.output_dir, fname))
            except BaseException:
                self.log('Warning: could not pickle state_dict.', color='red')
            if hasattr(self, 'tf_saver_elements'):
                self._tf_simple_save(itr)
Example #3
0
    def __init__(self,
                 output_dir=None,
                 output_fname='progress.txt',
                 exp_name=None,
                 phase='train',
                 console_out=False):
        """
        Initialize a Logger.

        Args:
            output_dir (string): A directory for saving results to. If
                ``None``, defaults to a temp directory of the form
                ``/tmp/experiments/somerandomnumber``.

            output_fname (string): Name for the tab-separated-value file
                containing metrics logged throughout a training run.
                Defaults to ``progress.txt``.

            exp_name (string): Experiment name. If you run multiple training
                runs and give them all the same ``exp_name``, the plotter
                will know to group them. (Use case: if you run the same
                hyperparameter configuration with multiple random seeds, you
                should give them all the same ``exp_name``.)
        """
        if proc_id() == 0:
            self.output_dir = output_dir or "/tmp/experiments/%i" % int(
                time.time())
            if osp.exists(self.output_dir):
                if console_out:
                    print(
                        "Warning: Main exp file dir %s already exists! Storing info there anyway."
                        % self.output_dir)
            else:
                os.makedirs(self.output_dir)

            output_path = osp.join(self.output_dir, output_fname)
            # TODO: Deal with this later.
            # Check if training progress already exists, don't overwrite it.
            if osp.exists(output_path) and output_fname == 'progress.txt':
                raise ValueError(
                    colorize('Experiment files %s already exists!' %
                             output_path,
                             'yellow',
                             bold=True))
            self.output_file = open(output_path, 'w')
            atexit.register(self.output_file.close)
            #print(colorize("Logging data to %s"%self.output_file.name, 'green', bold=True))
        else:
            self.output_dir = None
            self.output_file = None
        self.first_row = True
        self.log_headers = []
        self.log_current_row = {}
        self.exp_name = exp_name
Example #4
0
 def _tf_simple_save(self, itr=None):
     """
     Uses simple_save to save a trained model, plus info to make it easy
     to associated tensors to variables after restore.
     """
     if proc_id() == 0:
         assert hasattr(self, 'tf_saver_elements'), \
             "First have to setup saving with self.setup_tf_saver"
         fpath = 'simple_save' + ('%d' % itr if itr is not None else '')
         fpath = osp.join(self.output_dir, fpath)
         if osp.exists(fpath):
             # simple_save refuses to be useful if fpath already exists,
             # so just delete fpath if it's there.
             shutil.rmtree(fpath)
         tf.saved_model.simple_save(export_dir=fpath,
                                    **self.tf_saver_elements)
         joblib.dump(self.tf_saver_info, osp.join(fpath, 'model_info.pkl'))
Example #5
0
    def save_config(self, config, append=False):
        """
        Log an experiment configuration.

        Call this once at the top of your experiment, passing in all important
        config vars as a dict. This will serialize the config to JSON, while
        handling anything which can't be serialized in a graceful way (writing
        as informative a string as possible).

        Example use:

        .. code-block:: python

            logger = EpochLogger(**logger_kwargs)
            logger.save_config(locals())
        """
        config_json = convert_json(config)
        if self.exp_name is not None:
            config_json['exp_name'] = self.exp_name
        if proc_id() == 0:
            # Drop context sampler config
            if 'context_sampler' in config_json:
                config_json.pop('context_sampler')
            output = json.dumps(config_json,
                                separators=(',', ':\t'),
                                indent=4,
                                sort_keys=True)
            # print(colorize('Saving config:\n', color='cyan', bold=True))
            # print(output)
            if append:
                with open(osp.join(self.output_dir, "config.json"),
                          'a+') as out:
                    out.write(output)
            else:
                with open(osp.join(self.output_dir, "config.json"),
                          'w') as out:
                    out.write(output)
Example #6
0
 def log(self, msg, color='green'):
     """Print a colorized message to stdout."""
     if proc_id() == 0:
         print(colorize(msg, color, bold=True))