예제 #1
0
 def resume(self, base_dir: Optional[str] = None):
     """
     check whether chekpoint and model be within cp_dir, if in it, restore otherwise initialize randomly.
     """
     cp_dir = os.path.join(base_dir or self._base_dir, 'model')
     if self._save2single_file:
         ckpt_path = os.path.join(cp_dir, 'checkpoint.pth')
         if os.path.exists(ckpt_path):
             checkpoint = th.load(ckpt_path, map_location=self.device)
             for k, v in self._trainer_modules.items():
                 if hasattr(v, 'load_state_dict'):
                     self._trainer_modules[k].load_state_dict(checkpoint[k])
                 else:
                     getattr(self, k).fill_(checkpoint[k])
             logger.info(
                 colorize(f'Resume model from {ckpt_path} SUCCESSFULLY.',
                          color='green'))
     else:
         for k, v in self._trainer_modules.items():
             model_path = os.path.join(cp_dir, f'{k}.pth')
             if os.path.exists(model_path):
                 if hasattr(v, 'load_state_dict'):
                     self._trainer_modules[k].load_state_dict(
                         th.load(model_path, map_location=self.device))
                 else:
                     getattr(self, k).fill_(
                         th.load(model_path, map_location=self.device))
                 logger.info(
                     colorize(
                         f'Resume model from {model_path} SUCCESSFULLY.',
                         color='green'))
예제 #2
0
def load_config(filename: str, not_find_error=True) -> Dict:
    if os.path.exists(filename):
        with open(filename, 'r', encoding='utf-8') as f:
            x = yaml.safe_load(f.read())
        logger.info(
            colorize(f'load config from {filename} successfully',
                     color='green'))
        return x or {}
    else:
        if not_find_error:
            raise Exception('cannot find this config.')
        else:
            logger.info(
                colorize(
                    f'load config from {filename} failed, cannot find file.',
                    color='red'))
예제 #3
0
def save_config(dicpath: str, config: Dict, filename: str) -> NoReturn:
    if not os.path.exists(dicpath):
        os.makedirs(dicpath)
    with open(os.path.join(dicpath, filename), 'w', encoding='utf-8') as fw:
        yaml.dump(config, fw)
    logger.info(
        colorize(f'save config to {dicpath} successfully', color='green'))
예제 #4
0
def get_model_info(name: str) -> Tuple[Callable, Dict, str, str]:
    '''
    Args:
        name: name of algorithms
    Return:
        algo_class of the algorithm model named `name`.
        defaulf config of specified algorithm.
        policy_type of policy, `on-policy` or `off-policy`
    '''
    algo_info = registry.get_model_info(name)
    class_name = algo_info['algo_class']
    policy_mode = algo_info['policy_mode']
    policy_type = algo_info['policy_type']
    LOGO = algo_info.get('logo', '')
    logger.info(colorize(LOGO, color='green'))

    model = getattr(
        importlib.import_module(f'rls.algos.{policy_type}.{name}'),
        class_name)

    algo_config = {}
    algo_config.update(
        load_yaml(f'rls/algos/config.yaml')['general']
    )
    algo_config.update(
        load_yaml(f'rls/algos/config.yaml')[policy_mode.replace('-', '_')]
    )
    algo_config.update(
        load_yaml(f'rls/algos/config.yaml')[name]
    )
    return model, algo_config, policy_mode, policy_type
예제 #5
0
def check_or_create(dicpath: str, name: str = '') -> NoReturn:
    """
    check dictionary whether existing, if not then create it.
    """
    if not os.path.exists(dicpath):
        os.makedirs(dicpath)
        logger.info(colorize(
            ''.join([f'create {name} directionary :', dicpath]), color='green'))
예제 #6
0
파일: base.py 프로젝트: zhijie-ai/RLs
 def save_checkpoint(self, **kwargs) -> NoReturn:
     """
     save the training model 
     """
     if not self.no_save:
         train_step = int(kwargs.get('train_step', 0))
         self.saver.save(checkpoint_number=train_step)
         logger.info(colorize(f'Save checkpoint success. Training step: {train_step}', color='green'))
         self.write_training_info(kwargs)
예제 #7
0
파일: base.py 프로젝트: zhijie-ai/RLs
 def init_or_restore(self, base_dir: Optional[str] = None) -> NoReturn:
     """
     check whether chekpoint and model be within cp_dir, if in it, restore otherwise initialize randomly.
     """
     if base_dir is not None:
         cp_dir = os.path.join(base_dir, 'model')
         if os.path.exists(os.path.join(cp_dir, 'checkpoint')):
             try:
                 ckpt = tf.train.latest_checkpoint(cp_dir)
                 self.checkpoint.restore(ckpt).expect_partial()    # 从指定路径导入模型
             except:
                 logger.error(colorize(f'restore model from {cp_dir} FAILED.', color='red'))
                 raise Exception(f'restore model from {cp_dir} FAILED.')
             else:
                 logger.info(colorize(f'restore model from {ckpt} SUCCUESS.', color='green'))
     else:
         ckpt = self.saver.latest_checkpoint
         self.checkpoint.restore(ckpt).expect_partial()  # 从本模型目录载入模型,断点续训
         logger.info(colorize(f'restore model from {ckpt} SUCCUESS.', color='green'))
         logger.info(colorize('initialize model SUCCUESS.', color='green'))
예제 #8
0
def get_model_info(name: str) -> Tuple[Any, Any]:
    """
    Args:
        name: name of algorithms
    Return:
        class of the algorithm model
        whether algorithm is sarl or marl
    """
    algo_info = registry.get_model_info(name)
    logger.info(colorize(algo_info.get('logo', ''), color='green'))
    model_class = getattr(importlib.import_module(f"rls.algorithms.{algo_info['path']}"), algo_info['class_name'])
    return model_class, algo_info['is_multi']
예제 #9
0
def learner(env, model, ip, port):
    check_port_in_use(port, ip, try_times=10, server_name='learner')
    assert hasattr(model, 'apex_learn') and hasattr(model, 'apex_cal_td'), 'this algorithm does not support Ape-X learning for now.'

    server = grpc.server(futures.ThreadPoolExecutor())
    apex_learner_pb2_grpc.add_LearnerServicer_to_server(LearnerServicer(model), server)
    server.add_insecure_port(':'.join([ip, port]))
    server.start()
    logger.info(colorize('start learner success.', color='green'))

    # eval_thread = EvalThread(env, model)
    # eval_thread.start()

    # GymCollector.evaluate(env, model)
    server.wait_for_termination()
예제 #10
0
파일: buffer.py 프로젝트: zhijie-ai/RLs
def buffer(ip, port, learner_ip, learner_port, buffer_args):

    check_port_in_use(port, ip, try_times=10, server_name='buffer')

    buffer = PrioritizedExperienceReplay(**buffer_args)
    threadLock = threading.Lock()

    server = grpc.server(futures.ThreadPoolExecutor())
    apex_buffer_pb2_grpc.add_BufferServicer_to_server(
        BufferServicer(buffer=buffer, lock=threadLock), server)
    server.add_insecure_port(':'.join([ip, port]))
    server.start()
    logger.info(colorize('start buffer success.', color='green'))

    learn_thread = LearnThread(learner_ip, learner_port, buffer, threadLock)
    learn_thread.start()

    server.wait_for_termination()
예제 #11
0
 def save(self) -> NoReturn:
     """
     save the training model 
     """
     if self._is_save:
         _data = {}
         for k, v in self._trainer_modules.items():
             if hasattr(v, 'state_dict'):
                 _data[k] = v.state_dict()
             else:
                 _data[k] = v  # tensor/Number
         if self._save2single_file:
             th.save(_data, os.path.join(self.cp_dir, 'checkpoint.pth'))
         else:
             for k, v in _data.items():
                 th.save(v, os.path.join(self.cp_dir, f'{k}.pth'))
         logger.info(
             colorize(
                 f'Save checkpoint success. Training step: {self._cur_train_step}',
                 color='green'))
예제 #12
0
from copy import deepcopy
from collections import deque
from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.side_channel.engine_configuration_channel import EngineConfigurationChannel
from mlagents_envs.side_channel.environment_parameters_channel import EnvironmentParametersChannel

from rls.utils.logging_utils import get_logger
from rls.utils.display import colorize
logger = get_logger(__name__)

try:
    import cv2
    cv2.ocl.setUseOpenCL(False)
except:
    logger.warning(
        colorize('opencv-python is needed to train visual-based model.',
                 color='yellow'))
    pass

from rls.common.yaml_ops import load_yaml
from rls.utils.np_utils import int2action_index
from rls.utils.indexs import (SingleAgentEnvArgs, MultiAgentEnvArgs)

# obs: [
#     brain1: [
#         agent1: [],
#         agent2: [],
#         ...
#         agentn: []
#     ],
#     brain2: [
#         agent1: [],
예제 #13
0
#!/usr/bin/env python3
# encoding: utf-8

import gym

from rls.utils.display import colorize
from rls.utils.logging_utils import get_logger
logger = get_logger(__name__)

try:
    import cv2
    cv2.ocl.setUseOpenCL(False)
except:
    logger.warning(
        colorize('opencv-python is needed to train visual-based model.',
                 color='yellow'))
    pass

try:
    import imageio
except:
    logger.warning(
        colorize('imageio should be installed to record vedio if needed.',
                 color='yellow'))
    pass

import numpy as np

from collections import deque
from gym.spaces import (Box, Discrete, Tuple)
예제 #14
0
    def __init__(
            self,
            n_copies=1,
            is_save=True,
            base_dir='',
            device: str = 'cpu',
            max_train_step=sys.maxsize,
            max_frame_step=sys.maxsize,
            max_train_episode=sys.maxsize,
            save_frequency=100,
            save2single_file=False,
            n_step_value=4,
            gamma=0.999,
            logger_types=['none'],
            decay_lr=False,
            normalize_vector_obs=False,
            obs_with_pre_action=False,
            oplr_params=dict(),
            rep_net_params={
                'vector_net_params': {
                    'h_dim': 16,
                    'network_type': 'adaptive'  # rls.nn.represents.vectors
                },
                'visual_net_params': {
                    'h_dim': 128,
                    'network_type': 'simple'  # rls.nn.represents.visuals
                },
                'encoder_net_params': {
                    'h_dim': 16,
                    'network_type': 'identity'  # rls.nn.represents.encoders
                },
                'memory_net_params': {
                    'rnn_units': 16,
                    'network_type': 'lstm'
                }
            },
            **kwargs):
        """
        inputs:
            a_dim: action spaces
            base_dir: the directory that store data, like model, logs, and other data
        """
        self.n_copies = n_copies
        self._is_save = is_save
        self._base_dir = base_dir
        self._training_name = os.path.split(self._base_dir)[-1]
        self.device = device
        logger.info(colorize(f"PyTorch Tensor Device: {self.device}"))
        self._max_train_step = max_train_step

        self._should_learn_cond_train_step = Until(max_train_step)
        self._should_learn_cond_frame_step = Until(max_frame_step)
        self._should_learn_cond_train_episode = Until(max_train_episode)
        self._should_save_model = Every(save_frequency)

        self._save2single_file = save2single_file
        self.gamma = gamma
        self._logger_types = logger_types
        self._n_step_value = n_step_value
        self._decay_lr = decay_lr  # TODO: implement
        self._normalize_vector_obs = normalize_vector_obs  # TODO: implement
        self._obs_with_pre_action = obs_with_pre_action
        self._rep_net_params = dict(rep_net_params)
        self._oplr_params = dict(oplr_params)

        super().__init__()

        self.memory_net_params = rep_net_params.get('memory_net_params', {
            'rnn_units': 16,
            'network_type': 'lstm'
        })
        self.use_rnn = self.memory_net_params.get('network_type',
                                                  'identity') != 'identity'

        self.cp_dir, self.log_dir = [
            os.path.join(base_dir, i) for i in ['model', 'log']
        ]

        if self._is_save:
            check_or_create(self.cp_dir, 'checkpoints(models)')

        self._cur_interact_step = th.tensor(0).long().to(self.device)
        self._cur_train_step = th.tensor(0).long().to(self.device)
        self._cur_frame_step = th.tensor(0).long().to(self.device)
        self._cur_episode = th.tensor(0).long().to(self.device)

        self._trainer_modules = {
            '_cur_train_step': self._cur_train_step,
            '_cur_frame_step': self._cur_frame_step,
            '_cur_episode': self._cur_episode
        }

        self._buffer = self._build_buffer()
        self._loggers = self._build_loggers() if self._is_save else list()
예제 #15
0
from rls.utils.display import colorize
from rls.envs.gym_wrapper.utils import build_env
from rls.utils.np_utils import get_discrete_action_list
from rls.utils.specs import (ObsSpec,
                             SingleAgentEnvArgs,
                             ModelObservations,
                             SingleModelInformation,
                             GymVectorizedType,
                             NamedTupleStaticClass)
from rls.utils.logging_utils import get_logger
logger = get_logger(__name__)

try:
    import gym_minigrid
except ImportError:
    logger.warning(colorize("import gym_minigrid failed, using 'pip3 install gym-minigrid' install it.", color='yellow'))
    pass

try:
    # if wanna render, added 'renders=True' or(depends on env) 'render=True' in gym.make() function manually.
    import pybullet_envs
except ImportError:
    logger.warning(colorize("import pybullet_envs failed, using 'pip3 install PyBullet' install it.", color='yellow'))
    pass

try:
    import gym_donkeycar
except ImportError:
    logger.warning(colorize("import gym_minigrid failed, using 'pip install gym_donkeycar' install it.", color='yellow'))
    pass