Exemplo n.º 1
0
 def _setup_model(self) -> None:
     super(DQN, self)._setup_model()
     self._create_aliases()
     self.exploration_schedule = get_linear_fn(
         self.exploration_initial_eps,
         self.exploration_final_eps,
         self.exploration_fraction,
     )
Exemplo n.º 2
0
 def _setup_model(self) -> None:
     super(OffPAC, self)._setup_model()
     self._create_aliases()
     self.trajectory_buffer = TrajectoryBuffer(
         self.buffer_size,
         self.observation_space,
         self.action_space,
         self.device
     )
     self.replay_buffer = self.trajectory_buffer
     self.exploration_schedule = get_linear_fn(
         self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction
     )
Exemplo n.º 3
0
    def _setup_model(self) -> None:
        super(DQN, self)._setup_model()
        self._create_aliases()
        self.exploration_schedule = get_linear_fn(
            self.exploration_initial_eps,
            self.exploration_final_eps,
            self.exploration_fraction,
        )
        # Account for multiple environments
        # each call to step() corresponds to n_envs transitions
        if self.n_envs > 1:
            if self.n_envs > self.target_update_interval:
                warnings.warn(
                    "The number of environments used is greater than the target network "
                    f"update interval ({self.n_envs} > {self.target_update_interval}), "
                    "therefore the target network will be updated after each call to env.step() "
                    f"which corresponds to {self.n_envs} steps.")

            self.target_update_interval = max(
                self.target_update_interval // self.n_envs, 1)
Exemplo n.º 4
0
                       n_envs=args.n_envs,
                       vec_env_cls=SubprocVecEnv,
                       vec_env_kwargs={"start_method": "fork"})

    # Custom Feature Extractor backbone
    policy_kwargs = {
        "features_extractor_class": CustomFeatureExtractor,
        "features_extractor_kwargs": {
            "features_dim": args.features_dim,
            "model_arch": args.model_arch
        },
        "normalize_images": False
    }

    # define model
    clip_schedule = get_linear_fn(args.clip_start_val, args.clip_end_val,
                                  args.clip_progress_ratio)
    if args.load is not None:
        model = PPO.load(args.load, env, clip_range=clip_schedule)
    else:
        model = PPO("CnnPolicy",
                    env,
                    policy_kwargs=policy_kwargs,
                    verbose=1,
                    n_steps=args.num_rollout_steps,
                    learning_rate=args.lr,
                    gamma=args.gamma,
                    tensorboard_log=args.tb_log,
                    n_epochs=args.n_epochs,
                    clip_range=clip_schedule,
                    batch_size=args.batch_size,
                    seed=args.seed,
Exemplo n.º 5
0
    def create_model(
        self,
        seed,
        algo_name,
        env,
        tensorboard_log_dir,
        hyperparams,
        best_model_save_path=None,
        model_to_load=None,
        continue_learning=False,
        env_name="CartPole-v1",
        n_timesteps=-1,
        save_replay_buffer: bool = True,
    ):

        old_hyperparams = dict()

        # Create learning rate schedules for ppo2 and sac
        if algo_name in ["ppo2", "sac", "td3"]:
            for key in ["learning_rate", "cliprange", "cliprange_vf"]:
                if key not in hyperparams:
                    continue
                if isinstance(hyperparams[key], str):
                    self.logger.debug("Key {}, value {}".format(key, hyperparams[key]))
                    old_hyperparams[key] = hyperparams[key]
                    schedule, initial_value = hyperparams[key].split("_")
                    initial_value = float(initial_value)
                    hyperparams[key] = linear_schedule(initial_value)
                elif isinstance(hyperparams[key], (float, int)):
                    # Negative value: ignore (ex: for clipping)
                    if hyperparams[key] < 0:
                        continue
                    old_hyperparams[key] = float(hyperparams[key])
                    hyperparams[key] = constfn(float(hyperparams[key]))
                else:
                    raise ValueError("Invalid value for {}: {}".format(key, hyperparams[key]))

        if algo_name == "ppo2":

            if self.sb_version == "sb3":
                raise NotImplementedError("PPO still in sb2")

            if best_model_save_path and continue_learning:
                model = PPO2.load(
                    self.load_model(best_model_save_path, model_to_load),
                    env=env,
                    tensorboard_log=tensorboard_log_dir,
                    verbose=1,
                )
                key = "cliprange"
                cl_cliprange_value = 0.08  # new policy can be a bit different than the old one
                if key in old_hyperparams:
                    if isinstance(old_hyperparams[key], str):
                        self.logger.debug("Setting cliprange to lin_{}".format(cl_cliprange_value))
                        model.cliprange = linear_schedule(cl_cliprange_value)
                    elif isinstance(old_hyperparams[key], (float, int)):
                        self.logger.debug("Setting cliprange to value {}".format(cl_cliprange_value))
                        model.cliprange = constfn(cl_cliprange_value)
                else:
                    # default value is too high for continual learning (0.2)
                    self.logger.debug("Setting cliprange to value {}".format(cl_cliprange_value))
                    model.cliprange = cl_cliprange_value

                return model
            elif best_model_save_path:
                return PPO2.load(
                    self.load_model(best_model_save_path, model_to_load),
                    env=env,
                    tensorboard_log=tensorboard_log_dir,
                    verbose=1,
                    n_cpu_tf_sess=n_cpu_tf_sess,
                )
            return PPO2(env=env, verbose=1, tensorboard_log=tensorboard_log_dir, **hyperparams, n_cpu_tf_sess=n_cpu_tf_sess,)

        elif algo_name == "sac":
            if self.sb_version == "sb3":
                if best_model_save_path and continue_learning:
                    model = stable_baselines3.SAC.load(
                        self.load_model(best_model_save_path, model_to_load),
                        env=env,
                        seed=seed,
                        tensorboard_log=tensorboard_log_dir,
                        verbose=1,
                    )
                    model.load_replay_buffer(path=best_model_save_path + "/replay_buffer")
                    self.logger.debug("Model replay buffer size: {}".format(model.replay_buffer.size()))
                    self.logger.debug("Setting learning_starts to 0")
                    model.learning_starts = 0

                    value = get_value_given_key(best_model_save_path + "/progress.csv", "ent_coef")
                    if value:
                        ent_coef = float(value)
                        self.logger.debug("Restore model old ent_coef: {}".format("auto_" + str(ent_coef)))
                        model.ent_coef = "auto_" + str(ent_coef)
                        model.target_entropy = str(ent_coef)

                    return model
                elif best_model_save_path:
                    return stable_baselines3.SAC.load(
                        self.load_model(best_model_save_path, model_to_load),
                        env=env,
                        seed=seed,
                        tensorboard_log=tensorboard_log_dir,
                        verbose=1,
                        n_cpu_tf_sess=n_cpu_tf_sess,
                    )
                assert n_timesteps > 0, "n_timesteps > 0: {}".format(n_timesteps)
                return stable_baselines3.SAC(env=env, verbose=0, seed=seed, tensorboard_log=tensorboard_log_dir, **hyperparams)

            else:
                if best_model_save_path and continue_learning:
                    model = CustomSAC.load(
                        self.load_model(best_model_save_path, model_to_load),
                        env=env,
                        tensorboard_log=tensorboard_log_dir,
                        verbose=1,
                    )
                    self.logger.debug("Model replay buffer size: {}".format(len(model.replay_buffer)))
                    self.logger.debug("Setting learning_starts to 0")
                    model.learning_starts = 0
                    if not save_replay_buffer:
                        self.logger.debug("Setting save_replay_buffer to False")
                        model.save_replay_buffer = False

                    value = get_value_given_key(best_model_save_path + "/progress.csv", "ent_coef")
                    if value:
                        ent_coef = float(value)
                        self.logger.debug("Restore model old ent_coef: {}".format("auto_" + str(ent_coef)))
                        model.ent_coef = "auto_" + str(ent_coef)
                        model.target_entropy = str(ent_coef)

                    return model

                elif best_model_save_path:
                    # do not load replay buffer since we are in testing mode (no continue_learning)
                    return SAC.load(
                        self.load_model(best_model_save_path, model_to_load),
                        env=env,
                        tensorboard_log=tensorboard_log_dir,
                        verbose=1,
                        n_cpu_tf_sess=n_cpu_tf_sess,
                    )
                return CustomSAC(
                    total_timesteps=n_timesteps,
                    env=env,
                    verbose=1,
                    tensorboard_log=tensorboard_log_dir,
                    **hyperparams,
                    n_cpu_tf_sess=n_cpu_tf_sess,
                    save_replay_buffer=save_replay_buffer,
                )

        elif algo_name == "dqn":

            if self.sb_version == "sb3":

                if best_model_save_path:
                    if continue_learning:
                        model = stable_baselines3.DQN.load(
                            self.load_model(best_model_save_path, model_to_load),
                            env=env,
                            seed=seed,
                            tensorboard_log=tensorboard_log_dir,
                            verbose=0,
                        )
                        model.load_replay_buffer(path=best_model_save_path + "/replay_buffer")
                        model.learning_starts = 0
                        model.exploration_fraction = 0.0005
                        model.exploration_initial_eps = model.exploration_final_eps
                        model.exploration_schedule = get_linear_fn(
                            model.exploration_initial_eps, model.exploration_final_eps, model.exploration_fraction
                        )
                        self.logger.debug("Model replay buffer size: {}".format(model.replay_buffer.size()))
                        self.logger.debug("Setting learning_starts to {}".format(model.learning_starts))
                        self.logger.debug("Setting exploration_fraction to {}".format(model.exploration_fraction))
                        self.logger.debug("Setting exploration_initial_eps to {}".format(model.exploration_initial_eps))
                        return model
                    return stable_baselines3.DQN.load(
                        self.load_model(best_model_save_path, model_to_load),
                        env=env,
                        seed=seed,
                        tensorboard_log=tensorboard_log_dir,
                        verbose=1,
                    )
                return stable_baselines3.DQN(env=env, verbose=0, seed=seed, tensorboard_log=tensorboard_log_dir, **hyperparams)
            else:
                if best_model_save_path:
                    if continue_learning:
                        model = CustomDQN.load(
                            self.load_model(best_model_save_path, model_to_load),
                            env=env,
                            tensorboard_log=tensorboard_log_dir,
                            verbose=1,
                        )
                        self.logger.debug("Model replay buffer size: {}".format(len(model.replay_buffer)))
                        self.logger.debug(
                            "Setting exploration initial eps to exploration final eps {}".format(model.exploration_final_eps)
                        )
                        self.logger.debug("Setting learning_starts to 0")
                        if not save_replay_buffer:
                            self.logger.debug("Setting save_replay_buffer to False")
                            model.save_replay_buffer = False
                        model.learning_starts = 0
                        model.exploration_fraction = 0.005
                        model.exploration_initial_eps = model.exploration_final_eps
                        return model
                    return DQN.load(
                        self.load_model(best_model_save_path, model_to_load),
                        env=env,
                        tensorboard_log=tensorboard_log_dir,
                        verbose=1,
                        n_cpu_tf_sess=n_cpu_tf_sess,
                    )
                return CustomDQN(
                    env=env,
                    save_replay_buffer=save_replay_buffer,
                    verbose=1,
                    tensorboard_log=tensorboard_log_dir,
                    **hyperparams,
                    n_cpu_tf_sess=n_cpu_tf_sess,
                )
        raise NotImplementedError("algo_name {} not supported yet".format(algo_name))