def cli(): # Add current working directory to system search paths # before attempt any import statement sys.path.insert(0, os.path.abspath(os.path.curdir)) app.call_after_init(read_config_file) app.run(main, flags_parser=flags_parser)
def main(enable_v2_behavior=True, config_logical_devices=True): """All-in-one main function for tf.distribute tests.""" if config_logical_devices: app.call_after_init(_set_logical_devices) if enable_v2_behavior: v2_compat.enable_v2_behavior() else: v2_compat.disable_v2_behavior() multi_process_runner.test_main()
def config_with_absl(self): # Run this before calling `app.run(main)` etc import absl.flags as absl_FLAGS # noqa: F401 from absl import app, flags as absl_flags self.use_absl = True self.absl_flags = absl_flags absl_defs = { bool: absl_flags.DEFINE_bool, int: absl_flags.DEFINE_integer, str: absl_flags.DEFINE_string, 'enum': absl_flags.DEFINE_enum } for name, val in self.values.items(): flag_type, meta_args, meta_kwargs = self.meta[name] absl_defs[flag_type](name, val, *meta_args, **meta_kwargs) app.call_after_init(lambda: self.complete_absl_config(absl_flags))
def real_main(argv): """The main function.""" if os.environ.get('APP_TEST_PRINT_ARGV', False): sys.stdout.write('argv: {}\n'.format(' '.join(argv))) if FLAGS.raise_exception: raise MyException if FLAGS.raise_usage_error: if FLAGS.usage_error_exitcode is not None: raise app.UsageError('Error!', FLAGS.usage_error_exitcode) else: raise app.UsageError('Error!') if FLAGS.faulthandler_sigsegv: faulthandler._sigsegv() # pylint: disable=protected-access sys.exit(1) # Should not reach here. if FLAGS.print_init_callbacks: app.call_after_init( lambda: _callback_results.append('during real_main')) for value in _callback_results: print('callback: {}'.format(value)) sys.exit(0) # Ensure that we have a random C++ flag in flags.FLAGS; this shows # us that app.run() did the right thing in conjunction with C++ flags. helper_type = os.environ['APP_TEST_HELPER_TYPE'] if helper_type == 'clif': if 'heap_check_before_constructors' in flags.FLAGS: print('PASS: C++ flag present and helper_type is {}'.format( helper_type)) sys.exit(0) else: print('FAILED: C++ flag absent but helper_type is {}'.format( helper_type)) sys.exit(1) elif helper_type == 'pure_python': if 'heap_check_before_constructors' in flags.FLAGS: print('FAILED: C++ flag present but helper_type is pure_python') sys.exit(1) else: print('PASS: C++ flag absent and helper_type is pure_python') sys.exit(0) else: print('Unexpected helper_type "{}"'.format(helper_type)) sys.exit(1)
flags.DEFINE_bool('whiten', False, 'Whether to normalize images.') flags.DEFINE_string('data_dir', None, 'Data directory. ' 'If None then environment variable ML_DATA ' 'will be used as a data directory.') FLAGS = flags.FLAGS def _data_setup(): # set up data directory global DATA_DIR DATA_DIR = FLAGS.data_dir or os.environ['ML_DATA'] app.call_after_init(_data_setup) def record_parse_mnist(serialized_example, image_shape=None): features = tf.parse_single_example( serialized_example, features={'image': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([], tf.int64)}) image = tf.image.decode_image(features['image']) if image_shape: image.set_shape(image_shape) image = tf.pad(image, [[2] * 2, [2] * 2, [0] * 2]) image = tf.cast(image, tf.float32) * (2.0 / 255) - 1.0 return dict(image=image, label=features['label'])
@contextlib.contextmanager def set_availables_tmp(new_ds_types: Iterable[DatasetType]) -> Iterator[None]: """Contextmanager/decorator version of `set_availables`.""" old_ds_types = set(_current_available) try: set_availables(new_ds_types) yield finally: set_availables(old_ds_types) # Restore previous permissions def _set_default_visibility() -> None: """Overwrites the default visibility for the TFDS scripts. If the script executed is a TFDS script, then it restricts the visibility to only open-source non-community datasets. """ import __main__ # pytype: disable=import-error # pylint: disable=g-import-not-at-top main_file = getattr(__main__, '__file__', None) if main_file and 'tensorflow_datasets' in pathlib.Path(main_file).parts: # If the script is launched from within a TFDS script, we disable community # datasets and restrict scripts to only public datasets. # Accessing community datasets should be explicitly requested. set_availables([ DatasetType.TFDS_PUBLIC, ]) app.call_after_init(_set_default_visibility)
passed to the worker processes each time a test case is ran. Returns: a TestEnvironment object. """ return _env def _set_total_phsyical_gpus(): if in_main_process(): env().total_phsyical_gpus = len( context.context().list_physical_devices("GPU")) # This is needed in case CUDA is lazily loaded. app.call_after_init(_set_total_phsyical_gpus) _TestResult = collections.namedtuple("_TestResult", ["status", "message"]) def _test_runner(test_id, test_env): """Executes the test with the given test_id. This is a simple wrapper around TestRunner to be used with multi_process_runner. Similar to test.main(), but it executes only one test specified by test_id and returns whether the test succeeds. If the test fails, the function prints failures and errors to stdout. Args: test_id: TestCase.id() test_env: a TestEnvironment object.
'flags_parser, but found: {}'.format(argv)) def flags_parser(argv): print('Function called: flags_parser.') if os.environ.get('APP_TEST_FLAGS_PARSER_PARSE_FLAGS', None): FLAGS(argv) return flags_parser_argv_sentinel # Holds results from callbacks triggered by `app.run_after_init`. _callback_results = [] if __name__ == '__main__': kwargs = {'main': main} main_function_name = os.environ.get('APP_TEST_CUSTOM_MAIN_FUNC', None) if main_function_name: kwargs['main'] = globals()[main_function_name] custom_argv = os.environ.get('APP_TEST_CUSTOM_ARGV', None) if custom_argv: kwargs['argv'] = custom_argv.split(' ') if os.environ.get('APP_TEST_USE_CUSTOM_PARSER', None): kwargs['flags_parser'] = flags_parser app.call_after_init(lambda: _callback_results.append('before app.run')) app.install_exception_handler(MyExceptionHandler('first')) app.install_exception_handler(MyExceptionHandler('second')) app.run(**kwargs) sys.exit('This is not reachable.')