Exemplo n.º 1
0
def plot_env(algos,
             env,
             data_folder,
             legend_list=None,
             frame_scales=None,
             showlegend=False):
    legend_list = deepcopy(legend_list)
    trial_metrics_path_list = []
    for idx, algo in enumerate(algos):
        try:
            trial_metrics_path_list.append(
                get_trial_metrics_path(algo, env, data_folder))
        except Exception as e:
            if legend_list is not None:
                del legend_list[idx]
            logger.warning(f'Nothing to plot for algo: {algo}, env: {env}')
    env = guard_env_name(env)
    title = env
    if showlegend:
        graph_prepath = f'{data_folder}/{env}-legend'
    else:
        graph_prepath = f'{data_folder}/{env}'
    palette = [master_palette_dict[k] for k in legend_list]
    viz.plot_multi_trial(trial_metrics_path_list,
                         legend_list,
                         title,
                         graph_prepath,
                         ma=True,
                         name_time_pairs=[('mean_returns', 'frames')],
                         frame_scales=frame_scales,
                         palette=palette,
                         showlegend=showlegend)
Exemplo n.º 2
0
def test_logger(test_str):
    logger.critical(test_str)
    logger.debug(test_str)
    logger.error(test_str)
    logger.exception(test_str)
    logger.info(test_str)
    logger.warning(test_str)
Exemplo n.º 3
0
 def reset(self):
     if self.waiting_step:
         logger.warning(
             'Called reset() while waiting for the step to complete')
         self.step_wait()
     for pipe in self.parent_pipes:
         pipe.send(('reset', None))
     return self._decode_obses([pipe.recv() for pipe in self.parent_pipes])
Exemplo n.º 4
0
Arquivo: viz.py Projeto: kengz/SLM-Lab
def save_image(figure, filepath):
    if os.environ['PY_ENV'] == 'test':
        return
    filepath = util.smart_path(filepath)
    try:
        pio.write_image(figure, filepath, scale=2)
    except Exception as e:
        logger.warning(
            f'Failed to generate graph. Run retro-analysis to generate graphs later. {e}\nIf running on a headless server, prepend your Python command with `xvfb-run -a `, for example `xvfb-run -a python run_lab.py`'
        )
Exemplo n.º 5
0
    def check_fn(*args, **kwargs):
        if not to_check_train_step():
            return fn(*args, **kwargs)

        net = args[0]  # first arg self
        # get pre-update parameters to compare
        pre_params = [param.clone() for param in net.parameters()]

        # run train_step, get loss
        loss = fn(*args, **kwargs)
        assert not torch.isnan(loss).any(), loss

        # get post-update parameters to compare
        post_params = [param.clone() for param in net.parameters()]
        if loss == 0.0:
            # if loss is 0, there should be no updates
            # TODO if without momentum, parameters should not change too
            for p_name, param in net.named_parameters():
                assert param.grad.norm() == 0
        else:
            # check parameter updates
            try:
                assert not all(
                    torch.equal(w1, w2)
                    for w1, w2 in zip(pre_params, post_params)
                ), f'Model parameter is not updated in train_step(), check if your tensor is detached from graph. Loss: {loss:g}'
                logger.info(
                    f'Model parameter is updated in train_step(). Loss: {loss: g}'
                )
            except Exception as e:
                logger.error(e)
                if os.environ.get('PY_ENV') == 'test':
                    # raise error if in unit test
                    raise (e)

            # check grad norms
            min_norm, max_norm = 0.0, 1e5
            for p_name, param in net.named_parameters():
                try:
                    grad_norm = param.grad.norm()
                    assert min_norm < grad_norm < max_norm, f'Gradient norm for {p_name} is {grad_norm:g}, fails the extreme value check {min_norm} < grad_norm < {max_norm}. Loss: {loss:g}. Check your network and loss computation.'
                except Exception as e:
                    logger.warning(e)
            logger.info(f'Gradient norms passed value check.')
        logger.debug('Passed network parameter update check.')
        # store grad norms for debugging
        net.store_grad_norms()
        return loss
Exemplo n.º 6
0
def get_random_baseline(env_name):
    '''Get a single random baseline for env; if does not exist in file, generate live and update the file'''
    random_baseline = util.read(FILEPATH)
    if env_name in random_baseline:
        baseline = random_baseline[env_name]
    else:
        try:
            logger.info(f'Generating random baseline for {env_name}')
            baseline = gen_random_baseline(env_name, NUM_EVAL)
        except Exception as e:
            logger.warning(
                f'Cannot start env: {env_name}, skipping random baseline generation'
            )
            baseline = None
        # update immediately
        logger.info(f'Updating new random baseline in {FILEPATH}')
        random_baseline[env_name] = baseline
        util.write(random_baseline, FILEPATH)
    return baseline
Exemplo n.º 7
0
def plot_envs(algos, envs, data_folder, legend_list, frame_scales=None):
    for idx, env in enumerate(envs):
        try:
            plot_env(algos,
                     env,
                     data_folder,
                     legend_list=legend_list,
                     frame_scales=frame_scales,
                     showlegend=False)
            if idx == len(envs) - 1:
                # plot extra to crop legend out
                plot_env(algos,
                         env,
                         data_folder,
                         legend_list=legend_list,
                         frame_scales=frame_scales,
                         showlegend=True)
        except Exception as e:
            logger.warning(f'Cant plot for env: {env}. Error: {e}')
Exemplo n.º 8
0
Arquivo: viz.py Projeto: ssfve/SLM-Lab
# The data visualization module
# Defines plotting methods for analysis
from glob import glob
from plotly import graph_objs as go, io as pio, tools
from plotly.offline import init_notebook_mode, iplot
from slm_lab.lib import logger, util
import colorlover as cl
import os
import pydash as ps

logger = logger.get_logger(__name__)

# moving-average window size for plotting
PLOT_MA_WINDOW = 100
# warn orca failure only once
orca_warn_once = ps.once(lambda e: logger.warning(
    f'Failed to generate graph. Run retro-analysis to generate graphs later.'))
if util.is_jupyter():
    init_notebook_mode(connected=True)


def calc_sr_ma(sr):
    '''Calculate the moving-average of a series to be plotted'''
    return sr.rolling(PLOT_MA_WINDOW, min_periods=1).mean()


def create_label(y_col,
                 x_col,
                 title=None,
                 y_title=None,
                 x_title=None,
                 legend_name=None):
Exemplo n.º 9
0
# Defines plotting methods for analysis
from glob import glob
from plotly import graph_objs as go, io as pio, tools
from plotly.offline import init_notebook_mode, iplot
from slm_lab.lib import logger, util
import colorlover as cl
import os
import pydash as ps

logger = logger.get_logger(__name__)

# moving-average window size for plotting
PLOT_MA_WINDOW = 100
# warn orca failure only once
orca_warn_once = ps.once(lambda e: logger.warning(
    f'Failed to generate graph. Run retro-analysis to generate graphs later. {e}\nIf running on a headless server, prepend your Python command with `xvfb-run -a `, for example `xvfb-run -a python run_lab.py`'
))
if util.is_jupyter():
    init_notebook_mode(connected=True)


def calc_sr_ma(sr):
    '''Calculate the moving-average of a series to be plotted'''
    return sr.rolling(PLOT_MA_WINDOW, min_periods=1).mean()


def create_label(y_col,
                 x_col,
                 title=None,
                 y_title=None,
                 x_title=None,