Esempio n. 1
0
    def _spawn(self):
        """Intialize distributed worker, learner and centralized replay buffer."""
        replay_buffer = ReplayBuffer(
            self.hyper_params.buffer_size,
            self.hyper_params.batch_size,
        )
        per_buffer = PrioritizedBufferWrapper(
            replay_buffer, alpha=self.hyper_params.per_alpha)
        self.global_buffer = ApeXBufferWrapper.remote(per_buffer, self.args,
                                                      self.hyper_params,
                                                      self.comm_cfg)

        learner = build_learner(self.learner_cfg)
        self.learner = ApeXLearnerWrapper.remote(learner, self.comm_cfg)

        state_dict = learner.get_state_dict()
        worker_build_args = dict(args=self.args, state_dict=state_dict)

        self.workers = []
        self.num_workers = self.hyper_params.num_workers
        for rank in range(self.num_workers):
            worker_build_args["rank"] = rank
            worker = build_worker(self.worker_cfg,
                                  build_args=worker_build_args)
            apex_worker = ApeXWorkerWrapper.remote(worker, self.args,
                                                   self.comm_cfg)
            self.workers.append(apex_worker)

        self.logger = build_logger(self.logger_cfg)

        self.processes = self.workers + [
            self.learner, self.global_buffer, self.logger
        ]
Esempio n. 2
0
    def _initialize(self):
        """Initialize non-common things."""
        if not self.args.test:
            # load demo replay memory
            demos = self._load_demos()

            if self.use_n_step:
                demos, demos_n_step = common_utils.get_n_step_info_from_demo(
                    demos, self.hyper_params.n_step, self.hyper_params.gamma)

                self.memory_n = ReplayBuffer(
                    max_len=self.hyper_params.buffer_size,
                    batch_size=self.hyper_params.batch_size,
                    n_step=self.hyper_params.n_step,
                    gamma=self.hyper_params.gamma,
                    demo=demos_n_step,
                )

            # replay memory
            self.memory = ReplayBuffer(
                self.hyper_params.buffer_size,
                self.hyper_params.batch_size,
                demo=demos,
            )
            self.memory = PrioritizedBufferWrapper(
                self.memory,
                alpha=self.hyper_params.per_alpha,
                epsilon_d=self.hyper_params.per_eps_demo,
            )

        self.learner_cfg.type = "DQfDLearner"
        self.learner = build_learner(self.learner_cfg)
Esempio n. 3
0
    def __init__(
        self,
        env: gym.Env,
        env_info: ConfigDict,
        args: argparse.Namespace,
        hyper_params: ConfigDict,
        learner_cfg: ConfigDict,
        log_cfg: ConfigDict,
    ):
        """Initialize."""
        Agent.__init__(self, env, env_info, args, log_cfg)

        self.transition: list = list()
        self.episode_step = 0
        self.i_episode = 0

        self.hyper_params = hyper_params
        self.learner_cfg = learner_cfg
        self.learner_cfg.args = self.args
        self.learner_cfg.env_info = self.env_info
        self.learner_cfg.hyper_params = self.hyper_params
        self.learner_cfg.log_cfg = self.log_cfg
        self.learner_cfg.device = device

        self.learner = build_learner(self.learner_cfg)
Esempio n. 4
0
    def _initialize(self):
        """Initialize non-common things."""
        self.save_distillation_dir = None
        if not self.hyper_params.is_student:
            # Since raining teacher do not require DistillationBuffer,
            # it overloads DQNAgent._initialize.
            print("[INFO] Teacher mode.")
            DQNAgent._initialize(self)
            self.make_distillation_dir()
        else:
            # Training student or generating distillation data(test).
            print("[INFO] Student mode.")
            self.softmax_tau = 0.01

            build_args = dict(
                hyper_params=self.hyper_params,
                log_cfg=self.log_cfg,
                env_name=self.env_info.name,
                state_size=self.env_info.observation_space.shape,
                output_size=self.env_info.action_space.n,
                is_test=self.is_test,
                load_from=self.load_from,
            )
            self.learner = build_learner(self.learner_cfg, build_args)
            self.dataset_path = self.hyper_params.dataset_path

            self.memory = DistillationBuffer(self.hyper_params.batch_size,
                                             self.dataset_path)
            if self.is_test:
                self.make_distillation_dir()
Esempio n. 5
0
    def _initialize(self):
        """Initialize non-common things."""
        if not self.args.test:

            self.memory = RecurrentReplayBuffer(
                self.hyper_params.buffer_size,
                self.hyper_params.batch_size,
                self.hyper_params.sequence_size,
                self.hyper_params.overlap_size,
                n_step=self.hyper_params.n_step,
                gamma=self.hyper_params.gamma,
            )
            self.memory = PrioritizedBufferWrapper(
                self.memory, alpha=self.hyper_params.per_alpha)

            # replay memory for multi-steps
            if self.use_n_step:
                self.memory_n = RecurrentReplayBuffer(
                    self.hyper_params.buffer_size,
                    self.hyper_params.batch_size,
                    self.hyper_params.sequence_size,
                    self.hyper_params.overlap_size,
                    n_step=self.hyper_params.n_step,
                    gamma=self.hyper_params.gamma,
                )

        self.learner = build_learner(self.learner_cfg)
Esempio n. 6
0
    def _initialize(self):
        """Initialize non-common things."""
        if not self.is_test:
            # replay memory for a single step
            self.memory = ReplayBuffer(
                self.hyper_params.buffer_size,
                self.hyper_params.batch_size,
            )
            self.memory = PrioritizedBufferWrapper(
                self.memory, alpha=self.hyper_params.per_alpha)

            # replay memory for multi-steps
            if self.use_n_step:
                self.memory_n = ReplayBuffer(
                    self.hyper_params.buffer_size,
                    self.hyper_params.batch_size,
                    n_step=self.hyper_params.n_step,
                    gamma=self.hyper_params.gamma,
                )

        build_args = dict(
            hyper_params=self.hyper_params,
            log_cfg=self.log_cfg,
            env_name=self.env_info.name,
            state_size=self.env_info.observation_space.shape,
            output_size=self.env_info.action_space.n,
            is_test=self.is_test,
            load_from=self.load_from,
        )
        self.learner = build_learner(self.learner_cfg, build_args)
Esempio n. 7
0
    def _initialize(self):
        """Initialize non-common things."""
        self.per_beta = self.hyper_params.per_beta

        self.use_n_step = self.hyper_params.n_step > 1

        if not self.args.test:
            # load demo replay memory
            with open(self.args.demo_path, "rb") as f:
                demos = pickle.load(f)

            if self.use_n_step:
                demos, demos_n_step = common_utils.get_n_step_info_from_demo(
                    demos, self.hyper_params.n_step, self.hyper_params.gamma)

                # replay memory for multi-steps
                self.memory_n = ReplayBuffer(
                    max_len=self.hyper_params.buffer_size,
                    batch_size=self.hyper_params.batch_size,
                    n_step=self.hyper_params.n_step,
                    gamma=self.hyper_params.gamma,
                    demo=demos_n_step,
                )

            # replay memory for a single step
            self.memory = ReplayBuffer(
                self.hyper_params.buffer_size,
                self.hyper_params.batch_size,
            )
            self.memory = PrioritizedBufferWrapper(
                self.memory, alpha=self.hyper_params.per_alpha)

        self.learner_cfg.type = "DDPGfDLearner"
        self.learner = build_learner(self.learner_cfg)
Esempio n. 8
0
    def _initialize(self):
        """Initialize non-common things."""
        if not self.args.test:
            # replay memory
            self.memory = ReplayBuffer(self.hyper_params.buffer_size,
                                       self.hyper_params.batch_size)

        self.learner = build_learner(self.learner_cfg)
Esempio n. 9
0
    def _initialize(self):
        """Initialize non-common things."""
        # load demo replay memory
        with open(self.args.demo_path, "rb") as f:
            demo = list(pickle.load(f))

        # HER
        if self.hyper_params.use_her:
            self.her = build_her(self.hyper_params.her)
            print(f"[INFO] Build {str(self.her)}.")

            if self.hyper_params.desired_states_from_demo:
                self.her.fetch_desired_states_from_demo(demo)

            self.transitions_epi: list = list()
            self.desired_state = np.zeros((1, ))
            demo = self.her.generate_demo_transitions(demo)

            if not self.her.is_goal_in_state:
                self.state_dim = (self.state_dim[0] * 2, )
        else:
            self.her = None

        if not self.args.test:
            # Replay buffers
            demo_batch_size = self.hyper_params.demo_batch_size
            self.demo_memory = ReplayBuffer(len(demo), demo_batch_size)
            self.demo_memory.extend(demo)

            self.memory = ReplayBuffer(self.hyper_params.sac_buffer_size,
                                       demo_batch_size)

            # set hyper parameters
            self.hyper_params["lambda2"] = 1.0 / demo_batch_size

        self.args.cfg_path = self.args.offer_cfg_path
        self.args.load_from = self.args.load_offer_from
        self.hyper_params.buffer_size = self.hyper_params.sac_buffer_size
        self.hyper_params.batch_size = self.hyper_params.sac_batch_size

        self.learner_cfg.type = "BCSACLearner"
        self.learner_cfg.hyper_params = self.hyper_params

        self.learner = build_learner(self.learner_cfg)

        del self.hyper_params.buffer_size
        del self.hyper_params.batch_size

        # init stack
        self.stack_size = self.args.stack_size
        self.stack_buffer = deque(maxlen=self.args.stack_size)
        self.stack_buffer_2 = deque(maxlen=self.args.stack_size)

        self.scores = list()
        self.utilities = list()
        self.rounds = list()
        self.opp_utilities = list()
Esempio n. 10
0
    def __init__(
        self,
        env: gym.Env,
        env_info: ConfigDict,
        hyper_params: ConfigDict,
        learner_cfg: ConfigDict,
        log_cfg: ConfigDict,
        is_test: bool,
        load_from: str,
        is_render: bool,
        render_after: int,
        is_log: bool,
        save_period: int,
        episode_num: int,
        max_episode_steps: int,
        interim_test_num: int,
    ):
        """Initialize."""
        Agent.__init__(
            self,
            env,
            env_info,
            log_cfg,
            is_test,
            load_from,
            is_render,
            render_after,
            is_log,
            save_period,
            episode_num,
            max_episode_steps,
            interim_test_num,
        )

        self.transition: list = list()
        self.episode_step = 0
        self.i_episode = 0

        self.hyper_params = hyper_params
        self.learner_cfg = learner_cfg

        build_args = dict(
            hyper_params=self.hyper_params,
            log_cfg=self.log_cfg,
            env_name=self.env_info.name,
            state_size=self.env_info.observation_space.shape,
            output_size=self.env_info.action_space.shape[0],
            is_test=self.is_test,
            load_from=self.load_from,
        )
        self.learner = build_learner(self.learner_cfg, build_args)
Esempio n. 11
0
    def __init__(
        self,
        env: gym.Env,
        env_info: ConfigDict,
        hyper_params: ConfigDict,
        learner_cfg: ConfigDict,
        log_cfg: ConfigDict,
        is_test: bool,
        load_from: str,
        is_render: bool,
        render_after: int,
        is_log: bool,
        save_period: int,
        episode_num: int,
        max_episode_steps: int,
        interim_test_num: int,
    ):
        Agent.__init__(
            self,
            env,
            env_info,
            log_cfg,
            is_test,
            load_from,
            is_render,
            render_after,
            is_log,
            save_period,
            episode_num,
            max_episode_steps,
            interim_test_num,
        )
        self.episode_step = 0
        self.i_episode = 0

        self.episode_num = episode_num

        self.hyper_params = hyper_params
        self.learner_cfg = learner_cfg

        build_args = dict(
            hyper_params=hyper_params,
            log_cfg=log_cfg,
            env_info=env_info,
            is_test=is_test,
            load_from=load_from,
        )
        self.learner = build_learner(self.learner_cfg, build_args)
        self.memory = ReplayMemory(
            self.hyper_params.buffer_size, self.hyper_params.n_rollout
        )
Esempio n. 12
0
    def __init__(
        self,
        env: gym.Env,
        env_info: ConfigDict,
        args: argparse.Namespace,
        hyper_params: ConfigDict,
        learner_cfg: ConfigDict,
        noise_cfg: ConfigDict,
        log_cfg: ConfigDict,
    ):
        """Initialize.

        Args:
            env (gym.Env): openAI Gym environment
            args (argparse.Namespace): arguments including hyperparameters and training settings

        """
        Agent.__init__(self, env, env_info, args, log_cfg)

        self.curr_state = np.zeros((1,))
        self.total_step = 0
        self.episode_step = 0
        self.update_step = 0
        self.i_episode = 0

        self.hyper_params = hyper_params
        self.learner_cfg = learner_cfg
        self.learner_cfg.args = self.args
        self.learner_cfg.env_info = self.env_info
        self.learner_cfg.hyper_params = self.hyper_params
        self.learner_cfg.log_cfg = self.log_cfg
        self.learner_cfg.noise_cfg = noise_cfg
        self.learner_cfg.device = device

        # noise instance to make randomness of action
        self.exploration_noise = GaussianNoise(
            self.env_info.action_space.shape[0],
            noise_cfg.exploration_noise,
            noise_cfg.exploration_noise,
        )

        if not self.args.test:
            # replay memory
            self.memory = ReplayBuffer(
                self.hyper_params.buffer_size, self.hyper_params.batch_size
            )

        self.learner = build_learner(self.learner_cfg)
Esempio n. 13
0
    def _initialize(self):
        """Initialize non-common things."""
        # load demo replay memory
        with open(self.hyper_params.demo_path, "rb") as f:
            demo = list(pickle.load(f))

        # HER
        if self.hyper_params.use_her:
            self.her = build_her(self.hyper_params.her)
            print(f"[INFO] Build {str(self.her)}.")

            if self.hyper_params.desired_states_from_demo:
                self.her.fetch_desired_states_from_demo(demo)

            self.transitions_epi: list = list()
            self.desired_state = np.zeros((1, ))
            demo = self.her.generate_demo_transitions(demo)

            if not self.her.is_goal_in_state:
                self.env_info.observation_space.shape = (
                    self.self.env_info.observation_space.shape[0] * 2, )
        else:
            self.her = None

        if not self.is_test:
            # Replay buffers
            demo_batch_size = self.hyper_params.demo_batch_size
            self.demo_memory = ReplayBuffer(len(demo), demo_batch_size)
            self.demo_memory.extend(demo)

            self.memory = ReplayBuffer(self.hyper_params.buffer_size,
                                       self.hyper_params.batch_size)

            # set hyper parameters
            self.hyper_params["lambda2"] = 1.0 / demo_batch_size

        build_args = dict(
            hyper_params=self.hyper_params,
            log_cfg=self.log_cfg,
            noise_cfg=self.noise_cfg,
            env_name=self.env_info.name,
            state_size=self.env_info.observation_space.shape,
            output_size=self.env_info.action_space.shape[0],
            is_test=self.is_test,
            load_from=self.load_from,
        )
        self.learner = build_learner(self.learner_cfg, build_args)
Esempio n. 14
0
    def _initialize(self):
        """Initialize non-common things."""
        self.per_beta = self.hyper_params.per_beta
        self.use_n_step = self.hyper_params.n_step > 1

        if not self.is_test:
            # load demo replay memory
            with open(self.hyper_params.demo_path, "rb") as f:
                demos = pickle.load(f)

            if self.use_n_step:
                demos, demos_n_step = common_utils.get_n_step_info_from_demo(
                    demos, self.hyper_params.n_step, self.hyper_params.gamma)

                # replay memory for multi-steps
                self.memory_n = ReplayBuffer(
                    max_len=self.hyper_params.buffer_size,
                    batch_size=self.hyper_params.batch_size,
                    n_step=self.hyper_params.n_step,
                    gamma=self.hyper_params.gamma,
                    demo=demos_n_step,
                )

            # replay memory for a single step
            self.memory = ReplayBuffer(
                self.hyper_params.buffer_size,
                self.hyper_params.batch_size,
                demo=demos,
            )
            self.memory = PrioritizedBufferWrapper(
                self.memory,
                alpha=self.hyper_params.per_alpha,
                epsilon_d=self.hyper_params.per_eps_demo,
            )

        build_args = dict(
            hyper_params=self.hyper_params,
            log_cfg=self.log_cfg,
            noise_cfg=self.noise_cfg,
            env_name=self.env_info.name,
            state_size=self.env_info.observation_space.shape,
            output_size=self.env_info.action_space.shape[0],
            is_test=self.is_test,
            load_from=self.load_from,
        )
        self.learner = build_learner(self.learner_cfg, build_args)
Esempio n. 15
0
    def _initialize(self):
        """Initialize non-common things."""
        self.softmax_tau = 0.01
        self.learner = build_learner(self.learner_cfg)

        self.buffer_path = (
            f"./data/distillation_buffer/{self.log_cfg.env_name}/" +
            f"{self.log_cfg.agent}/{self.log_cfg.curr_time}/")
        if self.args.distillation_buffer_path:
            self.buffer_path = "./" + self.args.distillation_buffer_path
        os.makedirs(self.buffer_path, exist_ok=True)

        self.memory = DistillationBuffer(
            self.hyper_params.batch_size,
            self.buffer_path,
            self.log_cfg.curr_time,
        )
Esempio n. 16
0
    def __init__(
        self,
        env: gym.Env,
        env_info: ConfigDict,
        args: argparse.Namespace,
        hyper_params: ConfigDict,
        learner_cfg: ConfigDict,
        log_cfg: ConfigDict,
    ):
        """Initialize.

        Args:
            env (gym.Env): openAI Gym environment
            args (argparse.Namespace): arguments including hyperparameters and training settings

        """
        env_gen = env_generator(env.spec.id, args)
        env_multi = make_envs(env_gen, n_envs=hyper_params.n_workers)

        Agent.__init__(self, env, env_info, args, log_cfg)

        self.episode_steps = np.zeros(hyper_params.n_workers, dtype=np.int)
        self.states: list = []
        self.actions: list = []
        self.rewards: list = []
        self.values: list = []
        self.masks: list = []
        self.log_probs: list = []
        self.i_episode = 0
        self.next_state = np.zeros((1, ))

        self.hyper_params = hyper_params
        self.learner_cfg = learner_cfg
        self.learner_cfg.args = self.args
        self.learner_cfg.env_info = self.env_info
        self.learner_cfg.hyper_params = self.hyper_params
        self.learner_cfg.log_cfg = self.log_cfg
        self.learner_cfg.device = device

        if not self.args.test:
            self.env = env_multi

        self.epsilon = hyper_params.max_epsilon

        self.learner = build_learner(self.learner_cfg)
Esempio n. 17
0
    def _initialize(self):
        """Initialize non-common things."""
        if not self.is_test:
            # replay memory
            self.memory = ReplayBuffer(self.hyper_params.buffer_size,
                                       self.hyper_params.batch_size)

        build_args = dict(
            hyper_params=self.hyper_params,
            log_cfg=self.log_cfg,
            noise_cfg=self.noise_cfg,
            env_name=self.env_info.name,
            state_size=self.env_info.observation_space.shape,
            output_size=self.env_info.action_space.shape[0],
            is_test=self.is_test,
            load_from=self.load_from,
        )
        self.learner = build_learner(self.learner_cfg, build_args)
Esempio n. 18
0
    def _initialize(self):
        """Initialize non-common things."""
        self.save_distillation_dir = None
        if not self.args.student:
            # Since raining teacher do not require DistillationBuffer,
            # it overloads DQNAgent._initialize.

            DQNAgent._initialize(self)
            self.make_distillation_dir()
        else:
            # Training student or generating distillation data(test).

            self.softmax_tau = 0.01
            self.learner = build_learner(self.learner_cfg)
            self.dataset_path = self.hyper_params.dataset_path

            self.memory = DistillationBuffer(self.hyper_params.batch_size,
                                             self.dataset_path)
            if self.args.test:
                self.make_distillation_dir()
Esempio n. 19
0
    def _initialize(self):
        """Initialize non-common things."""
        if not self.is_test:
            # load demo replay memory
            demos = self._load_demos()

            if self.use_n_step:
                demos, demos_n_step = common_utils.get_n_step_info_from_demo(
                    demos, self.hyper_params.n_step, self.hyper_params.gamma)

                self.memory_n = ReplayBuffer(
                    max_len=self.hyper_params.buffer_size,
                    batch_size=self.hyper_params.batch_size,
                    n_step=self.hyper_params.n_step,
                    gamma=self.hyper_params.gamma,
                    demo=demos_n_step,
                )

            # replay memory
            self.memory = ReplayBuffer(
                self.hyper_params.buffer_size,
                self.hyper_params.batch_size,
                demo=demos,
            )
            self.memory = PrioritizedBufferWrapper(
                self.memory,
                alpha=self.hyper_params.per_alpha,
                epsilon_d=self.hyper_params.per_eps_demo,
            )

        build_args = dict(
            hyper_params=self.hyper_params,
            log_cfg=self.log_cfg,
            env_name=self.env_info.name,
            state_size=self.env_info.observation_space.shape,
            output_size=self.env_info.action_space.n,
            is_test=self.is_test,
            load_from=self.load_from,
        )
        self.learner_cfg.type = "DQfDLearner"
        self.learner = build_learner(self.learner_cfg, build_args)
Esempio n. 20
0
    def _spawn(self):
        """Intialize distributed worker, learner and centralized replay buffer."""
        replay_buffer = ReplayBuffer(
            self.hyper_params.buffer_size,
            self.hyper_params.batch_size,
        )
        per_buffer = PrioritizedBufferWrapper(
            replay_buffer, alpha=self.hyper_params.per_alpha)
        self.global_buffer = ApeXBufferWrapper.remote(per_buffer,
                                                      self.hyper_params,
                                                      self.comm_cfg)

        # Build learner
        learner_build_args = dict(
            hyper_params=self.hyper_params,
            log_cfg=self.log_cfg,
            env_name=self.env_info.name,
            state_size=self.env_info.observation_space.shape,
            output_size=self.env_info.action_space.n,
            is_test=self.is_test,
            load_from=self.load_from,
        )
        learner = build_learner(self.learner_cfg, learner_build_args)
        self.learner = ApeXLearnerWrapper.remote(learner, self.comm_cfg)

        # Build workers
        state_dict = learner.get_state_dict()
        worker_build_args = dict(
            hyper_params=self.hyper_params,
            backbone=self.learner_cfg.backbone,
            head=self.learner_cfg.head,
            loss_type=self.learner_cfg.loss_type,
            state_dict=state_dict,
            env_name=self.env_info.name,
            state_size=self.env_info.observation_space.shape,
            output_size=self.env_info.action_space.n,
            is_atari=self.env_info.is_atari,
            max_episode_steps=self.max_episode_steps,
        )
        self.workers = []
        self.num_workers = self.hyper_params.num_workers
        for rank in range(self.num_workers):
            worker_build_args["rank"] = rank
            worker = build_worker(self.worker_cfg,
                                  build_args=worker_build_args)
            apex_worker = ApeXWorkerWrapper.remote(worker, self.comm_cfg)
            self.workers.append(apex_worker)

        # Build logger
        logger_build_args = dict(
            log_cfg=self.log_cfg,
            comm_cfg=self.comm_cfg,
            backbone=self.learner_cfg.backbone,
            head=self.learner_cfg.head,
            env_name=self.env_info.name,
            is_atari=self.env_info.is_atari,
            state_size=self.env_info.observation_space.shape,
            output_size=self.env_info.action_space.n,
            max_update_step=self.hyper_params.max_update_step,
            episode_num=self.episode_num,
            max_episode_steps=self.max_episode_steps,
            is_log=self.is_log,
            is_render=self.is_render,
            interim_test_num=self.interim_test_num,
        )

        self.logger = build_logger(self.logger_cfg, logger_build_args)

        self.processes = self.workers + [
            self.learner, self.global_buffer, self.logger
        ]
Esempio n. 21
0
    def __init__(
        self,
        env: gym.Env,
        env_info: ConfigDict,
        hyper_params: ConfigDict,
        learner_cfg: ConfigDict,
        log_cfg: ConfigDict,
        is_test: bool,
        load_from: str,
        is_render: bool,
        render_after: int,
        is_log: bool,
        save_period: int,
        episode_num: int,
        max_episode_steps: int,
        interim_test_num: int,
    ):

        Agent.__init__(
            self,
            env,
            env_info,
            log_cfg,
            is_test,
            load_from,
            is_render,
            render_after,
            is_log,
            save_period,
            episode_num,
            max_episode_steps,
            interim_test_num,
        )

        env_multi = (env if is_test else self.make_parallel_env(
            max_episode_steps, hyper_params.n_workers))

        self.episode_steps = np.zeros(hyper_params.n_workers, dtype=np.int)
        self.states: list = []
        self.actions: list = []
        self.rewards: list = []
        self.values: list = []
        self.masks: list = []
        self.log_probs: list = []
        self.i_episode = 0
        self.next_state = np.zeros((1, ))

        self.hyper_params = hyper_params
        self.learner_cfg = learner_cfg

        if not self.is_test:
            self.env = env_multi

        self.epsilon = hyper_params.max_epsilon

        output_size = (self.env_info.action_space.n if self.is_discrete else
                       self.env_info.action_space.shape[0])

        build_args = dict(
            hyper_params=self.hyper_params,
            log_cfg=self.log_cfg,
            env_name=self.env_info.name,
            state_size=self.env_info.observation_space.shape,
            output_size=output_size,
            is_test=self.is_test,
            load_from=self.load_from,
        )
        self.learner = build_learner(self.learner_cfg, build_args)