Beispiel #1
0
def main(cmd):
    cfg = yaml.safe_load(cmd.cfg.read())
    print(yaml.dump(cfg, default_flow_style=False))

    root_dir = os.path.join(cfg['training']['path'], cfg['name'])
    if not os.path.exists(root_dir):
        os.makedirs(root_dir)

    tfprocess = TFProcess(cfg)
    tfprocess.init_net_v2()

    tfprocess.restore_v2()

    START_FROM = cmd.start

    tfprocess.global_step.assign(START_FROM)
    tfprocess.manager.save(checkpoint_number=START_FROM)
Beispiel #2
0
def main(cmd):
    cfg = yaml.safe_load(cmd.cfg.read())
    print(yaml.dump(cfg, default_flow_style=False))

    num_chunks = cfg['dataset']['num_chunks']
    allow_less = cfg['dataset'].get('allow_less_chunks', False)
    train_ratio = cfg['dataset']['train_ratio']
    experimental_parser = cfg['dataset'].get('experimental_v5_only_dataset',
                                             False)
    num_train = int(num_chunks * train_ratio)
    num_test = num_chunks - num_train
    if 'input_test' in cfg['dataset']:
        train_chunks = get_latest_chunks(cfg['dataset']['input_train'],
                                         num_train, allow_less)
        test_chunks = get_latest_chunks(cfg['dataset']['input_test'], num_test,
                                        allow_less)
    else:
        chunks = get_latest_chunks(cfg['dataset']['input'], num_chunks,
                                   allow_less)
        if allow_less:
            num_train = int(len(chunks) * train_ratio)
            num_test = len(chunks) - num_train
        train_chunks = chunks[:num_train]
        test_chunks = chunks[num_train:]

    shuffle_size = cfg['training']['shuffle_size']
    total_batch_size = cfg['training']['batch_size']
    batch_splits = cfg['training'].get('num_batch_splits', 1)
    train_workers = cfg['dataset'].get('train_workers', None)
    test_workers = cfg['dataset'].get('test_workers', None)
    if total_batch_size % batch_splits != 0:
        raise ValueError('num_batch_splits must divide batch_size evenly')
    split_batch_size = total_batch_size // batch_splits
    # Load data with split batch size, which will be combined to the total batch size in tfprocess.
    ChunkParser.BATCH_SIZE = split_batch_size

    root_dir = os.path.join(cfg['training']['path'], cfg['name'])
    if not os.path.exists(root_dir):
        os.makedirs(root_dir)
    tfprocess = TFProcess(cfg)
    experimental_reads = max(2, mp.cpu_count() - 2) // 2
    extractor = select_extractor(tfprocess.INPUT_MODE)

    def read(x):
        return tf.data.FixedLengthRecordDataset(
            x,
            8308,
            compression_type='GZIP',
            num_parallel_reads=experimental_reads)

    if experimental_parser:
        train_dataset = tf.data.Dataset.from_tensor_slices(train_chunks).shuffle(len(train_chunks)).repeat().batch(256)\
                         .interleave(read, num_parallel_calls=2)\
                         .batch(SKIP_MULTIPLE*SKIP).map(semi_sample).unbatch()\
                         .shuffle(shuffle_size)\
                         .batch(split_batch_size).map(extractor).prefetch(4)
    else:
        train_parser = ChunkParser(train_chunks,
                                   tfprocess.INPUT_MODE,
                                   shuffle_size=shuffle_size,
                                   sample=SKIP,
                                   batch_size=ChunkParser.BATCH_SIZE,
                                   workers=train_workers)
        train_dataset = tf.data.Dataset.from_generator(
            train_parser.parse,
            output_types=(tf.string, tf.string, tf.string, tf.string,
                          tf.string))
        train_dataset = train_dataset.map(ChunkParser.parse_function)
        train_dataset = train_dataset.prefetch(4)

    shuffle_size = int(shuffle_size * (1.0 - train_ratio))
    if experimental_parser:
        test_dataset = tf.data.Dataset.from_tensor_slices(test_chunks).shuffle(len(test_chunks)).repeat().batch(256)\
                         .interleave(read, num_parallel_calls=2)\
                         .batch(SKIP_MULTIPLE*SKIP).map(semi_sample).unbatch()\
                         .shuffle(shuffle_size)\
                         .batch(split_batch_size).map(extractor).prefetch(4)
    else:
        test_parser = ChunkParser(test_chunks,
                                  tfprocess.INPUT_MODE,
                                  shuffle_size=shuffle_size,
                                  sample=SKIP,
                                  batch_size=ChunkParser.BATCH_SIZE,
                                  workers=test_workers)
        test_dataset = tf.data.Dataset.from_generator(
            test_parser.parse,
            output_types=(tf.string, tf.string, tf.string, tf.string,
                          tf.string))
        test_dataset = test_dataset.map(ChunkParser.parse_function)
        test_dataset = test_dataset.prefetch(4)

    validation_dataset = None
    if 'input_validation' in cfg['dataset']:
        valid_chunks = get_all_chunks(cfg['dataset']['input_validation'])
        validation_dataset = tf.data.FixedLengthRecordDataset(valid_chunks, 8308, compression_type='GZIP', num_parallel_reads=experimental_reads)\
                               .batch(split_batch_size, drop_remainder=True).map(extractor).prefetch(4)

    tfprocess.init_v2(train_dataset, test_dataset, validation_dataset)

    tfprocess.restore_v2()

    # If number of test positions is not given
    # sweeps through all test chunks statistically
    # Assumes average of 10 samples per test game.
    # For simplicity, testing can use the split batch size instead of total batch size.
    # This does not affect results, because test results are simple averages that are independent of batch size.
    num_evals = cfg['training'].get('num_test_positions',
                                    len(test_chunks) * 10)
    num_evals = max(1, num_evals // ChunkParser.BATCH_SIZE)
    print("Using {} evaluation batches".format(num_evals))

    tfprocess.process_loop_v2(total_batch_size,
                              num_evals,
                              batch_splits=batch_splits)

    if cmd.output is not None:
        if cfg['training'].get('swa_output', False):
            tfprocess.save_swa_weights_v2(cmd.output)
        else:
            tfprocess.save_leelaz_weights_v2(cmd.output)

    train_parser.shutdown()
    test_parser.shutdown()