def train(processed_dir, read_file=None, save_file=None, epochs=2, logdir=None, checkpoint_freq=1000): test_dataset = DataSet.read(os.path.join(processed_dir, "test.chunk.gz")) train_chunk_files = [ os.path.join(processed_dir, fname) for fname in os.listdir(processed_dir) if TRAINING_CHUNK_RE.match(fname) ] save_file = os.path.join(os.getcwd(), save_file) n = PolicyNetWork() try: n.initialize_variables(save_file) except: n.initialize_variables(None) if logdir is not None: n.initialize_logging(logdir) last_save_checkpoint = 0 for i in range(epochs): random.shuffle(train_chunk_files) for file in train_chunk_files: print("Using %s" % file) train_dataset = DataSet.read(file) train_dataset.shuffle() with timer("training"): n.train(train_dataset) n.save_variables(save_file) if n.get_global_step() > last_save_checkpoint + checkpoint_freq: with timer("test set evaluation"): n.check_accuracy(test_dataset) last_save_checkpoint = n.get_global_step()
def train(args=args, hps=hps): from utils.load_data_sets import DataSet from Network import Network TRAINING_CHUNK_RE = re.compile(r"train\d+\.chunk.gz") run = Network(args, hps, args.load_model_path) test_dataset = DataSet.read( os.path.join(args.processed_dir, "test.chunk.gz")) train_chunk_files = [ os.path.join(args.processed_dir, fname) for fname in os.listdir(args.processed_dir) if TRAINING_CHUNK_RE.match(fname) ] random.shuffle(train_chunk_files) global_step = 0 for file in train_chunk_files: global_step += 1 print("Using %s" % file) train_dataset = DataSet.read(file) train_dataset.shuffle() with timer("training"): run.train(train_dataset) if global_step % 1 == 0: with timer("test set evaluation"): run.test(test_dataset, proportion=.1) print('Now, I am the Master.')
def train(flags=FLAGS, hps=HPS): from utils.load_data_sets import DataSet from Network import Network TRAINING_CHUNK_RE = re.compile(r"train\d+\.chunk.gz") run = Network(flags, hps) test_dataset = DataSet.read( os.path.join(flags.processed_dir, "test.chunk.gz")) train_chunk_files = [ os.path.join(flags.processed_dir, fname) for fname in os.listdir(flags.processed_dir) if TRAINING_CHUNK_RE.match(fname) ] random.shuffle(train_chunk_files) global_step = 0 lr = flags.lr with open("result.txt", "a") as f: for g_epoch in range(flags.global_epoch): for file in train_chunk_files: global_step += 1 # prepare training set print(f"Using {file}", file=f) train_dataset = DataSet.read(file) train_dataset.shuffle() with timer("training"): # train run.train(train_dataset) if global_step % 1 == 0: # eval with timer("test set evaluation"): run.test(test_dataset, proportion=.1) print(f'Global step {global_step} finshed.', file=f) print(f'Global epoch {g_epoch} finshed.', file=f) print('Now, I am the Master.', file=f)
def test(flags=FLAGS, hps=HPS): from utils.load_data_sets import DataSet from Network import Network import tensorflow as tf net = Network(flags, hps) # print(net.sess.run({var.name:var for var in tf.global_variables() if 'bn' in var.name})) test_dataset = DataSet.read(os.path.join(flags.processed_dir, "test.chunk.gz")) with timer("test set evaluation"): net.test(test_dataset, proportion=0.25, force_save_model=False)
def selfplay(flags=FLAGS, hps=HPS): from utils.load_data_sets import DataSet from model.SelfPlayWorker import SelfPlayWorker from Network import Network test_dataset = DataSet.read(os.path.join(flags.processed_dir, "test.chunk.gz")) #test_dataset = None """set the batch size to -1==None""" flags.n_batch = -1 net = Network(flags, hps) Worker = SelfPlayWorker(net, flags) def train(epoch: int): lr = schedule_lrn_rate(epoch) Worker.run(lr=lr) # TODO: consider tensorflow copy_to_graph def get_best_model(): return Network(flags, hps) def evaluate_generations(): best_model = get_best_model() Worker.evaluate_model(best_model) def evaluate_testset(): Worker.evaluate_testset(test_dataset) """Self Play Pipeline starts here""" for g_epoch in range(flags.global_epoch): logger.info(f'Global epoch {g_epoch} start.') """Train""" train(g_epoch) """Evaluate on test dataset""" evaluate_testset() """Evaluate against best model""" evaluate_generations() logger.info(f'Global epoch {g_epoch} finish.')
def train(flags=FLAGS, hps=HPS): from utils.load_data_sets import DataSet from Network import Network TRAINING_CHUNK_RE = re.compile(r"train\d+\.chunk.gz") net = Network(flags, hps) test_dataset = DataSet.read(os.path.join(flags.processed_dir, "test.chunk.gz")) train_chunk_files = [os.path.join(flags.processed_dir, fname) for fname in os.listdir(flags.processed_dir) if TRAINING_CHUNK_RE.match(fname)] def training_datasets(): random.shuffle(train_chunk_files) return (DataSet.read(file) for file in train_chunk_files) global_step = 0 lr = flags.lr with open("result.txt", "a") as f: for g_epoch in range(flags.global_epoch): """Train""" lr = schedule_lrn_rate(g_epoch) for train_dataset in training_datasets(): global_step += 1 # prepare training set logger.info(f"Global step {global_step} start") train_dataset.shuffle() with timer("training"): net.train(train_dataset, lrn_rate=lr) """Evaluate""" if global_step % 1 == 0: with timer("test set evaluation"): net.test(test_dataset, proportion=0.25, force_save_model=global_step % 10 == 0) logger.info(f'Global step {global_step} finshed.') logger.info(f'Global epoch {g_epoch} finshed.')
def preprocess(*data_sets, processed_dir="processed_data"): processed_dir = os.path.join(os.getcwd(), processed_dir) if not os.path.isdir(processed_dir): os.mkdir(processed_dir) test_chunk, training_chunks = parse_data_sets(*data_sets) print("Allocating %s positions as test; remainder as training" % len(test_chunk), file=sys.stderr) print("Writing test chunk") test_dataset = DataSet.from_positions_w_context(test_chunk, is_test=True) test_filename = os.path.join(processed_dir, "test.chunk.gz") test_dataset.write(test_filename) training_datasets = map(DataSet.from_positions_w_context, training_chunks) for i, train_dataset in enumerate(training_datasets): if i % 10 == 0: print("Writing training chunk %s" % i) train_filename = os.path.join(processed_dir, "train%s.chunk.gz" % i) train_dataset.write(train_filename) print("%s chunks written" % (i+1))
def preprocess(*data_sets, processed_dir="..\go_data\pre_data"): processed_dir = os.path.join(os.getcwd(), processed_dir) if not os.path.isdir(processed_dir): os.mkdir(processed_dir) test_chunk, training_chunks = parse_data_sets(*data_sets) print("%s的数据作为test(测试集),剩下的数据作为训练集" % len(test_chunk)) # , file=sys.stderr) print("制作test chunk(测试集)") test_dataset = DataSet.from_positions_w_context(test_chunk, is_test=True) test_filename = os.path.join(processed_dir, "test.chunk.gz") test_dataset.write(test_filename) print("制作train chunk(训练集)") training_datasets = map(DataSet.from_positions_w_context, training_chunks) for i, train_dataset in enumerate(training_datasets): if i % 10 == 0: print("已经制作了%s训练集" % (i + 1)) train_filename = os.path.join(processed_dir, "train%s.chunk.gz" % i) train_dataset.write(train_filename)
def training_datasets(): random.shuffle(train_chunk_files) return (DataSet.read(file) for file in train_chunk_files)