def build_model(game): num_actions = game.num_distinct_actions() observation_shape = game.observation_tensor_shape() net = model_lib.keras_mlp( observation_shape, num_actions, num_layers=2, num_hidden=64) return model_lib.Model( net, l2_regularization=1e-4, learning_rate=0.01, device="cpu")
def build_model(game, model_type): return model_lib.Model(model_type, game.observation_tensor_shape(), game.num_distinct_actions(), nn_width=64, nn_depth=2, weight_decay=1e-4, learning_rate=0.01, path=None)
def main(_): game = pyspiel.load_game("tic_tac_toe") # 1. Define a model model = model_lib.Model( FLAGS.nn_model, game.observation_tensor_shape(), game.num_distinct_actions(), nn_width=FLAGS.nn_width, nn_depth=FLAGS.nn_depth, weight_decay=1e-4, learning_rate=0.01, path=None) print("Model type: {}({}, {}), size: {} variables".format( FLAGS.nn_model, FLAGS.nn_width, FLAGS.nn_depth, model.num_trainable_variables)) # 2. Create an MCTS bot using the model evaluator = evaluator_lib.AlphaZeroEvaluator(game, model) bot = mcts.MCTSBot(game, 1., 20, evaluator, solve=False, dirichlet_noise=(0.25, 1.)) # 3. Build an AlphaZero instance a0 = alpha_zero.AlphaZero(game, bot, model, replay_buffer_capacity=FLAGS.replay_buffer_capacity, action_selection_transition=4) # 4. Create a bot using min-max search. It can never lose tic-tac-toe, so # a success condition for our AlphaZero bot is to draw all games with it. minimax_bot = MinimaxBot(game) # 5. Run training loop for num_round in range(FLAGS.num_rounds): logging.info("------------- Starting round %s out of %s -------------", num_round, FLAGS.num_rounds) if num_round % FLAGS.evaluation_frequency == 0: num_evaluations = 50 logging.info("Playing %s games against the minimax player.", num_evaluations) (_, losses, draws) = bot_evaluation(game, [minimax_bot, a0.bot], num_evaluations=50) logging.info("Result against Minimax player: %s losses and %s draws.", losses, draws) logging.info("Running %s games of self play", FLAGS.num_self_play_games) a0.self_play(num_self_play_games=FLAGS.num_self_play_games) logging.info("Training the net for %s epochs.", FLAGS.num_training_epochs) a0.update(FLAGS.num_training_epochs, batch_size=FLAGS.batch_size, verbose=True) logging.info("Cache: %s", evaluator.cache_info()) evaluator.clear_cache()
def build_model(game, model_type): num_actions = game.num_distinct_actions() observation_shape = game.observation_tensor_shape() if model_type == "resnet": net = model_lib.keras_resnet( observation_shape, num_actions, num_residual_blocks=2, num_filters=64, value_head_hidden_size=64, data_format="channels_first") elif model_type == "mlp": net = model_lib.keras_mlp( observation_shape, num_actions, num_layers=2, num_hidden=64) else: raise ValueError("Invalid model_type: {}".format(model_type)) return model_lib.Model( net, l2_regularization=1e-4, learning_rate=0.01, device="cpu")
def main(_): game = pyspiel.load_game("tic_tac_toe") num_actions = game.num_distinct_actions() observation_shape = game.observation_tensor_shape() # 1. Define a keras net if FLAGS.net_type == "resnet": net = model_lib.keras_resnet(observation_shape, num_actions, num_residual_blocks=1, num_filters=256, data_format="channels_first") elif FLAGS.net_type == "mlp": net = model_lib.keras_mlp(observation_shape, num_actions, num_layers=2, num_hidden=64) else: raise ValueError( ("Invalid value for 'net_type'. Must be either 'mlp' or " "'resnet', but was %s") % FLAGS.net_type) model = model_lib.Model(net, l2_regularization=1e-4, learning_rate=0.01, device=FLAGS.device) # 2. Create an MCTS bot using the previous keras net evaluator = evaluator_lib.AlphaZeroEvaluator(game, model) bot = mcts.MCTSBot(game, 1., 20, evaluator, solve=False, dirichlet_noise=(0.25, 1.)) # 3. Build an AlphaZero instance a0 = alpha_zero.AlphaZero( game, bot, model, replay_buffer_capacity=FLAGS.replay_buffer_capacity, action_selection_transition=4) # 4. Create a bot using min-max search. It can never lose tic-tac-toe, so # a success condition for our AlphaZero bot is to draw all games with it. minimax_bot = MinimaxBot(game) # 5. Run training loop for num_round in range(FLAGS.num_rounds): logging.info("------------- Starting round %s out of %s -------------", num_round, FLAGS.num_rounds) if num_round % FLAGS.evaluation_frequency == 0: num_evaluations = 50 logging.info("Playing %s games against the minimax player.", num_evaluations) (_, losses, draws) = bot_evaluation(game, [minimax_bot, a0.bot], num_evaluations=50) logging.info( "Result against Minimax player: %s losses and %s draws.", losses, draws) logging.info("Running %s games of self play", FLAGS.num_self_play_games) a0.self_play(num_self_play_games=FLAGS.num_self_play_games) logging.info("Training the net for %s epochs.", FLAGS.num_training_epochs) a0.update(FLAGS.num_training_epochs, batch_size=FLAGS.batch_size, verbose=True) logging.info("Cache: %s", evaluator.cache_info()) evaluator.clear_cache()