示例#1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--strategy", type=str, default="SARSA")
    parser.add_argument("--stop", type=int, default=1)

    args = parser.parse_args()

    env_config = {
        "slate_size": 2,
        "seed": 0,
        "convert_to_discrete_action_space": False,
    }

    config = {
        "num_workers": 0,
        "slateq_strategy": args.strategy,
        "env_config": env_config
    }

    ray.init()

    trainer = slateq.SlateQTrainer(config=config, env=recsim_env_name)

    for i in range(args.stop):
        trainer.train()

    ray.shutdown()
示例#2
0
文件: test_slateq.py 项目: vakker/ray
    def test_slateq_compilation(self):
        """Test whether an A2CTrainer can be built with both frameworks."""
        config = {
            "env": LongTermSatisfactionRecSimEnv,
        }

        num_iterations = 1

        # Test only against torch (no other frameworks supported so far).
        for _ in framework_iterator(config, frameworks="torch"):
            trainer = slateq.SlateQTrainer(config=config)
            for i in range(num_iterations):
                results = trainer.train()
                check_train_results(results)
                print(results)
            check_compute_single_action(trainer)
            trainer.stop()
示例#3
0
    def test_slateq_compilation(self):
        """Test whether a SlateQTrainer can be built with both frameworks."""
        config = {
            "env": InterestEvolutionRecSimEnv,
            "learning_starts": 1000,
        }

        num_iterations = 1

        for _ in framework_iterator(config, with_eager_tracing=True):
            trainer = slateq.SlateQTrainer(config=config)
            for i in range(num_iterations):
                results = trainer.train()
                check_train_results(results)
                print(results)
            check_compute_single_action(trainer)
            trainer.stop()
示例#4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--env-slate-size", type=int, default=2)
    parser.add_argument("--env-seed", type=int, default=0)
    parser.add_argument("--strategy", type=str, default="SARSA")
    parser.add_argument("--stop", type=int, default=1)

    args = parser.parse_args()

    assert args.strategy in ALL_SLATEQ_STRATEGIES, "Invalid  SlateQ Strategy {}".format(
        args.strategy)

    env_config = {
        "slate_size": args.env_slate_size,
        "seed": args.env_seed,
        "convert_to_discrete_action_space": False,
    }

    # config = slateq.DEFAULT_CONFIG.copy()
    # config["num_gpus"] = 0
    config = {}
    config["num_workers"] = 5
    config["slateq_strategy"] = args.strategy
    config["env_config"] = env_config

    ray.init()

    trainer = slateq.SlateQTrainer(config=config, env=recsim_env_name)

    result = trainer.train()
    best_checkpoint = trainer.save()
    best_reward = result['episode_reward_mean']
    print("Mean Reward {}:{}".format(1, result['episode_reward_mean']))

    for i in range(1, args.stop):
        result = trainer.train()
        print("Mean Reward {}:{}".format(i + 1, result['episode_reward_mean']))
        best_reward = max(best_reward, result['episode_reward_mean'])
        if best_reward == result['episode_reward_mean']:
            best_checkpoint = trainer.save()

    print("BEST Mean Reward  :", best_reward)
    print("BEST Checkpoint at:", best_checkpoint)
    ray.shutdown()
示例#5
0
def main():
    args = parser.parse_args()
    ray.init()

    if args.agent not in ["DQN", "SlateQ"]:
        raise ValueError(args.agent)

    env_config = {
        "slate_size": args.env_slate_size,
        "seed": args.env_seed,
        "convert_to_discrete_action_space": args.agent == "DQN",
    }

    if args.use_tune:
        time_signature = datetime.now().strftime("%Y-%m-%d_%H_%M_%S")
        name = f"SlateQ/{args.agent}-seed{args.env_seed}-{time_signature}"
        if args.agent == "DQN":
            tune.run(
                "DQN",
                stop={"timesteps_total": 4000000},
                name=name,
                config={
                    "env": recsim_env_name,
                    "num_gpus": args.num_gpus,
                    "num_workers": args.num_workers,
                    "env_config": env_config,
                },
                num_samples=args.tune_num_samples,
                verbose=1,
            )
        else:
            tune.run(
                "SlateQ",
                stop={"timesteps_total": 4000000},
                name=name,
                config={
                    "env": recsim_env_name,
                    "num_gpus": args.num_gpus,
                    "num_workers": args.num_workers,
                    "slateq_strategy": tune.grid_search(ALL_SLATEQ_STRATEGIES),
                    "env_config": env_config,
                },
                num_samples=args.tune_num_samples,
                verbose=1,
            )
    else:
        # directly run using the trainer interface (good for debugging)
        if args.agent == "DQN":
            config = dqn.DEFAULT_CONFIG.copy()
            config["num_gpus"] = 0
            config["num_workers"] = 0
            config["env_config"] = env_config
            trainer = dqn.DQNTrainer(config=config, env=recsim_env_name)
        else:
            config = slateq.DEFAULT_CONFIG.copy()
            config["num_gpus"] = 0
            config["num_workers"] = 0
            config["slateq_strategy"] = args.strategy
            config["env_config"] = env_config
            trainer = slateq.SlateQTrainer(config=config, env=recsim_env_name)
        for i in range(10):
            result = trainer.train()
            print(pretty_print(result))
    ray.shutdown()
示例#6
0
def main():
    args = parser.parse_args()
    ray.init(num_cpus=args.num_cpus or None, local_mode=args.local_mode)

    env_config = {
        "num_candidates": args.env_num_candidates,
        "resample_documents": not args.env_dont_resample_documents,
        "slate_size": args.env_slate_size,
        "seed": args.env_seed,
        "convert_to_discrete_action_space": args.run == "DQN",
    }

    config = {
        "env": (InterestEvolutionRecSimEnv if args.env == "interest-evolution"
                else InterestExplorationRecSimEnv if args.env
                == "interest-exploration" else LongTermSatisfactionRecSimEnv),
        "framework":
        args.framework,
        "num_gpus":
        args.num_gpus,
        "num_workers":
        args.num_workers,
        "env_config":
        env_config,
        "learning_starts":
        args.learning_starts,
    }

    # Perform a test run on the env with a random agent to see, what
    # the random baseline reward is.
    if args.random_test_episodes:
        print(f"Running {args.random_test_episodes} episodes to get a random "
              "agent's baseline reward ...")
        env = config["env"](config=env_config)
        env.reset()
        num_episodes = 0
        episode_rewards = []
        episode_reward = 0.0
        while num_episodes < args.random_test_episodes:
            action = env.action_space.sample()
            _, r, d, _ = env.step(action)
            episode_reward += r
            if d:
                num_episodes += 1
                episode_rewards.append(episode_reward)
                episode_reward = 0.0
                env.reset()
        print(f"Ran {args.random_test_episodes} episodes with a random agent "
              "reaching a mean episode return of "
              f"{np.mean(episode_rewards)}+/-{sem(episode_rewards)}.")

    if args.use_tune:
        stop = {
            "training_iteration": args.stop_iters,
            "timesteps_total": args.stop_timesteps,
            "episode_reward_mean": args.stop_reward,
        }

        if args.run == "SlateQ":
            config.update({
                "slateq_strategy": args.slateq_strategy,
            })
        results = tune.run(
            args.run,
            stop=stop,
            config=config,
            num_samples=args.tune_num_samples,
            verbose=2,
        )

        if args.as_test:
            check_learning_achieved(results, args.stop_reward)

    else:
        # Directly run using the trainer interface (good for debugging).
        if args.run == "DQN":
            trainer = dqn.DQNTrainer(config=config)
        else:
            config.update({
                "slateq_strategy": args.slateq_strategy,
            })
            trainer = slateq.SlateQTrainer(config=config)
        for i in range(10):
            result = trainer.train()
            print(pretty_print(result))
    ray.shutdown()
示例#7
0
def main():
    args = parser.parse_args()
    ray.init(num_cpus=args.num_cpus or None, local_mode=args.local_mode)

    env_config = {
        "num_candidates": args.env_num_candidates,
        "resample_documents": not args.env_dont_resample_documents,
        "slate_size": args.env_slate_size,
        "seed": args.env_seed,
        "convert_to_discrete_action_space": args.run == "DQN",
    }

    config = {
        "env": (InterestEvolutionRecSimEnv if args.env == "interest-evolution"
                else InterestExplorationRecSimEnv if args.env
                == "interest-exploration" else LongTermSatisfactionRecSimEnv),
        "hiddens": [
            1024,
            1024,
        ],
        "num_gpus":
        args.num_gpus,
        "num_workers":
        args.num_workers,
        "env_config":
        env_config,
        "lr_choice_model":
        0.003,
        "lr_q_model":
        0.003,
        "rollout_fragment_length":
        4,
        "exploration_config": {
            "epsilon_timesteps": 50000,
            "final_epsilon": 0.02,
        },
        "target_network_update_freq":
        1,
        "tau":
        5e-3,
        "evaluation_interval":
        1,
        "evaluation_num_workers":
        4,
        "evaluation_duration":
        200,
        "evaluation_duration_unit":
        "episodes",
        "evaluation_parallel_to_training":
        True,
    }

    # Perform a test run on the env with a random agent to see, what
    # the random baseline reward is.
    if args.random_test_episodes:
        print(f"Running {args.random_test_episodes} episodes to get a random "
              "agent's baseline reward ...")
        env = config["env"](config=env_config)
        env.reset()
        num_episodes = 0
        episode_rewards = []
        episode_reward = 0.0
        while num_episodes < args.random_test_episodes:
            action = env.action_space.sample()
            _, r, d, _ = env.step(action)
            episode_reward += r
            if d:
                num_episodes += 1
                episode_rewards.append(episode_reward)
                episode_reward = 0.0
                env.reset()
        print(f"Ran {args.random_test_episodes} episodes with a random agent "
              "reaching a mean episode return of "
              f"{np.mean(episode_rewards)}+/-{sem(episode_rewards)}.")

    if args.use_tune:
        stop = {
            "training_iteration": args.stop_iters,
            "timesteps_total": args.stop_timesteps,
            "episode_reward_mean": args.stop_reward,
        }

        time_signature = datetime.now().strftime("%Y-%m-%d_%H_%M_%S")
        name = f"SlateQ/{args.run}-seed{args.env_seed}-{time_signature}"
        if args.run == "SlateQ":
            config.update({
                "slateq_strategy": args.slateq_strategy,
            })
        results = tune.run(
            args.run,
            stop=stop,
            name=name,
            config=config,
            num_samples=args.tune_num_samples,
            verbose=2,
        )

        if args.as_test:
            check_learning_achieved(results, args.stop_reward)

    else:
        # Directly run using the trainer interface (good for debugging).
        if args.run == "DQN":
            trainer = dqn.DQNTrainer(config=config)
        else:
            config.update({
                "slateq_strategy": args.slateq_strategy,
            })
            trainer = slateq.SlateQTrainer(config=config)
        for i in range(10):
            result = trainer.train()
            print(pretty_print(result))
    ray.shutdown()
示例#8
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--agent",
        type=str,
        default="SlateQ",
        help=("Select agent policy. Choose from: DQN and SlateQ. "
              "Default value: SlateQ."),
    )
    parser.add_argument(
        "--strategy",
        type=str,
        default="QL",
        help=("Strategy for the SlateQ agent. Choose from: " +
              ", ".join(ALL_SLATEQ_STRATEGIES) + ". "
              "Default value: QL. Ignored when using Tune."),
    )
    parser.add_argument(
        "--use-tune",
        action="store_true",
        help=("Run with Tune so that the results are logged into Tensorboard. "
              "For debugging, it's easier to run without Ray Tune."),
    )
    parser.add_argument("--tune-num-samples", type=int, default=10)
    parser.add_argument("--env-slate-size", type=int, default=2)
    parser.add_argument("--env-seed", type=int, default=0)
    parser.add_argument("--num-gpus",
                        type=float,
                        default=0.,
                        help="Only used if running with Tune.")
    parser.add_argument("--num-workers",
                        type=int,
                        default=0,
                        help="Only used if running with Tune.")
    args = parser.parse_args()

    if args.agent not in ["DQN", "SlateQ"]:
        raise ValueError(args.agent)

    env_config = {
        "slate_size": args.env_slate_size,
        "seed": args.env_seed,
        "convert_to_discrete_action_space": args.agent == "DQN",
    }

    ray.init()
    if args.use_tune:
        time_signature = datetime.now().strftime("%Y-%m-%d_%H_%M_%S")
        name = f"SlateQ/{args.agent}-seed{args.env_seed}-{time_signature}"
        if args.agent == "DQN":
            tune.run("DQN",
                     stop={"timesteps_total": 4000000},
                     name=name,
                     config={
                         "env": recsim_env_name,
                         "num_gpus": args.num_gpus,
                         "num_workers": args.num_workers,
                         "env_config": env_config,
                     },
                     num_samples=args.tune_num_samples,
                     verbose=1)
        else:
            tune.run("SlateQ",
                     stop={"timesteps_total": 4000000},
                     name=name,
                     config={
                         "env": recsim_env_name,
                         "num_gpus": args.num_gpus,
                         "num_workers": args.num_workers,
                         "slateq_strategy":
                         tune.grid_search(ALL_SLATEQ_STRATEGIES),
                         "env_config": env_config,
                     },
                     num_samples=args.tune_num_samples,
                     verbose=1)
    else:
        # directly run using the trainer interface (good for debugging)
        if args.agent == "DQN":
            config = dqn.DEFAULT_CONFIG.copy()
            config["num_gpus"] = 0
            config["num_workers"] = 0
            config["env_config"] = env_config
            trainer = dqn.DQNTrainer(config=config, env=recsim_env_name)
        else:
            config = slateq.DEFAULT_CONFIG.copy()
            config["num_gpus"] = 0
            config["num_workers"] = 0
            config["slateq_strategy"] = args.strategy
            config["env_config"] = env_config
            trainer = slateq.SlateQTrainer(config=config, env=recsim_env_name)
        for i in range(10):
            result = trainer.train()
            print(pretty_print(result))
    ray.shutdown()