def test_booleans(self): """Test to ensure boolean flags trigger as expected. """ flags_core.parse_flags([__file__, "--use_synthetic_data"]) assert flags.FLAGS.use_synthetic_data
def test_benchmark_setting(self): defaults = dict( hooks=["LoggingMetricHook"], benchmark_log_dir="/tmp/12345", gcp_project="project_abc", ) flags_core.set_defaults(**defaults) flags_core.parse_flags() for key, value in defaults.items(): assert flags.FLAGS.get_flag_value(name=key, default=None) == value
def test_default_setting(self): """Test to ensure fields exist and defaults can be set. """ defaults = dict(data_dir="dfgasf", model_dir="dfsdkjgbs", train_epochs=534, epochs_between_evals=15, batch_size=256, hooks=["LoggingTensorHook"], num_parallel_calls=18, inter_op_parallelism_threads=5, intra_op_parallelism_threads=10, data_format="channels_first") flags_core.set_defaults(**defaults) flags_core.parse_flags() for key, value in defaults.items(): assert flags.FLAGS.get_flag_value(name=key, default=None) == value
def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1): """Performs a minimal run of a model. This function is intended to test for syntax errors throughout a model. A very limited run is performed using synthetic data. Args: main: The primary function used to exercise a code path. Generally this function is "<MODULE>.main(argv)". tmp_root: Root path for the temp directory created by the test class. extra_flags: Additional flags passed by the caller of this function. synth: Use synthetic data. max_train: Maximum number of allowed training steps. """ extra_flags = [] if extra_flags is None else extra_flags model_dir = tempfile.mkdtemp(dir=tmp_root) args = [ sys.argv[0], "--model_dir", model_dir, "--train_epochs", "1", "--epochs_between_evals", "1" ] + extra_flags if synth: args.append("--use_synthetic_data") if max_train is not None: args.extend(["--max_train_steps", str(max_train)]) try: flags_core.parse_flags(argv=args) main(flags.FLAGS) finally: if os.path.exists(model_dir): shutil.rmtree(model_dir)
def test_parse_dtype_info(self): for dtype_str, tf_dtype, loss_scale in [["fp16", tf.float16, 128], ["fp32", tf.float32, 1]]: flags_core.parse_flags([__file__, "--dtype", dtype_str]) self.assertEqual(flags_core.get_tf_dtype(flags.FLAGS), tf_dtype) self.assertEqual(flags_core.get_loss_scale(flags.FLAGS), loss_scale) flags_core.parse_flags( [__file__, "--dtype", dtype_str, "--loss_scale", "5"]) self.assertEqual(flags_core.get_loss_scale(flags.FLAGS), 5) with self.assertRaises(SystemExit): flags_core.parse_flags([__file__, "--dtype", "int8"])