def RegisterFlagValidator( flag_name: str, checker: Callable[[Any], bool], message: str = "Flag validation failed", ): """Adds a constraint, which will be enforced during program execution. The constraint is validated when flags are initially parsed, and after each change of the corresponding flag's value. Args: flag_name: str, name of the flag to be checked. checker: callable, a function to validate the flag. input - A single positional argument: The value of the corresponding flag (string, boolean, etc. This value will be passed to checker by the library). output - bool, True if validator constraint is satisfied. If constraint is not satisfied, it should either return False or raise flags.ValidationError(desired_error_message). message: str, error text to be shown to the user if checker returns False. If checker raises flags.ValidationError, message from the raised error will be shown. Raises: AttributeError: Raised when flag_name is not registered as a valid flag name. """ absl_flags.register_validator(flag_name, checker, message)
def define_flags(): """ description: """ flags.DEFINE_string("model_file", None, "Path and file name of the TFLite model file.") flags.DEFINE_integer("appended_resource_id", None, "Appended resource file to print") flags.mark_flag_as_required("model_file") flags.register_validator("model_file", lambda value: os.path.exists(value), message="model_file does not exists")
def setup_flags(required_flags=()): flags.DEFINE_bool('cross', False, 'Enables cross-validation.') flags.DEFINE_bool('balanced', False, 'Set to learn on balanced data.') flags.DEFINE_string('outfile', None, 'Path to output file.') flags.DEFINE_string('datafile', None, 'Path to file with data.') flags.DEFINE_integer('gamescount', 1, 'Number of games to play.') def file_exists_if_arg_given(path): return path is None or Path(path).is_file() flags.register_validator('datafile', file_exists_if_arg_given, message='Datafile must exists.') for required_flag in required_flags: flags.mark_flag_as_required(required_flag)
def _validate_flags(): flags.register_validator('checkpoint_path', bool, 'Must provide `checkpoint_path`.') flags.register_validator( 'generated_x_dir', lambda x: False if (FLAGS.image_set_y_glob and not x) else True, 'Must provide `generated_x_dir`.') flags.register_validator( 'generated_y_dir', lambda x: False if (FLAGS.image_set_x_glob and not x) else True, 'Must provide `generated_y_dir`.')
def main(): """Parse flags.""" flags.register_validator( "features", lambda values: all(value in _FEATURES for value in values), message="Flags --features contains unknown value(s).") flags.register_validator( "mode", lambda value: value is not None, message="Flag --mode must be set with either `predict` or `train` value" ) flags.register_validator( "targets", lambda values: all(value in _TARGETS for value in values), message="Flag --targets contains unknown value(s).") app.run(run)
from absl import flags from absl import app from absl import logging import json import jsonlines import tensorflow as tf import re FLAGS = flags.FLAGS # TODO: Compute basic stats for text fields and labels. flags.DEFINE_string('text_fields_re', None, 'Matcher for names of the text fields.') flags.register_validator( 'text_fields_re', lambda value: isinstance(value, str) and re.compile(value), message='--text_fields_re must be a regexp string.') flags.DEFINE_string('label_fields_re', None, 'Matcher for names of the label fields.') flags.register_validator( 'label_fields_re', lambda value: isinstance(value, str) and re.compile(value), message='--label_fields_re must be a regexp string.') flags.DEFINE_string('input_jsonlines_path', None, 'Path to the JSON-lines input file.') flags.register_validator( 'input_jsonlines_path', lambda value: isinstance(value, str), message='--input_jsonlines_path must be a string.')
# limitations under the License. """Tests for flagsaver.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from absl import flags from absl.testing import absltest from absl.testing import flagsaver flags.DEFINE_string('flagsaver_test_flag0', 'unchanged0', 'flag to test with') flags.DEFINE_string('flagsaver_test_flag1', 'unchanged1', 'flag to test with') flags.DEFINE_string('flagsaver_test_validated_flag', None, 'flag to test with') flags.register_validator('flagsaver_test_validated_flag', lambda x: not x) flags.DEFINE_string('flagsaver_test_validated_flag1', None, 'flag to test with') flags.DEFINE_string('flagsaver_test_validated_flag2', None, 'flag to test with') @flags.multi_flags_validator( ('flagsaver_test_validated_flag1', 'flagsaver_test_validated_flag2')) def validate_test_flags(flag_dict): return (flag_dict['flagsaver_test_validated_flag1'] == flag_dict['flagsaver_test_validated_flag2']) FLAGS = flags.FLAGS
import json import logging import os from absl import flags from perfkitbenchmarker import configs from perfkitbenchmarker import data from perfkitbenchmarker import sample from perfkitbenchmarker import vm_util from perfkitbenchmarker.linux_packages import netperf from six.moves import range flags.DEFINE_list('bidirectional_network_tests', ['TCP_STREAM', 'TCP_MAERTS', 'TCP_MAERTS'], 'The network tests to run.') flags.register_validator( 'bidirectional_network_tests', lambda benchmarks: benchmarks and set(benchmarks).issubset(ALL_TESTS)) flags.DEFINE_integer('bidirectional_network_test_length', 60, 'bidirectional_network test length, in seconds', lower_bound=1) flags.DEFINE_integer('bidirectional_stream_num_streams', 8, 'Number of netperf processes to run.', lower_bound=1) ALL_TESTS = ['TCP_STREAM', 'TCP_MAERTS'] FLAGS = flags.FLAGS
import numpy as np import coords import go import sys # 505 moves for 19x19, 113 for 9x9 flags.DEFINE_integer('max_game_length', int(go.N ** 2 * 1.4), 'Move number at which game is forcibly terminated') flags.DEFINE_float('c_puct', 1.38, 'Exploration constant balancing priors vs. value net output.') flags.DEFINE_float('dirichlet_noise_alpha', 0.03 * 361 / (go.N ** 2), 'Concentrated-ness of the noise being injected into priors.') flags.register_validator('dirichlet_noise_alpha', lambda x: 0 <= x < 1) flags.DEFINE_float('dirichlet_noise_weight', 0.25, 'How much to weight the priors vs. dirichlet noise when mixing') flags.register_validator('dirichlet_noise_weight', lambda x: 0 <= x < 1) FLAGS = flags.FLAGS class DummyNode(object): """A fake node of a MCTS search tree. This node is intended to be a placeholder for the root node, which would otherwise have no parent node. If all nodes have parents, code becomes simpler."""
PID_PREFIX = 'TF_PS_PID' MODELS = [ 'vgg11', 'vgg16', 'vgg19', 'lenet', 'googlenet', 'overfeat', 'alexnet', 'trivial', 'inception3', 'inception4', 'resnet50', 'resnet101', 'resnet152' ] FP16 = 'float16' FP32 = 'float32' flags.DEFINE_boolean( 'tf_forward_only', False, '''whether use forward-only or training for benchmarking''') flags.DEFINE_list('tf_models', ['inception3', 'vgg16', 'alexnet', 'resnet50', 'resnet152'], 'name of the models to run') flags.register_validator( 'tf_models', lambda models: models and set(models).issubset(MODELS), 'Invalid models list. tf_models must be a subset of ' + ', '.join(MODELS)) flags.DEFINE_string( 'tf_data_dir', None, 'Path to dataset in TFRecord format (aka Example ' 'protobufs). If not specified, synthetic data will be ' 'used.') flags.DEFINE_string('tf_data_module', 'tensorflow/ILSVRC2012', 'Data path in preprovisioned data bucket.') flags.DEFINE_integer('tf_num_files_train', 1024, 'The number of files for training') flags.DEFINE_integer('tf_num_files_val', 128, 'The number of files for validation') flags.DEFINE_enum('tf_data_name', 'imagenet', ['imagenet', 'flowers'], 'Name of dataset: imagenet or flowers.') flags.DEFINE_list( 'tf_batch_sizes', None, 'batch sizes per compute device. '
arch=FLAGS.arch, warmup_pos=FLAGS.warmup_pos, batch=FLAGS.batch, nclass=dataset.nclass, ema=FLAGS.ema, beta=FLAGS.beta, consistency_weight=FLAGS.consistency_weight, scales=FLAGS.scales or (log_width - 2), filters=FLAGS.filters, repeat=FLAGS.repeat) model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10) if __name__ == '__main__': utils.setup_tf() flags.DEFINE_float('consistency_weight', 50., 'Consistency weight.') flags.DEFINE_float('warmup_pos', 0.4, 'Relative position at which constraint loss warmup ends.') flags.DEFINE_float('wd', 0.02, 'Weight decay.') flags.DEFINE_float('ema', 0.999, 'Exponential moving average of params.') flags.DEFINE_float('beta', 0.5, 'Mixup beta.') flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.') flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.') flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.') FLAGS.set_default('dataset', 'cifar10.3@250-5000') FLAGS.set_default('batch', 64) FLAGS.set_default('lr', 0.002) FLAGS.set_default('train_kimg', 1 << 16) flags.register_validator('nu', lambda nu: nu == 2, message='nu must be 2 for pi-model.') app.run(main)
'learning rate. Use 0 to do no warmup.') flags.DEFINE_integer('num_bins', 15, 'Number of bins for ECE.') flags.DEFINE_float('one_minus_momentum', 0.1, 'Optimizer momentum.') flags.DEFINE_string('output_dir', '/tmp/cifar', 'Output directory.') flags.DEFINE_integer( 'per_core_batch_size', 64, 'Batch size per TPU core/GPU. The number of new ' 'datapoints gathered per batch is this number divided by ' 'ensemble_size (we tile the batch by that # of times).') flags.DEFINE_integer('seed', 42, 'Random seed.') flags.DEFINE_integer('train_epochs', 200, 'Number of training epochs.') flags.DEFINE_float('train_proportion', default=1.0, help='only use a proportion of training set.') flags.register_validator('train_proportion', lambda tp: tp > 0.0 and tp <= 1.0, message='--train_proportion must be in (0, 1].') # Accelerator flags. flags.DEFINE_bool('use_gpu', False, 'Whether to run on GPU or otherwise TPU.') flags.DEFINE_integer('num_cores', 8, 'Number of TPU cores or number of GPUs.') flags.DEFINE_string('tpu', None, 'Name of the TPU. Only used if use_gpu is False.') flags.DEFINE_bool('use_bfloat16', False, 'Whether to use mixed precision.') def load_cifar100_c(corruption_name, corruption_intensity, batch_size, use_bfloat16, path,
"`--runs_per_benchmark=5` will execute each benchmark 5 times in " "parallel. When calling nitroml.results.overview(), metrics in benchmark " "run results can be optionally aggregated to compute means, standard " "deviations, and other aggregate metrics.") def _validate_regex(regex: Text) -> bool: try: re.compile(regex) return True except re.error: return False flags.register_validator("match", _validate_regex, message="--match must be a valid regex.") def _qualified_name(prefix: Text, name: Text) -> Text: return "{}.{}".format(prefix, name) if prefix else name class _BenchmarkPipeline(object): """A pipeline for a benchmark.""" def __init__(self, benchmark_name: Text, base_pipeline: List[base_component.BaseComponent], evaluator: tfx.Evaluator = None, add_evaluator: bool = True): self._benchmark_name = benchmark_name
flags.DEFINE_string('db_user', os.getenv(f'{PREFIX}_DB_USER', None), 'MAD DB user') flags.DEFINE_integer('batchsize', os.getenv(f'{PREFIX}_BATCHSIZE', 5), 'Queries per request') flags.DEFINE_integer( 'loop_interval', os.getenv(f'{PREFIX}_LOOP_INTERVAL', None), 'Interval in hours to poll repeatedly. A random 5% jitter is applied.') def not_null(value): return value is not None # API endpoint & token must be provided flags.register_validator('token', not_null, message='API token must be set', flag_values=FLAGS) flags.register_validator('api', not_null, message='API endpoint must be set', flag_values=FLAGS) # DB credentials are also mandatory and have no reasonable defaults flags.register_validator('db_pass', not_null, message='DB password must be set', flag_values=FLAGS) flags.register_validator('db_user', not_null, message='DB user must be set', flag_values=FLAGS)
from ct.crypto import pem from ct.crypto.asn1 import print_util FLAGS = gflags.FLAGS gflags.DEFINE_bool("subject", False, "Print option: prints certificate subject") gflags.DEFINE_bool("issuer", False, "Print option: prints certificate issuer") gflags.DEFINE_bool("fingerprint", False, "Print option: prints certificate " "fingerprint") gflags.DEFINE_string("digest", "sha1", "Print option: fingerprint digest to use") gflags.DEFINE_bool("debug", False, "Print option: prints full ASN.1 debug information") gflags.DEFINE_string("filetype", "", "Read option: specify an input file " "format (pem or der). If no format is specified, the " "parser attempts to detect the format automatically.") gflags.register_validator("filetype", lambda value: not value or value.lower() in {"pem", "der"}, message="--filetype must be one of pem or der") def print_cert(certificate): if not FLAGS.subject and not FLAGS.issuer and not FLAGS.fingerprint: if FLAGS.debug: print "%r" % certificate else: print certificate else: if FLAGS.subject: print "subject:\n%s" % certificate.print_subject_name() if FLAGS.issuer: print "issuer:\n%s" % certificate.print_issuer_name() if FLAGS.fingerprint:
from absl import logging FLAGS = flags.FLAGS flags.DEFINE_string('dataset_path', None, 'Path to the JSON file containing ' 'the dataset.') flags.DEFINE_string('split_path', None, 'Path to the JSON file containing ' 'split information.') flags.DEFINE_string('save_path', None, 'Path to the directory where to ' 'save the files to.') flags.mark_flag_as_required('save_path') flags.register_validator('dataset_path', os.path.exists, 'Dataset not found.') flags.register_validator('split_path', os.path.exists, 'Split not found.') Dataset = Dict[Text, List[Tuple[Text, Text]]] def load_json(path: Text) -> Any: logging.info(f'Reading json from {path} into memory...') with open(path, 'r', encoding='utf-8') as f: data = json.load(f) logging.info(f'Successfully loaded json data from {path} into memory.') return data def tokenize_punctuation(text: Text) -> Text: text = map(lambda c: f' {c} ' if c in string.punctuation else c, text)
from absl import app from absl import flags from cfq import evaluate as evaluator FLAGS = flags.FLAGS flags.DEFINE_string('questions_path', None, 'Path to the input questions.') flags.DEFINE_string('golden_answers_path', None, 'Path to the expected (golden) answers.') flags.DEFINE_string('inferred_answers_path', None, 'Path to the inferred answers.') flags.DEFINE_string('output_path', None, 'Path to write evaluation results to') flags.mark_flag_as_required('output_path') flags.register_validator('questions_path', os.path.exists, 'Questions path not found.') flags.register_validator('golden_answers_path', os.path.exists, 'Golden answers path not found.') flags.register_validator('inferred_answers_path', os.path.exists, 'Inferred answers path not found.') def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') accuracy_result = evaluator.get_accuracy_result(FLAGS.questions_path, FLAGS.golden_answers_path, FLAGS.inferred_answers_path) evaluator.write_accuracy_result( accuracy_result, FLAGS.output_path, print_output=True)
'If testing GSI/LSI, use the primary keyname' 'of the index you want to test') flags.DEFINE_enum('aws_dynamodb_attributetype', 'S', ['S', 'N', 'B'], 'The type of attribute, default to S (String).' 'Alternates are N (Number) and B (Binary).') flags.DEFINE_integer('aws_dynamodb_read_capacity', '5', 'Set RCU for dynamodb table') flags.DEFINE_integer('aws_dynamodb_write_capacity', '5', 'Set WCU for dynamodb table') flags.DEFINE_integer('aws_dynamodb_lsi_count', 0, 'Set amount of Local Secondary Indexes. Only set 0-5') flags.register_validator('aws_dynamodb_lsi_count', lambda value: -1 < value < 6, message='--count must be from 0-5') flags.register_validator('aws_dynamodb_use_sort', lambda sort: sort or not FLAGS.aws_dynamodb_lsi_count, message='--aws_dynamodb_lsi_count requires sort key.') flags.DEFINE_integer('aws_dynamodb_gsi_count', 0, 'Set amount of Global Secondary Indexes. Only set 0-5') flags.register_validator('aws_dynamodb_gsi_count', lambda value: -1 < value < 6, message='--count must be from 0-5') flags.DEFINE_boolean('aws_dynamodb_ycsb_consistentReads', False, "Consistent reads cost 2x eventual reads. " "'false' is default which is eventual") flags.DEFINE_integer('aws_dynamodb_connectMax', 50, 'Maximum number of concurrent dynamodb connections. '
flags.DEFINE_bool('tabular_solver', True, 'Use tabular solver?') flags.DEFINE_string('env_name', 'grid', 'Environment to evaluate on.') flags.DEFINE_string('solver_name', 'dice', 'Type of solver to use.') flags.DEFINE_string('save_dir', None, 'Directory to save results to.') flags.DEFINE_float('function_exponent', 1.5, 'Exponent for f function in DualDICE.') flags.DEFINE_bool('deterministic_env', False, 'assume deterministic env.') flags.DEFINE_integer('batch_size', 512, 'batch_size for training models.') flags.DEFINE_integer('num_steps', 200000, 'num_steps for training models.') flags.DEFINE_integer('log_every', 500, 'log after certain number of steps.') flags.DEFINE_float('nu_learning_rate', 0.0001, 'nu lr') flags.DEFINE_float('zeta_learning_rate', 0.001, 'z lr') flags.register_validator('solver_name', lambda value: value in ['dice'], message='Unknown solver.') flags.register_validator('env_name', lambda value: value in ['grid'], message='Unknown environment.') flags.register_validator('alpha', lambda value: 0 <= value <= 1, message='Invalid value.') def get_env_and_policies(env_name, tabular_obs, alpha): """Get environment and policies.""" if env_name == 'grid': length = 10 env = gridworld_envs.GridWalk(length, tabular_obs) policy0 = gridworld_policies.get_behavior_gridwalk_policy(
'Defaults to 1.') _MPSTAT_PUBLISH = flags.DEFINE_boolean( 'mpstat_publish', False, 'Whether to publish mpstat statistics.') _MPSTAT_PUBLISH_PER_INTERVAL_SAMPLES = flags.DEFINE_boolean( 'mpstat_publish_per_interval_samples', False, 'Whether to publish a separate mpstat statistics sample ' 'for each interval. If True, --mpstat_publish must be True.') FLAGS = flags.FLAGS _TWENTY_THREE_HOURS_IN_SECONDS = 23 * 60 * 60 flags.register_validator( _MPSTAT_INTERVAL.name, lambda value: value < _TWENTY_THREE_HOURS_IN_SECONDS, message=('If --mpstat_interval must be less than 23 hours (if it\'s set ' 'near or above 24 hours, it becomes hard to infer sample ' 'timestamp from mpstat output.')) flags.register_validator( _MPSTAT_PUBLISH_PER_INTERVAL_SAMPLES.name, lambda value: FLAGS.mpstat_publish or not value, message=('If --mpstat_publish_per_interval is True, --mpstat_publish must ' 'be True.')) def _ParseStartTime(output: str) -> float: """Parse the start time of the mpstat report. Args: output: output of mpstat
'Base learning rate when total batch size is 128. It is ' 'scaled by the ratio of the total batch size to 128.') flags.DEFINE_float('one_minus_momentum', 0.1, 'Optimizer momentum.') flags.DEFINE_integer( 'lr_warmup_epochs', 1, 'Number of epochs for a linear warmup to the initial ' 'learning rate. Use 0 to do no warmup.') flags.DEFINE_float('lr_decay_ratio', 0.2, 'Amount to decay learning rate.') flags.DEFINE_list('lr_decay_epochs', ['60', '120', '160'], 'Epochs to decay learning rate by.') flags.DEFINE_float( 'train_proportion', 1., 'Only a fraction (between 0 and 1) of the train set is used for training. ' 'The remainder can be used for validation.') flags.register_validator('train_proportion', lambda tp: tp > 0.0 and tp <= 1.0, message='--train_proportion must be in (0, 1].') flags.DEFINE_float('l2', 2e-4, 'L2 regularization coefficient.') flags.DEFINE_float('label_smoothing', 0., 'Label smoothing parameter in [0,1].') flags.register_validator('label_smoothing', lambda ls: ls >= 0.0 and ls <= 1.0, message='--label_smoothing must be in [0, 1].') flags.DEFINE_bool( 'download_data', False, 'Whether to download data locally when initializing the dataset.') # Data Augmentation flags. flags.DEFINE_bool('augmix', False, 'Whether to perform AugMix [4] on the input data.')
'', 'Where the results json file should be written. It will only be written if ' 'all benchmarks are run. Use an empty string to avoid writing this file.') flags.DEFINE_string( 'notes', '', 'Any notes to write into the results json file.') def benchmark_name_validator(benchmark_name): """Checks that benchmark_name is "ALL", or refers to exactly one benchmark.""" return (benchmark_name == 'ALL' or all_benchmarks.find_benchmark_with_name(benchmark_name) is not None) flags.register_validator('benchmark_name', benchmark_name_validator, message=('benchmark_name must be "ALL" or refer to ' 'exactly one benchmark.')) def run_on_all_benchmarks(): """Runs value search on all benchmarks, printing results to stdout.""" benchmark_count = 0 benchmark_success = 0 unsolved_benchmarks = [] solution_times = [] # Only including successful tasks. settings = settings_module.from_list(FLAGS.settings)
from absl.testing import absltest import glob import os import inspect from importlib import import_module from functools import reduce FLAGS = flags.FLAGS flags.DEFINE_list( 'folders', ['./'], 'Comma-separated list of folders to run all tests in. \ If no input is given, runs all tests in this directory.') # Make sure all folders specified exist. flags.register_validator( 'folders', lambda folder: reduce(lambda x, y: x and y, map(os.path.exists, folder)), message="Specified folder(s) does not exist.") def find_modules(folders): """Returns a list of all modules in test_folders directories. Args: folders: list of folders to check for modules Returns: A list of tuples (module_name, imported module) """ # Get all python files list_of_files = [] for folder in folders:
# distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Module containing redis installation and cleanup functions.""" from typing import Any, Dict from absl import flags from perfkitbenchmarker import linux_packages _VERSION = flags.DEFINE_string('redis_server_version', '6.2.1', 'Version of redis server to use.') flags.register_validator( 'redis_server_version', lambda version: int(version.split('.')[0]) >= 6, message='Redis version must be 6 or greater.') _IO_THREADS = flags.DEFINE_integer( 'redis_server_io_threads', 4, 'Only supported for redis version >= 6, the ' 'number of redis server IO threads to use.') _IO_THREADS_DO_READS = flags.DEFINE_bool( 'redis_server_io_threads_do_reads', False, 'If true, makes both reads and writes use IO threads instead of just ' 'writes.') _IO_THREAD_AFFINITY = flags.DEFINE_bool( 'redis_server_io_threads_cpu_affinity', False, 'If true, attempts to pin IO threads to CPUs.') _ENABLE_SNAPSHOTS = flags.DEFINE_bool( 'redis_server_enable_snapshots', False, 'If true, uses the default redis snapshot policy.') _NUM_PROCESSES = flags.DEFINE_integer(
FLAGS = gflags.FLAGS gflags.DEFINE_bool("subject", False, "Print option: prints certificate subject") gflags.DEFINE_bool("issuer", False, "Print option: prints certificate issuer") gflags.DEFINE_bool("fingerprint", False, "Print option: prints certificate " "fingerprint") gflags.DEFINE_string("digest", "sha1", "Print option: fingerprint digest to use") gflags.DEFINE_bool("debug", False, "Print option: prints full ASN.1 debug information") gflags.DEFINE_string( "filetype", "", "Read option: specify an input file " "format (pem or der). If no format is specified, the " "parser attempts to detect the format automatically.") gflags.register_validator( "filetype", lambda value: not value or value.lower() in {"pem", "der"}, message="--filetype must be one of pem or der") def print_cert(certificate): if not FLAGS.subject and not FLAGS.issuer and not FLAGS.fingerprint: if FLAGS.debug: print "%r" % certificate else: print certificate else: if FLAGS.subject: print "subject:\n%s" % certificate.print_subject_name() if FLAGS.issuer: print "issuer:\n%s" % certificate.print_issuer_name() if FLAGS.fingerprint:
flags.DEFINE_string('output_file', None, 'The file to output csv data to.') flags.mark_flag_as_required('output_file') def validate_time_parsable(time_string: Text): try: parser.parse(time_string) except ValueError as e: logging.error('Could not parse {}, {}'.format(time_string, e)) return False return True flags.mark_flags_as_required(['start_time', 'end_time']) flags.register_validator('start_time', validate_time_parsable, message='--start_time must be parsable by dateutil.') flags.register_validator('end_time', validate_time_parsable, message='--end_time must be parsable by dateutil.') @flags.multi_flags_validator(['start_time', 'end_time'], message='start_time must be before end_time') def validate_start_time_before_end_time(flags_dict): return parser.parse(flags_dict['start_time']) < parser.parse( flags_dict['end_time']) flags.mark_flag_as_required('granularity') flags.register_validator(
FLAGS = flags.FLAGS flags.DEFINE_string( "lexicon_dir", "src/analyzer/lexicon/base", "Path to the directory that contains the lexicon TSV dumps.") flags.DEFINE_string( "morphotactics_dir", "src/analyzer/morphotactics/model", "Path to the directory that contains the text files that define" " rewrite rules of morphotactics model.") flags.DEFINE_string( "output_dir", "bin", "Path to the directory to which compiled OpenFST format transducer" " specification and symbols table file will be written to as text file") flags.register_validator("lexicon_dir", lambda v: os.path.isdir(v)) flags.register_validator("morphotactics_dir", lambda v: os.path.isdir(v)) _RewriteRule = rule_pb2.RewriteRule _RewriteRuleSet = rule_pb2.RewriteRuleSet _SYMBOLS_REGEX = re.compile( # First inflectional group. r"\(.+?\[[A-Z\.,:\(\)\'\-\"`\$]+?\]|" # Inflectional group boundaries. r"\)\(\[[A-Z]+?\]|" # Derivational morphemes. r"-(?:[^\W\d_]|')+?\[[A-z]+?=[A-z]+?\]|" # Inflectional morphemes and features. r"\+(?:[^\W\d_]|['\.])*?\[[A-z]+?=[A-z0-9]+?\]|" # Proper noun analysis.
def modify_validators(): def no_space(value): return ' ' not in value flags.register_validator('flagsaver_test_flag0', no_space) self.assertLen(FLAGS['flagsaver_test_flag0'].validators, 2)
flags.DEFINE_integer('image_size', 256, 'Width and height to crop training and testing frames. ' 'Must be a multiple of 16', lower_bound=16) flags.DEFINE_integer('batch_size', 16, 'Training batch size.', lower_bound=1) flags.DEFINE_float('learning_rate', 2e-5, 'Learning rate for Adam optimization.', lower_bound=0.0) flags.register_validator('image_size', lambda image_size: image_size % 16 == 0, message='\'image_size\' must multiple of 16.') flags.mark_flag_as_required('model_dir') flags.mark_flag_as_required('train_pattern') flags.mark_flag_as_required('test_pattern') def main(_): inference_fn = network.inference hparams = tf.contrib.training.HParams(learning_rate=FLAGS.learning_rate) model_fn = estimator.create_model_fn(inference_fn, hparams) config = tf.estimator.RunConfig(FLAGS.model_dir) tf_estimator = tf.estimator.Estimator(model_fn=model_fn, config=config) train_dataset_fn = dataset.create_dataset_fn(FLAGS.train_pattern,
from absl import flags import numpy as np import coords import go import mcts import sgf_wrapper from utils import dbg from player_interface import MCTSPlayerInterface flags.DEFINE_integer( 'softpick_move_cutoff', (go.N * go.N // 12) // 2 * 2, 'The move number (<) up to which moves are softpicked from MCTS visits.') # Ensure that both white and black have an equal number of softpicked moves. flags.register_validator('softpick_move_cutoff', lambda x: x % 2 == 0) flags.DEFINE_float( 'resign_threshold', -0.9, 'The post-search Q evaluation at which resign should happen.' 'A threshold of -1 implies resign is disabled.') flags.register_validator('resign_threshold', lambda x: -1 <= x < 0) flags.DEFINE_integer( 'num_readouts', 800, 'Number of searches to add to the MCTS search tree before playing a move.') flags.register_validator('num_readouts', lambda x: x > 0) flags.DEFINE_integer( 'parallel_readouts', 8, 'Number of searches to execute in parallel. This is also the batch size'
from xls.common import gfile from xls.experimental.smtlib import flags_checks from xls.experimental.smtlib import n_bit_nested_add_generator from xls.experimental.smtlib import n_bit_nested_mul_generator from xls.experimental.smtlib import n_bit_nested_shift_generator from xls.experimental.smtlib import solvers_op_comparison_functions FLAGS = flags.FLAGS flags.DEFINE_string("op", None, "Operation for the smt2 files (add, mul, shl)") flags.DEFINE_integer("nests", None, "Integer for the number of nested operations.") flags.DEFINE_list("bits_list", None, "List of n values for each n-bit multiplication proof.") flags.register_validator( "bits_list", flags_checks.list_contains_only_integers, message="--bits_list must contain only integers.") flags.DEFINE_list("solvers", None, "List of solvers to test.") flags.DEFINE_string("fname", None, "Name for the file to store the data.") flags.mark_flag_as_required("op") flags.mark_flag_as_required("nests") flags.mark_flag_as_required("bits_list") flags.mark_flag_as_required("fname") flags.mark_flag_as_required("solvers") def create_and_get_smt_files_bits_list(op, nests_val, bits_list): """Creates smt2 files for the necessary proof and return them in a list. Given an operation, the number of nests, and a list of bits, create SMTLIB2