Example #1
0
    def setUpClass(cls):
        """Hook method for setting up class fixture before running tests in
        the class.
        """
        # GCP
        cls.project: str = xds_flags.PROJECT.value
        cls.network: str = xds_flags.NETWORK.value
        cls.gcp_service_account: str = xds_k8s_flags.GCP_SERVICE_ACCOUNT.value
        cls.td_bootstrap_image = xds_k8s_flags.TD_BOOTSTRAP_IMAGE.value
        cls.xds_server_uri = xds_flags.XDS_SERVER_URI.value
        cls.ensure_firewall = xds_flags.ENSURE_FIREWALL.value
        cls.firewall_allowed_ports = xds_flags.FIREWALL_ALLOWED_PORTS.value

        # Resource names.
        # TODO(sergiitk): Drop namespace parsing when --namespace is removed.
        cls.resource_prefix = (xds_flags.RESOURCE_PREFIX.value
                               or xds_flags.NAMESPACE.value)
        if not cls.resource_prefix:
            raise flags.IllegalFlagValueError(
                'Required one of the flags: --resource_prefix or --namespace')

        if xds_flags.RESOURCE_SUFFIX.value is not None:
            cls._resource_suffix_randomize = False
            cls.resource_suffix = xds_flags.RESOURCE_SUFFIX.value

        # Test server
        cls.server_image = xds_k8s_flags.SERVER_IMAGE.value
        cls.server_name = xds_flags.SERVER_NAME.value
        cls.server_port = xds_flags.SERVER_PORT.value
        cls.server_maintenance_port = xds_flags.SERVER_MAINTENANCE_PORT.value
        cls.server_xds_host = xds_flags.SERVER_NAME.value
        cls.server_xds_port = xds_flags.SERVER_XDS_PORT.value

        # Test client
        cls.client_image = xds_k8s_flags.CLIENT_IMAGE.value
        cls.client_name = xds_flags.CLIENT_NAME.value
        cls.client_port = xds_flags.CLIENT_PORT.value

        # Test suite settings
        cls.force_cleanup = _FORCE_CLEANUP.value
        cls.debug_use_port_forwarding = \
            xds_k8s_flags.DEBUG_USE_PORT_FORWARDING.value
        cls.check_local_certs = _CHECK_LOCAL_CERTS.value

        # Resource managers
        cls.k8s_api_manager = k8s.KubernetesApiManager(
            xds_k8s_flags.KUBE_CONTEXT.value)
        cls.gcp_api_manager = gcp.api.GcpApiManager()
Example #2
0
    def _parse(self, value):
        # A `DictFlag` should not be overridable from the command line; only the
        # dotted `Item` flags should be. However, the _parse() method will still be
        # called in two situations:

        # 1. Via the base `Flag`'s constructor, which calls `_parse()` to process
        #    the default value, which will be the shared dict.
        # 2. When processing command line overrides. We don't want to allow this
        #    normally, however some libraries will serialize and deserialize all
        #    flags, e.g. to pass values between processes, so we accept a dummy
        #    empty serialized value for these cases. It's unlikely users will try to
        #    set the dict flag to an empty string from the command line.
        if value is self._shared_dict or value == _EMPTY:
            return self._shared_dict
        raise flags.IllegalFlagValueError(
            "Can't override a dict flag directly. Did you mean to override one of "
            "its `Item`s instead?")
Example #3
0
def get_relative_artifacts_dir() -> str:
  if len(FLAGS.functions) > 1:
    # We only allow testing multiple functions with a single target backend
    # so that we can store the artifacts under:
    #   'artifacts_dir/multiple_functions__backend/...'
    # We specialize the 'multiple_functions' dir by backend to avoid overwriting
    # tf_input.mlir and iree_input.mlir. These are typically identical across
    # backends, but are not when the functions to compile change per-backend.
    if len(FLAGS.target_backends) != 1:
      raise flags.IllegalFlagValueError(
          "Expected len(target_backends) == 1 when len(functions) > 1, but got "
          f"the following values for target_backends: {FLAGS.target_backends}.")
    function_str = f"multiple_functions__{FLAGS.target_backends[0]}"
  else:
    function_str = FLAGS.functions[0]
  dim_str = "dynamic_dims" if FLAGS.dynamic_dims else "static_dims"
  complex_str = "complex" if FLAGS.test_complex else "non_complex"
  return os.path.join("tf", "math", function_str, f"{dim_str}_{complex_str}")
Example #4
0
def main(argv):
  del argv  # Unused.
  if hasattr(tf, "enable_v2_behavior"):
    tf.enable_v2_behavior()

  if FLAGS.list_functions_with_complex_tests:
    for function_name, unit_test_specs in FUNCTIONS_TO_UNIT_TEST_SPECS.items():
      for spec in unit_test_specs:
        if tf_utils.is_complex(spec.input_signature):
          print(f'    "{function_name}",')
    return

  if FLAGS.functions is None:
    raise flags.IllegalFlagValueError(
        "'--functions' must be specified if "
        "'--list_functions_with_complex_tests' isn't")

  TfMathTest.generate_unit_tests(TfMathModule)
  tf.test.main()