Ejemplo n.º 1
0
    def _init_env(self):
        if self.base_config['env_type'] == 'UNITY':
            from algorithm.env_wrapper.unity_wrapper import UnityWrapper

            if self.run_in_editor:
                self.env = UnityWrapper(train_mode=self.train_mode,
                                        n_agents=self.base_config['n_agents'])
            else:
                self.env = UnityWrapper(
                    train_mode=self.train_mode,
                    file_name=self.base_config['build_path'][sys.platform],
                    base_port=self.base_config['port'],
                    no_graphics=self.base_config['no_graphics']
                    and not self.render,
                    scene=self.base_config['scene'],
                    additional_args=self.additional_args,
                    n_agents=self.base_config['n_agents'])

        elif self.base_config['env_type'] == 'GYM':
            from algorithm.env_wrapper.gym_wrapper import GymWrapper

            self.env = GymWrapper(train_mode=self.train_mode,
                                  env_name=self.base_config['build_path'],
                                  render=self.render,
                                  n_agents=self.base_config['n_agents'])
        else:
            raise RuntimeError(
                f'Undefined Environment Type: {self.base_config["env_type"]}')

        self.obs_shapes, self.d_action_size, self.c_action_size = self.env.init(
        )
        self.action_size = self.d_action_size + self.c_action_size

        self._logger.info(f'{self.base_config["build_path"]} initialized')
    def _init_env(self, config_path, replay_config, sac_config,
                  model_root_path):
        self._stub = StubController(self.net_config)

        if self.config['env_type'] == 'UNITY':
            from algorithm.env_wrapper.unity_wrapper import UnityWrapper

            if self.run_in_editor:
                self.env = UnityWrapper(train_mode=self.train_mode,
                                        base_port=5004)
            else:
                self.env = UnityWrapper(
                    train_mode=self.train_mode,
                    file_name=self.config['build_path'][sys.platform],
                    base_port=self.config['build_port'],
                    scene=self.config['scene'],
                    n_agents=self.config['n_agents'])

        elif self.config['env_type'] == 'GYM':
            from algorithm.env_wrapper.gym_wrapper import GymWrapper

            self.env = GymWrapper(train_mode=self.train_mode,
                                  env_name=self.config['build_path'],
                                  render=self.render,
                                  n_agents=self.config['n_agents'])
        else:
            raise RuntimeError(
                f'Undefined Environment Type: {self.config["env_type"]}')

        self.obs_dims, self.action_dim, is_discrete = self.env.init()

        self.logger.info(f'{self.config["build_path"]} initialized')

        # If model exists, load saved model, or copy a new one
        if os.path.isfile(f'{model_root_path}/sac_model.py'):
            custom_sac_model = importlib.import_module(
                f'{model_root_path.replace("/",".")}.sac_model')
        else:
            custom_sac_model = importlib.import_module(
                f'{config_path.replace("/",".")}.{self.config["sac"]}')
            shutil.copyfile(f'{config_path}/{self.config["sac"]}.py',
                            f'{model_root_path}/sac_model.py')

        self.sac = SAC_DS_Base(obs_dims=self.obs_dims,
                               action_dim=self.action_dim,
                               is_discrete=is_discrete,
                               model_root_path=model_root_path,
                               model=custom_sac_model,
                               train_mode=self.train_mode,
                               last_ckpt=self.last_ckpt,
                               **sac_config)
Ejemplo n.º 3
0
class Main(object):
    def __init__(self, root_dir, config_dir, args):
        """
        config_path: the directory of config file
        args: command arguments generated by argparse
        """
        self._logger = logging.getLogger('test_env')

        config_abs_dir = self._init_config(root_dir, config_dir, args)

        self._init_env()
        self._run()

    def _init_config(self, root_dir, config_dir, args):
        config_abs_dir = Path(root_dir).joinpath(config_dir)
        config_abs_path = config_abs_dir.joinpath('config.yaml')
        default_config_abs_path = Path(__file__).resolve().parent.joinpath(
            'algorithm', 'default_config.yaml')
        # Merge default_config.yaml and custom config.yaml
        config = config_helper.initialize_config_from_yaml(
            default_config_abs_path, config_abs_path, args.config)

        # Initialize config from command line arguments
        self.train_mode = not args.run
        self.render = args.render
        self.run_in_editor = args.editor
        self.additional_args = args.additional_args
        self.save_image = args.save_image

        if args.port is not None:
            config['base_config']['port'] = args.port
        if args.agents is not None:
            config['base_config']['n_agents'] = args.agents

        config['base_config']['name'] = config_helper.generate_base_name(
            config['base_config']['name'])

        # The absolute directory of a specific training
        model_abs_dir = Path(root_dir).joinpath('models',
                                                config['base_config']['scene'],
                                                config['base_config']['name'])
        model_abs_dir.mkdir(parents=True, exist_ok=True)
        self.model_abs_dir = model_abs_dir

        config_helper.display_config(config, self._logger)

        self.base_config = config['base_config']
        self.reset_config = config['reset_config']

        return config_abs_dir

    def _init_env(self):
        if self.base_config['env_type'] == 'UNITY':
            from algorithm.env_wrapper.unity_wrapper import UnityWrapper

            if self.run_in_editor:
                self.env = UnityWrapper(train_mode=self.train_mode,
                                        n_agents=self.base_config['n_agents'])
            else:
                self.env = UnityWrapper(
                    train_mode=self.train_mode,
                    file_name=self.base_config['build_path'][sys.platform],
                    base_port=self.base_config['port'],
                    no_graphics=self.base_config['no_graphics']
                    and not self.render,
                    scene=self.base_config['scene'],
                    additional_args=self.additional_args,
                    n_agents=self.base_config['n_agents'])

        elif self.base_config['env_type'] == 'GYM':
            from algorithm.env_wrapper.gym_wrapper import GymWrapper

            self.env = GymWrapper(train_mode=self.train_mode,
                                  env_name=self.base_config['build_path'],
                                  render=self.render,
                                  n_agents=self.base_config['n_agents'])
        else:
            raise RuntimeError(
                f'Undefined Environment Type: {self.base_config["env_type"]}')

        self.obs_shapes, self.d_action_size, self.c_action_size = self.env.init(
        )
        self.action_size = self.d_action_size + self.c_action_size

        self._logger.info(f'{self.base_config["build_path"]} initialized')

    def _run(self):
        obs_list = self.env.reset(reset_config=self.reset_config)

        agents = [Agent(i) for i in range(self.base_config['n_agents'])]

        agent_size = len(agents)
        fig_size = sum([len(s) == 3 for s in self.obs_shapes])

        img_save_index = 0

        plt.ion()
        fig, axes = plt.subplots(nrows=agent_size,
                                 ncols=fig_size,
                                 squeeze=False,
                                 figsize=(3 * fig_size, 3 * agent_size))
        ims = [[] for _ in range(agent_size)]
        for i in range(agent_size):
            j = 0
            for obs_shape in self.obs_shapes:
                if len(obs_shape) == 3:
                    axes[i][j].axis('off')
                    ims[i].append(axes[i][j].imshow(np.zeros(obs_shape)))
                    j += 1

        iteration = 0

        step_timer = elapsed_timer(self._logger, 'One step interacting', 200)

        while iteration != self.base_config['max_iter']:
            if self.base_config['reset_on_iteration'] or any(
                [a.max_reached for a in agents]):
                obs_list = self.env.reset(reset_config=self.reset_config)
                for agent in agents:
                    agent.clear()
            else:
                for agent in agents:
                    agent.reset()

            action = np.zeros([agent_size, self.action_size], dtype=np.float32)
            step = 0

            try:
                while not all([a.done for a in agents]):
                    j = 0
                    for obs in obs_list:
                        if len(obs.shape) > 2:
                            for i, image in enumerate(obs):
                                ims[i][j].set_data(image)
                            j += 1

                    fig.canvas.draw()
                    fig.canvas.flush_events()

                    action = np.random.rand(len(agents), self.action_size)

                    with step_timer:
                        next_obs_list, reward, local_done, max_reached = self.env.step(
                            action[..., :self.d_action_size],
                            action[..., self.d_action_size:])

                    if step == self.base_config['max_step_each_iter']:
                        local_done = [True] * len(agents)
                        max_reached = [True] * len(agents)

                    episode_trans_list = [
                        agents[i].add_transition([o[i] for o in obs_list],
                                                 action[i], reward[i],
                                                 local_done[i], max_reached[i],
                                                 [o[i] for o in next_obs_list],
                                                 None)
                        for i in range(len(agents))
                    ]

                    if self.save_image:
                        episode_trans_list = [
                            t for t in episode_trans_list if t is not None
                        ]
                        for episode_trans in episode_trans_list:
                            # n_obses_list: list([1, episode_len, *obs_shapes_i], ...)
                            n_obses_list, *_ = episode_trans

                            for i, n_obses in enumerate(n_obses_list):
                                n_obses = n_obses[0]
                                if len(n_obses.shape) > 2:
                                    img = Image.fromarray(
                                        np.uint8(n_obses[0] * 255))
                                    self._logger.info(
                                        f'Saved {img_save_index}-{i}')
                                    img.save(self.model_abs_dir.joinpath(
                                        f'{img_save_index}-{i}.gif'),
                                             save_all=True,
                                             append_images=[
                                                 Image.fromarray(
                                                     np.uint8(o * 255))
                                                 for o in n_obses[1:]
                                             ])

                        img_save_index += 1

                    obs_list = next_obs_list

                    step += 1

            except Exception as e:
                self._logger.error(e)
                self._logger.error('Exiting...')
                break

            iteration += 1

        self.env.close()
class Learner(object):
    train_mode = True
    _agent_class = Agent

    _training_lock = threading.Lock()
    _is_training = False

    def __init__(self, config_path, args):
        self._now = time.strftime('%Y%m%d%H%M%S', time.localtime(time.time()))

        (self.config, self.net_config, self.reset_config, replay_config,
         sac_config, model_root_path) = self._init_config(config_path, args)

        self._init_env(config_path, replay_config, sac_config, model_root_path)
        try:
            self._run()
        except KeyboardInterrupt:
            self.logger.warning('KeyboardInterrupt in _run')
            self.close()

    def _init_config(self, config_path, args):
        config_file_path = f'{config_path}/config_ds.yaml'
        config = config_helper.initialize_config_from_yaml(
            f'{Path(__file__).resolve().parent}/default_config.yaml',
            config_file_path, args.config)

        # Initialize config from command line arguments
        self.train_mode = not args.run
        self.last_ckpt = args.ckpt
        self.render = args.render
        self.run_in_editor = args.editor

        if args.name is not None:
            config['base_config']['name'] = args.name
        if args.build_port is not None:
            config['base_config']['build_port'] = args.build_port
        if args.sac is not None:
            config['base_config']['sac'] = args.sac
        if args.agents is not None:
            config['base_config']['n_agents'] = args.agents

        config['base_config']['name'] = config['base_config']['name'].replace(
            '{time}', self._now)
        model_root_path = f'models/ds/{config["base_config"]["scene"]}/{config["base_config"]["name"]}'

        logger_file = f'{model_root_path}/{args.logger_file}' if args.logger_file is not None else None
        self.logger = config_helper.set_logger('ds.learner', logger_file)

        if self.train_mode:
            config_helper.save_config(config, model_root_path, 'config.yaml')

        config_helper.display_config(config, self.logger)

        return (config['base_config'], config['net_config'],
                config['reset_config'], config['replay_config'],
                config['sac_config'], model_root_path)

    def _init_env(self, config_path, replay_config, sac_config,
                  model_root_path):
        self._stub = StubController(self.net_config)

        if self.config['env_type'] == 'UNITY':
            from algorithm.env_wrapper.unity_wrapper import UnityWrapper

            if self.run_in_editor:
                self.env = UnityWrapper(train_mode=self.train_mode,
                                        base_port=5004)
            else:
                self.env = UnityWrapper(
                    train_mode=self.train_mode,
                    file_name=self.config['build_path'][sys.platform],
                    base_port=self.config['build_port'],
                    scene=self.config['scene'],
                    n_agents=self.config['n_agents'])

        elif self.config['env_type'] == 'GYM':
            from algorithm.env_wrapper.gym_wrapper import GymWrapper

            self.env = GymWrapper(train_mode=self.train_mode,
                                  env_name=self.config['build_path'],
                                  render=self.render,
                                  n_agents=self.config['n_agents'])
        else:
            raise RuntimeError(
                f'Undefined Environment Type: {self.config["env_type"]}')

        self.obs_dims, self.action_dim, is_discrete = self.env.init()

        self.logger.info(f'{self.config["build_path"]} initialized')

        # If model exists, load saved model, or copy a new one
        if os.path.isfile(f'{model_root_path}/sac_model.py'):
            custom_sac_model = importlib.import_module(
                f'{model_root_path.replace("/",".")}.sac_model')
        else:
            custom_sac_model = importlib.import_module(
                f'{config_path.replace("/",".")}.{self.config["sac"]}')
            shutil.copyfile(f'{config_path}/{self.config["sac"]}.py',
                            f'{model_root_path}/sac_model.py')

        self.sac = SAC_DS_Base(obs_dims=self.obs_dims,
                               action_dim=self.action_dim,
                               is_discrete=is_discrete,
                               model_root_path=model_root_path,
                               model=custom_sac_model,
                               train_mode=self.train_mode,
                               last_ckpt=self.last_ckpt,
                               **sac_config)

    def _get_policy_variables(self):
        with self._training_lock:
            variables = self.sac.get_policy_variables()

        return [v.numpy() for v in variables]

    def _get_action(self, obs_list, rnn_state=None):
        if self.sac.use_rnn:
            assert rnn_state is not None

        with self._training_lock:
            if self.sac.use_rnn:
                action, next_rnn_state = self.sac.choose_rnn_action(
                    obs_list, rnn_state)
                next_rnn_state = next_rnn_state
                return action.numpy(), next_rnn_state.numpy()
            else:
                action = self.sac.choose_action(obs_list)
                return action.numpy()

    def _get_td_error(self,
                      n_obses_list,
                      n_actions,
                      n_rewards,
                      next_obs_list,
                      n_dones,
                      n_mu_probs,
                      n_rnn_states=None):
        """
        n_obses_list: list([1, episode_len, obs_dim_i], ...)
        n_actions: [1, episode_len, action_dim]
        n_rewards: [1, episode_len]
        next_obs_list: list([1, obs_dim_i], ...)
        n_dones: [1, episode_len]
        n_rnn_states: [1, episode_len, rnn_state_dim]
        """
        td_error = self.sac.get_episode_td_error(
            n_obses_list=n_obses_list,
            n_actions=n_actions,
            n_rewards=n_rewards,
            next_obs_list=next_obs_list,
            n_dones=n_dones,
            n_mu_probs=n_mu_probs,
            n_rnn_states=n_rnn_states if self.sac.use_rnn else None)

        return td_error

    def _post_rewards(self, peer, n_rewards):
        pass

    def _policy_evaluation(self):
        try:
            use_rnn = self.sac.use_rnn

            iteration = 0
            start_time = time.time()

            obs_list = self.env.reset(reset_config=self.reset_config)

            agents = [
                self._agent_class(i, use_rnn=use_rnn)
                for i in range(self.config['n_agents'])
            ]

            if use_rnn:
                initial_rnn_state = self.sac.get_initial_rnn_state(len(agents))
                rnn_state = initial_rnn_state

            while True:
                # not training, waiting...
                if self.train_mode and not self._is_training:
                    if self._t_evaluation_terminated:
                        break
                    time.sleep(EVALUATION_WAITING_TIME)
                    continue

                if self.config['reset_on_iteration']:
                    obs_list = self.env.reset(reset_config=self.reset_config)
                    for agent in agents:
                        agent.clear()

                    if use_rnn:
                        rnn_state = initial_rnn_state
                else:
                    for agent in agents:
                        agent.reset()

                action = np.zeros([len(agents), self.action_dim],
                                  dtype=np.float32)
                step = 0

                while False in [a.done
                                for a in agents] and (not self.train_mode
                                                      or self._is_training):
                    with self._training_lock:
                        if use_rnn:
                            action, next_rnn_state = self.sac.choose_rnn_action(
                                [o.astype(np.float32) for o in obs_list],
                                action, rnn_state)
                            next_rnn_state = next_rnn_state.numpy()
                        else:
                            action = self.sac.choose_action(
                                [o.astype(np.float32) for o in obs_list])

                    action = action.numpy()

                    next_obs_list, reward, local_done, max_reached = self.env.step(
                        action)

                    if step == self.config['max_step']:
                        local_done = [True] * len(agents)
                        max_reached = [True] * len(agents)

                    for i, agent in enumerate(agents):
                        agent.add_transition([o[i] for o in obs_list],
                                             action[i], reward[i],
                                             local_done[i], max_reached[i],
                                             [o[i] for o in next_obs_list],
                                             rnn_state[i] if use_rnn else None)

                    obs_list = next_obs_list
                    action[local_done] = np.zeros(self.action_dim)
                    if use_rnn:
                        rnn_state = next_rnn_state
                        rnn_state[local_done] = initial_rnn_state[local_done]

                    step += 1

                if self.train_mode:
                    with self._training_lock:
                        self._log_episode_summaries(iteration, agents)

                self._log_episode_info(iteration, start_time, agents)

                iteration += 1

                if self.train_mode:
                    time.sleep(EVALUATION_INTERVAL)
        except KeyboardInterrupt:
            self.logger.warning('KeyboardInterrupt in _policy_evaluation')
        except Exception as e:
            self.logger.error(e)

    def _log_episode_summaries(self, iteration, agents):
        rewards = np.array([a.reward for a in agents])
        self.sac.write_constant_summaries([{
            'tag': 'reward/mean',
            'simple_value': rewards.mean()
        }, {
            'tag': 'reward/max',
            'simple_value': rewards.max()
        }, {
            'tag': 'reward/min',
            'simple_value': rewards.min()
        }], iteration)

    def _log_episode_info(self, iteration, start_time, agents):
        time_elapse = (time.time() - start_time) / 60
        rewards = [a.reward for a in agents]
        rewards = ", ".join([f"{i:6.1f}" for i in rewards])
        self.logger.info(f'{iteration}, {time_elapse:.2f}, rewards {rewards}')

    def _get_sampled_data(self):
        while True:
            sampled = self._stub.get_sampled_data()

            if sampled is None:
                self.logger.warning('no data sampled')
                self._is_training = False
                time.sleep(RESAMPLE_TIME)
                continue
            else:
                self._is_training = True
                return sampled

    def _run_training_client(self):
        self._stub.clear_replay_buffer()

        while True:
            (pointers, n_obses_list, n_actions, n_rewards, next_obs_list,
             n_dones, n_mu_probs, rnn_state,
             priority_is) = self._get_sampled_data()

            with self._training_lock:
                td_error, update_data = self.sac.train(
                    pointers=pointers,
                    n_obses_list=n_obses_list,
                    n_actions=n_actions,
                    n_rewards=n_rewards,
                    next_obs_list=next_obs_list,
                    n_dones=n_dones,
                    n_mu_probs=n_mu_probs,
                    priority_is=priority_is,
                    rnn_state=rnn_state)

            self._stub.update_td_error(pointers, td_error)
            for pointers, key, data in update_data:
                self._stub.update_transitions(pointers, key, data)

    def _run(self):
        t_evaluation = threading.Thread(target=self._policy_evaluation)
        self._t_evaluation_terminated = False
        t_evaluation.start()

        if self.train_mode:
            servicer = LearnerService(self._get_action,
                                      self._get_policy_variables,
                                      self._get_td_error, self._post_rewards)
            self.server = grpc.server(
                futures.ThreadPoolExecutor(max_workers=20))
            learner_pb2_grpc.add_LearnerServiceServicer_to_server(
                servicer, self.server)
            self.server.add_insecure_port(
                f'[::]:{self.net_config["learner_port"]}')
            self.server.start()
            self.logger.info(
                f'Learner server is running on [{self.net_config["learner_port"]}]...'
            )

            self._run_training_client()

    def close(self):
        self._t_evaluation_terminated = True
        self.env.close()
        if self.train_mode:
            self.server.stop(None)
Ejemplo n.º 5
0
class Actor(object):
    _agent_class = Agent

    def __init__(self, root_dir, config_dir, args):
        self._logger = logging.getLogger('ds.actor')

        constant_config, config_abs_dir = self._init_constant_config(
            root_dir, config_dir, args)

        # The evolver stub is fixed,
        # but the learner stub will be generated by evolver
        self._evolver_stub = EvolverStubController(
            constant_config['net_config']['evolver_host'],
            constant_config['net_config']['evolver_port'])

        self._sac_actor_lock = ReadWriteLock(5, 1, 1, logger=self._logger)

        learner_host, learner_port = self._evolver_stub.register_to_evolver()

        self._logger.info(f'Assigned to learner {learner_host}:{learner_port}')

        # The learner stub is generated by evovler
        self._stub = StubController(learner_host, learner_port)

        self._init_config(constant_config, args)
        self._init_env()
        self._init_sac(config_abs_dir)
        self._init_episode_sender(learner_host, learner_port)

        self._run()

        self.close()

    def _init_constant_config(self, root_dir, config_dir, args):
        default_config_abs_path = Path(__file__).resolve().parent.joinpath(
            'default_config.yaml')
        config_abs_dir = Path(root_dir).joinpath(config_dir)
        config_abs_path = config_abs_dir.joinpath('config_ds.yaml')
        config = config_helper.initialize_config_from_yaml(
            default_config_abs_path, config_abs_path, args.config)

        # Initialize config from command line arguments
        self.additional_args = args.additional_args
        self.device = args.device
        self.run_in_editor = args.editor
        self.logger_in_file = args.logger_in_file

        if args.evolver_host is not None:
            config['net_config']['evolver_host'] = args.evolver_host
        if args.evolver_port is not None:
            config['net_config']['evolver_port'] = args.evolver_port

        return config, config_abs_dir

    def _init_config(self, config, args):
        if args.build_port is not None:
            config['base_config']['build_port'] = args.build_port
        if args.nn is not None:
            config['base_config']['nn'] = args.nn
        if args.agents is not None:
            config['base_config']['n_agents'] = args.agents

        self.base_config = config['base_config']

        register_response = self._stub.register_to_learner()

        (model_abs_dir, _id, reset_config, model_config,
         sac_config) = register_response
        self._logger.info(f'Assigned to id {_id}')

        config['reset_config'] = self.reset_config = reset_config
        config['model_config'] = self.model_config = model_config
        config['sac_config'] = sac_config

        self.sac_config = config['sac_config']
        self.model_abs_dir = model_abs_dir

        # Set logger file if available
        if self.logger_in_file:
            logger_file = Path(model_abs_dir).joinpath(f'actor-{_id}.log')
            config_helper.set_logger(logger_file)
            self._logger.info(f'Set to logger {logger_file}')

        config_helper.display_config(config, self._logger)

    def _init_env(self):
        if self.base_config['env_type'] == 'UNITY':
            from algorithm.env_wrapper.unity_wrapper import UnityWrapper

            if self.run_in_editor:
                self.env = UnityWrapper()
            else:
                self.env = UnityWrapper(
                    file_name=self.base_config['build_path'][sys.platform],
                    base_port=self.base_config['build_port'],
                    no_graphics=self.base_config['no_graphics'],
                    scene=self.base_config['scene'],
                    additional_args=self.additional_args,
                    n_agents=self.base_config['n_agents'])

        elif self.base_config['env_type'] == 'GYM':
            from algorithm.env_wrapper.gym_wrapper import GymWrapper

            self.env = GymWrapper(env_name=self.base_config['build_path'],
                                  n_agents=self.base_config['n_agents'])
        else:
            raise RuntimeError(
                f'Undefined Environment Type: {self.base_config["env_type"]}')

        self.obs_shapes, self.d_action_size, self.c_action_size = self.env.init(
        )
        self.action_size = self.d_action_size + self.c_action_size

        self._logger.info(f'{self.base_config["build_path"]} initialized')

    def _init_sac(self, config_abs_dir):
        nn_abs_path = Path(config_abs_dir).joinpath(
            f'{self.base_config["nn"]}.py')
        spec = importlib.util.spec_from_file_location('nn', str(nn_abs_path))
        custom_nn_model = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(custom_nn_model)

        self.sac_actor = SAC_DS_Base(obs_shapes=self.obs_shapes,
                                     d_action_size=self.d_action_size,
                                     c_action_size=self.c_action_size,
                                     model_abs_dir=None,
                                     model=custom_nn_model,
                                     model_config=self.model_config,
                                     device=self.device,
                                     train_mode=False,
                                     **self.sac_config)

        self._logger.info(f'SAC_ACTOR started')

    def _init_episode_sender(self, learner_host, learner_port):
        max_episode_size = self.base_config['max_episode_size']
        episode_shapes = [[(1, max_episode_size, *o) for o in self.obs_shapes],
                          (1, max_episode_size,
                           self.d_action_size + self.c_action_size),
                          (1, max_episode_size),
                          [(1, *o) for o in self.obs_shapes],
                          (1, max_episode_size), (1, max_episode_size),
                          (1, max_episode_size,
                           *self.sac_actor.rnn_state_shape)
                          if self.sac_actor.use_rnn else None]
        episode_dtypes = [[np.float32
                           for _ in self.obs_shapes], np.float32, np.float32,
                          [np.float32 for _ in self.obs_shapes], bool,
                          np.float32,
                          np.float32 if self.sac_actor.use_rnn else None]

        self._episode_buffer = SharedMemoryManager(
            self.base_config['episode_queue_size'],
            logger=self._logger,
            counter_get_shm_index_empty_log='Episode shm index is empty',
            timer_get_shm_index_log='Get an episode shm index',
            timer_get_data_log='Get an episode',
            log_repeat=ELAPSED_REPEAT)

        self._episode_buffer.init_from_shapes(episode_shapes, episode_dtypes)
        self._episode_size_array = mp.Array(
            'i', range(self.base_config['episode_queue_size']))

        for _ in range(self.base_config['episode_sender_process_num']):
            mp.Process(target=EpisodeSender,
                       kwargs={
                           'logger_in_file': self.logger_in_file,
                           'model_abs_dir': self.model_abs_dir,
                           'learner_host': learner_host,
                           'learner_port': learner_port,
                           'episode_buffer': self._episode_buffer,
                           'episode_size_array': self._episode_size_array
                       }).start()

    def _update_policy_variables(self):
        variables = self._stub.get_policy_variables()
        if variables is not None:
            if not any([np.isnan(np.min(v)) for v in variables]):
                with self._sac_actor_lock.write():
                    self.sac_actor.update_policy_variables(variables)
                self._logger.info('Policy variables updated')
            else:
                self._logger.warning('NAN in variables, skip updating')

    def _add_trans(self,
                   l_obses_list,
                   l_actions,
                   l_rewards,
                   next_obs_list,
                   l_dones,
                   l_probs,
                   l_rnn_states=None):

        if l_obses_list[0].shape[
                1] < self.sac_actor.burn_in_step + self.sac_actor.n_step:
            return
        """
        Args:
            l_obses_list: list([1, episode_len, *obs_shapes_i], ...)
            l_actions: [1, episode_len, action_size]
            l_rewards: [1, episode_len]
            next_obs_list: list([1, *obs_shapes_i], ...)
            l_dones: [1, episode_len]
            l_probs: [1, episode_len]
            l_rnn_states: [1, episode_len, *rnn_state_shape]
        """
        episode_idx = self._episode_buffer.put([
            l_obses_list, l_actions, l_rewards, next_obs_list, l_dones,
            l_probs, l_rnn_states
        ])
        self._episode_size_array[episode_idx] = l_obses_list[0].shape[1]

    def _run(self):
        use_rnn = self.sac_actor.use_rnn

        obs_list = self.env.reset(reset_config=self.reset_config)

        agents = [
            self._agent_class(
                i,
                use_rnn=use_rnn,
                max_return_episode_trans=self.base_config['max_episode_size'])
            for i in range(self.base_config['n_agents'])
        ]

        if use_rnn:
            initial_rnn_state = self.sac_actor.get_initial_rnn_state(
                len(agents))
            rnn_state = initial_rnn_state

        iteration = 0

        while self._stub.connected and self._evolver_stub.connected:
            if self.base_config['reset_on_iteration'] or any(
                [a.max_reached for a in agents]):
                obs_list = self.env.reset(reset_config=self.reset_config)
                for agent in agents:
                    agent.clear()

                if use_rnn:
                    rnn_state = initial_rnn_state
            else:
                for agent in agents:
                    agent.reset()

            action = np.zeros([len(agents), self.action_size],
                              dtype=np.float32)
            step = 0

            if self.base_config['update_policy_mode']:
                self._update_policy_variables()

            try:
                while not all([a.done
                               for a in agents]) and self._stub.connected:
                    # burn in padding
                    for agent in agents:
                        if agent.is_empty():
                            for _ in range(self.sac_actor.burn_in_step):
                                agent.add_transition(
                                    obs_list=[
                                        np.zeros(t) for t in self.obs_shapes
                                    ],
                                    action=np.zeros(self.action_size),
                                    reward=0.,
                                    local_done=False,
                                    max_reached=False,
                                    next_obs_list=[
                                        np.zeros(t) for t in self.obs_shapes
                                    ],
                                    prob=0.,
                                    rnn_state=initial_rnn_state[0])

                    if self.base_config['update_policy_mode']:
                        with self._sac_actor_lock.read():
                            if use_rnn:
                                action, prob, next_rnn_state = self.sac_actor.choose_rnn_action(
                                    [o.astype(np.float32) for o in obs_list],
                                    action,
                                    rnn_state,
                                    force_rnd_if_avaiable=True)
                            else:
                                action, prob = self.sac_actor.choose_action(
                                    [o.astype(np.float32) for o in obs_list],
                                    force_rnd_if_avaiable=True)

                    else:
                        # Get action from learner each step
                        # TODO need prob
                        raise Exception('TODO need prob')
                        if use_rnn:
                            action_rnn_state = self._stub.get_action(
                                [o.astype(np.float32) for o in obs_list],
                                rnn_state)
                            if action_rnn_state is None:
                                break
                            action, next_rnn_state = action_rnn_state
                        else:
                            action = self._stub.get_action(
                                [o.astype(np.float32) for o in obs_list])
                            if action is None:
                                break

                    next_obs_list, reward, local_done, max_reached = self.env.step(
                        action[..., :self.d_action_size],
                        action[..., self.d_action_size:])

                    if step == self.base_config['max_step_each_iter']:
                        local_done = [True] * len(agents)
                        max_reached = [True] * len(agents)

                    episode_trans_list = [
                        agents[i].add_transition(
                            obs_list=[o[i] for o in obs_list],
                            action=action[i],
                            reward=reward[i],
                            local_done=local_done[i],
                            max_reached=max_reached[i],
                            next_obs_list=[o[i] for o in next_obs_list],
                            prob=prob[i],
                            rnn_state=rnn_state[i] if use_rnn else None)
                        for i in range(len(agents))
                    ]

                    episode_trans_list = [
                        t for t in episode_trans_list if t is not None
                    ]
                    if len(episode_trans_list) != 0:
                        # ep_obses_list, ep_actions, ep_rewards, next_obs_list, ep_dones, ep_probs,
                        # ep_rnn_states
                        for episode_trans in episode_trans_list:
                            self._add_trans(*episode_trans)

                    obs_list = next_obs_list
                    action[local_done] = np.zeros(self.action_size)
                    if use_rnn:
                        rnn_state = next_rnn_state
                        rnn_state[local_done] = initial_rnn_state[local_done]

                    step += 1

            except:
                self._logger.error(traceback.format_exc())
                self._logger.error('Exiting...')
                break

            self._log_episode_info(iteration, agents)
            iteration += 1

        self.close()

    def _log_episode_info(self, iteration, agents):
        rewards = [a.reward for a in agents]
        rewards = ", ".join([f"{i:6.1f}" for i in rewards])
        max_step = max([a.steps for a in agents])
        self._logger.info(f'{iteration}, S {max_step}, R {rewards}')

    def close(self):
        if hasattr(self, 'env'):
            self.env.close()
        if hasattr(self, '_stub'):
            self._stub.close()

        self._evolver_stub.close()

        self._logger.warning('Closed')
    def _init_env(self):
        # Each time actor connects to the learner and replay, initialize env

        # Initialize config
        config = config_helper.initialize_config_from_yaml(
            f'{Path(__file__).resolve().parent}/default_config.yaml',
            self.config_file_path, self.config_cat)

        if self.cmd_args.build_port is not None:
            config['base_config']['build_port'] = self.cmd_args.build_port
        if self.cmd_args.sac is not None:
            config['base_config']['sac'] = self.cmd_args.sac
        if self.cmd_args.agents is not None:
            config['base_config']['n_agents'] = self.cmd_args.agents
        if self.cmd_args.noise is not None:
            config['sac_config']['noise'] = self.cmd_args.noise

        self.config = config['base_config']
        sac_config = config['sac_config']
        self.reset_config = config['reset_config']

        # Initialize environment
        if self.config['env_type'] == 'UNITY':
            from algorithm.env_wrapper.unity_wrapper import UnityWrapper

            if self.run_in_editor:
                self.env = UnityWrapper(train_mode=self.train_mode,
                                        base_port=5004)
            else:
                self.env = UnityWrapper(
                    train_mode=self.train_mode,
                    file_name=self.config['build_path'][sys.platform],
                    no_graphics=self.train_mode,
                    base_port=self.config['build_port'],
                    scene=self.config['scene'],
                    n_agents=self.config['n_agents'])

        elif self.config['env_type'] == 'GYM':
            from algorithm.env_wrapper.gym_wrapper import GymWrapper

            self.env = GymWrapper(train_mode=self.train_mode,
                                  env_name=self.config['build_path'],
                                  n_agents=self.config['n_agents'])
        else:
            raise RuntimeError(
                f'Undefined Environment Type: {self.config["env_type"]}')

        self.obs_dims, self.action_dim, is_discrete = self.env.init()

        self.logger.info(f'{self.config["build_path"]} initialized')

        custom_sac_model = importlib.import_module(
            f'{self.config_path.replace("/",".")}.{self.config["sac"]}')

        self.sac_actor = SAC_DS_Base(obs_dims=self.obs_dims,
                                     action_dim=self.action_dim,
                                     is_discrete=is_discrete,
                                     model_root_path=None,
                                     model=custom_sac_model,
                                     train_mode=False,
                                     **sac_config)

        self.logger.info(f'sac_actor initialized')
class Actor(object):
    train_mode = True
    _agent_class = Agent
    _logged_waiting_for_connection = False

    def __init__(self, config_path, args):
        self.config_path = config_path
        self.cmd_args = args
        net_config = self._init_constant_config(self.config_path, args)

        self._stub = StubController(net_config)
        self._run()

    def _init_constant_config(self, config_path, args):
        config_file_path = f'{config_path}/config_ds.yaml'
        config = config_helper.initialize_config_from_yaml(
            f'{Path(__file__).resolve().parent}/default_config.yaml',
            config_file_path, args.config)
        self.config_file_path = config_file_path

        # Initialize config from command line arguments
        self.train_mode = not args.run
        self.run_in_editor = args.editor
        self.config_cat = args.config

        self.logger = config_helper.set_logger('ds.actor', args.logger_file)

        return config['net_config']

    def _init_env(self):
        # Each time actor connects to the learner and replay, initialize env

        # Initialize config
        config = config_helper.initialize_config_from_yaml(
            f'{Path(__file__).resolve().parent}/default_config.yaml',
            self.config_file_path, self.config_cat)

        if self.cmd_args.build_port is not None:
            config['base_config']['build_port'] = self.cmd_args.build_port
        if self.cmd_args.sac is not None:
            config['base_config']['sac'] = self.cmd_args.sac
        if self.cmd_args.agents is not None:
            config['base_config']['n_agents'] = self.cmd_args.agents
        if self.cmd_args.noise is not None:
            config['sac_config']['noise'] = self.cmd_args.noise

        self.config = config['base_config']
        sac_config = config['sac_config']
        self.reset_config = config['reset_config']

        # Initialize environment
        if self.config['env_type'] == 'UNITY':
            from algorithm.env_wrapper.unity_wrapper import UnityWrapper

            if self.run_in_editor:
                self.env = UnityWrapper(train_mode=self.train_mode,
                                        base_port=5004)
            else:
                self.env = UnityWrapper(
                    train_mode=self.train_mode,
                    file_name=self.config['build_path'][sys.platform],
                    no_graphics=self.train_mode,
                    base_port=self.config['build_port'],
                    scene=self.config['scene'],
                    n_agents=self.config['n_agents'])

        elif self.config['env_type'] == 'GYM':
            from algorithm.env_wrapper.gym_wrapper import GymWrapper

            self.env = GymWrapper(train_mode=self.train_mode,
                                  env_name=self.config['build_path'],
                                  n_agents=self.config['n_agents'])
        else:
            raise RuntimeError(
                f'Undefined Environment Type: {self.config["env_type"]}')

        self.obs_dims, self.action_dim, is_discrete = self.env.init()

        self.logger.info(f'{self.config["build_path"]} initialized')

        custom_sac_model = importlib.import_module(
            f'{self.config_path.replace("/",".")}.{self.config["sac"]}')

        self.sac_actor = SAC_DS_Base(obs_dims=self.obs_dims,
                                     action_dim=self.action_dim,
                                     is_discrete=is_discrete,
                                     model_root_path=None,
                                     model=custom_sac_model,
                                     train_mode=False,
                                     **sac_config)

        self.logger.info(f'sac_actor initialized')

    def _update_policy_variables(self):
        variables = self._stub.update_policy_variables()
        if variables is not None:
            self.sac_actor.update_policy_variables(variables)

    def _add_trans(self,
                   n_obses_list,
                   n_actions,
                   n_rewards,
                   next_obs_list,
                   n_dones,
                   n_rnn_states=None):

        self._stub.post_rewards(n_rewards)

        if n_obses_list[0].shape[
                1] < self.sac_actor.burn_in_step + self.sac_actor.n_step:
            return

        if self.sac_actor.use_rnn:
            n_mu_probs = self.sac_actor.get_n_probs(n_obses_list, n_actions,
                                                    n_rnn_states[:, 0,
                                                                 ...]).numpy()
            self._stub.add_transitions(n_obses_list, n_actions, n_rewards,
                                       next_obs_list, n_dones, n_mu_probs,
                                       n_rnn_states)
        else:
            n_mu_probs = self.sac_actor.get_n_probs(n_obses_list,
                                                    n_actions).numpy()
            self._stub.add_transitions(n_obses_list, n_actions, n_rewards,
                                       next_obs_list, n_dones, n_mu_probs)

    def _run(self):
        iteration = 0

        while True:
            # Replay or learner is offline, waiting...
            if not self._stub.connected:
                if iteration != 0:
                    self.env.close()
                    self.logger.info(f'{self.config["build_path"]} closed')
                    iteration = 0

                if not self._logged_waiting_for_connection:
                    self.logger.warning('waiting for connection')
                    self._logged_waiting_for_connection = True
                time.sleep(WAITING_CONNECTION_TIME)
                continue
            self._logged_waiting_for_connection = False

            # Learner is online, reset all settings
            if iteration == 0 and self._stub.connected:
                self._init_env()
                use_rnn = self.sac_actor.use_rnn

                obs_list = self.env.reset(reset_config=self.reset_config)

                agents = [
                    self._agent_class(i, use_rnn=use_rnn)
                    for i in range(self.config['n_agents'])
                ]

                if use_rnn:
                    initial_rnn_state = self.sac_actor.get_initial_rnn_state(
                        len(agents))
                    rnn_state = initial_rnn_state

            if self.config['reset_on_iteration']:
                obs_list = self.env.reset(reset_config=self.reset_config)
                for agent in agents:
                    agent.clear()

                if use_rnn:
                    rnn_state = initial_rnn_state
            else:
                for agent in agents:
                    agent.reset()

            action = np.zeros([len(agents), self.action_dim], dtype=np.float32)
            step = 0

            if self.config['update_policy_mode'] and self.config[
                    'update_policy_variables_per_step'] == -1:
                self._update_policy_variables()

            while False in [a.done for a in agents] and self._stub.connected:
                # burn in padding
                for agent in agents:
                    if agent.is_empty():
                        for _ in range(self.sac_actor.burn_in_step):
                            agent.add_transition(
                                [np.zeros(t) for t in self.obs_dims],
                                np.zeros(self.action_dim), 0, False, False,
                                [np.zeros(t)
                                 for t in self.obs_dims], initial_rnn_state[0])

                if self.config['update_policy_mode']:
                    # Update policy variables each "update_policy_variables_per_step"
                    if self.config[
                            'update_policy_variables_per_step'] != -1 and step % self.config[
                                'update_policy_variables_per_step'] == 0:
                        self._update_policy_variables()

                    if use_rnn:
                        action, next_rnn_state = self.sac_actor.choose_rnn_action(
                            [o.astype(np.float32) for o in obs_list], action,
                            rnn_state)
                        next_rnn_state = next_rnn_state.numpy()
                    else:
                        action = self.sac_actor.choose_action(
                            [o.astype(np.float32) for o in obs_list])

                    action = action.numpy()
                else:
                    # Get action from learner each step
                    if use_rnn:
                        action_rnn_state = self._stub.get_action(
                            [o.astype(np.float32) for o in obs_list],
                            rnn_state)
                        if action_rnn_state is None:
                            break
                        action, next_rnn_state = action_rnn_state
                    else:
                        action = self._stub.get_action(
                            [o.astype(np.float32) for o in obs_list])
                        if action is None:
                            break

                next_obs_list, reward, local_done, max_reached = self.env.step(
                    action)

                if step == self.config['max_step']:
                    local_done = [True] * len(agents)
                    max_reached = [True] * len(agents)

                episode_trans_list = [
                    agents[i].add_transition([o[i] for o in obs_list],
                                             action[i], reward[i],
                                             local_done[i], max_reached[i],
                                             [o[i] for o in next_obs_list],
                                             rnn_state[i] if use_rnn else None)
                    for i in range(len(agents))
                ]

                if self.train_mode:
                    episode_trans_list = [
                        t for t in episode_trans_list if t is not None
                    ]
                    if len(episode_trans_list) != 0:
                        for episode_trans in episode_trans_list:
                            self._add_trans(*episode_trans)

                obs_list = next_obs_list
                action[local_done] = np.zeros(self.action_dim)
                if use_rnn:
                    rnn_state = next_rnn_state
                    rnn_state[local_done] = initial_rnn_state[local_done]

                step += 1

            self._log_episode_info(iteration, agents)
            iteration += 1

    def _log_episode_info(self, iteration, agents):
        rewards = [a.reward for a in agents]
        rewards = ", ".join([f"{i:6.1f}" for i in rewards])
        self.logger.info(f'{iteration}, rewards {rewards}')
Ejemplo n.º 8
0
class Main(object):
    train_mode = True
    _agent_class = Agent  # For different environments

    def __init__(self, config_path, args):
        """
        config_path: the directory of config file
        args: command arguments generated by argparse
        """
        self._now = time.strftime('%Y%m%d%H%M%S', time.localtime(time.time()))

        (self.config, self.reset_config,
         replay_config,
         sac_config,
         model_root_path) = self._init_config(config_path, args)
        self._init_env(model_root_path, config_path,
                       sac_config,
                       replay_config)
        self._run()

    def _init_config(self, config_path, args):
        config_file_path = f'{config_path}/config.yaml'
        # Merge default_config.yaml and custom config.yaml
        config = config_helper.initialize_config_from_yaml(f'{Path(__file__).resolve().parent}/default_config.yaml',
                                                           config_file_path,
                                                           args.config)

        # Initialize config from command line arguments
        self.train_mode = not args.run
        self.last_ckpt = args.ckpt
        self.render = args.render
        self.run_in_editor = args.editor

        if args.name is not None:
            config['base_config']['name'] = args.name
        if args.port is not None:
            config['base_config']['port'] = args.port
        if args.sac is not None:
            config['base_config']['sac'] = args.sac
        if args.agents is not None:
            config['base_config']['n_agents'] = args.agents

        # Replace {time} from current time and random letters
        rand = ''.join(random.sample(string.ascii_letters, 4))
        config['base_config']['name'] = config['base_config']['name'].replace('{time}', self._now + rand)
        model_root_path = f'models/{config["base_config"]["scene"]}/{config["base_config"]["name"]}'

        logger_file = f'{model_root_path}/{args.logger_file}' if args.logger_file is not None else None
        self.logger = config_helper.set_logger('sac', logger_file)

        if self.train_mode:
            config_helper.save_config(config, model_root_path, 'config.yaml')

        config_helper.display_config(config, self.logger)

        return (config['base_config'],
                config['reset_config'],
                config['replay_config'],
                config['sac_config'],
                model_root_path)

    def _init_env(self, model_root_path, config_path,
                  sac_config,
                  replay_config):
        if self.config['env_type'] == 'UNITY':
            from algorithm.env_wrapper.unity_wrapper import UnityWrapper

            if self.run_in_editor:
                self.env = UnityWrapper(train_mode=self.train_mode, base_port=5004)
            else:
                self.env = UnityWrapper(train_mode=self.train_mode,
                                        file_name=self.config['build_path'][sys.platform],
                                        base_port=self.config['port'],
                                        scene=self.config['scene'],
                                        n_agents=self.config['n_agents'])

        elif self.config['env_type'] == 'GYM':
            from algorithm.env_wrapper.gym_wrapper import GymWrapper

            self.env = GymWrapper(train_mode=self.train_mode,
                                  env_name=self.config['build_path'],
                                  render=self.render,
                                  n_agents=self.config['n_agents'])
        else:
            raise RuntimeError(f'Undefined Environment Type: {self.config["env_type"]}')

        self.obs_dims, self.action_dim, is_discrete = self.env.init()

        # If model exists, load saved model, or copy a new one
        if os.path.isfile(f'{model_root_path}/sac_model.py'):
            custom_sac_model = importlib.import_module(f'{model_root_path.replace("/",".")}.sac_model')
        else:
            custom_sac_model = importlib.import_module(f'{config_path.replace("/",".")}.{self.config["sac"]}')
            shutil.copyfile(f'{config_path}/{self.config["sac"]}.py', f'{model_root_path}/sac_model.py')

        self.sac = SAC_Base(obs_dims=self.obs_dims,
                            action_dim=self.action_dim,
                            is_discrete=is_discrete,
                            model_root_path=model_root_path,
                            model=custom_sac_model,
                            train_mode=self.train_mode,
                            last_ckpt=self.last_ckpt,

                            replay_config=replay_config,

                            **sac_config)

    def _run(self):
        use_rnn = self.sac.use_rnn

        obs_list = self.env.reset(reset_config=self.reset_config)

        agents = [self._agent_class(i, use_rnn=self.sac.use_rnn)
                  for i in range(self.config['n_agents'])]

        if use_rnn:
            initial_rnn_state = self.sac.get_initial_rnn_state(len(agents))
            rnn_state = initial_rnn_state

        is_max_reached = False

        for iteration in range(self.config['max_iter'] + 1):
            if self.config['reset_on_iteration'] or is_max_reached:
                obs_list = self.env.reset(reset_config=self.reset_config)
                for agent in agents:
                    agent.clear()

                if use_rnn:
                    rnn_state = initial_rnn_state
            else:
                for agent in agents:
                    agent.reset()

            is_max_reached = False
            action = np.zeros([len(agents), self.action_dim], dtype=np.float32)
            step = 0

            while False in [a.done for a in agents]:
                if use_rnn:
                    # burn-in padding
                    for agent in [a for a in agents if a.is_empty()]:
                        for _ in range(self.sac.burn_in_step):
                            agent.add_transition([np.zeros(t) for t in self.obs_dims],
                                                 np.zeros(self.action_dim),
                                                 0, False, False,
                                                 [np.zeros(t) for t in self.obs_dims],
                                                 initial_rnn_state[0])

                    action, next_rnn_state = self.sac.choose_rnn_action([o.astype(np.float32) for o in obs_list],
                                                                        action,
                                                                        rnn_state)
                    next_rnn_state = next_rnn_state.numpy()
                else:
                    action = self.sac.choose_action([o.astype(np.float32) for o in obs_list])

                action = action.numpy()

                next_obs_list, reward, local_done, max_reached = self.env.step(action)

                if step == self.config['max_step']:
                    local_done = [True] * len(agents)
                    max_reached = [True] * len(agents)
                    is_max_reached = True

                episode_trans_list = [agents[i].add_transition([o[i] for o in obs_list],
                                                               action[i],
                                                               reward[i],
                                                               local_done[i],
                                                               max_reached[i],
                                                               [o[i] for o in next_obs_list],
                                                               rnn_state[i] if use_rnn else None)
                                      for i in range(len(agents))]

                if self.train_mode:
                    episode_trans_list = [t for t in episode_trans_list if t is not None]
                    if len(episode_trans_list) != 0:
                        # n_obses_list, n_actions, n_rewards, next_obs_list, n_dones,
                        # n_rnn_states
                        for episode_trans in episode_trans_list:
                            self.sac.fill_replay_buffer(*episode_trans)
                    self.sac.train()

                obs_list = next_obs_list
                action[local_done] = np.zeros(self.action_dim)
                if use_rnn:
                    rnn_state = next_rnn_state
                    rnn_state[local_done] = initial_rnn_state[local_done]

                step += 1

            if self.train_mode:
                self._log_episode_summaries(iteration, agents)

            self._log_episode_info(iteration, agents)

        self.sac.save_model()
        self.env.close()

    def _log_episode_summaries(self, iteration, agents):
        rewards = np.array([a.reward for a in agents])
        self.sac.write_constant_summaries([
            {'tag': 'reward/mean', 'simple_value': rewards.mean()},
            {'tag': 'reward/max', 'simple_value': rewards.max()},
            {'tag': 'reward/min', 'simple_value': rewards.min()}
        ], iteration)

    def _log_episode_info(self, iteration, agents):
        rewards = [a.reward for a in agents]
        rewards = ", ".join([f"{i:6.1f}" for i in rewards])
        self.logger.info(f'{iteration}, rewards {rewards}')
Ejemplo n.º 9
0
class Main(object):
    train_mode = True
    _agent_class = Agent  # For different environments

    def __init__(self, root_dir, config_dir, args):
        """
        config_path: the directory of config file
        args: command arguments generated by argparse
        """
        self._logger = logging.getLogger('sac')

        config_abs_dir = self._init_config(root_dir, config_dir, args)

        self._init_env()
        self._init_sac(config_abs_dir)

        self._run()

    def _init_config(self, root_dir, config_dir, args):
        config_abs_dir = Path(root_dir).joinpath(config_dir)
        config_abs_path = config_abs_dir.joinpath('config.yaml')
        default_config_abs_path = Path(__file__).resolve().parent.joinpath(
            'default_config.yaml')
        # Merge default_config.yaml and custom config.yaml
        config = config_helper.initialize_config_from_yaml(
            default_config_abs_path, config_abs_path, args.config)

        # Initialize config from command line arguments
        self.train_mode = not args.run
        self.render = args.render
        self.run_in_editor = args.editor
        self.additional_args = args.additional_args
        self.disable_sample = args.disable_sample
        self.alway_use_env_nn = args.use_env_nn
        self.device = args.device
        self.last_ckpt = args.ckpt

        if args.name is not None:
            config['base_config']['name'] = args.name
        if args.port is not None:
            config['base_config']['port'] = args.port
        if args.nn is not None:
            config['base_config']['nn'] = args.nn
        if args.agents is not None:
            config['base_config']['n_agents'] = args.agents
        if args.max_iter is not None:
            config['base_config']['max_iter'] = args.max_iter

        config['base_config']['name'] = config_helper.generate_base_name(
            config['base_config']['name'])

        # The absolute directory of a specific training
        model_abs_dir = Path(root_dir).joinpath('models',
                                                config['base_config']['scene'],
                                                config['base_config']['name'])
        model_abs_dir.mkdir(parents=True, exist_ok=True)
        self.model_abs_dir = model_abs_dir

        if args.logger_in_file:
            config_helper.set_logger(Path(model_abs_dir).joinpath(f'log.log'))

        if self.train_mode:
            config_helper.save_config(config, model_abs_dir, 'config.yaml')

        config_helper.display_config(config, self._logger)

        self.base_config = config['base_config']
        self.reset_config = config['reset_config']
        self.model_config = config['model_config']
        self.replay_config = config['replay_config']
        self.sac_config = config['sac_config']

        return config_abs_dir

    def _init_env(self):
        if self.base_config['env_type'] == 'UNITY':
            from algorithm.env_wrapper.unity_wrapper import UnityWrapper

            if self.run_in_editor:
                self.env = UnityWrapper(train_mode=self.train_mode,
                                        n_agents=self.base_config['n_agents'])
            else:
                self.env = UnityWrapper(
                    train_mode=self.train_mode,
                    file_name=self.base_config['build_path'][sys.platform],
                    base_port=self.base_config['port'],
                    no_graphics=self.base_config['no_graphics']
                    and not self.render,
                    scene=self.base_config['scene'],
                    additional_args=self.additional_args,
                    n_agents=self.base_config['n_agents'])

        elif self.base_config['env_type'] == 'GYM':
            from algorithm.env_wrapper.gym_wrapper import GymWrapper

            self.env = GymWrapper(train_mode=self.train_mode,
                                  env_name=self.base_config['build_path'],
                                  render=self.render,
                                  n_agents=self.base_config['n_agents'])
        else:
            raise RuntimeError(
                f'Undefined Environment Type: {self.base_config["env_type"]}')

        self.obs_shapes, self.d_action_size, self.c_action_size = self.env.init(
        )
        self.action_size = self.d_action_size + self.c_action_size

        self._logger.info(f'{self.base_config["build_path"]} initialized')

    def _init_sac(self, config_abs_dir: Path):
        # If nn models exists, load saved model, or copy a new one
        nn_model_abs_path = self.model_abs_dir.joinpath('nn_models.py')
        if not self.alway_use_env_nn and nn_model_abs_path.exists():
            spec = importlib.util.spec_from_file_location(
                'nn', str(nn_model_abs_path))
            self._logger.info(f'Loaded nn from existed {nn_model_abs_path}')
        else:
            nn_abs_path = config_abs_dir.joinpath(
                f'{self.base_config["nn"]}.py')
            spec = importlib.util.spec_from_file_location(
                'nn', str(nn_abs_path))
            self._logger.info(f'Loaded nn in env dir: {nn_abs_path}')
            if not self.alway_use_env_nn:
                shutil.copyfile(nn_abs_path, nn_model_abs_path)

        custom_nn_model = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(custom_nn_model)

        self.sac = SAC_Base(obs_shapes=self.obs_shapes,
                            d_action_size=self.d_action_size,
                            c_action_size=self.c_action_size,
                            model_abs_dir=self.model_abs_dir,
                            model=custom_nn_model,
                            model_config=self.model_config,
                            device=self.device,
                            train_mode=self.train_mode,
                            last_ckpt=self.last_ckpt,
                            replay_config=self.replay_config,
                            **self.sac_config)

    def _run(self):
        use_rnn = self.sac.use_rnn

        obs_list = self.env.reset(reset_config=self.reset_config)

        agents = [
            self._agent_class(i, use_rnn=self.sac.use_rnn)
            for i in range(self.base_config['n_agents'])
        ]

        if use_rnn:
            initial_rnn_state = self.sac.get_initial_rnn_state(len(agents))
            rnn_state = initial_rnn_state

        iteration = 0
        trained_steps = 0

        while iteration != self.base_config['max_iter']:
            if self.base_config[
                    'max_step'] != -1 and trained_steps >= self.base_config[
                        'max_step']:
                break

            if self.base_config['reset_on_iteration'] or any(
                [a.max_reached for a in agents]):
                obs_list = self.env.reset(reset_config=self.reset_config)
                for agent in agents:
                    agent.clear()

                if use_rnn:
                    rnn_state = initial_rnn_state
            else:
                for agent in agents:
                    agent.reset()

            action = np.zeros([len(agents), self.action_size],
                              dtype=np.float32)
            step = 0
            iter_time = time.time()

            try:
                while not all([a.done for a in agents]):
                    if use_rnn:
                        # burn-in padding
                        for agent in [a for a in agents if a.is_empty()]:
                            for _ in range(self.sac.burn_in_step):
                                agent.add_transition(
                                    obs_list=[
                                        np.zeros(t) for t in self.obs_shapes
                                    ],
                                    action=np.zeros(self.action_size),
                                    reward=0.,
                                    local_done=False,
                                    max_reached=False,
                                    next_obs_list=[
                                        np.zeros(t) for t in self.obs_shapes
                                    ],
                                    prob=0.,
                                    rnn_state=initial_rnn_state[0])

                        action, prob, next_rnn_state = self.sac.choose_rnn_action(
                            [o.astype(np.float32) for o in obs_list],
                            action,
                            rnn_state,
                            disable_sample=self.disable_sample)
                    else:
                        action, prob = self.sac.choose_action(
                            [o.astype(np.float32) for o in obs_list],
                            disable_sample=self.disable_sample)

                    next_obs_list, reward, local_done, max_reached = self.env.step(
                        action[..., :self.d_action_size],
                        action[..., self.d_action_size:])

                    if step == self.base_config['max_step_each_iter']:
                        local_done = [True] * len(agents)
                        max_reached = [True] * len(agents)

                    episode_trans_list = [
                        agents[i].add_transition(
                            obs_list=[o[i] for o in obs_list],
                            action=action[i],
                            reward=reward[i],
                            local_done=local_done[i],
                            max_reached=max_reached[i],
                            next_obs_list=[o[i] for o in next_obs_list],
                            prob=prob[i],
                            rnn_state=rnn_state[i] if use_rnn else None)
                        for i in range(len(agents))
                    ]

                    if self.train_mode:
                        episode_trans_list = [
                            t for t in episode_trans_list if t is not None
                        ]
                        if len(episode_trans_list) != 0:
                            # ep_obses_list, ep_actions, ep_rewards, next_obs_list, ep_dones, ep_probs,
                            # ep_rnn_states
                            for episode_trans in episode_trans_list:
                                self.sac.fill_replay_buffer(*episode_trans)
                        trained_steps = self.sac.train()

                    obs_list = next_obs_list
                    action[local_done] = np.zeros(self.action_size)
                    if use_rnn:
                        rnn_state = next_rnn_state
                        rnn_state[local_done] = initial_rnn_state[local_done]

                    step += 1

            except:
                self._logger.error(traceback.format_exc())
                self._logger.error('Exiting...')
                break

            if self.train_mode:
                self._log_episode_summaries(agents)

            self._log_episode_info(iteration, time.time() - iter_time, agents)

            if self.train_mode and (
                    p := self.model_abs_dir.joinpath('save_model')).exists():
                self.sac.save_model()
                p.unlink()

            iteration += 1

        if self.train_mode:
            self.sac.save_model()
        self.env.close()