def __init__(self, actor: torch.nn.Module, critic: torch.nn.Module, optim: torch.optim.Optimizer, dist_fn: Type[torch.distributions.Distribution], vf_coef: float = 0.5, ent_coef: float = 0.01, max_grad_norm: Optional[float] = None, gae_lambda: float = 0.95, max_batchsize: int = 256, **kwargs: Any) -> None: super().__init__(actor, optim, dist_fn, **kwargs) self.critic = critic assert 0.0 <= gae_lambda <= 1.0, "GAE lambda should be in [0, 1]." self._lambda = gae_lambda self._weight_vf = vf_coef self._weight_ent = ent_coef self._grad_norm = max_grad_norm self._batch = max_batchsize self._actor_critic = ActorCritic(self.actor, self.critic)
def main(cfg): # Single env just used to determine state and actions spaces env = hydra.utils.instantiate(cfg.env) train_envs = ts.env.DummyVectorEnv([ lambda: hydra.utils.instantiate(cfg.env) for _ in range(cfg.vec_env.n_train_envs) ]) test_envs = ts.env.DummyVectorEnv([ lambda: hydra.utils.instantiate(cfg.env) for _ in range(cfg.vec_env.n_test_envs) ]) # Seed everything for reproducibility # Train and test encs can either take int or list same length as n_envs np.random.seed(cfg.seed) torch.manual_seed(cfg.seed) train_envs.seed(cfg.vec_env.train_env_seed) test_envs.seed(cfg.vec_env.test_env_seed) # Set up the actor and critic networks state_shape = env.observation_space.shape or env.observation_space.n action_shape = env.action_space.shape or env.action_space.n net_a = hydra.utils.instantiate(cfg.actor_network, state_shape=state_shape, activation=nn.Tanh) actor = Actor(net_a, action_shape=action_shape, device=cfg.device).to(cfg.device) net_c = hydra.utils.instantiate(cfg.critic_network, state_shape=state_shape, activation=nn.Tanh) critic = Critic(net_c, device=cfg.device).to(cfg.device) optim = hydra.utils.instantiate(cfg.optimiser, params=ActorCritic(actor, critic).parameters()) policy = hydra.utils.instantiate( cfg.policy, actor=actor, critic=critic, optim=optim, dist_fn=dist_fn, action_space=env.action_space, ) logger = hydra.utils.instantiate(cfg.logger) conf_dict = OmegaConf.to_container(cfg, resolve=True) logger.log_hyperparameters(conf_dict) # Log Hydra config files as artifacts for easy replication logger.experiment.log_artifacts(logger.run_id, cfg.hydra_logdir) # Set up checkpointing if needed if cfg.checkpoints.save_checkpoints: cp_path = Path(cfg.checkpoints.path) cp_path.mkdir(parents=True, exist_ok=True) best_models_path = cp_path / "best_models" best_models_path.mkdir(parents=True, exist_ok=True) # Runs when we have a new best mean evaluation reward def save_fn(policy, cp_path=cp_path, mlflow_logger=logger): now_time = datetime.strftime(datetime.now(), "%H-%M-%S") model_checkpoint_path = cp_path / f"best_models/policy_{now_time}.pt" torch.save(policy.state_dict(), model_checkpoint_path) mlflow_logger.experiment.log_artifact(mlflow_logger.run_id, model_checkpoint_path, "best_models") # Runs at frequency based on save_interval argument to logger def save_checkpoint_fn(epoch, env_step, gradient_step, cp_path=cp_path): checkpoint_dir = cp_path / f"epoch_{epoch:04}" checkpoint_dir.mkdir(parents=True, exist_ok=True) checkpoint_filepath = checkpoint_dir / "policy.pt" torch.save({"model": policy.state_dict()}, checkpoint_filepath) return checkpoint_filepath else: save_fn = None save_checkpoint_fn = None train_collector = hydra.utils.instantiate(cfg.train_collector, policy=policy, env=train_envs) test_collector = hydra.utils.instantiate( cfg.test_collector, policy=policy, env=test_envs, preprocess_fn=logger.info_logger.preprocess_fn, ) result = hydra.utils.call( cfg.trainer, policy=policy, train_collector=train_collector, test_collector=test_collector, logger=logger, save_fn=save_fn, save_checkpoint_fn=save_checkpoint_fn, ) print(f'Finished training in {result["duration"]}') logger.close()
def test_ppo(args=get_args()): env = gym.make(args.task) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: default_reward_threshold = {"CartPole-v0": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold) # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.training_num)]) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor(net, args.action_shape, device=args.device).to(args.device) critic = Critic(net, device=args.device).to(args.device) actor_critic = ActorCritic(actor, critic) # orthogonal initialization for m in actor_critic.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight) torch.nn.init.zeros_(m.bias) optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) dist = torch.distributions.Categorical policy = PPOPolicy(actor, critic, optim, dist, discount_factor=args.gamma, max_grad_norm=args.max_grad_norm, eps_clip=args.eps_clip, vf_coef=args.vf_coef, ent_coef=args.ent_coef, gae_lambda=args.gae_lambda, reward_normalization=args.rew_norm, dual_clip=args.dual_clip, value_clip=args.value_clip, action_space=env.action_space, deterministic_eval=True, advantage_normalization=args.norm_adv, recompute_advantage=args.recompute_adv) feature_dim = args.hidden_sizes[-1] feature_net = MLP(np.prod(args.state_shape), output_dim=feature_dim, hidden_sizes=args.hidden_sizes[:-1], device=args.device) action_dim = np.prod(args.action_shape) icm_net = IntrinsicCuriosityModule(feature_net, feature_dim, action_dim, hidden_sizes=args.hidden_sizes[-1:], device=args.device).to(args.device) icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) policy = ICMPolicy(policy, icm_net, icm_optim, args.lr_scale, args.reward_scale, args.forward_loss_weight) # collector train_collector = Collector( policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'ppo_icm') writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): return mean_rewards >= args.reward_threshold # trainer result = onpolicy_trainer(policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, step_per_collect=args.step_per_collect, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) rews, lens = result["rews"], result["lens"] print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
def test_ppo(args=get_args()): env, train_envs, test_envs = make_atari_env( args.task, args.seed, args.training_num, args.test_num, scale=args.scale_obs, frame_stack=args.frames_stack, ) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) # define model net = DQN(*args.state_shape, args.action_shape, device=args.device, features_only=True, output_dim=args.hidden_size) actor = Actor(net, args.action_shape, device=args.device, softmax_output=False) critic = Critic(net, device=args.device) optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr) lr_scheduler = None if args.lr_decay: # decay learning rate to 0 linearly max_update_num = np.ceil( args.step_per_epoch / args.step_per_collect) * args.epoch lr_scheduler = LambdaLR( optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) # define policy def dist(p): return torch.distributions.Categorical(logits=p) policy = PPOPolicy( actor, critic, optim, dist, discount_factor=args.gamma, gae_lambda=args.gae_lambda, max_grad_norm=args.max_grad_norm, vf_coef=args.vf_coef, ent_coef=args.ent_coef, reward_normalization=args.rew_norm, action_scaling=False, lr_scheduler=lr_scheduler, action_space=env.action_space, eps_clip=args.eps_clip, value_clip=args.value_clip, dual_clip=args.dual_clip, advantage_normalization=args.norm_adv, recompute_advantage=args.recompute_adv, ).to(args.device) if args.icm_lr_scale > 0: feature_net = DQN(*args.state_shape, args.action_shape, args.device, features_only=True) action_dim = np.prod(args.action_shape) feature_dim = feature_net.output_dim icm_net = IntrinsicCuriosityModule( feature_net.net, feature_dim, action_dim, hidden_sizes=args.hidden_sizes, device=args.device, ) icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) policy = ICMPolicy(policy, icm_net, icm_optim, args.icm_lr_scale, args.icm_reward_scale, args.icm_forward_loss_weight).to(args.device) # load a previous policy if args.resume_path: policy.load_state_dict( torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = VectorReplayBuffer( args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True, save_only_last_obs=True, stack_num=args.frames_stack, ) # collector train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") args.algo_name = "ppo_icm" if args.icm_lr_scale > 0 else "ppo" log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) log_path = os.path.join(args.logdir, log_name) # logger if args.logger == "wandb": logger = WandbLogger( save_interval=1, name=log_name.replace(os.path.sep, "__"), run_id=args.resume_id, config=args, project=args.wandb_project, ) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) if args.logger == "tensorboard": logger = TensorboardLogger(writer) else: # wandb logger.load(writer) def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards): if env.spec.reward_threshold: return mean_rewards >= env.spec.reward_threshold elif "Pong" in args.task: return mean_rewards >= 20 else: return False def save_checkpoint_fn(epoch, env_step, gradient_step): # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html ckpt_path = os.path.join(log_path, "checkpoint.pth") torch.save({"model": policy.state_dict()}, ckpt_path) return ckpt_path # watch agent's performance def watch(): print("Setup test envs ...") policy.eval() test_envs.seed(args.seed) if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") buffer = VectorReplayBuffer( args.buffer_size, buffer_num=len(test_envs), ignore_obs_next=True, save_only_last_obs=True, stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) rew = result["rews"].mean() print(f"Mean reward (over {result['n/ep']} episodes): {rew}") if args.watch: watch() exit(0) # test train_collector and start filling replay buffer train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, step_per_collect=args.step_per_collect, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, test_in_train=False, resume_from_log=args.resume_id is not None, save_checkpoint_fn=save_checkpoint_fn, ) pprint.pprint(result) watch()
def test_a2c_with_il(args=get_args()): # if you want to use python vector env, please refer to other test scripts train_envs = env = envpool.make_gym(args.task, num_envs=args.training_num, seed=args.seed) test_envs = envpool.make_gym(args.task, num_envs=args.test_num, seed=args.seed) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: default_reward_threshold = {"CartPole-v0": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) # model net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor(net, args.action_shape, device=args.device).to(args.device) critic = Critic(net, device=args.device).to(args.device) optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr) dist = torch.distributions.Categorical policy = A2CPolicy(actor, critic, optim, dist, discount_factor=args.gamma, gae_lambda=args.gae_lambda, vf_coef=args.vf_coef, ent_coef=args.ent_coef, max_grad_norm=args.max_grad_norm, reward_normalization=args.rew_norm, action_space=env.action_space) # collector train_collector = Collector( policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'a2c') writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): return mean_rewards >= args.reward_threshold # trainer result = onpolicy_trainer(policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) rews, lens = result["rews"], result["lens"] print(f"Final reward: {rews.mean()}, length: {lens.mean()}") policy.eval() # here we define an imitation collector with a trivial policy # if args.task == 'CartPole-v0': # env.spec.reward_threshold = 190 # lower the goal net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) net = Actor(net, args.action_shape, device=args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.il_lr) il_policy = ImitationPolicy(net, optim, action_space=env.action_space) il_test_collector = Collector( il_policy, envpool.make_gym(args.task, num_envs=args.test_num, seed=args.seed), ) train_collector.reset() result = offpolicy_trainer(il_policy, train_collector, il_test_collector, args.epoch, args.il_step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) il_policy.eval() collector = Collector(il_policy, env) result = collector.collect(n_episode=1, render=args.render) rews, lens = result["rews"], result["lens"] print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
def test_ppo(args=get_args()): env = gym.make(args.task) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n args.max_action = env.action_space.high[0] if args.reward_threshold is None: default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold) # you can also use tianshou.env.SubprocVectorEnv # train_envs = gym.make(args.task) train_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.training_num)]) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = ActorProb(net, args.action_shape, max_action=args.max_action, device=args.device).to(args.device) critic = Critic(Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device), device=args.device).to(args.device) actor_critic = ActorCritic(actor, critic) # orthogonal initialization for m in actor_critic.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight) torch.nn.init.zeros_(m.bias) optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) # replace DiagGuassian with Independent(Normal) which is equivalent # pass *logits to be consistent with policy.forward def dist(*logits): return Independent(Normal(*logits), 1) policy = PPOPolicy( actor, critic, optim, dist, discount_factor=args.gamma, max_grad_norm=args.max_grad_norm, eps_clip=args.eps_clip, vf_coef=args.vf_coef, ent_coef=args.ent_coef, reward_normalization=args.rew_norm, advantage_normalization=args.norm_adv, recompute_advantage=args.recompute_adv, dual_clip=args.dual_clip, value_clip=args.value_clip, gae_lambda=args.gae_lambda, action_space=env.action_space, ) # collector train_collector = Collector( policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, "ppo") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards): return mean_rewards >= args.reward_threshold def save_checkpoint_fn(epoch, env_step, gradient_step): # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html ckpt_path = os.path.join(log_path, "checkpoint.pth") # Example: saving by epoch num # ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") torch.save( { "model": policy.state_dict(), "optim": optim.state_dict(), }, ckpt_path) return ckpt_path if args.resume: # load from existing checkpoint print(f"Loading agent under {log_path}") ckpt_path = os.path.join(log_path, "checkpoint.pth") if os.path.exists(ckpt_path): checkpoint = torch.load(ckpt_path, map_location=args.device) policy.load_state_dict(checkpoint["model"]) optim.load_state_dict(checkpoint["optim"]) print("Successfully restore policy and optim.") else: print("Fail to restore policy and optim.") # trainer trainer = OnpolicyTrainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn, ) for epoch, epoch_stat, info in trainer: print(f"Epoch: {epoch}") print(epoch_stat) print(info) assert stop_fn(info["best_reward"]) if __name__ == "__main__": pprint.pprint(info) # Let's watch its performance! env = gym.make(args.task) policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) rews, lens = result["rews"], result["lens"] print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
def test_ppo(args=get_args()): args.cfg_path = f"maps/{args.task}.cfg" args.wad_path = f"maps/{args.task}.wad" args.res = (args.skip_num, 84, 84) env = Env(args.cfg_path, args.frames_stack, args.res) args.state_shape = args.res args.action_shape = env.action_space.shape or env.action_space.n # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # make environments train_envs = ShmemVectorEnv([ lambda: Env(args.cfg_path, args.frames_stack, args.res) for _ in range(args.training_num) ]) test_envs = ShmemVectorEnv([ lambda: Env(args.cfg_path, args.frames_stack, args.res, args.save_lmp) for _ in range(min(os.cpu_count() - 1, args.test_num)) ]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # define model net = DQN(*args.state_shape, args.action_shape, device=args.device, features_only=True, output_dim=args.hidden_size) actor = Actor(net, args.action_shape, device=args.device, softmax_output=False) critic = Critic(net, device=args.device) optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr) lr_scheduler = None if args.lr_decay: # decay learning rate to 0 linearly max_update_num = np.ceil( args.step_per_epoch / args.step_per_collect) * args.epoch lr_scheduler = LambdaLR( optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) # define policy def dist(p): return torch.distributions.Categorical(logits=p) policy = PPOPolicy(actor, critic, optim, dist, discount_factor=args.gamma, gae_lambda=args.gae_lambda, max_grad_norm=args.max_grad_norm, vf_coef=args.vf_coef, ent_coef=args.ent_coef, reward_normalization=args.rew_norm, action_scaling=False, lr_scheduler=lr_scheduler, action_space=env.action_space, eps_clip=args.eps_clip, value_clip=args.value_clip, dual_clip=args.dual_clip, advantage_normalization=args.norm_adv, recompute_advantage=args.recompute_adv).to(args.device) if args.icm_lr_scale > 0: feature_net = DQN(*args.state_shape, args.action_shape, device=args.device, features_only=True, output_dim=args.hidden_size) action_dim = np.prod(args.action_shape) feature_dim = feature_net.output_dim icm_net = IntrinsicCuriosityModule(feature_net.net, feature_dim, action_dim, device=args.device) icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) policy = ICMPolicy(policy, icm_net, icm_optim, args.icm_lr_scale, args.icm_reward_scale, args.icm_forward_loss_weight).to(args.device) # load a previous policy if args.resume_path: policy.load_state_dict( torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True, save_only_last_obs=True, stack_num=args.frames_stack) # collector train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # log log_name = 'ppo_icm' if args.icm_lr_scale > 0 else 'ppo' log_path = os.path.join(args.logdir, args.task, log_name) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) logger = TensorboardLogger(writer) def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): if env.spec.reward_threshold: return mean_rewards >= env.spec.reward_threshold elif 'Pong' in args.task: return mean_rewards >= 20 else: return False # watch agent's performance def watch(): print("Setup test envs ...") policy.eval() test_envs.seed(args.seed) if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") buffer = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs), ignore_obs_next=True, save_only_last_obs=True, stack_num=args.frames_stack) collector = Collector(policy, test_envs, buffer, exploration_noise=True) result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) rew = result["rews"].mean() lens = result["lens"].mean() * args.skip_num print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') print(f'Mean length (over {result["n/ep"]} episodes): {lens}') if args.watch: watch() exit(0) # test train_collector and start filling replay buffer train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = onpolicy_trainer(policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, step_per_collect=args.step_per_collect, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, test_in_train=False) pprint.pprint(result) watch()
def test_discrete_crr(args=get_args()): # envs env, _, test_envs = make_atari_env( args.task, args.seed, 1, args.test_num, scale=args.scale_obs, frame_stack=args.frames_stack, ) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) # model feature_net = DQN( *args.state_shape, args.action_shape, device=args.device, features_only=True ).to(args.device) actor = Actor( feature_net, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, softmax_output=False, ).to(args.device) critic = Critic( feature_net, hidden_sizes=args.hidden_sizes, last_size=np.prod(args.action_shape), device=args.device, ).to(args.device) actor_critic = ActorCritic(actor, critic) optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) # define policy policy = DiscreteCRRPolicy( actor, critic, optim, args.gamma, policy_improvement_mode=args.policy_improvement_mode, ratio_upper_bound=args.ratio_upper_bound, beta=args.beta, min_q_weight=args.min_q_weight, target_update_freq=args.target_update_freq, ).to(args.device) # load a previous policy if args.resume_path: policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # buffer assert os.path.exists(args.load_buffer_name), \ "Please run atari_qrdqn.py first to get expert's data buffer." if args.load_buffer_name.endswith(".pkl"): buffer = pickle.load(open(args.load_buffer_name, "rb")) elif args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) else: print(f"Unknown buffer format: {args.load_buffer_name}") exit(0) # collector test_collector = Collector(policy, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") args.algo_name = "crr" log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) log_path = os.path.join(args.logdir, log_name) # logger if args.logger == "wandb": logger = WandbLogger( save_interval=1, name=log_name.replace(os.path.sep, "__"), run_id=args.resume_id, config=args, project=args.wandb_project, ) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) if args.logger == "tensorboard": logger = TensorboardLogger(writer) else: # wandb logger.load(writer) def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards): return False # watch agent's performance def watch(): print("Setup test envs ...") policy.eval() test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) pprint.pprint(result) rew = result["rews"].mean() print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') if args.watch: watch() exit(0) result = offline_trainer( policy, buffer, test_collector, args.epoch, args.update_per_epoch, args.test_num, args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, ) pprint.pprint(result) watch()
def test_discrete_bcq(args=get_args()): # envs env = gym.make(args.task) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: default_reward_threshold = {"CartPole-v0": 190} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold) test_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) # model net = Net(args.state_shape, args.hidden_sizes[0], device=args.device) policy_net = Actor(net, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device).to(args.device) imitation_net = Actor(net, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device).to(args.device) actor_critic = ActorCritic(policy_net, imitation_net) optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) policy = DiscreteBCQPolicy( policy_net, imitation_net, optim, args.gamma, args.n_step, args.target_update_freq, args.eps_test, args.unlikely_action_threshold, args.imitation_logits_penalty, ) # buffer if os.path.exists(args.load_buffer_name) and os.path.isfile( args.load_buffer_name): if args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) else: buffer = pickle.load(open(args.load_buffer_name, "rb")) else: buffer = gather_data() # collector test_collector = Collector(policy, test_envs, exploration_noise=True) log_path = os.path.join(args.logdir, args.task, 'discrete_bcq') writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): return mean_rewards >= args.reward_threshold def save_checkpoint_fn(epoch, env_step, gradient_step): # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html torch.save( { 'model': policy.state_dict(), 'optim': optim.state_dict(), }, os.path.join(log_path, 'checkpoint.pth')) if args.resume: # load from existing checkpoint print(f"Loading agent under {log_path}") ckpt_path = os.path.join(log_path, 'checkpoint.pth') if os.path.exists(ckpt_path): checkpoint = torch.load(ckpt_path, map_location=args.device) policy.load_state_dict(checkpoint['model']) optim.load_state_dict(checkpoint['optim']) print("Successfully restore policy and optim.") else: print("Fail to restore policy and optim.") result = offline_trainer(policy, buffer, test_collector, args.epoch, args.update_per_epoch, args.test_num, args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) policy.eval() policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) rews, lens = result["rews"], result["lens"] print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
class A2CPolicy(PGPolicy): """Implementation of Synchronous Advantage Actor-Critic. arXiv:1602.01783. :param torch.nn.Module actor: the actor network following the rules in :class:`~tianshou.policy.BasePolicy`. (s -> logits) :param torch.nn.Module critic: the critic network. (s -> V(s)) :param torch.optim.Optimizer optim: the optimizer for actor and critic network. :param dist_fn: distribution class for computing the action. :type dist_fn: Type[torch.distributions.Distribution] :param float discount_factor: in [0, 1]. Default to 0.99. :param float vf_coef: weight for value loss. Default to 0.5. :param float ent_coef: weight for entropy loss. Default to 0.01. :param float max_grad_norm: clipping gradients in back propagation. Default to None. :param float gae_lambda: in [0, 1], param for Generalized Advantage Estimation. Default to 0.95. :param bool reward_normalization: normalize estimated values to have std close to 1. Default to False. :param int max_batchsize: the maximum size of the batch when computing GAE, depends on the size of available memory and the memory cost of the model; should be as large as possible within the memory constraint. Default to 256. :param bool action_scaling: whether to map actions from range [-1, 1] to range [action_spaces.low, action_spaces.high]. Default to True. :param str action_bound_method: method to bound action to range [-1, 1], can be either "clip" (for simply clipping the action), "tanh" (for applying tanh squashing) for now, or empty string for no bounding. Default to "clip". :param Optional[gym.Space] action_space: env's action space, mandatory if you want to use option "action_scaling" or "action_bound_method". Default to None. :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in optimizer in each policy.update(). Default to None (no lr_scheduler). :param bool deterministic_eval: whether to use deterministic action instead of stochastic action sampled by the policy. Default to False. .. seealso:: Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. """ def __init__(self, actor: torch.nn.Module, critic: torch.nn.Module, optim: torch.optim.Optimizer, dist_fn: Type[torch.distributions.Distribution], vf_coef: float = 0.5, ent_coef: float = 0.01, max_grad_norm: Optional[float] = None, gae_lambda: float = 0.95, max_batchsize: int = 256, **kwargs: Any) -> None: super().__init__(actor, optim, dist_fn, **kwargs) self.critic = critic assert 0.0 <= gae_lambda <= 1.0, "GAE lambda should be in [0, 1]." self._lambda = gae_lambda self._weight_vf = vf_coef self._weight_ent = ent_coef self._grad_norm = max_grad_norm self._batch = max_batchsize self._actor_critic = ActorCritic(self.actor, self.critic) def process_fn(self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray) -> Batch: batch = self._compute_returns(batch, buffer, indices) batch.act = to_torch_as(batch.act, batch.v_s) return batch def _compute_returns(self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray) -> Batch: v_s, v_s_ = [], [] with torch.no_grad(): for minibatch in batch.split(self._batch, shuffle=False, merge_last=True): v_s.append(self.critic(minibatch.obs)) v_s_.append(self.critic(minibatch.obs_next)) batch.v_s = torch.cat(v_s, dim=0).flatten() # old value v_s = batch.v_s.cpu().numpy() v_s_ = torch.cat(v_s_, dim=0).flatten().cpu().numpy() # when normalizing values, we do not minus self.ret_rms.mean to be numerically # consistent with OPENAI baselines' value normalization pipeline. Emperical # study also shows that "minus mean" will harm performances a tiny little bit # due to unknown reasons (on Mujoco envs, not confident, though). if self._rew_norm: # unnormalize v_s & v_s_ v_s = v_s * np.sqrt(self.ret_rms.var + self._eps) v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) unnormalized_returns, advantages = self.compute_episodic_return( batch, buffer, indices, v_s_, v_s, gamma=self._gamma, gae_lambda=self._lambda) if self._rew_norm: batch.returns = unnormalized_returns / \ np.sqrt(self.ret_rms.var + self._eps) self.ret_rms.update(unnormalized_returns) else: batch.returns = unnormalized_returns batch.returns = to_torch_as(batch.returns, batch.v_s) batch.adv = to_torch_as(advantages, batch.v_s) return batch def learn( # type: ignore self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any) -> Dict[str, List[float]]: losses, actor_losses, vf_losses, ent_losses = [], [], [], [] for _ in range(repeat): for minibatch in batch.split(batch_size, merge_last=True): # calculate loss for actor dist = self(minibatch).dist log_prob = dist.log_prob(minibatch.act) log_prob = log_prob.reshape(len(minibatch.adv), -1).transpose(0, 1) actor_loss = -(log_prob * minibatch.adv).mean() # calculate loss for critic value = self.critic(minibatch.obs).flatten() vf_loss = F.mse_loss(minibatch.returns, value) # calculate regularization and overall loss ent_loss = dist.entropy().mean() loss = actor_loss + self._weight_vf * vf_loss \ - self._weight_ent * ent_loss self.optim.zero_grad() loss.backward() if self._grad_norm: # clip large gradient nn.utils.clip_grad_norm_(self._actor_critic.parameters(), max_norm=self._grad_norm) self.optim.step() actor_losses.append(actor_loss.item()) vf_losses.append(vf_loss.item()) ent_losses.append(ent_loss.item()) losses.append(loss.item()) # update learning rate if lr_scheduler is given if self.lr_scheduler is not None: self.lr_scheduler.step() return { "loss": losses, "loss/actor": actor_losses, "loss/vf": vf_losses, "loss/ent": ent_losses, }
def test_discrete_crr(args=get_args()): # envs env = gym.make(args.task) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: default_reward_threshold = {"CartPole-v0": 180} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) test_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.test_num)] ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) # model net = Net(args.state_shape, args.hidden_sizes[0], device=args.device) actor = Actor( net, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, softmax_output=False ) critic = Critic( net, hidden_sizes=args.hidden_sizes, last_size=np.prod(args.action_shape), device=args.device ) actor_critic = ActorCritic(actor, critic) optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) policy = DiscreteCRRPolicy( actor, critic, optim, args.gamma, target_update_freq=args.target_update_freq, ).to(args.device) # buffer if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) else: buffer = pickle.load(open(args.load_buffer_name, "rb")) else: buffer = gather_data() # collector test_collector = Collector(policy, test_envs, exploration_noise=True) log_path = os.path.join(args.logdir, args.task, 'discrete_crr') writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): return mean_rewards >= args.reward_threshold result = offline_trainer( policy, buffer, test_collector, args.epoch, args.update_per_epoch, args.test_num, args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger ) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) rews, lens = result["rews"], result["lens"] print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
def test_gail(args=get_args()): if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) else: buffer = pickle.load(open(args.load_buffer_name, "rb")) else: buffer = gather_data() env = gym.make(args.task) if args.reward_threshold is None: default_reward_threshold = {"Pendulum-v0": -1100, "Pendulum-v1": -1100} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n args.max_action = env.action_space.high[0] # you can also use tianshou.env.SubprocVectorEnv # train_envs = gym.make(args.task) train_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.training_num)] ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.test_num)] ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = ActorProb( net, args.action_shape, max_action=args.max_action, device=args.device ).to(args.device) critic = Critic( Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device), device=args.device ).to(args.device) actor_critic = ActorCritic(actor, critic) # orthogonal initialization for m in actor_critic.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight) torch.nn.init.zeros_(m.bias) optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) # discriminator disc_net = Critic( Net( args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, activation=torch.nn.Tanh, device=args.device, concat=True, ), device=args.device ).to(args.device) for m in disc_net.modules(): if isinstance(m, torch.nn.Linear): # orthogonal initialization torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) torch.nn.init.zeros_(m.bias) disc_optim = torch.optim.Adam(disc_net.parameters(), lr=args.disc_lr) # replace DiagGuassian with Independent(Normal) which is equivalent # pass *logits to be consistent with policy.forward def dist(*logits): return Independent(Normal(*logits), 1) policy = GAILPolicy( actor, critic, optim, dist, buffer, disc_net, disc_optim, disc_update_num=args.disc_update_num, discount_factor=args.gamma, max_grad_norm=args.max_grad_norm, eps_clip=args.eps_clip, vf_coef=args.vf_coef, ent_coef=args.ent_coef, reward_normalization=args.rew_norm, advantage_normalization=args.norm_adv, recompute_advantage=args.recompute_adv, dual_clip=args.dual_clip, value_clip=args.value_clip, gae_lambda=args.gae_lambda, action_space=env.action_space, ) # collector train_collector = Collector( policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)) ) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'gail') writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): return mean_rewards >= args.reward_threshold def save_checkpoint_fn(epoch, env_step, gradient_step): # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html torch.save( { 'model': policy.state_dict(), 'optim': optim.state_dict(), }, os.path.join(log_path, 'checkpoint.pth') ) if args.resume: # load from existing checkpoint print(f"Loading agent under {log_path}") ckpt_path = os.path.join(log_path, 'checkpoint.pth') if os.path.exists(ckpt_path): checkpoint = torch.load(ckpt_path, map_location=args.device) policy.load_state_dict(checkpoint['model']) optim.load_state_dict(checkpoint['optim']) print("Successfully restore policy and optim.") else: print("Fail to restore policy and optim.") # trainer result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn, ) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) rews, lens = result["rews"], result["lens"] print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
def test_gail(args=get_args()): env = gym.make(args.task) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n args.max_action = env.action_space.high[0] print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) # train_envs = gym.make(args.task) train_envs = SubprocVectorEnv([ lambda: NoRewardEnv(gym.make(args.task)) for _ in range(args.training_num) ], norm_obs=True) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv( [lambda: gym.make(args.task) for _ in range(args.test_num)], norm_obs=True, obs_rms=train_envs.obs_rms, update_obs_rms=False) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, device=args.device) actor = ActorProb(net_a, args.action_shape, max_action=args.max_action, unbounded=True, device=args.device).to(args.device) net_c = Net(args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, device=args.device) critic = Critic(net_c, device=args.device).to(args.device) torch.nn.init.constant_(actor.sigma_param, -0.5) for m in list(actor.modules()) + list(critic.modules()): if isinstance(m, torch.nn.Linear): # orthogonal initialization torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) torch.nn.init.zeros_(m.bias) # do last policy layer scaling, this will make initial actions have (close to) # 0 mean and std, and will help boost performances, # see https://arxiv.org/abs/2006.05990, Fig.24 for details for m in actor.mu.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.zeros_(m.bias) m.weight.data.copy_(0.01 * m.weight.data) optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr) # discriminator net_d = Net(args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, device=args.device, concat=True) disc_net = Critic(net_d, device=args.device).to(args.device) for m in disc_net.modules(): if isinstance(m, torch.nn.Linear): # orthogonal initialization torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) torch.nn.init.zeros_(m.bias) disc_optim = torch.optim.Adam(disc_net.parameters(), lr=args.disc_lr) lr_scheduler = None if args.lr_decay: # decay learning rate to 0 linearly max_update_num = np.ceil( args.step_per_epoch / args.step_per_collect) * args.epoch lr_scheduler = LambdaLR( optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) def dist(*logits): return Independent(Normal(*logits), 1) # expert replay buffer dataset = d4rl.qlearning_dataset(gym.make(args.expert_data_task)) dataset_size = dataset['rewards'].size print("dataset_size", dataset_size) expert_buffer = ReplayBuffer(dataset_size) for i in range(dataset_size): expert_buffer.add( Batch( obs=dataset['observations'][i], act=dataset['actions'][i], rew=dataset['rewards'][i], done=dataset['terminals'][i], obs_next=dataset['next_observations'][i], )) print("dataset loaded") policy = GAILPolicy(actor, critic, optim, dist, expert_buffer, disc_net, disc_optim, disc_update_num=args.disc_update_num, discount_factor=args.gamma, gae_lambda=args.gae_lambda, max_grad_norm=args.max_grad_norm, vf_coef=args.vf_coef, ent_coef=args.ent_coef, reward_normalization=args.rew_norm, action_scaling=True, action_bound_method=args.bound_action_method, lr_scheduler=lr_scheduler, action_space=env.action_space, eps_clip=args.eps_clip, value_clip=args.value_clip, dual_clip=args.dual_clip, advantage_normalization=args.norm_adv, recompute_advantage=args.recompute_adv) # load a previous policy if args.resume_path: policy.load_state_dict( torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector if args.training_num > 1: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs) # log t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_gail' log_path = os.path.join(args.logdir, args.task, 'gail', log_file) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) logger = TensorboardLogger(writer, update_interval=100, train_interval=100) def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) if not args.watch: # trainer result = onpolicy_trainer(policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, step_per_collect=args.step_per_collect, save_best_fn=save_best_fn, logger=logger, test_in_train=False) pprint.pprint(result) # Let's watch its performance! policy.eval() test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) print( f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}' )