def train(self, max_training_step): for training_step in range(max_training_step): trajectory = self.worker.sample_trajectory(-1, False) obs, acs, ac_logprobs, rews, nobs, dones, values, next_values = trajectory loss = self.estimate_policy_loss(ac_logprobs, rews) TorchUtils.update_network(self.policy_network_optimizer, loss) print(sum(rews), loss) self.logger.log("episodic reward", sum(rews), training_step) self.logger.log("policy loss", loss, training_step)
def _build_network(self): model = [] last_input_dim = self.input_dim for i in range(self.num_layers): model.append(nn.Linear(last_input_dim, self.hidden_units[i])) model.append( TorchUtils.get_activation_fn(self.hidden_activations[i])()) last_input_dim = self.hidden_units[i] model.append(nn.Linear(last_input_dim, self.output_dim)) last_activation = TorchUtils.get_activation_fn(self.output_activation) if last_activation is not None: model.append(last_activation) return nn.Sequential(*model).to(self.device)
def train(self, max_training_step, timesteps_per_learning): self.policy_network.train() self.value_network.train() for training_step in range(max_training_step): trajectory = self.worker.sample_trajectory(timesteps_per_learning, False) obs, acs, ac_logprobs, rews, nobs, dones, values, next_values = trajectory rets = RLUtils.get_nstep_td_return(rews, next_values, dones, self.gamma, self.n_step) vf_loss = self.estimate_value_loss(values, rets) TorchUtils.update_network(self.value_network_optimizer, vf_loss) pg_loss = self.estimate_policy_loss(ac_logprobs, values, rets) TorchUtils.update_network(self.policy_network_optimizerloss, pg_loss) print(sum(rews)) self.logger.log("episodic reward", sum(rews), training_step) self.logger.log("policy loss", pg_loss, training_step)
parser.add_argument("--gamma", type=float, help="discounted rate", default=0.99) parser.add_argument( "--max-training-step", type=int, help="max episode size for learning", default=1000, ) parser.add_argument("--epoch", type=int, help="epoch for learning policy network", default=1) args = parser.parse_args() env = gym.make(args.env_name) input_space = env.observation_space output_space = env.action_space policy_network_factory = PolicyNetworkFactory() network_setting = MLPNetworkSetting() policy_network = policy_network_factory.get_network( input_space, output_space, network_setting, TorchUtils.get_device()) algo = REINFORCEAlgorithm(env, policy_network, args.gamma, args.lr, args.epoch) algo.train(args.max_training_step) TorchUtils.save_model(policy_network, str(algo), env.unwrapped.spec.id) env.close()
def forward(self, state): transformed_state = TorchUtils.transform_input(state) return self.model(transformed_state)
from common.network_setting import MLPNetworkSetting from common.torch_utils import TorchUtils from common.il_utils import ILUtils from algorithm.reinforcement_learning.reinforce_algorithm import REINFORCEAlgorithm if __name__ == "__main__": parser = argparse.ArgumentParser( description="reinforce algorithm for gym env") parser.add_argument("--env-name", type=str, default="CartPole-v0") parser.add_argument("--policy-load-path", type=str, default="model/REINFORCECartPole-v0") parser.add_argument("--demo-save-path", type=str, default="demodata/") parser.add_argument("--num-traj", type=int, default=100) args = parser.parse_args() env = gym.make(args.env_name) input_space = env.observation_space output_space = env.action_space policy_network_factory = PolicyNetworkFactory() network_setting = MLPNetworkSetting() policy_network = policy_network_factory.get_network( input_space, output_space, network_setting, TorchUtils.get_device()) policy_network = TorchUtils.load_model(policy_network, args.policy_load_path) save_path = args.demo_save_path + str(env) ILUtils.collect_demodata_from_model(env, policy_network, args.num_traj, save_path) env.close()
type=int, help="epoch for learning policy network", default=3) parser.add_argument("--n-step", type=int, help="n-step for td learning", default=4) args = parser.parse_args() env = gym.make(args.env_name) input_space = env.observation_space output_space = env.action_space network_setting = MLPNetworkSetting() policy_network_factory = PolicyNetworkFactory() policy_network = policy_network_factory.get_network( input_space, output_space, network_setting, TorchUtils.get_device()) value_network_factory = ValueNetworkFactory() value_network = value_network_factory.get_network(input_space, output_space, network_setting, TorchUtils.get_device()) algo = A2CAlgorithm(env, policy_network, value_network, args.gamma, args.lr, args.epoch, args.n_step) algo.train(args.max_training_step, args.timesteps_per_learning) # TorchUtils.save_model(policy_network, str(algo), env.unwrapped.spec.id) env.close()