Exemple #1
0
    def test_tic_tac_toe(self):
        raw_env = rl_environment.Environment('tic_tac_toe')
        env = open_spiel_wrapper.OpenSpielWrapper(raw_env)

        # Test converted observation spec.
        observation_spec = env.observation_spec()
        self.assertEqual(type(observation_spec), open_spiel_wrapper.OLT)
        self.assertEqual(type(observation_spec.observation), specs.Array)
        self.assertEqual(type(observation_spec.legal_actions), specs.Array)
        self.assertEqual(type(observation_spec.terminal), specs.Array)

        # Test converted action spec.
        action_spec: specs.DiscreteArray = env.action_spec()
        self.assertEqual(type(action_spec), specs.DiscreteArray)
        self.assertEqual(action_spec.shape, ())
        self.assertEqual(action_spec.minimum, 0)
        self.assertEqual(action_spec.maximum, 8)
        self.assertEqual(action_spec.num_values, 9)
        self.assertEqual(action_spec.dtype, np.dtype('int32'))

        # Test step.
        timestep = env.reset()
        self.assertTrue(timestep.first())
        _ = env.step([0])
        env.close()
    def test_loop_run(self):
        raw_env = rl_environment.Environment('tic_tac_toe')
        env = open_spiel_wrapper.OpenSpielWrapper(raw_env)
        env = wrappers.SinglePrecisionWrapper(env)
        environment_spec = acme.make_environment_spec(env)

        actors = []
        for _ in range(env.num_players):
            actors.append(RandomActor(environment_spec))

        loop = open_spiel_environment_loop.OpenSpielEnvironmentLoop(
            env, actors)
        result = loop.run_episode()
        self.assertIn('episode_length', result)
        self.assertIn('episode_return', result)
        self.assertIn('steps_per_second', result)

        loop.run(num_episodes=10)
        loop.run(num_steps=100)
Exemple #3
0
def main(_):
    # Create an environment and grab the spec.
    env_configs = {'players': FLAGS.num_players} if FLAGS.num_players else {}
    raw_environment = rl_environment.Environment(FLAGS.game, **env_configs)

    environment = open_spiel_wrapper.OpenSpielWrapper(raw_environment)
    environment = wrappers.SinglePrecisionWrapper(
        environment)  # type: open_spiel_wrapper.OpenSpielWrapper
    environment_spec = acme.make_environment_spec(environment)

    # Build the networks.
    networks = []
    policy_networks = []
    for _ in range(environment.num_players):
        network = legal_actions.MaskedSequential([
            snt.Flatten(),
            snt.nets.MLP([50, 50, environment_spec.actions.num_values])
        ])
        policy_network = snt.Sequential([
            network,
            legal_actions.EpsilonGreedy(epsilon=0.1, threshold=-1e8)
        ])
        networks.append(network)
        policy_networks.append(policy_network)

    # Construct the agents.
    agents = []

    for network, policy_network in zip(networks, policy_networks):
        agents.append(
            dqn.DQN(environment_spec=environment_spec,
                    network=network,
                    policy_network=policy_network))

    # Run the environment loop.
    loop = open_spiel_environment_loop.OpenSpielEnvironmentLoop(
        environment, agents)
    loop.run(num_episodes=100000)