Esempio n. 1
0
  def test_booleans(self):
    """Test to ensure boolean flags trigger as expected.
    """

    flags_core.parse_flags([__file__, "--use_synthetic_data"])

    assert flags.FLAGS.use_synthetic_data
    def test_booleans(self):
        """Test to ensure boolean flags trigger as expected.
        """

        flags_core.parse_flags([__file__, "--use_synthetic_data"])

        assert flags.FLAGS.use_synthetic_data
Esempio n. 3
0
def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1):
  """Performs a minimal run of a model.

    This function is intended to test for syntax errors throughout a model. A
  very limited run is performed using synthetic data.

  Args:
    main: The primary function used to exercise a code path. Generally this
      function is "<MODULE>.main(argv)".
    tmp_root: Root path for the temp directory created by the test class.
    extra_flags: Additional flags passed by the caller of this function.
    synth: Use synthetic data.
    max_train: Maximum number of allowed training steps.
  """

  extra_flags = [] if extra_flags is None else extra_flags

  model_dir = tempfile.mkdtemp(dir=tmp_root)

  args = [sys.argv[0], "--model_dir", model_dir, "--train_epochs", "1",
          "--epochs_between_evals", "1"] + extra_flags

  if synth:
    args.append("--use_synthetic_data")

  if max_train is not None:
    args.extend(["--max_train_steps", str(max_train)])

  try:
    flags_core.parse_flags(argv=args)
    main(flags.FLAGS)
  finally:
    if os.path.exists(model_dir):
      shutil.rmtree(model_dir)
Esempio n. 4
0
def run_synthetic(main, tmp_root, extra_flags=None, synth=True):
    """Performs a minimal run of a model.

    This function is intended to test for syntax errors throughout a model. A
  very limited run is performed using synthetic data.

  Args:
    main: The primary function used to exercise a code path. Generally this
      function is "<MODULE>.main(argv)".
    tmp_root: Root path for the temp directory created by the test class.
    extra_flags: Additional flags passed by the caller of this function.
    synth: Use synthetic data.
  """

    extra_flags = [] if extra_flags is None else extra_flags

    model_dir = tempfile.mkdtemp(dir=tmp_root)

    args = [
        sys.argv[0], "--model_dir", model_dir, "--train_epochs", "1",
        "--epochs_between_evals", "1"
    ] + extra_flags

    if synth:
        args.append("--use_synthetic_data")

    try:
        flags_core.parse_flags(argv=args)
        main(flags.FLAGS)
    finally:
        if os.path.exists(model_dir):
            shutil.rmtree(model_dir)
def run_end_to_end(main: Callable[[Any], None],
                   extra_flags: Optional[Iterable[str]] = None,
                   model_dir: Optional[str] = None):
  """Runs the classifier trainer end-to-end."""
  extra_flags = [] if extra_flags is None else extra_flags
  args = [sys.argv[0], '--model_dir', model_dir] + extra_flags
  flags_core.parse_flags(argv=args)
  main(flags.FLAGS)
    def test_benchmark_setting(self):
        defaults = dict(
            hooks=["LoggingMetricHook"],
            benchmark_log_dir="/tmp/12345",
            gcp_project="project_abc",
        )

        flags_core.set_defaults(**defaults)
        flags_core.parse_flags()

        for key, value in defaults.items():
            assert flags.FLAGS.get_flag_value(name=key, default=None) == value
Esempio n. 7
0
  def test_benchmark_setting(self):
    defaults = dict(
        hooks=["LoggingMetricHook"],
        benchmark_log_dir="/tmp/12345",
        gcp_project="project_abc",
    )

    flags_core.set_defaults(**defaults)
    flags_core.parse_flags()

    for key, value in defaults.items():
      assert flags.FLAGS.get_flag_value(name=key, default=None) == value
Esempio n. 8
0
    def test_default_setting(self):
        """Test to ensure fields exist and defaults can be set. """
        defaults = dict(data_dir="dfgasf",
                        model_dir="dfsdkjgbs",
                        train_epochs=534,
                        epochs_between_evals=15,
                        batch_size=256,
                        hooks=["LoggingTensorHook"],
                        num_parallel_calls=18,
                        inter_op_parallelism_threads=5,
                        intra_op_parallelism_threads=10,
                        data_format="channels_first")

        flags_core.set_defaults(**defaults)
        flags_core.parse_flags()

        for key, value in defaults.items():
            assert flags.FLAGS.get_flag_value(name=key, default=None) == value
    def test_get_nondefault_flags_as_str(self):
        defaults = dict(clean=True,
                        data_dir="abc",
                        hooks=["LoggingTensorHook"],
                        stop_threshold=1.5,
                        use_synthetic_data=False)
        flags_core.set_defaults(**defaults)
        flags_core.parse_flags()

        expected_flags = ""
        self.assertEqual(flags_core.get_nondefault_flags_as_str(),
                         expected_flags)

        flags.FLAGS.clean = False
        expected_flags += "--noclean"
        self.assertEqual(flags_core.get_nondefault_flags_as_str(),
                         expected_flags)

        flags.FLAGS.data_dir = "xyz"
        expected_flags += " --data_dir=xyz"
        self.assertEqual(flags_core.get_nondefault_flags_as_str(),
                         expected_flags)

        flags.FLAGS.hooks = ["aaa", "bbb", "ccc"]
        expected_flags += " --hooks=aaa,bbb,ccc"
        self.assertEqual(flags_core.get_nondefault_flags_as_str(),
                         expected_flags)

        flags.FLAGS.stop_threshold = 3.
        expected_flags += " --stop_threshold=3.0"
        self.assertEqual(flags_core.get_nondefault_flags_as_str(),
                         expected_flags)

        flags.FLAGS.use_synthetic_data = True
        expected_flags += " --use_synthetic_data"
        self.assertEqual(flags_core.get_nondefault_flags_as_str(),
                         expected_flags)

        # Assert that explicit setting a flag to its default value does not cause it
        # to appear in the string
        flags.FLAGS.use_synthetic_data = False
        expected_flags = expected_flags[:-len(" --use_synthetic_data")]
        self.assertEqual(flags_core.get_nondefault_flags_as_str(),
                         expected_flags)
Esempio n. 10
0
  def test_parse_dtype_info(self):
    for dtype_str, tf_dtype, loss_scale in [["fp16", tf.float16, 128],
                                            ["fp32", tf.float32, 1]]:
      flags_core.parse_flags([__file__, "--dtype", dtype_str])

      self.assertEqual(flags_core.get_tf_dtype(flags.FLAGS), tf_dtype)
      self.assertEqual(flags_core.get_loss_scale(flags.FLAGS), loss_scale)

      flags_core.parse_flags(
          [__file__, "--dtype", dtype_str, "--loss_scale", "5"])
      self.assertEqual(flags_core.get_loss_scale(flags.FLAGS), 5)

      flags_core.parse_flags(
          [__file__, "--dtype", dtype_str, "--loss_scale", "dynamic"])
      self.assertEqual(flags_core.get_loss_scale(flags.FLAGS), "dynamic")

    with self.assertRaises(SystemExit):
      flags_core.parse_flags([__file__, "--dtype", "int8"])

    with self.assertRaises(SystemExit):
      flags_core.parse_flags([__file__, "--dtype", "fp16",
                              "--loss_scale", "abc"])
Esempio n. 11
0
  def test_default_setting(self):
    """Test to ensure fields exist and defaults can be set.
    """

    defaults = dict(
        data_dir="dfgasf",
        model_dir="dfsdkjgbs",
        train_epochs=534,
        epochs_between_evals=15,
        batch_size=256,
        hooks=["LoggingTensorHook"],
        num_parallel_calls=18,
        inter_op_parallelism_threads=5,
        intra_op_parallelism_threads=10,
        data_format="channels_first"
    )

    flags_core.set_defaults(**defaults)
    flags_core.parse_flags()

    for key, value in defaults.items():
      assert flags.FLAGS.get_flag_value(name=key, default=None) == value
Esempio n. 12
0
    def test_parse_dtype_info(self):
        for dtype_str, tf_dtype, loss_scale in [["fp16", tf.float16, 128],
                                                ["fp32", tf.float32, 1]]:
            flags_core.parse_flags([__file__, "--dtype", dtype_str])

            self.assertEqual(flags_core.get_tf_dtype(flags.FLAGS), tf_dtype)
            self.assertEqual(flags_core.get_loss_scale(flags.FLAGS),
                             loss_scale)

            flags_core.parse_flags(
                [__file__, "--dtype", dtype_str, "--loss_scale", "5"])
            self.assertEqual(flags_core.get_loss_scale(flags.FLAGS), 5)

            flags_core.parse_flags(
                [__file__, "--dtype", dtype_str, "--loss_scale", "dynamic"])
            self.assertEqual(flags_core.get_loss_scale(flags.FLAGS), "dynamic")

        with self.assertRaises(SystemExit):
            flags_core.parse_flags([__file__, "--dtype", "int8"])

        with self.assertRaises(SystemExit):
            flags_core.parse_flags(
                [__file__, "--dtype", "fp16", "--loss_scale", "abc"])
Esempio n. 13
0
    def test_parse_dtype_info(self):
        flags_core.parse_flags([__file__, "--dtype", "fp16"])
        self.assertEqual(flags_core.get_tf_dtype(flags.FLAGS), tf.float16)
        self.assertEqual(
            flags_core.get_loss_scale(flags.FLAGS, default_for_fp16=2), 2)

        flags_core.parse_flags(
            [__file__, "--dtype", "fp16", "--loss_scale", "5"])
        self.assertEqual(
            flags_core.get_loss_scale(flags.FLAGS, default_for_fp16=2), 5)

        flags_core.parse_flags(
            [__file__, "--dtype", "fp16", "--loss_scale", "dynamic"])
        self.assertEqual(
            flags_core.get_loss_scale(flags.FLAGS, default_for_fp16=2),
            "dynamic")

        flags_core.parse_flags([__file__, "--dtype", "fp32"])
        self.assertEqual(flags_core.get_tf_dtype(flags.FLAGS), tf.float32)
        self.assertEqual(
            flags_core.get_loss_scale(flags.FLAGS, default_for_fp16=2), 1)

        flags_core.parse_flags(
            [__file__, "--dtype", "fp32", "--loss_scale", "5"])
        self.assertEqual(
            flags_core.get_loss_scale(flags.FLAGS, default_for_fp16=2), 5)

        with self.assertRaises(SystemExit):
            flags_core.parse_flags([__file__, "--dtype", "int8"])

        with self.assertRaises(SystemExit):
            flags_core.parse_flags(
                [__file__, "--dtype", "fp16", "--loss_scale", "abc"])