def test_tic_tac_toe(self): raw_env = rl_environment.Environment('tic_tac_toe') env = open_spiel_wrapper.OpenSpielWrapper(raw_env) # Test converted observation spec. observation_spec = env.observation_spec() self.assertEqual(type(observation_spec), open_spiel_wrapper.OLT) self.assertEqual(type(observation_spec.observation), specs.Array) self.assertEqual(type(observation_spec.legal_actions), specs.Array) self.assertEqual(type(observation_spec.terminal), specs.Array) # Test converted action spec. action_spec: specs.DiscreteArray = env.action_spec() self.assertEqual(type(action_spec), specs.DiscreteArray) self.assertEqual(action_spec.shape, ()) self.assertEqual(action_spec.minimum, 0) self.assertEqual(action_spec.maximum, 8) self.assertEqual(action_spec.num_values, 9) self.assertEqual(action_spec.dtype, np.dtype('int32')) # Test step. timestep = env.reset() self.assertTrue(timestep.first()) _ = env.step([0]) env.close()
def test_loop_run(self): raw_env = rl_environment.Environment('tic_tac_toe') env = open_spiel_wrapper.OpenSpielWrapper(raw_env) env = wrappers.SinglePrecisionWrapper(env) environment_spec = acme.make_environment_spec(env) actors = [] for _ in range(env.num_players): actors.append(RandomActor(environment_spec)) loop = open_spiel_environment_loop.OpenSpielEnvironmentLoop( env, actors) result = loop.run_episode() self.assertIn('episode_length', result) self.assertIn('episode_return', result) self.assertIn('steps_per_second', result) loop.run(num_episodes=10) loop.run(num_steps=100)
def main(_): # Create an environment and grab the spec. env_configs = {'players': FLAGS.num_players} if FLAGS.num_players else {} raw_environment = rl_environment.Environment(FLAGS.game, **env_configs) environment = open_spiel_wrapper.OpenSpielWrapper(raw_environment) environment = wrappers.SinglePrecisionWrapper( environment) # type: open_spiel_wrapper.OpenSpielWrapper environment_spec = acme.make_environment_spec(environment) # Build the networks. networks = [] policy_networks = [] for _ in range(environment.num_players): network = legal_actions.MaskedSequential([ snt.Flatten(), snt.nets.MLP([50, 50, environment_spec.actions.num_values]) ]) policy_network = snt.Sequential([ network, legal_actions.EpsilonGreedy(epsilon=0.1, threshold=-1e8) ]) networks.append(network) policy_networks.append(policy_network) # Construct the agents. agents = [] for network, policy_network in zip(networks, policy_networks): agents.append( dqn.DQN(environment_spec=environment_spec, network=network, policy_network=policy_network)) # Run the environment loop. loop = open_spiel_environment_loop.OpenSpielEnvironmentLoop( environment, agents) loop.run(num_episodes=100000)