Пример #1
0
 def __init__(self,
              log_dir=None,
              n_boot=5000,
              therm_frac=0.25,
              nw_include=None,
              calc_stats=True,
              filter_str=None,
              runs_np=False):
     """Initialization method."""
     self._log_dir = log_dir
     self._n_boot = n_boot
     self._therm_frac = therm_frac
     self._nw_include = nw_include
     self._calc_stats = calc_stats
     self.run_dirs = io.get_run_dirs(log_dir, filter_str, runs_np)
     self._params = io.loadz(os.path.join(self._log_dir, 'parameters.pkl'))
     self._train_weights = (
         self._params['x_scale_weight'],
         self._params['x_translation_weight'],
         self._params['x_transformation_weight'],
         self._params['v_scale_weight'],
         self._params['v_translation_weight'],
         self._params['v_transformation_weight'],
     )
     _tws_title = ', '.join((str(i) for i in self._train_weights))
     self._tws_title = f'({_tws_title})'
     self._tws_fname = ''.join((io.strf(i) for i in self._train_weights))
Пример #2
0
    def _plot_setup(self, run_params, idx=None, nw_run=True):
        """Setup for creating plots.

        Returns:
            fname (str): String containing the filename containing info about
                data.
            title_str (str): Title string to set as title of figure.
        """
        eps = run_params['eps']
        beta = run_params['beta']
        net_weights = run_params['net_weights']

        nw_str = ''.join((io.strf(i).replace('.', '') for i in net_weights))
        nws = '(' + ', '.join((str(i) for i in net_weights)) + ')'

        lf = self._params['num_steps']
        fname = f'lf{lf}'
        run_steps = run_params['run_steps']
        fname += f'_steps{run_steps}'
        title_str = (r"$N_{\mathrm{LF}} = $" + f'{lf}, '
                     r"$\beta = $" + f'{beta:.1g}, '
                     r"$\varepsilon = $" + f'{eps:.3g}')
        eps_str = f'{eps:4g}'.replace('.', '')
        fname += f'_e{eps_str}'

        if self._params.get('eps_fixed', False):
            title_str += ' (fixed)'
            fname += '_fixed'

        if any([tw == 0 for tw in self._train_weights]):
            title_str += (', ' + r"$\mathrm{nw}_{\mathrm{train}} = $" +
                          f' {self._tws_title}')
            fname += f'_train{self._tws_fname}'

        clip_value = self._params.get('clip_value', 0)
        if clip_value > 0:
            title_str += f', clip: {clip_value}'
            fname += f'_clip{clip_value}'.replace('.', '')

        if nw_run:
            title_str += ', ' + r"$\mathrm{nw}_{\mathrm{run}}=$" + f' {nws}'
            fname += f'_{nw_str}'
            #  fname += f'_{net_weights_str}'

        if idx is not None:
            fname += f'_{idx}'

        return fname, title_str