示例#1
0
def test_smart_path():
    rel_path = 'test/lib/test_util.py'
    fake_rel_path = 'test/lib/test_util.py_fake'
    abs_path = os.path.abspath(__file__)
    assert util.smart_path(rel_path) == abs_path
    assert util.smart_path(fake_rel_path) == abs_path + '_fake'
    assert util.smart_path(abs_path) == abs_path
    assert util.smart_path(abs_path, as_dir=True) == os.path.dirname(abs_path)
示例#2
0
文件: viz.py 项目: ronald-xie/SLM-Lab
def save_image(figure, filepath=None):
    if os.environ['PY_ENV'] == 'test':
        return
    if filepath is None:
        filepath = f'{PLOT_FILEDIR}/{ps.get(figure, "layout.title")}.png'
    filepath = util.smart_path(filepath)
    dirname, filename = os.path.split(filepath)
    try:
        cmd = f'orca graph -o {filename} \'{json.dumps(figure)}\''
        if 'linux' in sys.platform:
            cmd = 'xvfb-run -a -s "-screen 0 1400x900x24" -- ' + cmd
        Popen(cmd, cwd=dirname, shell=True, stderr=DEVNULL, stdout=DEVNULL)
        logger.info(f'Graph saved to {dirname}/{filename}')
    except Exception as e:
        logger.exception(
            'Please install orca for plotly and run retro-analysis to generate graphs.')
示例#3
0
def test_write_read_as_plain_list(test_str, filename, dtype):
    data_path = f'test/fixture/lib/util/{filename}'
    util.write(test_str, util.smart_path(data_path))
    assert os.path.exists(data_path)
    data_dict = util.read(util.smart_path(data_path))
    assert isinstance(data_dict, dtype)
示例#4
0
def test_write_read_as_df(test_df, filename, dtype):
    data_path = f'test/fixture/lib/util/{filename}'
    util.write(test_df, util.smart_path(data_path))
    assert os.path.exists(data_path)
    data_df = util.read(util.smart_path(data_path))
    assert isinstance(data_df, dtype)
示例#5
0
def load(net, model_path):
    '''Save model weights from a path into a net module'''
    device = None if torch.cuda.is_available() else 'cpu'
    net.load_state_dict(torch.load(util.smart_path(model_path), map_location=device))
    logger.info(f'Loaded model from {model_path}')
示例#6
0
def save(net, model_path):
    '''Save model weights to path'''
    torch.save(net.state_dict(), util.smart_path(model_path))
    logger.info(f'Saved model to {model_path}')
示例#7
0
def get_env_path(env_name):
    '''Get the path to Unity env binaries distributed via npm'''
    env_path = util.smart_path(f'node_modules/slm-env-{env_name}/build/{env_name}')
    env_dir = os.path.dirname(env_path)
    assert os.path.exists(env_dir), f'Missing {env_path}. See README to install from yarn.'
    return env_path
示例#8
0
def save(net, model_path):
    '''Save model weights to path'''
    torch.save(net.state_dict(), util.smart_path(model_path))
示例#9
0
def test_write_read_as_plain_list(test_str, filename, dtype):
    data_path = f'test/fixture/lib/util/{filename}'
    util.write(test_str, util.smart_path(data_path))
    assert os.path.exists(data_path)
    data_dict = util.read(util.smart_path(data_path))
    assert isinstance(data_dict, dtype)
示例#10
0
def test_write_read_as_df(test_df, filename, dtype):
    data_path = f'test/fixture/lib/util/{filename}'
    util.write(test_df, util.smart_path(data_path))
    assert os.path.exists(data_path)
    data_df = util.read(util.smart_path(data_path))
    assert isinstance(data_df, dtype)
示例#11
0
def load(net, model_path):
    '''Save model weights from a path into a net module'''
    net.load_state_dict(torch.load(util.smart_path(model_path)))
    logger.info(f'Loaded model from {model_path}')
示例#12
0
The data visualization module
TODO pie, swarm, box plots
'''

from plotly import (
    graph_objs as go,
    offline as py,
    tools,
)
from slm_lab import config
from slm_lab.lib import logger, util
import os
import plotly
import pydash as ps

PLOT_FILEDIR = util.smart_path('data')
os.makedirs(PLOT_FILEDIR, exist_ok=True)
if util.is_jupyter():
    py.init_notebook_mode(connected=True)
logger = logger.get_logger(__name__)


def plot(*args, **kwargs):
    if util.is_jupyter():
        return py.iplot(*args, **kwargs)
    else:
        kwargs.update({'auto_open': ps.get(kwargs, 'auto_open', False)})
        return py.plot(*args, **kwargs)


def save_image(figure, filepath=None):
示例#13
0
from slm_lab.lib import util
import colorlog
import logging
import os
import sys

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # mute tf warnings on optimized setup
LOG_FILEPATH = util.smart_path(f'log/{util.get_ts()}_slm_lab.log')
LOG_FORMAT = '[%(asctime)s %(levelname)s] %(message)s'
LOG_LEVEL = logging.DEBUG if bool(os.environ.get('DEBUG')) else logging.INFO


class DedentFormatter(logging.Formatter):
    '''The formatter to dedent broken python multiline string'''
    def format(self, record):
        record.msg = util.dedent(record.msg)
        return super(DedentFormatter, self).format(record)


os.makedirs(os.path.dirname(LOG_FILEPATH), exist_ok=True)
color_formatter = colorlog.ColoredFormatter(
    '%(log_color)s[%(asctime)s %(levelname)s]%(reset)s %(message)s')
fh = logging.FileHandler(LOG_FILEPATH)
sh = logging.StreamHandler(sys.stdout)
sh.setFormatter(color_formatter)

lab_logger = logging.getLogger('slm')
lab_logger.setLevel(LOG_LEVEL)
lab_logger.addHandler(fh)
lab_logger.addHandler(sh)
lab_logger.propagate = False
示例#14
0
                     showlegend=False)
            if idx == len(envs) - 1:
                # plot extra to crop legend out
                plot_env(algos,
                         env,
                         data_folder,
                         legend_list=legend_list,
                         frame_scales=frame_scales,
                         showlegend=True)
        except Exception as e:
            logger.warning(f'Cant plot for env: {env}. Error: {e}')


# Discrete
# LunarLander + Small Atari + Unity
data_folder = util.smart_path('../Desktop/benchmark/discrete')

algos = [
    'dqn',
    'ddqn_per',
    'a2c_gae',
    'a2c_nstep',
    'ppo',
    '*sac',
]
legend_list = [
    'DQN',
    'DDQN+PER',
    'A2C (GAE)',
    'A2C (n-step)',
    'PPO',
示例#15
0
from slm_lab.lib import util
import colorlog
import logging
import os
import sys

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # mute tf warnings on optimized setup
LOG_FILEPATH = util.smart_path(
    f'log/{os.environ.get("PY_ENV")}_{util.get_timestamp()}_slm_lab.log')
LOG_FORMAT = '[%(asctime)s %(levelname)s] %(message)s'
LOG_LEVEL = logging.DEBUG if bool(os.environ.get('DEBUG')) else logging.INFO


class DedentFormatter(logging.Formatter):
    '''The formatter to dedent broken python multiline string'''

    def format(self, record):
        record.msg = util.dedent(record.msg)
        return super(DedentFormatter, self).format(record)


os.makedirs(os.path.dirname(LOG_FILEPATH), exist_ok=True)
color_formatter = colorlog.ColoredFormatter(
    '%(log_color)s[%(asctime)s %(levelname)s]%(reset)s %(message)s')
fh = logging.FileHandler(LOG_FILEPATH)
sh = logging.StreamHandler(sys.stdout)
sh.setFormatter(color_formatter)

lab_logger = logging.getLogger('slm')
lab_logger.setLevel(LOG_LEVEL)
lab_logger.addHandler(fh)
示例#16
0
文件: viz.py 项目: ronald-xie/SLM-Lab
'''
from plotly import (
    graph_objs as go,
    offline as py,
    tools,
)
from slm_lab import config
from slm_lab.lib import logger, util
from subprocess import Popen, DEVNULL
import os
import plotly
import pydash as ps
import sys
import ujson as json

PLOT_FILEDIR = util.smart_path('data')
os.makedirs(PLOT_FILEDIR, exist_ok=True)
if util.is_jupyter():
    py.init_notebook_mode(connected=True)
logger = logger.get_logger(__name__)


def plot(*args, **kwargs):
    if util.is_jupyter():
        return py.iplot(*args, **kwargs)
    else:
        kwargs.update({'auto_open': ps.get(kwargs, 'auto_open', False)})
        return py.plot(*args, **kwargs)


def save_image(figure, filepath=None):