def test_booleans(self): """Test to ensure boolean flags trigger as expected. """ flags_core.parse_flags([__file__, "--use_synthetic_data"]) assert flags.FLAGS.use_synthetic_data
def test_benchmark_setting(self): defaults = dict( hooks=["LoggingMetricHook"], benchmark_log_dir="/tmp/12345", gcp_project="project_abc", ) flags_core.set_defaults(**defaults) flags_core.parse_flags() for key, value in defaults.items(): assert flags.FLAGS.get_flag_value(name=key, default=None) == value
def test_get_nondefault_flags_as_str(self): defaults = dict(clean=True, data_dir="abc", hooks=["LoggingTensorHook"], stop_threshold=1.5, use_synthetic_data=False) flags_core.set_defaults(**defaults) flags_core.parse_flags() expected_flags = "" self.assertEqual(flags_core.get_nondefault_flags_as_str(), expected_flags) flags.FLAGS.clean = False expected_flags += "--noclean" self.assertEqual(flags_core.get_nondefault_flags_as_str(), expected_flags) flags.FLAGS.data_dir = "xyz" expected_flags += " --data_dir=xyz" self.assertEqual(flags_core.get_nondefault_flags_as_str(), expected_flags) flags.FLAGS.hooks = ["aaa", "bbb", "ccc"] expected_flags += " --hooks=aaa,bbb,ccc" self.assertEqual(flags_core.get_nondefault_flags_as_str(), expected_flags) flags.FLAGS.stop_threshold = 3. expected_flags += " --stop_threshold=3.0" self.assertEqual(flags_core.get_nondefault_flags_as_str(), expected_flags) flags.FLAGS.use_synthetic_data = True expected_flags += " --use_synthetic_data" self.assertEqual(flags_core.get_nondefault_flags_as_str(), expected_flags) # Assert that explicit setting a flag to its default value does not cause it # to appear in the string flags.FLAGS.use_synthetic_data = False expected_flags = expected_flags[:-len(" --use_synthetic_data")] self.assertEqual(flags_core.get_nondefault_flags_as_str(), expected_flags)
def run_synthetic(main, tmp_root, extra_flags=None, synth=True, train_epochs=1, epochs_between_evals=1): """Performs a minimal run of a model. This function is intended to test for syntax errors throughout a model. A very limited run is performed using synthetic data. Args: main: The primary function used to exercise a code path. Generally this function is "<MODULE>.main(argv)". tmp_root: Root path for the temp directory created by the test class. extra_flags: Additional flags passed by the caller of this function. synth: Use synthetic data. train_epochs: Value of the --train_epochs flag. epochs_between_evals: Value of the --epochs_between_evals flag. """ extra_flags = [] if extra_flags is None else extra_flags model_dir = tempfile.mkdtemp(dir=tmp_root) args = [sys.argv[0], "--model_dir", model_dir] + extra_flags if synth: args.append("--use_synthetic_data") if train_epochs is not None: args.extend(["--train_epochs", str(train_epochs)]) if epochs_between_evals is not None: args.extend(["--epochs_between_evals", str(epochs_between_evals)]) try: flags_core.parse_flags(argv=args) main(flags.FLAGS) finally: if os.path.exists(model_dir): shutil.rmtree(model_dir)
def test_default_setting(self): """Test to ensure fields exist and defaults can be set. """ defaults = dict(data_dir="dfgasf", model_dir="dfsdkjgbs", train_epochs=534, epochs_between_evals=15, batch_size=256, hooks=["LoggingTensorHook"], num_parallel_calls=18, inter_op_parallelism_threads=5, intra_op_parallelism_threads=10, data_format="channels_first") flags_core.set_defaults(**defaults) flags_core.parse_flags() for key, value in defaults.items(): assert flags.FLAGS.get_flag_value(name=key, default=None) == value
def test_parse_dtype_info(self): flags_core.parse_flags([__file__, "--dtype", "fp16"]) self.assertEqual(flags_core.get_tf_dtype(flags.FLAGS), tf.float16) self.assertEqual( flags_core.get_loss_scale(flags.FLAGS, default_for_fp16=2), 2) flags_core.parse_flags( [__file__, "--dtype", "fp16", "--loss_scale", "5"]) self.assertEqual( flags_core.get_loss_scale(flags.FLAGS, default_for_fp16=2), 5) flags_core.parse_flags( [__file__, "--dtype", "fp16", "--loss_scale", "dynamic"]) self.assertEqual( flags_core.get_loss_scale(flags.FLAGS, default_for_fp16=2), "dynamic") flags_core.parse_flags([__file__, "--dtype", "fp32"]) self.assertEqual(flags_core.get_tf_dtype(flags.FLAGS), tf.float32) self.assertEqual( flags_core.get_loss_scale(flags.FLAGS, default_for_fp16=2), 1) flags_core.parse_flags( [__file__, "--dtype", "fp32", "--loss_scale", "5"]) self.assertEqual( flags_core.get_loss_scale(flags.FLAGS, default_for_fp16=2), 5) with self.assertRaises(SystemExit): flags_core.parse_flags([__file__, "--dtype", "int8"]) with self.assertRaises(SystemExit): flags_core.parse_flags( [__file__, "--dtype", "fp16", "--loss_scale", "abc"])