def make_dataset_from_selfplay(data_extracts): """ Returns an iterable of tf.Examples. Args: data_extracts: An iterable of (position, pi, result) tuples """ f = dual_net.get_features() tf_examples = (make_tf_example(features_lib.extract_features(pos, f), pi, result) for pos, pi, result in data_extracts) return tf_examples
def test_make_dataset_from_sgf(self): with tempfile.NamedTemporaryFile() as sgf_file, \ tempfile.NamedTemporaryFile() as record_file: sgf_file.write(TEST_SGF.encode('utf8')) sgf_file.seek(0) preprocessing.make_dataset_from_sgf(sgf_file.name, record_file.name) recovered_data = self.extract_data(record_file.name) start_pos = go.Position() first_move = coords.from_sgf('fd') next_pos = start_pos.play_move(first_move) second_move = coords.from_sgf('cf') f = dual_net.get_features() expected_data = [ (features.extract_features(start_pos, f), preprocessing._one_hot(coords.to_flat(first_move)), -1), (features.extract_features(next_pos, f), preprocessing._one_hot(coords.to_flat(second_move)), -1) ] self.assertEqualData(expected_data, recovered_data)
def _make_tf_example_from_pwc(position_w_context): f = dual_net.get_features() features = features_lib.extract_features(position_w_context.position, f) pi = _one_hot(coords.to_flat(position_w_context.next_move)) value = position_w_context.result return make_tf_example(features, pi, value)