Пример #1
0
    def _initialize(self):
        """Initialize agent components."""
        # Define env specific model params
        self.experiment_info.env.state_dim = self.env.observation_space.shape[
            0]
        self.experiment_info.env.action_dim = self.env.action_space.shape[0]
        self.experiment_info.env.action_range = [
            self.env.action_space.low.tolist(),
            self.env.action_space.high.tolist(),
        ]

        # Build learner
        self.learner = build_learner(self.experiment_info, self.hyper_params,
                                     self.model_cfg)

        # Build replay buffer, wrap with PER buffer if using it
        self.replay_buffer = ReplayBuffer(self.hyper_params)
        if self.hyper_params.use_per:
            self.replay_buffer = PrioritizedReplayBuffer(
                self.replay_buffer, self.hyper_params)

        # Build action selector
        self.action_selector = build_action_selector(self.experiment_info)
        self.action_selector = OUNoise(self.action_selector,
                                       self.env.action_space)

        # Build logger
        if self.experiment_info.log_wandb:
            experiment_cfg = OmegaConf.create(
                dict(
                    experiment_info=self.experiment_info,
                    hyper_params=self.hyper_params,
                    model=self.learner.model_cfg,
                ))
            self.logger = Logger(experiment_cfg)
Пример #2
0
    def _initialize(self):
        """Initialize agent components"""
        # set env specific model params
        self.model_cfg.params.model_cfg.state_dim = self.env.observation_space.shape
        self.model_cfg.params.model_cfg.action_dim = self.env.action_space.n

        # Build learner
        self.learner = build_learner(
            self.experiment_info, self.hyper_params, self.model_cfg
        )

        # Build replay buffer, wrap with PER buffer if using it
        self.replay_buffer = ReplayBuffer(self.hyper_params)
        if self.hyper_params.use_per:
            self.replay_buffer = PrioritizedReplayBuffer(
                self.replay_buffer, self.hyper_params
            )

        # Build action selector, wrap with e-greedy exploration
        self.action_selector = build_action_selector(self.experiment_info)
        self.action_selector = EpsGreedy(
            self.action_selector, self.env.action_space, self.hyper_params
        )

        if self.experiment_info.log_wandb:
            experiment_cfg = OmegaConf.create(
                dict(
                    experiment_info=self.experiment_info,
                    hyper_params=self.hyper_params,
                    model=self.learner.model_cfg,
                )
            )
            self.logger = Logger(experiment_cfg)
Пример #3
0
    def _initialize(self):
        """Set env specific configs and build learner."""
        self.experiment_info.env.state_dim = self.env.observation_space.shape[
            0]
        if self.experiment_info.env.is_discrete:
            self.experiment_info.env.action_dim = self.env.action_space.n
        else:
            self.experiment_info.env.action_dim = self.env.action_space.shape[
                0]
            self.experiment_info.env.action_range = [
                self.env.action_space.low.tolist(),
                self.env.action_space.high.tolist(),
            ]

        self.learner = build_learner(self.experiment_info, self.hyper_params,
                                     self.model_cfg)

        self.action_selector = build_action_selector(self.experiment_info,
                                                     self.use_cuda)

        # Build logger
        if self.experiment_info.log_wandb:
            experiment_cfg = OmegaConf.create(
                dict(
                    experiment_info=self.experiment_info,
                    hyper_params=self.hyper_params,
                    model=self.learner.model_cfg,
                ))
            self.logger = Logger(experiment_cfg)
Пример #4
0
def main(cfg: DictConfig):
    # print all configs
    print(cfg.pretty())

    # build env
    print("===INITIALIZING ENV===")
    env = build_env(cfg.experiment_info)
    print(env.reset())
    print("=================")

    # build model
    print("===INITIALIZING MODEL===")
    cfg.model.params.model_cfg.state_dim = env.observation_space.shape
    cfg.model.params.model_cfg.action_dim = env.action_space.n
    cfg.model.params.model_cfg.fc.output.params.output_size = env.action_space.n
    model = build_model(cfg.model)
    test_input = torch.FloatTensor(env.reset()).unsqueeze(0)
    print(model)
    print(model.forward(test_input))
    print("===================")

    # build action_selector
    print("===INITIALIZING ACTION SELECTOR===")
    action_selector = build_action_selector(cfg.experiment_info)
    print(action_selector)
    print("==============================")

    # build loss
    print("===INITIALIZING LOSS===")
    loss = build_loss(cfg.experiment_info)
    print(loss)
    print("==================")

    # build learner
    print("===INITIALIZING LEARNER===")
    learner = build_learner(**cfg)
    print(learner)
    print("=====================")

    # build agent
    print("===INITIALIZING AGENT===")
    agent = build_agent(**cfg)
    print(agent)
    print("=====================")