Exemple #1
0
 def initialize_policy(self):
     if self.args.policy == 'dqn':
         q_network = FlattenMlp(input_size=self.args.augmented_obs_dim,
                                output_size=self.args.act_space.n,
                                hidden_sizes=self.args.dqn_layers).to(
                                    ptu.device)
         self.agent = DQN(
             q_network,
             # optimiser_vae=self.optimizer_vae,
             lr=self.args.policy_lr,
             gamma=self.args.gamma,
             tau=self.args.soft_target_tau,
         ).to(ptu.device)
     else:
         # assert self.args.act_space.__class__.__name__ == "Box", (
         #     "Can't train SAC with discrete action space!")
         q1_network = FlattenMlp(
             input_size=self.args.augmented_obs_dim + self.args.action_dim,
             output_size=1,
             hidden_sizes=self.args.dqn_layers).to(ptu.device)
         q2_network = FlattenMlp(
             input_size=self.args.augmented_obs_dim + self.args.action_dim,
             output_size=1,
             hidden_sizes=self.args.dqn_layers).to(ptu.device)
         policy = TanhGaussianPolicy(
             obs_dim=self.args.augmented_obs_dim,
             action_dim=self.args.action_dim,
             hidden_sizes=self.args.policy_layers).to(ptu.device)
         self.agent = SAC(
             policy,
             q1_network,
             q2_network,
             actor_lr=self.args.actor_lr,
             critic_lr=self.args.critic_lr,
             gamma=self.args.gamma,
             tau=self.args.soft_target_tau,
             use_cql=self.args.use_cql if 'use_cql' in self.args else False,
             alpha_cql=self.args.alpha_cql
             if 'alpha_cql' in self.args else None,
             entropy_alpha=self.args.entropy_alpha,
             automatic_entropy_tuning=self.args.automatic_entropy_tuning,
             alpha_lr=self.args.alpha_lr,
             clip_grad_value=self.args.clip_grad_value,
         ).to(ptu.device)
Exemple #2
0
def load_agent(args, agent_path):
    q1_network = FlattenMlp(input_size=args.obs_dim + args.action_dim,
                            output_size=1,
                            hidden_sizes=args.dqn_layers)
    q2_network = FlattenMlp(input_size=args.obs_dim + args.action_dim,
                            output_size=1,
                            hidden_sizes=args.dqn_layers)
    policy = TanhGaussianPolicy(obs_dim=args.obs_dim,
                                action_dim=args.action_dim,
                                hidden_sizes=args.policy_layers)
    agent = SAC(policy,
                q1_network,
                q2_network,
                actor_lr=args.actor_lr,
                critic_lr=args.critic_lr,
                gamma=args.gamma,
                tau=args.soft_target_tau,
                entropy_alpha=args.entropy_alpha,
                automatic_entropy_tuning=args.automatic_entropy_tuning,
                alpha_lr=args.alpha_lr).to(ptu.device)
    agent.load_state_dict(torch.load(agent_path))
    return agent
Exemple #3
0
    def initialize_policy(self):

        if self.args.policy == 'dqn':
            assert self.args.act_space.__class__.__name__ == "Discrete", (
                "Can't train DQN with continuous action space!")
            q_network = FlattenMlp(input_size=self.args.obs_dim,
                                   output_size=self.args.act_space.n,
                                   hidden_sizes=self.args.dqn_layers)
            self.agent = DQN(
                q_network,
                # optimiser_vae=self.optimizer_vae,
                lr=self.args.policy_lr,
                gamma=self.args.gamma,
                eps_init=self.args.dqn_epsilon_init,
                eps_final=self.args.dqn_epsilon_final,
                exploration_iters=self.args.dqn_exploration_iters,
                tau=self.args.soft_target_tau,
            ).to(ptu.device)
        # elif self.args.policy == 'ddqn':
        #     assert self.args.act_space.__class__.__name__ == "Discrete", (
        #         "Can't train DDQN with continuous action space!")
        #     q_network = FlattenMlp(input_size=self.args.obs_dim,
        #                            output_size=self.args.act_space.n,
        #                            hidden_sizes=self.args.dqn_layers)
        #     self.agent = DoubleDQN(
        #         q_network,
        #         # optimiser_vae=self.optimizer_vae,
        #         lr=self.args.policy_lr,
        #         eps_optim=self.args.dqn_eps,
        #         alpha_optim=self.args.dqn_alpha,
        #         gamma=self.args.gamma,
        #         eps_init=self.args.dqn_epsilon_init,
        #         eps_final=self.args.dqn_epsilon_final,
        #         exploration_iters=self.args.dqn_exploration_iters,
        #         tau=self.args.soft_target_tau,
        #     ).to(ptu.device)
        elif self.args.policy == 'sac':
            assert self.args.act_space.__class__.__name__ == "Box", (
                "Can't train SAC with discrete action space!")
            q1_network = FlattenMlp(input_size=self.args.obs_dim +
                                    self.args.action_dim,
                                    output_size=1,
                                    hidden_sizes=self.args.dqn_layers)
            q2_network = FlattenMlp(input_size=self.args.obs_dim +
                                    self.args.action_dim,
                                    output_size=1,
                                    hidden_sizes=self.args.dqn_layers)
            policy = TanhGaussianPolicy(obs_dim=self.args.obs_dim,
                                        action_dim=self.args.action_dim,
                                        hidden_sizes=self.args.policy_layers)
            self.agent = SAC(
                policy,
                q1_network,
                q2_network,
                actor_lr=self.args.actor_lr,
                critic_lr=self.args.critic_lr,
                gamma=self.args.gamma,
                tau=self.args.soft_target_tau,
                entropy_alpha=self.args.entropy_alpha,
                automatic_entropy_tuning=self.args.automatic_entropy_tuning,
                alpha_lr=self.args.alpha_lr).to(ptu.device)
        else:
            raise NotImplementedError