Ejemplo n.º 1
0
def cli():
    # Add current working directory to system search paths
    # before attempt any import statement
    sys.path.insert(0, os.path.abspath(os.path.curdir))

    app.call_after_init(read_config_file)
    app.run(main, flags_parser=flags_parser)
Ejemplo n.º 2
0
def main(enable_v2_behavior=True, config_logical_devices=True):
    """All-in-one main function for tf.distribute tests."""
    if config_logical_devices:
        app.call_after_init(_set_logical_devices)
    if enable_v2_behavior:
        v2_compat.enable_v2_behavior()
    else:
        v2_compat.disable_v2_behavior()
    multi_process_runner.test_main()
Ejemplo n.º 3
0
  def config_with_absl(self):
    # Run this before calling `app.run(main)` etc
    import absl.flags as absl_FLAGS  # noqa: F401
    from absl import app, flags as absl_flags

    self.use_absl = True
    self.absl_flags = absl_flags
    absl_defs = { bool: absl_flags.DEFINE_bool,
                  int:  absl_flags.DEFINE_integer,
                  str:  absl_flags.DEFINE_string,
                  'enum': absl_flags.DEFINE_enum }

    for name, val in self.values.items():
      flag_type, meta_args, meta_kwargs = self.meta[name]
      absl_defs[flag_type](name, val, *meta_args, **meta_kwargs)
    app.call_after_init(lambda: self.complete_absl_config(absl_flags))
Ejemplo n.º 4
0
def real_main(argv):
    """The main function."""
    if os.environ.get('APP_TEST_PRINT_ARGV', False):
        sys.stdout.write('argv: {}\n'.format(' '.join(argv)))

    if FLAGS.raise_exception:
        raise MyException

    if FLAGS.raise_usage_error:
        if FLAGS.usage_error_exitcode is not None:
            raise app.UsageError('Error!', FLAGS.usage_error_exitcode)
        else:
            raise app.UsageError('Error!')

    if FLAGS.faulthandler_sigsegv:
        faulthandler._sigsegv()  # pylint: disable=protected-access
        sys.exit(1)  # Should not reach here.

    if FLAGS.print_init_callbacks:
        app.call_after_init(
            lambda: _callback_results.append('during real_main'))
        for value in _callback_results:
            print('callback: {}'.format(value))
        sys.exit(0)

    # Ensure that we have a random C++ flag in flags.FLAGS; this shows
    # us that app.run() did the right thing in conjunction with C++ flags.
    helper_type = os.environ['APP_TEST_HELPER_TYPE']
    if helper_type == 'clif':
        if 'heap_check_before_constructors' in flags.FLAGS:
            print('PASS: C++ flag present and helper_type is {}'.format(
                helper_type))
            sys.exit(0)
        else:
            print('FAILED: C++ flag absent but helper_type is {}'.format(
                helper_type))
            sys.exit(1)
    elif helper_type == 'pure_python':
        if 'heap_check_before_constructors' in flags.FLAGS:
            print('FAILED: C++ flag present but helper_type is pure_python')
            sys.exit(1)
        else:
            print('PASS: C++ flag absent and helper_type is pure_python')
            sys.exit(0)
    else:
        print('Unexpected helper_type "{}"'.format(helper_type))
        sys.exit(1)
Ejemplo n.º 5
0
flags.DEFINE_bool('whiten', False, 'Whether to normalize images.')
flags.DEFINE_string('data_dir', None,
                    'Data directory. '
                    'If None then environment variable ML_DATA '
                    'will be used as a data directory.')

FLAGS = flags.FLAGS


def _data_setup():
    # set up data directory
    global DATA_DIR
    DATA_DIR = FLAGS.data_dir or os.environ['ML_DATA']


app.call_after_init(_data_setup)


def record_parse_mnist(serialized_example, image_shape=None):
    features = tf.parse_single_example(
        serialized_example,
        features={'image': tf.FixedLenFeature([], tf.string),
                  'label': tf.FixedLenFeature([], tf.int64)})
    image = tf.image.decode_image(features['image'])
    if image_shape:
        image.set_shape(image_shape)
    image = tf.pad(image, [[2] * 2, [2] * 2, [0] * 2])
    image = tf.cast(image, tf.float32) * (2.0 / 255) - 1.0
    return dict(image=image, label=features['label'])

Ejemplo n.º 6
0
@contextlib.contextmanager
def set_availables_tmp(new_ds_types: Iterable[DatasetType]) -> Iterator[None]:
    """Contextmanager/decorator version of `set_availables`."""
    old_ds_types = set(_current_available)
    try:
        set_availables(new_ds_types)
        yield
    finally:
        set_availables(old_ds_types)  # Restore previous permissions


def _set_default_visibility() -> None:
    """Overwrites the default visibility for the TFDS scripts.

  If the script executed is a TFDS script, then it restricts the visibility
  to only open-source non-community datasets.
  """
    import __main__  # pytype: disable=import-error  # pylint: disable=g-import-not-at-top
    main_file = getattr(__main__, '__file__', None)
    if main_file and 'tensorflow_datasets' in pathlib.Path(main_file).parts:
        # If the script is launched from within a TFDS script, we disable community
        # datasets and restrict scripts to only public datasets.
        # Accessing community datasets should be explicitly requested.
        set_availables([
            DatasetType.TFDS_PUBLIC,
        ])


app.call_after_init(_set_default_visibility)
Ejemplo n.º 7
0
  passed to the worker processes each time a test case is ran.

  Returns:
    a TestEnvironment object.
  """
    return _env


def _set_total_phsyical_gpus():
    if in_main_process():
        env().total_phsyical_gpus = len(
            context.context().list_physical_devices("GPU"))


# This is needed in case CUDA is lazily loaded.
app.call_after_init(_set_total_phsyical_gpus)

_TestResult = collections.namedtuple("_TestResult", ["status", "message"])


def _test_runner(test_id, test_env):
    """Executes the test with the given test_id.

  This is a simple wrapper around TestRunner to be used with
  multi_process_runner. Similar to test.main(), but it executes only one test
  specified by test_id and returns whether the test succeeds. If the test fails,
  the function prints failures and errors to stdout.

  Args:
    test_id: TestCase.id()
    test_env: a TestEnvironment object.
Ejemplo n.º 8
0
            'flags_parser, but found: {}'.format(argv))


def flags_parser(argv):
    print('Function called: flags_parser.')
    if os.environ.get('APP_TEST_FLAGS_PARSER_PARSE_FLAGS', None):
        FLAGS(argv)
    return flags_parser_argv_sentinel


# Holds results from callbacks triggered by `app.run_after_init`.
_callback_results = []

if __name__ == '__main__':
    kwargs = {'main': main}
    main_function_name = os.environ.get('APP_TEST_CUSTOM_MAIN_FUNC', None)
    if main_function_name:
        kwargs['main'] = globals()[main_function_name]
    custom_argv = os.environ.get('APP_TEST_CUSTOM_ARGV', None)
    if custom_argv:
        kwargs['argv'] = custom_argv.split(' ')
    if os.environ.get('APP_TEST_USE_CUSTOM_PARSER', None):
        kwargs['flags_parser'] = flags_parser

    app.call_after_init(lambda: _callback_results.append('before app.run'))
    app.install_exception_handler(MyExceptionHandler('first'))
    app.install_exception_handler(MyExceptionHandler('second'))
    app.run(**kwargs)

    sys.exit('This is not reachable.')