def initialize_policy(self): if self.args.policy == 'dqn': q_network = FlattenMlp(input_size=self.args.augmented_obs_dim, output_size=self.args.act_space.n, hidden_sizes=self.args.dqn_layers).to( ptu.device) self.agent = DQN( q_network, # optimiser_vae=self.optimizer_vae, lr=self.args.policy_lr, gamma=self.args.gamma, tau=self.args.soft_target_tau, ).to(ptu.device) else: # assert self.args.act_space.__class__.__name__ == "Box", ( # "Can't train SAC with discrete action space!") q1_network = FlattenMlp( input_size=self.args.augmented_obs_dim + self.args.action_dim, output_size=1, hidden_sizes=self.args.dqn_layers).to(ptu.device) q2_network = FlattenMlp( input_size=self.args.augmented_obs_dim + self.args.action_dim, output_size=1, hidden_sizes=self.args.dqn_layers).to(ptu.device) policy = TanhGaussianPolicy( obs_dim=self.args.augmented_obs_dim, action_dim=self.args.action_dim, hidden_sizes=self.args.policy_layers).to(ptu.device) self.agent = SAC( policy, q1_network, q2_network, actor_lr=self.args.actor_lr, critic_lr=self.args.critic_lr, gamma=self.args.gamma, tau=self.args.soft_target_tau, use_cql=self.args.use_cql if 'use_cql' in self.args else False, alpha_cql=self.args.alpha_cql if 'alpha_cql' in self.args else None, entropy_alpha=self.args.entropy_alpha, automatic_entropy_tuning=self.args.automatic_entropy_tuning, alpha_lr=self.args.alpha_lr, clip_grad_value=self.args.clip_grad_value, ).to(ptu.device)
def load_agent(args, agent_path): q1_network = FlattenMlp(input_size=args.obs_dim + args.action_dim, output_size=1, hidden_sizes=args.dqn_layers) q2_network = FlattenMlp(input_size=args.obs_dim + args.action_dim, output_size=1, hidden_sizes=args.dqn_layers) policy = TanhGaussianPolicy(obs_dim=args.obs_dim, action_dim=args.action_dim, hidden_sizes=args.policy_layers) agent = SAC(policy, q1_network, q2_network, actor_lr=args.actor_lr, critic_lr=args.critic_lr, gamma=args.gamma, tau=args.soft_target_tau, entropy_alpha=args.entropy_alpha, automatic_entropy_tuning=args.automatic_entropy_tuning, alpha_lr=args.alpha_lr).to(ptu.device) agent.load_state_dict(torch.load(agent_path)) return agent
def initialize_policy(self): if self.args.policy == 'dqn': assert self.args.act_space.__class__.__name__ == "Discrete", ( "Can't train DQN with continuous action space!") q_network = FlattenMlp(input_size=self.args.obs_dim, output_size=self.args.act_space.n, hidden_sizes=self.args.dqn_layers) self.agent = DQN( q_network, # optimiser_vae=self.optimizer_vae, lr=self.args.policy_lr, gamma=self.args.gamma, eps_init=self.args.dqn_epsilon_init, eps_final=self.args.dqn_epsilon_final, exploration_iters=self.args.dqn_exploration_iters, tau=self.args.soft_target_tau, ).to(ptu.device) # elif self.args.policy == 'ddqn': # assert self.args.act_space.__class__.__name__ == "Discrete", ( # "Can't train DDQN with continuous action space!") # q_network = FlattenMlp(input_size=self.args.obs_dim, # output_size=self.args.act_space.n, # hidden_sizes=self.args.dqn_layers) # self.agent = DoubleDQN( # q_network, # # optimiser_vae=self.optimizer_vae, # lr=self.args.policy_lr, # eps_optim=self.args.dqn_eps, # alpha_optim=self.args.dqn_alpha, # gamma=self.args.gamma, # eps_init=self.args.dqn_epsilon_init, # eps_final=self.args.dqn_epsilon_final, # exploration_iters=self.args.dqn_exploration_iters, # tau=self.args.soft_target_tau, # ).to(ptu.device) elif self.args.policy == 'sac': assert self.args.act_space.__class__.__name__ == "Box", ( "Can't train SAC with discrete action space!") q1_network = FlattenMlp(input_size=self.args.obs_dim + self.args.action_dim, output_size=1, hidden_sizes=self.args.dqn_layers) q2_network = FlattenMlp(input_size=self.args.obs_dim + self.args.action_dim, output_size=1, hidden_sizes=self.args.dqn_layers) policy = TanhGaussianPolicy(obs_dim=self.args.obs_dim, action_dim=self.args.action_dim, hidden_sizes=self.args.policy_layers) self.agent = SAC( policy, q1_network, q2_network, actor_lr=self.args.actor_lr, critic_lr=self.args.critic_lr, gamma=self.args.gamma, tau=self.args.soft_target_tau, entropy_alpha=self.args.entropy_alpha, automatic_entropy_tuning=self.args.automatic_entropy_tuning, alpha_lr=self.args.alpha_lr).to(ptu.device) else: raise NotImplementedError