def test_config_benchmark_file_logger(self): # Set the benchmark_log_dir first since the benchmark_logger_type will need # the value to be set when it does the validation. with flagsaver.flagsaver(benchmark_log_dir='/tmp'): with flagsaver.flagsaver(benchmark_logger_type='BenchmarkFileLogger'): logger.config_benchmark_logger() self.assertIsInstance(logger.get_benchmark_logger(), logger.BenchmarkFileLogger)
def test_run_training(self, targets): """Tests whether the training loop can be run successfully. Generates test input files and runs the main driving code. Args: targets: the targets to train on. """ # Create test input and metadata files. num_examples, read_len = 20, 5 train_file = test_utils.create_tmp_train_file(num_examples, read_len) metadata_path = test_utils.create_tmp_metadata(num_examples, read_len) # Check that the training loop runs as expected. logdir = os.path.join(FLAGS.test_tmpdir, 'train:{}'.format(len(targets))) with flagsaver.flagsaver( train_files=train_file, metadata_path=metadata_path, targets=targets, logdir=logdir, hparams='train_steps=10,min_read_length=5', batch_size=10): run_training.main(FLAGS) # Check training loop ran by confirming existence of a checkpoint file. self.assertIsNotNone(tf.train.latest_checkpoint(FLAGS.logdir)) # Check training loop ran by confiming existence of a measures file. self.assertTrue( os.path.exists(os.path.join(FLAGS.logdir, 'measures.pbtxt')))
def test_run(self): hparams = deep_q_networks.get_hparams( replay_buffer_size=100, num_episodes=10, batch_size=10, update_frequency=1, save_frequency=1, dense_layers=[32], fingerprint_length=128, fingerprint_radius=2, num_bootstrap_heads=12, prioritized=True, double_q=True) hparams_file = os.path.join(self.mount_point, 'config.json') core.write_hparams(hparams, hparams_file) with flagsaver.flagsaver(model_dir=self.model_dir, hparams=hparams_file): optimize_logp_of_800_molecules.main(None)
def testCentos7OnStartup(self, flag_disable_yum_cron, additional_command): vm = CreateCentos7Vm() mock_remote = mock.Mock(return_value=('', '')) vm.RemoteHostCommand = mock_remote # pylint: disable=invalid-name if flag_disable_yum_cron is not None: with flagsaver.flagsaver(disable_yum_cron=flag_disable_yum_cron): vm.OnStartup() else: # tests the default value of the flag vm.OnStartup() common_call = ("echo 'Defaults:perfkit !requiretty' | " 'sudo tee /etc/sudoers.d/pkb') calls = [mock.call(common_call, login_shell=True)] if additional_command: calls.append(additional_command) mock_remote.assert_has_calls(calls)
def test_multi_objective_dqn(self): hparams = deep_q_networks.get_hparams(replay_buffer_size=100, num_episodes=10, batch_size=10, update_frequency=1, save_frequency=1, dense_layers=[32], fingerprint_length=128, num_bootstrap_heads=0, prioritized=False, double_q=False, fingerprint_radius=2) hparams_file = os.path.join(self.mount_point, 'config.json') core.write_hparams(hparams, hparams_file) with flagsaver.flagsaver(model_dir=self.model_dir, hparams=hparams_file): run_dqn.run_dqn(True)
def testWget_NoAdminURLSpecified(self): data_store.REL_DB.WriteClientMetadata(client_id=VfsTest.FAKE_CLIENT_ID, fleetspeak_enabled=False) api_client = self._get_fake_api_client() client = grr_colab.Client(api_client) vfs_obj = vfs.VFS(api_client, jobs_pb2.PathSpec.OS) with flagsaver.flagsaver(grr_admin_ui_url=''): with temp.AutoTempFilePath() as temp_file: with io.open(temp_file, 'wb') as filedesc: filedesc.write(b'foo bar') with client.open(temp_file): pass with self.assertRaises(ValueError): vfs_obj.wget(temp_file)
def _run_main(self, **kwargs): subprocess.run(['git', 'add', '*.pbtxt', 'data/*/*.pbtxt'], check=True) changed = subprocess.run(['git', 'diff', '--name-status', '--staged'], check=True, capture_output=True) with open('changed.txt', 'wb') as f: f.write(changed.stdout) subprocess.run(['git', 'commit', '-m', 'Submission'], check=True) run_flags = { 'input_file': 'changed.txt', 'update': True, 'cleanup': True } run_flags.update(kwargs) with flagsaver.flagsaver(**run_flags): process_dataset.main(()) return glob.glob(os.path.join(self.test_subdirectory, '**/*.pbtxt'), recursive=True)
def setUpClass(cls): # create training data output_path = FLAGS.test_tmpdir output_name = 'temp' with flagsaver.flagsaver( dataset_path=output_path, dataset_name=output_name, dataset_type='time_evolution', num_shards=1, total_time_steps=5, example_num_time_steps=8, time_step_interval=1, num_seeds=4): create_training_data.main([], runner=beam.runners.DirectRunner()) metadata_path = '{}/{}.metadata.json'.format(output_path, output_name) cls.metadata = readers.load_metadata(metadata_path) super(IntegrationTest, cls).setUpClass()
def test_end_to_end(self, distribution_strategy, flag_mode, run_post_eval): model_dir = self.get_temp_dir() flags_dict = dict(experiment='mock', mode=flag_mode, model_dir=model_dir, params_override=json.dumps(self._test_config)) with flagsaver.flagsaver(**flags_dict): params = train_utils.parse_configuration(flags.FLAGS) train_utils.serialize_config(params, model_dir) with distribution_strategy.scope(): task = task_factory.get_task(params.task, logging_dir=model_dir) _, logs = train_lib.run_experiment( distribution_strategy=distribution_strategy, task=task, mode=flag_mode, params=params, model_dir=model_dir, run_post_eval=run_post_eval) if 'eval' in flag_mode: self.assertTrue( tf.io.gfile.exists( os.path.join(model_dir, params.trainer.validation_summary_subdir))) if run_post_eval: self.assertNotEmpty(logs) else: self.assertEmpty(logs) self.assertNotEmpty( tf.io.gfile.glob(os.path.join(model_dir, 'params.yaml'))) if flag_mode == 'eval': return self.assertNotEmpty( tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint'))) # Tests continuous evaluation. _, logs = train_lib.run_experiment( distribution_strategy=distribution_strategy, task=task, mode='continuous_eval', params=params, model_dir=model_dir, run_post_eval=run_post_eval)
def test_evaluate_repeat(self, mock_eval): """Tests evaluate with repeats.""" mock_eval.train_and_evaluate.return_value = 'unused_output' # Create extra directories not created in setUp for repeat_2 for model_id in self.toy_data: eval_dir = os.path.join(self.output_dir, 'ab', model_id, 'repeat_2') tf.io.gfile.makedirs(eval_dir) with flagsaver.flagsaver(num_repeats=2): evaluator = run_evaluation.Evaluator(self.models_file, self.output_dir) evaluator.run_evaluation() expected_dir = os.path.join(self.output_dir, 'ab') mock_eval.train_and_evaluate.assert_has_calls([ tf.compat.v1.test.mock.call( tf.compat.v1.test.mock.ANY, tf.compat.v1.test.mock.ANY, os.path.join(expected_dir, 'abc', 'repeat_1')), tf.compat.v1.test.mock.call( tf.compat.v1.test.mock.ANY, tf.compat.v1.test.mock.ANY, os.path.join(expected_dir, 'abd', 'repeat_1')), tf.compat.v1.test.mock.call( tf.compat.v1.test.mock.ANY, tf.compat.v1.test.mock.ANY, os.path.join(expected_dir, 'abe', 'repeat_1')), tf.compat.v1.test.mock.call( tf.compat.v1.test.mock.ANY, tf.compat.v1.test.mock.ANY, os.path.join(expected_dir, 'abc', 'repeat_2')), tf.compat.v1.test.mock.call( tf.compat.v1.test.mock.ANY, tf.compat.v1.test.mock.ANY, os.path.join(expected_dir, 'abd', 'repeat_2')), tf.compat.v1.test.mock.call( tf.compat.v1.test.mock.ANY, tf.compat.v1.test.mock.ANY, os.path.join(expected_dir, 'abe', 'repeat_2')) ]) for model_id in self.toy_data: for repeat in range(2): self.assertTrue( tf.io.gfile.exists( os.path.join(expected_dir, model_id, 'repeat_%d' % (repeat + 1), 'results.json')))
def test_train_and_eval(self, use_document_interaction): # `create_tempdir` creates an isolated directory for this test, and # will be properly cleaned up by the test. data_dir = self.create_tempdir() data_file = os.path.join(data_dir, "elwc.tfrecord") vocab_file = os.path.join(data_dir, "vocab.txt") model_dir = os.path.join(data_dir, "model") # Save data. with tf.io.TFRecordWriter(data_file) as writer: for _ in range(10): writer.write(ELWC.SerializeToString()) with tf.io.gfile.GFile(vocab_file, "w") as writer: writer.write("\n".join(VOCAB) + "\n") with flagsaver.flagsaver( train_input_pattern=data_file, eval_input_pattern=data_file, test_input_pattern=data_file, model_dir=model_dir, vocab_file_path=vocab_file, train_batch_size=2, eval_batch_size=2, num_epochs=1, vocab_size=len(VOCAB), hidden_layer_dims=[16, 8], num_train_steps=2, use_document_interaction=use_document_interaction): antique_kpl_din.train_and_eval() # Check SavedModel is exported correctly. saved_model_path = os.path.join(model_dir, "export/") self.assertTrue(tf.saved_model.contains_saved_model(saved_model_path)) # Check SavedModel can be loaded and called correctly. saved_model = tf.keras.models.load_model(saved_model_path) elwc_predictor = saved_model.signatures[tf.saved_model.PREDICT_METHOD_NAME] listwise_logits = elwc_predictor( tf.convert_to_tensor( [ELWC.SerializeToString(), ELWC.SerializeToString()]))[tf.saved_model.PREDICT_OUTPUTS] self.assertAllEqual([2, 2], listwise_logits.get_shape().as_list())
def testEndToEndSuccess(self, use_tf_data_input): if use_tf_data_input and (LooseVersion(tf.__version__) < LooseVersion("2.5.0")): return logging.info("Using testdata in %s", self.get_temp_dir()) avg_model_dir = self._export_global_average_model() image_dir = self._write_cmy_dataset() saved_model_dir = os.path.join(self.get_temp_dir(), "final_saved_model") saved_model_expected_file = os.path.join(saved_model_dir, "saved_model.pb") tflite_output_file = os.path.join(self.get_temp_dir(), "final_model.tflite") labels_output_file = os.path.join(self.get_temp_dir(), "labels.txt") # Make sure we don't test for pre-existing files. self.assertFalse(os.path.isfile(saved_model_expected_file)) self.assertFalse(os.path.isfile(tflite_output_file)) self.assertFalse(os.path.isfile(labels_output_file)) with flagsaver.flagsaver( image_dir=image_dir, tfhub_module=avg_model_dir, # This dataset is expected to be fit perfectly. assert_accuracy_at_least=0.9, use_tf_data_input=use_tf_data_input, saved_model_dir=saved_model_dir, tflite_output_file=tflite_output_file, labels_output_file=labels_output_file, **self.DEFAULT_FLAGS): make_image_classifier.main([]) # Test that the SavedModel was written. self.assertTrue(os.path.isfile(saved_model_expected_file)) # Test that the TFLite model works. labels = self._load_labels(labels_output_file) lite_model = self._load_lite_model(tflite_output_file) for class_name, rgb in self.CMY_NAMES_AND_RGB_VALUES: input_batch = (_fill_image(rgb, self.IMAGE_SIZE)[None, ...] / np.array(255., dtype=np.float32)) output_batch = lite_model(input_batch) prediction = labels[np.argmax(output_batch[0])] self.assertEqual(class_name, prediction)
def test(self): input_path = self.create_tempdir('source').full_path output_path = self.create_tempdir('destination').full_path input_ds = test_util.dummy_era5_surface_dataset(times=90*24, freq='1H') input_ds.chunk({'time': 31}).to_zarr(input_path) expected = input_ds.groupby('time.month').apply( lambda x: x.groupby('time.hour').mean('time') ) with flagsaver.flagsaver( input_path=input_path, output_path=output_path, ): era5_climatology.main([]) actual = xarray.open_zarr(output_path) xarray.testing.assert_allclose(actual, expected)
def test_train_multiple_clients(self, disable_parallel, expected_num_examples): with flagsaver.flagsaver( fedjax_experimental_disable_parallel=disable_parallel): init_state = self.init_state() states = client_trainer.train_multiple_clients( federated_data=self._federated_data, client_ids=self._federated_data.client_ids, client_trainer=self._trainer, init_client_trainer_state=init_state, rng_seq=self._federated_algorithm._rng_seq, client_data_hparams=self._client_data_hparams) states = list(states) self.assertLen(states, 5) for s in states: self.assertEqual(s.num_examples, expected_num_examples) jax.tree_multimap(self.assertAllEqual, s.control_variate, init_state.control_variate)
def test_train_with_config_file(self, hparams_config_filename): with flagsaver.flagsaver( num_eval_steps=1, eval_batch_size=1, max_target_length=1, eval_dataset_name='mock_data', max_eval_target_length=1, max_predict_length=1, model_dir=FLAGS.test_tmpdir, ): hparams = os_hparams_utils.load_dataclass_from_config_dict( training_hparams.TrainingHParams, hparams_config_filename.get_config()) with concurrent.futures.ThreadPoolExecutor(max_workers=1) as io_executor: train.run_training( datasets=self.datasets, hparams=hparams, io_executor=io_executor, )
def test(self): output_path = os.path.join(FLAGS.test_tmpdir, 'temp.h5') # run the beam job with flagsaver.flagsaver(output_path=output_path, equation_name='burgers', equation_kwargs='{"num_points": 400}', num_tasks=2, time_max=1.0, time_delta=0.1, warmup=0): create_training_data.main([]) # verify the results with utils.read_h5py(output_path) as f: data = f['v'][...] metadata = dict(f.attrs) self.assertEqual(data.shape, (20, 400)) self.assertEqual(metadata, {'num_points': 400})
def test_experiment_run(self): # Stores all flags defined in bert_ranking.py. with flagsaver.flagsaver(train_input_pattern=self._train_file, eval_input_pattern=self._eval_file, learning_rate=0.001, train_batch_size=2, eval_batch_size=2, model_dir=self._model_dir, num_train_steps=5, num_eval_steps=2, loss="softmax_loss", local_training=True, list_size=5, dropout_rate=0.1, bert_config_file=self._bert_config_file, bert_init_ckpt=self._bert_init_ckpt, bert_max_seq_length=self._bert_max_seq_length, bert_num_warmup_steps=1): tfrbert_example.train_and_eval()
def test(self): output_path = os.path.join(FLAGS.test_tmpdir, 'temp.nc') # run the beam job with flagsaver.flagsaver(output_path=output_path, equation_name='burgers', equation_kwargs='{"num_points": 400}', num_samples=2, accuracy_orders=[1, 3, 5], time_max=1.0, time_delta=0.1, warmup=0): create_baseline_data.main([]) # verify the results with xarray.open_dataset(output_path) as ds: self.assertEqual(ds['y'].dims, ('sample', 'accuracy_order', 'time', 'x')) self.assertEqual(ds['y'].shape, (2, 3, 10, 400))
def test_get(self): with self.subTest('default'): self.assertEqual( self.FEDERATED_EXPERIMENT_CONFIG.get(), federated_experiment.FederatedExperimentConfig( root_dir='foo', num_rounds=1234)) with self.subTest('custom'): with flagsaver.flagsaver(root_dir='bar', num_rounds=567, checkpoint_frequency=2, num_checkpoints_to_keep=3, eval_frequency=4): self.assertEqual( self.FEDERATED_EXPERIMENT_CONFIG.get(), federated_experiment.FederatedExperimentConfig( root_dir='bar', num_rounds=567, checkpoint_frequency=2, num_checkpoints_to_keep=3, eval_frequency=4))
def test_train_mnist(self): # Create the random data and write it to the disk. test_subdirectory = tempfile.mkdtemp(dir=FLAGS.test_tmpdir) # Create the model parameters. model_path = os.path.join(test_subdirectory, 'temp_model') with flagsaver.flagsaver( model_path=model_path, save_period=1, num_dense_units='4,4', epochs=1, learning_rate=0.1, dropout=0.0, batch_size=32): train_mnist.main(argv=()) # Verify that the trained model was saved. self.assertTrue(gfile.Exists(os.path.join(model_path, 'test_accuracy.txt'))) self.assertLen(gfile.Glob(os.path.join(model_path, 'weights_epoch*')), 1)
def testWget_LinkWorksWithOfflineClient(self): data_store.REL_DB.WriteClientMetadata(client_id=VfsTest.FAKE_CLIENT_ID, fleetspeak_enabled=False) api_client = self._get_fake_api_client() client = grr_colab.Client(api_client) vfs_obj = vfs.VFS(api_client, jobs_pb2.PathSpec.OS) content = b'foo bar' with flagsaver.flagsaver(grr_admin_ui_url=self.endpoint): with temp.AutoTempFilePath() as temp_file: with io.open(temp_file, 'wb') as filedesc: filedesc.write(content) with client.open(temp_file): pass link = vfs_obj.wget(temp_file) self.assertEqual(requests.get(link).content, content)
def testProgramSchedule(self, dataset_list_override=None, train_executions_per_eval_override=None, eval_steps_per_loop_override=None, decode_steps_per_loop_override=None): with flagsaver.flagsaver( executor_datasets_to_eval=dataset_list_override, executor_train_executions_per_eval= train_executions_per_eval_override, executor_eval_steps_per_loop=eval_steps_per_loop_override, executor_decode_steps_per_loop=decode_steps_per_loop_override): ps_params = model_registry.GetProgramSchedule('test.DummyModel') if dataset_list_override is not None: self.assertAllEqual(ps_params.dataset_names, dataset_list_override.split(';')) else: self.assertAllEqual(ps_params.dataset_names, ['Dev', 'Test']) if train_executions_per_eval_override is not None: self.assertEqual(ps_params.train_executions_per_eval, train_executions_per_eval_override) else: self.assertEqual(ps_params.train_executions_per_eval, 0) # Assume only Dev and Test are avaiable eval datasets. eval_dev, eval_test, decode_dev, decode_test = 0, 0, 0, 0 if dataset_list_override is None or 'Dev' in dataset_list_override: if eval_steps_per_loop_override != 0: eval_dev += 1 if decode_steps_per_loop_override != 0: decode_dev += 1 if dataset_list_override is None or 'Test' in dataset_list_override: if eval_steps_per_loop_override != 0: eval_test += 1 if decode_steps_per_loop_override != 0: decode_test += 1 self.assertLen(ps_params.eval_programs, eval_dev + decode_dev + eval_test + decode_test) self._CheckProgramParams(ps_params.eval_programs, eval_dev, eval_test, decode_dev, decode_test)
def test_restore_text(self): testdata_dir = (pathlib.Path(absltest.get_default_test_srcdir()) / TESTDATA_DIR) input_file = testdata_dir / "example_mentions_no_text.jsonl" expected_file = testdata_dir / "example_mentions.jsonl" output_dir = self.create_tempdir() output_file = pathlib.Path(output_dir.full_path) / "restored.jsonl" with flagsaver.flagsaver(input=str(input_file), index_dir=str(testdata_dir), output=str(output_file)): restore_text.main([]) self.assertTrue(output_file.exists(), msg=str(output_file)) # Compare the dataclass representations of the output and expected files # rather than their string contents to be robust against fluctuations in the # serialization (e.g. due to dictionary order, etc). got = schema.load_jsonl(output_file, schema.ContextualMentions) expected = schema.load_jsonl(expected_file, schema.ContextualMentions) self.assertEqual(got, expected)
def test_train_and_eval(self, listwise_inference): tmp_dir = self.create_tempdir() data_file = os.path.join(tmp_dir, "elwc.tfrecord") if tf.io.gfile.exists(data_file): tf.io.gfile.remove(data_file) with tf.io.TFRecordWriter(data_file) as writer: for elwc in [ELWC] * 10: writer.write(elwc.SerializeToString()) model_dir = os.path.join(tmp_dir, "model") with flagsaver.flagsaver(train_path=data_file, eval_path=data_file, data_format="example_list_with_context", model_dir=model_dir, num_train_steps=10, listwise_inference=listwise_inference, group_size=1, weights_feature_name="doc_weight"): tf_ranking_tfrecord.train_and_eval()
def test_train_and_eval(self): data_dir = tf.compat.v1.test.get_temp_dir() data_file = os.path.join(data_dir, "elwc.tfrecord") if tf.io.gfile.exists(data_file): tf.io.gfile.remove(data_file) with tf.io.TFRecordWriter(data_file) as writer: for elwc in [ELWC] * 10: writer.write(elwc.SerializeToString()) model_dir = os.path.join(data_dir, "model") with flagsaver.flagsaver(train_input_pattern=data_file, eval_input_pattern=data_file, model_dir=model_dir, num_train_steps=10, list_size=2): pipeline_example.train_and_eval() if tf.io.gfile.exists(model_dir): tf.io.gfile.rmtree(model_dir)
def test_get_final_aggregated_documents(self): new_flags = get_updated_default_flags(num_documents_to_retrieve=3, use_aggregated_documents=1) with flagsaver.flagsaver(**new_flags): # stop_after_seeing_new_results with multiple steps in the history nqenv = create_env(max_history_entries=3) nqenv.stop_after_seeing_new_results = True final_documents = nqenv.get_final_document_list() self.assertLen(final_documents, 3) self.assertEqual(final_documents[0].content, '4') self.assertEqual(final_documents[1].content, '3') self.assertEqual(final_documents[2].content, '2') # stop_after_seeing_new_results=False with multiple steps in the history nqenv.stop_after_seeing_new_results = False final_documents = nqenv.get_final_document_list() self.assertLen(final_documents, 3) self.assertEqual(final_documents[0].content, '5') self.assertEqual(final_documents[1].content, '4') self.assertEqual(final_documents[2].content, '3')
def test(self): input_path = self.create_tempdir('source').full_path output_path = self.create_tempdir('destination').full_path input_ds = test_util.dummy_era5_surface_dataset(times=365) input_ds.chunk({'time': 31}).to_zarr(input_path) with flagsaver.flagsaver( input_path=input_path, output_path=output_path, ): era5_rechunk.main([]) output_ds = xarray.open_zarr(output_path) self.assertEqual({k: v[0] for k, v in output_ds.chunks.items()}, { 'latitude': 5, 'longitude': 5, 'time': 365 }) xarray.testing.assert_identical(input_ds, output_ds)
def testGenerateDataSample(self, level, data_format, data_percent, uniform_domain_distribution): temp_output = os.path.join(self.create_tempdir(), 'output') ref_output = os.path.join( FLAGS.test_srcdir, TEST_DIR, f'sgd_text_v2_uniform_{uniform_domain_distribution}_{data_percent}' ) with flagsaver.flagsaver( level=level, delimiter='=', data_format=data_format, sgd_file=os.path.join(FLAGS.test_srcdir, TEST_DIR, 'sgd_train.json'), schema_file=os.path.join(FLAGS.test_srcdir, TEST_DIR, 'sgd_train_schema.json'), output_file=temp_output, randomize_items=False, data_percent=data_percent, uniform_domain_distribution=uniform_domain_distribution): slots, item_desc = create_sgd_schemaless_data.load_schema() create_sgd_schemaless_data.generate_data(slots, item_desc) self.assertTrue(filecmp.cmp(temp_output, ref_output))
def test_recovery(self, distribution_strategy, flag_mode): loss_threshold = 1.0 model_dir = self.get_temp_dir() flags_dict = dict(experiment='mock', mode=flag_mode, model_dir=model_dir, params_override=json.dumps(self._test_config)) with flagsaver.flagsaver(**flags_dict): params = train_utils.parse_configuration(flags.FLAGS) params.trainer.loss_upper_bound = loss_threshold params.trainer.recovery_max_trials = 1 train_utils.serialize_config(params, model_dir) with distribution_strategy.scope(): task = task_factory.get_task(params.task, logging_dir=model_dir) # Saves a checkpoint for reference. model = task.build_model() checkpoint = tf.train.Checkpoint(model=model) checkpoint_manager = tf.train.CheckpointManager( checkpoint, self.get_temp_dir(), max_to_keep=2) checkpoint_manager.save() before_weights = model.get_weights() def build_losses(labels, model_outputs, aux_losses=None): del labels, model_outputs return tf.constant([loss_threshold], tf.float32) + aux_losses task.build_losses = build_losses model, _ = train_lib.run_experiment( distribution_strategy=distribution_strategy, task=task, mode=flag_mode, params=params, model_dir=model_dir) after_weights = model.get_weights() for left, right in zip(before_weights, after_weights): self.assertAllEqual(left, right)
def test_main(self): input_pattern = os.path.join(self.test_subdirectory, '*.pbtxt') output_dir = os.path.join(self.test_subdirectory, 'tables') with flagsaver.flagsaver(input=input_pattern, output=output_dir, database=True, cleanup=False): build_database.main(()) with open(os.path.join(output_dir, 'reactions.csv')) as f: df = pd.read_csv(f) # NOTE(kearnes): Map keys are not always serialized in the same order. df['deserialized'] = df.serialized.apply( lambda x: reaction_pb2.Reaction.FromString(bytes.fromhex(x))) del df['serialized'] pd.testing.assert_frame_equal( df, pd.DataFrame({ 'reaction_id': ['test'], 'reaction_smiles': ['reaction'], 'deserialized': [self.dataset.reactions[0]], })) with open(os.path.join(output_dir, 'inputs.csv')) as f: df = pd.read_csv(f) pd.testing.assert_frame_equal( df, pd.DataFrame({ 'reaction_id': ['test', 'test', 'test'], 'smiles': ['input1', 'input2a', 'input2b'], })) with open(os.path.join(output_dir, 'outputs.csv')) as f: df = pd.read_csv(f) pd.testing.assert_frame_equal( df, pd.DataFrame({ 'reaction_id': ['test'], 'smiles': ['product'], 'yield': [2.5] }))
def test_basic_functionality(self, mock_requests): # Set up the mock REST endpoint. def match_request_text(request): return 'instances' in (request.text or '') mock_requests.post('http://foo.com', request_headers={ 'Content-Type': 'application/json; charset=utf-8' }, additional_matcher=match_request_text, text='response') with flagsaver.flagsaver(): FLAGS.target = 'http://foo.com' FLAGS.batch_size = 1 FLAGS.duration_ms = 1000 FLAGS.target_latency_percentile = 0.9 FLAGS.target_latency_ns = 100000000 FLAGS.performance_sample_count = 5 FLAGS.query_count = 10 FLAGS.total_sample_count = 5 loadgen_rest_main.main(None)
def test_train_and_eval(self, listwise_inference): data_dir = tf.compat.v1.test.get_temp_dir() data_file = os.path.join(data_dir, "elwc.tfrecord") if tf.io.gfile.exists(data_file): tf.io.gfile.remove(data_file) with tf.io.TFRecordWriter(data_file) as writer: for elwc in [ELWC] * 10: writer.write(elwc.SerializeToString()) model_dir = os.path.join(data_dir, "model") with flagsaver.flagsaver(train_path=data_file, eval_path=data_file, data_format="example_list_with_context", model_dir=model_dir, num_train_steps=10, listwise_inference=listwise_inference, group_size=1): tf_ranking_tfrecord.train_and_eval() if tf.io.gfile.exists(model_dir): tf.io.gfile.rmtree(model_dir)
def test_get_default_benchmark_logger(self): with flagsaver.flagsaver(benchmark_logger_type='foo'): self.assertIsInstance(logger.get_benchmark_logger(), logger.BaseBenchmarkLogger)
def test_config_benchmark_bigquery_logger(self, mock_bigquery_client): with flagsaver.flagsaver(benchmark_logger_type='BenchmarkBigQueryLogger'): logger.config_benchmark_logger() self.assertIsInstance(logger.get_benchmark_logger(), logger.BenchmarkBigQueryLogger)
def test_config_base_benchmark_logger(self): with flagsaver.flagsaver(benchmark_logger_type='BaseBenchmarkLogger'): logger.config_benchmark_logger() self.assertIsInstance(logger.get_benchmark_logger(), logger.BaseBenchmarkLogger)