Пример #1
0
def train(processed_dir="processed_data"):
    checkpoint_freq = 10000
    read_file = None
    save_file = 'tmp2'
    epochs = 10
    logdir = 'logs2'

    #
    test_dataset = DataSet.read(os.path.join(processed_dir, "test.chunk.gz"))
    train_chunk_files = [
        os.path.join(processed_dir, fname)
        for fname in os.listdir(processed_dir)
        if TRAINING_CHUNK_RE.match(fname)
    ]
    if read_file is not None:
        read_file = os.path.join(os.getcwd(), save_file)
    n = PolicyNetwork()
    n.initialize_variables(read_file)
    if logdir is not None:
        n.initialize_logging(logdir)
    last_save_checkpoint = 0
    for i in range(epochs):
        random.shuffle(train_chunk_files)
        for file in train_chunk_files:
            print("提取 %s" % file)
            with timer("load dataset"):
                train_dataset = DataSet.read(file)
            with timer("training"):
                n.train(train_dataset)
            with timer("save model"):
                n.save_variables(save_file)
            if n.get_global_step() > last_save_checkpoint + checkpoint_freq:
                with timer("test set evaluation"):
                    n.check_accuracy(test_dataset)
                last_save_checkpoint = n.get_global_step()
Пример #2
0
def train(processed_dir,
          save_file=None,
          epochs=10,
          logdir=None,
          checkpoint_freq=10000):
    test_dataset = DataSet.read(os.path.join(processed_dir, "test.chunk.gz"))
    train_chunk_files = [
        os.path.join(processed_dir, fname)
        for fname in os.listdir(processed_dir)
        if TRAINING_CHUNK_RE.match(fname)
    ]
    save_file = os.path.join(os.getcwd(), save_file)
    n = PolicyNetwork()
    try:
        n.initialize_variables(save_file)
    except:
        n.initialize_variables(None)
    if logdir is not None:
        n.initialize_logging(logdir)
    last_save_checkpoint = 0
    for i in range(epochs):
        random.shuffle(train_chunk_files)
        for file in train_chunk_files:
            print("Using %s" % file)
            train_dataset = DataSet.read(file)
            train_dataset.shuffle()
            with timer("training"):
                n.train(train_dataset)
            n.save_variables(save_file)
            if n.get_global_step() > last_save_checkpoint + checkpoint_freq:
                with timer("test set evaluation"):
                    n.check_accuracy(test_dataset)
                last_save_checkpoint = n.get_global_step()
Пример #3
0
def train(processed_dir,
          read_file=None,
          save_file=None,
          epochs=10,
          logdir=None,
          checkpoint_freq=10000):
    test_dataset = DataSet.read(os.path.join(processed_dir, "test.chunk.gz"))
    train_chunk_files = [
        os.path.join(processed_dir, fname)
        for fname in os.listdir(processed_dir)
        if TRAINING_CHUNK_RE.match(fname)
    ]
    n = PolicyNetwork(DEFAULT_FEATURES.planes)
    n.initialize_variables(read_file)
    if logdir is not None:
        n.initialize_logging(logdir)
    last_save_checkpoint = 0
    for i in range(epochs):
        random.shuffle(train_chunk_files)
        for file in train_chunk_files:
            print("Using %s" % file)
            train_dataset = DataSet.read(file)
            n.train(train_dataset)
            if save_file is not None and n.get_global_step(
            ) > last_save_checkpoint + checkpoint_freq:
                n.check_accuracy(test_dataset)
                print("Saving checkpoint to %s" % save_file, file=sys.stderr)
                last_save_checkpoint = n.get_global_step()
                n.save_variables(save_file)

    if save_file is not None:
        n.save_variables(save_file)
        print("Finished training. New model saved to %s" % save_file,
              file=sys.stderr)
Пример #4
0
class PolicyNetworkBestMovePlayer(GtpInterface):
    def __init__(self, read_file):
        self.policy_network = PolicyNetwork(DEFAULT_FEATURES.planes,
                                            use_cpu=True)
        self.read_file = read_file
        super().__init__()

    def clear(self):
        super().clear()
        self.refresh_network()

    def refresh_network(self):
        # Ensure that the player is using the latest version of the network
        # so that the network can be continually trained even as it's playing.
        self.policy_network.initialize_variables(self.read_file)

    def suggest_move(self, position):
        if position.recent and position.n > 100 and position.recent[-1] == None:
            # Pass if the opponent passes
            return None
        move_probabilities = self.policy_network.run(position)
        for move in sorted_moves(move_probabilities):
            if go.is_reasonable(position, move):
                return move
        return None
Пример #5
0
def train(processed_dir, read_file=None, save_file=None, epochs=10, logdir=None, checkpoint_freq=10000):
    test_dataset = DataSet.read(os.path.join(processed_dir, "test.chunk.gz"))
    train_chunk_files = [os.path.join(processed_dir, fname) 
        for fname in os.listdir(processed_dir)
        if TRAINING_CHUNK_RE.match(fname)]
    if read_file is not None:
        read_file = os.path.join(os.getcwd(), save_file)
    n = PolicyNetwork()
    n.initialize_variables(read_file)
    if logdir is not None:
        n.initialize_logging(logdir)
    last_save_checkpoint = 0
    for i in range(epochs):
        random.shuffle(train_chunk_files)
        for file in train_chunk_files:
            print("Using %s" % file)
            with timer("load dataset"):
                train_dataset = DataSet.read(file)
            with timer("training"):
                n.train(train_dataset)
            with timer("save model"):
                n.save_variables(save_file)
            if n.get_global_step() > last_save_checkpoint + checkpoint_freq:
                with timer("test set evaluation"):
                    n.check_accuracy(test_dataset)
                last_save_checkpoint = n.get_global_step()
Пример #6
0
def make_gtp_instance(strategy_name, read_file):
    n = PolicyNetwork(use_cpu=True)
    n.initialize_variables(read_file)
    if strategy_name == 'random':
        instance = RandomPlayer()
    elif strategy_name == 'policy':
        instance = GreedyPolicyPlayer(n)
    elif strategy_name == 'randompolicy':
        instance = RandomPolicyPlayer(n)
    elif strategy_name == 'mcts':
        instance = MCTSPlayer(n)
    else:
        return None
    gtp_engine = gtp.Engine(instance)
    return gtp_engine
Пример #7
0
def make_gtp_instance(strategy_name, read_file):
    n = PolicyNetwork(use_cpu=True)
    n.initialize_variables(read_file)
    if strategy_name == 'random':
        instance = RandomPlayer()
    elif strategy_name == 'policy':
        instance = GreedyPolicyPlayer(n)
    elif strategy_name == 'randompolicy':
        instance = RandomPolicyPlayer(n)
    elif strategy_name == 'mcts':
        instance = MCTSPlayer(n)
    else:
        return None
    gtp_engine = gtp.Engine(instance)
    return gtp_engine
Пример #8
0
def train(processed_dir,
          read_file=None,
          save_file=None,
          epochs=10,
          logdir=None,
          checkpoint_freq=10000):
    test_dataset = DataSet.read(os.path.join(processed_dir, 'test.chunk.gz'))
    #print(test_dataset)
    train_chunk_files = [
        os.path.join(processed_dir, fname)
        for fname in os.listdir(processed_dir)
        if TRAINING_CHUNK_RE.match(fname)
    ]
    print(train_chunk_files)
    if read_file is not None:
        read_file = os.path.join(os.getcwd(), save_file)
    n = PolicyNetwork()
    n.initialize_variables()
    if logdir is not None:
        n.initialize_logging(logdir)

    last_save_checkpoint = 0
    for i in range(epochs):
        random.shuffle(train_chunk_files)
        for file in tqdm.tqdm(train_chunk_files, desc='epochs ' + str(i)):
            #print('Using %s' % file)
            with timer('load dataset'):
                train_dataset = DataSet.read(file)
            with timer('training'):
                n.train(train_dataset)
            if n.get_global_step() > last_save_checkpoint + checkpoint_freq:
                with timer('save model'):
                    n.save_variables(save_file)
                with timer('test set evaluation'):
                    n.check_accuracy(test_dataset)
                last_save_checkpoint = n.get_global_step()
        with timer('test set evaluation'):
            n.check_accuracy(test_dataset)
Пример #9
0
        for name, shape in policy_vars:
            v = tf.contrib.framework.load_variable('model/sl/', name)
            new_vars.append(
                tf.Variable(v,
                            name=name.replace('PolicNetwork',
                                              'PlayerNetwork')))
        saver = tf.train.Saver(new_vars)
        sess.run(tf.global_variables_initializer())
        saver.save(sess,
                   os.path.join(save_dir, str(t), 'player' + str(t) + '.ckpt'))


g1 = tf.Graph()
with g1.as_default():
    train_net = PolicyNetwork(scope="PolicNetwork")
    train_net.initialize_variables('model/sl/epoch_48.ckpt')

pos = go.Position()
train_net.run(pos)

g2 = tf.Graph()
with g2.as_default():
    player_net = PolicyNetwork(scope="PlayerNetwork")
    player_net.initialize_variables('model/rl/2/player2.ckpt')
pos = go.Position()
player_net.run(pos)

save_trained_policy(1, 'model/rl')

print("===========load new model=================")
g2 = tf.Graph()
Пример #10
0
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from IPython import display
import pylab as pl
import numpy as np
import os
import random
import re
import sys
import go
from policy import PolicyNetwork
from strategies import MCTSPlayerMixin
read_file = "saved_models/20170718"
WHITE, EMPTY, BLACK, FILL, KO, UNKNOWN = range(-1, 5)
n = PolicyNetwork(use_cpu=True)
n.initialize_variables(read_file)
instance = MCTSPlayerMixin(n)


class User():
    def __init__(self, name, state_size, action_size):
        self.name = name
        self.state_size = state_size
        self.action_size = action_size

    def act(self, state, tau):
        action = int(input('Enter your chosen action: '))
        pi = np.zeros(self.action_size)
        pi[action] = 1
        value = None
        NN_value = None
Пример #11
0
class MCTS(GtpInterface):
    def __init__(self, read_file, seconds_per_move=5):
        self.seconds_per_move = seconds_per_move
        self.max_rollout_depth = go.N * go.N * 3
        self.policy_network = PolicyNetwork(DEFAULT_FEATURES.planes,
                                            use_cpu=True)
        self.read_file = read_file
        super().__init__()

    def clear(self):
        super().clear()
        self.refresh_network()

    def refresh_network(self):
        # Ensure that the player is using the latest version of the network
        # so that the network can be continually trained even as it's playing.
        self.policy_network.initialize_variables(self.read_file)

    def suggest_move(self, position):
        if position.caps[0] + 50 < position.caps[1]:
            return gtp.RESIGN
        start = time.time()
        move_probs = self.policy_network.run(position)
        root = MCTSNode.root_node(position, move_probs)
        while time.time() - start < self.seconds_per_move:
            self.tree_search(root)
        # there's a theoretical bug here: if you refuse to pass, this AI will
        # eventually start filling in its own eyes.
        return max(root.children.keys(),
                   key=lambda move, root=root: root.children[move].N)

    def tree_search(self, root):
        print("tree search", file=sys.stderr)
        # selection
        chosen_leaf = root.select_leaf()
        # expansion
        position = chosen_leaf.compute_position()
        if position is None:
            print("illegal move!", file=sys.stderr)
            # See go.Position.play_move for notes on detecting legality
            del chosen_leaf.parent.children[chosen_leaf.move]
            return
        print("Investigating following position:\n%s" %
              (chosen_leaf.position, ),
              file=sys.stderr)
        move_probs = self.policy_network.run(position)
        chosen_leaf.expand(move_probs)
        # evaluation
        value = self.estimate_value(chosen_leaf)
        # backup
        print("value: %s" % value, file=sys.stderr)
        chosen_leaf.backup_value(value)

    def estimate_value(self, chosen_leaf):
        # Estimate value of position using rollout only (for now).
        # (TODO: Value network; average the value estimations from rollout + value network)
        leaf_position = chosen_leaf.position
        current = leaf_position
        while current.n < self.max_rollout_depth:
            move_probs = self.policy_network.run(current)
            current = self.play_valid_move(current, move_probs)
            if len(current.recent
                   ) > 2 and current.recent[-1] == current.recent[-2] == None:
                break
        else:
            print("max rollout depth exceeded!", file=sys.stderr)

        perspective = 1 if leaf_position.player1turn else -1
        return current.score() * perspective

    def play_valid_move(self, position, move_probs):
        for move in sorted_moves(move_probs):
            if go.is_eyeish(position.board, move):
                continue
            candidate_pos = position.play_move(move, mutate=True)
            if candidate_pos is not None:
                return candidate_pos
        return position.pass_move(mutate=True)