def resume(self, base_dir: Optional[str] = None): """ check whether chekpoint and model be within cp_dir, if in it, restore otherwise initialize randomly. """ cp_dir = os.path.join(base_dir or self._base_dir, 'model') if self._save2single_file: ckpt_path = os.path.join(cp_dir, 'checkpoint.pth') if os.path.exists(ckpt_path): checkpoint = th.load(ckpt_path, map_location=self.device) for k, v in self._trainer_modules.items(): if hasattr(v, 'load_state_dict'): self._trainer_modules[k].load_state_dict(checkpoint[k]) else: getattr(self, k).fill_(checkpoint[k]) logger.info( colorize(f'Resume model from {ckpt_path} SUCCESSFULLY.', color='green')) else: for k, v in self._trainer_modules.items(): model_path = os.path.join(cp_dir, f'{k}.pth') if os.path.exists(model_path): if hasattr(v, 'load_state_dict'): self._trainer_modules[k].load_state_dict( th.load(model_path, map_location=self.device)) else: getattr(self, k).fill_( th.load(model_path, map_location=self.device)) logger.info( colorize( f'Resume model from {model_path} SUCCESSFULLY.', color='green'))
def load_config(filename: str, not_find_error=True) -> Dict: if os.path.exists(filename): with open(filename, 'r', encoding='utf-8') as f: x = yaml.safe_load(f.read()) logger.info( colorize(f'load config from {filename} successfully', color='green')) return x or {} else: if not_find_error: raise Exception('cannot find this config.') else: logger.info( colorize( f'load config from {filename} failed, cannot find file.', color='red'))
def save_config(dicpath: str, config: Dict, filename: str) -> NoReturn: if not os.path.exists(dicpath): os.makedirs(dicpath) with open(os.path.join(dicpath, filename), 'w', encoding='utf-8') as fw: yaml.dump(config, fw) logger.info( colorize(f'save config to {dicpath} successfully', color='green'))
def get_model_info(name: str) -> Tuple[Callable, Dict, str, str]: ''' Args: name: name of algorithms Return: algo_class of the algorithm model named `name`. defaulf config of specified algorithm. policy_type of policy, `on-policy` or `off-policy` ''' algo_info = registry.get_model_info(name) class_name = algo_info['algo_class'] policy_mode = algo_info['policy_mode'] policy_type = algo_info['policy_type'] LOGO = algo_info.get('logo', '') logger.info(colorize(LOGO, color='green')) model = getattr( importlib.import_module(f'rls.algos.{policy_type}.{name}'), class_name) algo_config = {} algo_config.update( load_yaml(f'rls/algos/config.yaml')['general'] ) algo_config.update( load_yaml(f'rls/algos/config.yaml')[policy_mode.replace('-', '_')] ) algo_config.update( load_yaml(f'rls/algos/config.yaml')[name] ) return model, algo_config, policy_mode, policy_type
def check_or_create(dicpath: str, name: str = '') -> NoReturn: """ check dictionary whether existing, if not then create it. """ if not os.path.exists(dicpath): os.makedirs(dicpath) logger.info(colorize( ''.join([f'create {name} directionary :', dicpath]), color='green'))
def save_checkpoint(self, **kwargs) -> NoReturn: """ save the training model """ if not self.no_save: train_step = int(kwargs.get('train_step', 0)) self.saver.save(checkpoint_number=train_step) logger.info(colorize(f'Save checkpoint success. Training step: {train_step}', color='green')) self.write_training_info(kwargs)
def init_or_restore(self, base_dir: Optional[str] = None) -> NoReturn: """ check whether chekpoint and model be within cp_dir, if in it, restore otherwise initialize randomly. """ if base_dir is not None: cp_dir = os.path.join(base_dir, 'model') if os.path.exists(os.path.join(cp_dir, 'checkpoint')): try: ckpt = tf.train.latest_checkpoint(cp_dir) self.checkpoint.restore(ckpt).expect_partial() # 从指定路径导入模型 except: logger.error(colorize(f'restore model from {cp_dir} FAILED.', color='red')) raise Exception(f'restore model from {cp_dir} FAILED.') else: logger.info(colorize(f'restore model from {ckpt} SUCCUESS.', color='green')) else: ckpt = self.saver.latest_checkpoint self.checkpoint.restore(ckpt).expect_partial() # 从本模型目录载入模型,断点续训 logger.info(colorize(f'restore model from {ckpt} SUCCUESS.', color='green')) logger.info(colorize('initialize model SUCCUESS.', color='green'))
def get_model_info(name: str) -> Tuple[Any, Any]: """ Args: name: name of algorithms Return: class of the algorithm model whether algorithm is sarl or marl """ algo_info = registry.get_model_info(name) logger.info(colorize(algo_info.get('logo', ''), color='green')) model_class = getattr(importlib.import_module(f"rls.algorithms.{algo_info['path']}"), algo_info['class_name']) return model_class, algo_info['is_multi']
def learner(env, model, ip, port): check_port_in_use(port, ip, try_times=10, server_name='learner') assert hasattr(model, 'apex_learn') and hasattr(model, 'apex_cal_td'), 'this algorithm does not support Ape-X learning for now.' server = grpc.server(futures.ThreadPoolExecutor()) apex_learner_pb2_grpc.add_LearnerServicer_to_server(LearnerServicer(model), server) server.add_insecure_port(':'.join([ip, port])) server.start() logger.info(colorize('start learner success.', color='green')) # eval_thread = EvalThread(env, model) # eval_thread.start() # GymCollector.evaluate(env, model) server.wait_for_termination()
def buffer(ip, port, learner_ip, learner_port, buffer_args): check_port_in_use(port, ip, try_times=10, server_name='buffer') buffer = PrioritizedExperienceReplay(**buffer_args) threadLock = threading.Lock() server = grpc.server(futures.ThreadPoolExecutor()) apex_buffer_pb2_grpc.add_BufferServicer_to_server( BufferServicer(buffer=buffer, lock=threadLock), server) server.add_insecure_port(':'.join([ip, port])) server.start() logger.info(colorize('start buffer success.', color='green')) learn_thread = LearnThread(learner_ip, learner_port, buffer, threadLock) learn_thread.start() server.wait_for_termination()
def save(self) -> NoReturn: """ save the training model """ if self._is_save: _data = {} for k, v in self._trainer_modules.items(): if hasattr(v, 'state_dict'): _data[k] = v.state_dict() else: _data[k] = v # tensor/Number if self._save2single_file: th.save(_data, os.path.join(self.cp_dir, 'checkpoint.pth')) else: for k, v in _data.items(): th.save(v, os.path.join(self.cp_dir, f'{k}.pth')) logger.info( colorize( f'Save checkpoint success. Training step: {self._cur_train_step}', color='green'))
from copy import deepcopy from collections import deque from mlagents_envs.environment import UnityEnvironment from mlagents_envs.side_channel.engine_configuration_channel import EngineConfigurationChannel from mlagents_envs.side_channel.environment_parameters_channel import EnvironmentParametersChannel from rls.utils.logging_utils import get_logger from rls.utils.display import colorize logger = get_logger(__name__) try: import cv2 cv2.ocl.setUseOpenCL(False) except: logger.warning( colorize('opencv-python is needed to train visual-based model.', color='yellow')) pass from rls.common.yaml_ops import load_yaml from rls.utils.np_utils import int2action_index from rls.utils.indexs import (SingleAgentEnvArgs, MultiAgentEnvArgs) # obs: [ # brain1: [ # agent1: [], # agent2: [], # ... # agentn: [] # ], # brain2: [ # agent1: [],
#!/usr/bin/env python3 # encoding: utf-8 import gym from rls.utils.display import colorize from rls.utils.logging_utils import get_logger logger = get_logger(__name__) try: import cv2 cv2.ocl.setUseOpenCL(False) except: logger.warning( colorize('opencv-python is needed to train visual-based model.', color='yellow')) pass try: import imageio except: logger.warning( colorize('imageio should be installed to record vedio if needed.', color='yellow')) pass import numpy as np from collections import deque from gym.spaces import (Box, Discrete, Tuple)
def __init__( self, n_copies=1, is_save=True, base_dir='', device: str = 'cpu', max_train_step=sys.maxsize, max_frame_step=sys.maxsize, max_train_episode=sys.maxsize, save_frequency=100, save2single_file=False, n_step_value=4, gamma=0.999, logger_types=['none'], decay_lr=False, normalize_vector_obs=False, obs_with_pre_action=False, oplr_params=dict(), rep_net_params={ 'vector_net_params': { 'h_dim': 16, 'network_type': 'adaptive' # rls.nn.represents.vectors }, 'visual_net_params': { 'h_dim': 128, 'network_type': 'simple' # rls.nn.represents.visuals }, 'encoder_net_params': { 'h_dim': 16, 'network_type': 'identity' # rls.nn.represents.encoders }, 'memory_net_params': { 'rnn_units': 16, 'network_type': 'lstm' } }, **kwargs): """ inputs: a_dim: action spaces base_dir: the directory that store data, like model, logs, and other data """ self.n_copies = n_copies self._is_save = is_save self._base_dir = base_dir self._training_name = os.path.split(self._base_dir)[-1] self.device = device logger.info(colorize(f"PyTorch Tensor Device: {self.device}")) self._max_train_step = max_train_step self._should_learn_cond_train_step = Until(max_train_step) self._should_learn_cond_frame_step = Until(max_frame_step) self._should_learn_cond_train_episode = Until(max_train_episode) self._should_save_model = Every(save_frequency) self._save2single_file = save2single_file self.gamma = gamma self._logger_types = logger_types self._n_step_value = n_step_value self._decay_lr = decay_lr # TODO: implement self._normalize_vector_obs = normalize_vector_obs # TODO: implement self._obs_with_pre_action = obs_with_pre_action self._rep_net_params = dict(rep_net_params) self._oplr_params = dict(oplr_params) super().__init__() self.memory_net_params = rep_net_params.get('memory_net_params', { 'rnn_units': 16, 'network_type': 'lstm' }) self.use_rnn = self.memory_net_params.get('network_type', 'identity') != 'identity' self.cp_dir, self.log_dir = [ os.path.join(base_dir, i) for i in ['model', 'log'] ] if self._is_save: check_or_create(self.cp_dir, 'checkpoints(models)') self._cur_interact_step = th.tensor(0).long().to(self.device) self._cur_train_step = th.tensor(0).long().to(self.device) self._cur_frame_step = th.tensor(0).long().to(self.device) self._cur_episode = th.tensor(0).long().to(self.device) self._trainer_modules = { '_cur_train_step': self._cur_train_step, '_cur_frame_step': self._cur_frame_step, '_cur_episode': self._cur_episode } self._buffer = self._build_buffer() self._loggers = self._build_loggers() if self._is_save else list()
from rls.utils.display import colorize from rls.envs.gym_wrapper.utils import build_env from rls.utils.np_utils import get_discrete_action_list from rls.utils.specs import (ObsSpec, SingleAgentEnvArgs, ModelObservations, SingleModelInformation, GymVectorizedType, NamedTupleStaticClass) from rls.utils.logging_utils import get_logger logger = get_logger(__name__) try: import gym_minigrid except ImportError: logger.warning(colorize("import gym_minigrid failed, using 'pip3 install gym-minigrid' install it.", color='yellow')) pass try: # if wanna render, added 'renders=True' or(depends on env) 'render=True' in gym.make() function manually. import pybullet_envs except ImportError: logger.warning(colorize("import pybullet_envs failed, using 'pip3 install PyBullet' install it.", color='yellow')) pass try: import gym_donkeycar except ImportError: logger.warning(colorize("import gym_minigrid failed, using 'pip install gym_donkeycar' install it.", color='yellow')) pass