Пример #1
0
def main():
    resnet_policy = ResNetPolicy(stack_nb=2)
    # resnet_policy = UnitizedResNetPolicy(stack_nb=1)
    resnet = resnet_policy.network
    trainer = get_trainer(PrePolicyTrainer,
                          resnet,
                          0,
                          get_data_file('pre', 0),
                          batch_size=128,
                          epochs=200,
                          verbose=1)

    def scheduler(epoch):
        if epoch <= 60:
            return 0.05
        if epoch <= 120:
            return 0.01
        if epoch <= 160:
            return 0.002
        return 0.0004

    callbacks = [LearningRateScheduler(scheduler)]

    config_file_path = get_config_file('pre', resnet.name, 0)
    weight_file_path = get_weight_file('pre', resnet.name, 0)

    trainer.train(True, 0.9, callbacks=callbacks)

    json_dump_tuple(resnet.get_config(), config_file_path)
    resnet.save_weights(weight_file_path)
Пример #2
0
def main(index):
    resnet_mixture = ResNetMixture(stack_nb=2)
    resnet = resnet_mixture.network
    trainer = get_trainer(PreMixtureTrainer,
                          resnet,
                          0,
                          get_data_file('pre', 0),
                          batch_size=128,
                          epochs=200,
                          verbose=1)

    def scheduler(epoch):
        if epoch <= 60:
            return 0.05
        if epoch <= 120:
            return 0.01
        if epoch <= 160:
            return 0.002
        return 0.0004

    callbacks = [LearningRateScheduler(scheduler), Stop(0.16, 0.5)]

    config_file_path = get_config_file('pre', resnet.name, index)
    weight_file_path = get_weight_file('pre', resnet.name, index)

    trainer.train(True, 0.9, callbacks=callbacks)

    json_dump_tuple(resnet.get_config(), config_file_path)
    resnet.save_weights(weight_file_path)
Пример #3
0
from AlphaGomoku.mcts.rollout_mcts import RolloutMCTS
from AlphaGomoku.neural_networks import get_network
from AlphaGomoku.neural_networks.keras.weights import get_weight_file
from AlphaGomoku.board import Board
from AlphaGomoku.play import Human, Game

policy = get_network('policy', 'resnet', 'keras', stack_nb=2)
policy.load_weights(get_weight_file('pre', policy.network.name, 0))
mcts = RolloutMCTS(policy, policy, 100, thread_number=4,
                   depth=6, delete_threshold=10)

mcts.get_action = lambda board: mcts.mcts(board, 1)

pattern = int(input('1: player vs ai, 2: ai vs player, 3: ai vs ai: '))
if pattern == 1:
    black = Human()
    white = mcts
elif pattern == 2:
    black = mcts
    white = Human()
else:
    black = mcts
    white = mcts

game = Game(black, white, visualization=True, visualization_time=3)
game.play()
Пример #4
0
from AlphaGomoku.mcts.evaluation_mcts import EvaluationMCTS
from AlphaGomoku.mcts.rl_evaluation_mcts import RLEvaluationMCTS
from AlphaGomoku.neural_networks import get_network
from AlphaGomoku.neural_networks.keras.weights import get_weight_file
from AlphaGomoku.board import Board
from AlphaGomoku.play import Human, Game

mixture = get_network('mixture', 'resnet', 'keras', stack_nb=2)
# mixture.load_weights(get_weight_file('pre', mixture.network.name, 0))
# mixture.load_weights(get_weight_file('zero', mixture.network.name, '1_0'))
mixture.load_weights(get_weight_file('zero', mixture.network.name, '4'))
mcts = EvaluationMCTS(mixture, 500, thread_number=1, delete_threshold=10)
# mcts = RLEvaluationMCTS(mixture, 500, thread_number=4, delete_threshold=10)

mcts.get_action = lambda board: mcts.mcts(board, 1)

pattern = int(input('1: player vs ai, 2: ai vs player, 3: ai vs ai: '))
if pattern == 1:
    black = Human()
    white = mcts
elif pattern == 2:
    black = mcts
    white = Human()
else:
    black = mcts
    white = mcts

game = Game(black, white, visualization=True, visualization_time=3)
game.play()
Пример #5
0
from AlphaGomoku.rl.rl import EvaluationMainLoop
from AlphaGomoku.mcts.evaluation_mcts import EvaluationMCTS
from AlphaGomoku.mcts.rl_evaluation_mcts import RLEvaluationMCTS
from AlphaGomoku.neural_networks import get_network
from AlphaGomoku.neural_networks.keras.weights import get_weight_file

mixture = get_network('mixture', 'resnet', 'keras', stack_nb=2)
mixture.load_weights(get_weight_file('pre', mixture.network.name, 0))

kwargs = {
    'self_play_mcts_config': {
        'traverse_time': 10,
        'c_puct': None,
        'thread_number': 1,
        'delete_threshold': 100
    },
    'self_play_number': 3,
    'self_play_batch_size': 3,
    'self_play_cache_step': 10,
    'evaluate_mcts_config': {
        'traverse_time': 10,
        'c_puct': None,
        'thread_number': 1,
        'delete_threshold': 100
    },
    'evaluate_number': 5,
    'evaluate_batch_size': 10,
    'evaluate_win_ratio': 0.55,
    'evaluate_cache_step': 5,
    'train': {
        'batch_size': 128,