Exemplo n.º 1
0
 def test_config_benchmark_file_logger(self):
   # Set the benchmark_log_dir first since the benchmark_logger_type will need
   # the value to be set when it does the validation.
   with flagsaver.flagsaver(benchmark_log_dir='/tmp'):
     with flagsaver.flagsaver(benchmark_logger_type='BenchmarkFileLogger'):
       logger.config_benchmark_logger()
       self.assertIsInstance(logger.get_benchmark_logger(),
                             logger.BenchmarkFileLogger)
Exemplo n.º 2
0
  def test_run_training(self, targets):
    """Tests whether the training loop can be run successfully.

    Generates test input files and runs the main driving code.

    Args:
      targets: the targets to train on.
    """
    # Create test input and metadata files.
    num_examples, read_len = 20, 5
    train_file = test_utils.create_tmp_train_file(num_examples, read_len)
    metadata_path = test_utils.create_tmp_metadata(num_examples, read_len)

    # Check that the training loop runs as expected.
    logdir = os.path.join(FLAGS.test_tmpdir, 'train:{}'.format(len(targets)))
    with flagsaver.flagsaver(
        train_files=train_file,
        metadata_path=metadata_path,
        targets=targets,
        logdir=logdir,
        hparams='train_steps=10,min_read_length=5',
        batch_size=10):
      run_training.main(FLAGS)
      # Check training loop ran by confirming existence of a checkpoint file.
      self.assertIsNotNone(tf.train.latest_checkpoint(FLAGS.logdir))
      # Check training loop ran by confiming existence of a measures file.
      self.assertTrue(
          os.path.exists(os.path.join(FLAGS.logdir, 'measures.pbtxt')))
  def test_run(self):
    hparams = deep_q_networks.get_hparams(
        replay_buffer_size=100,
        num_episodes=10,
        batch_size=10,
        update_frequency=1,
        save_frequency=1,
        dense_layers=[32],
        fingerprint_length=128,
        fingerprint_radius=2,
        num_bootstrap_heads=12,
        prioritized=True,
        double_q=True)
    hparams_file = os.path.join(self.mount_point, 'config.json')
    core.write_hparams(hparams, hparams_file)

    with flagsaver.flagsaver(model_dir=self.model_dir, hparams=hparams_file):
      optimize_logp_of_800_molecules.main(None)
  def testCentos7OnStartup(self, flag_disable_yum_cron, additional_command):
    vm = CreateCentos7Vm()
    mock_remote = mock.Mock(return_value=('', ''))
    vm.RemoteHostCommand = mock_remote  # pylint: disable=invalid-name

    if flag_disable_yum_cron is not None:
      with flagsaver.flagsaver(disable_yum_cron=flag_disable_yum_cron):
        vm.OnStartup()
    else:
      # tests the default value of the flag
      vm.OnStartup()

    common_call = ("echo 'Defaults:perfkit !requiretty' | "
                   'sudo tee /etc/sudoers.d/pkb')
    calls = [mock.call(common_call, login_shell=True)]
    if additional_command:
      calls.append(additional_command)
    mock_remote.assert_has_calls(calls)
Exemplo n.º 5
0
    def test_multi_objective_dqn(self):
        hparams = deep_q_networks.get_hparams(replay_buffer_size=100,
                                              num_episodes=10,
                                              batch_size=10,
                                              update_frequency=1,
                                              save_frequency=1,
                                              dense_layers=[32],
                                              fingerprint_length=128,
                                              num_bootstrap_heads=0,
                                              prioritized=False,
                                              double_q=False,
                                              fingerprint_radius=2)
        hparams_file = os.path.join(self.mount_point, 'config.json')
        core.write_hparams(hparams, hparams_file)

        with flagsaver.flagsaver(model_dir=self.model_dir,
                                 hparams=hparams_file):
            run_dqn.run_dqn(True)
Exemplo n.º 6
0
    def testWget_NoAdminURLSpecified(self):
        data_store.REL_DB.WriteClientMetadata(client_id=VfsTest.FAKE_CLIENT_ID,
                                              fleetspeak_enabled=False)

        api_client = self._get_fake_api_client()
        client = grr_colab.Client(api_client)
        vfs_obj = vfs.VFS(api_client, jobs_pb2.PathSpec.OS)

        with flagsaver.flagsaver(grr_admin_ui_url=''):
            with temp.AutoTempFilePath() as temp_file:
                with io.open(temp_file, 'wb') as filedesc:
                    filedesc.write(b'foo bar')

                with client.open(temp_file):
                    pass

                with self.assertRaises(ValueError):
                    vfs_obj.wget(temp_file)
Exemplo n.º 7
0
 def _run_main(self, **kwargs):
     subprocess.run(['git', 'add', '*.pbtxt', 'data/*/*.pbtxt'], check=True)
     changed = subprocess.run(['git', 'diff', '--name-status', '--staged'],
                              check=True,
                              capture_output=True)
     with open('changed.txt', 'wb') as f:
         f.write(changed.stdout)
     subprocess.run(['git', 'commit', '-m', 'Submission'], check=True)
     run_flags = {
         'input_file': 'changed.txt',
         'update': True,
         'cleanup': True
     }
     run_flags.update(kwargs)
     with flagsaver.flagsaver(**run_flags):
         process_dataset.main(())
     return glob.glob(os.path.join(self.test_subdirectory, '**/*.pbtxt'),
                      recursive=True)
Exemplo n.º 8
0
  def setUpClass(cls):
    # create training data
    output_path = FLAGS.test_tmpdir
    output_name = 'temp'
    with flagsaver.flagsaver(
        dataset_path=output_path,
        dataset_name=output_name,
        dataset_type='time_evolution',
        num_shards=1,
        total_time_steps=5,
        example_num_time_steps=8,
        time_step_interval=1,
        num_seeds=4):
      create_training_data.main([], runner=beam.runners.DirectRunner())

    metadata_path = '{}/{}.metadata.json'.format(output_path, output_name)
    cls.metadata = readers.load_metadata(metadata_path)
    super(IntegrationTest, cls).setUpClass()
Exemplo n.º 9
0
    def test_end_to_end(self, distribution_strategy, flag_mode, run_post_eval):
        model_dir = self.get_temp_dir()
        flags_dict = dict(experiment='mock',
                          mode=flag_mode,
                          model_dir=model_dir,
                          params_override=json.dumps(self._test_config))
        with flagsaver.flagsaver(**flags_dict):
            params = train_utils.parse_configuration(flags.FLAGS)
            train_utils.serialize_config(params, model_dir)
            with distribution_strategy.scope():
                task = task_factory.get_task(params.task,
                                             logging_dir=model_dir)

            _, logs = train_lib.run_experiment(
                distribution_strategy=distribution_strategy,
                task=task,
                mode=flag_mode,
                params=params,
                model_dir=model_dir,
                run_post_eval=run_post_eval)

        if 'eval' in flag_mode:
            self.assertTrue(
                tf.io.gfile.exists(
                    os.path.join(model_dir,
                                 params.trainer.validation_summary_subdir)))
        if run_post_eval:
            self.assertNotEmpty(logs)
        else:
            self.assertEmpty(logs)
        self.assertNotEmpty(
            tf.io.gfile.glob(os.path.join(model_dir, 'params.yaml')))
        if flag_mode == 'eval':
            return
        self.assertNotEmpty(
            tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint')))
        # Tests continuous evaluation.
        _, logs = train_lib.run_experiment(
            distribution_strategy=distribution_strategy,
            task=task,
            mode='continuous_eval',
            params=params,
            model_dir=model_dir,
            run_post_eval=run_post_eval)
Exemplo n.º 10
0
    def test_evaluate_repeat(self, mock_eval):
        """Tests evaluate with repeats."""
        mock_eval.train_and_evaluate.return_value = 'unused_output'

        # Create extra directories not created in setUp for repeat_2
        for model_id in self.toy_data:
            eval_dir = os.path.join(self.output_dir, 'ab', model_id,
                                    'repeat_2')
            tf.io.gfile.makedirs(eval_dir)

        with flagsaver.flagsaver(num_repeats=2):
            evaluator = run_evaluation.Evaluator(self.models_file,
                                                 self.output_dir)
            evaluator.run_evaluation()

        expected_dir = os.path.join(self.output_dir, 'ab')
        mock_eval.train_and_evaluate.assert_has_calls([
            tf.compat.v1.test.mock.call(
                tf.compat.v1.test.mock.ANY, tf.compat.v1.test.mock.ANY,
                os.path.join(expected_dir, 'abc', 'repeat_1')),
            tf.compat.v1.test.mock.call(
                tf.compat.v1.test.mock.ANY, tf.compat.v1.test.mock.ANY,
                os.path.join(expected_dir, 'abd', 'repeat_1')),
            tf.compat.v1.test.mock.call(
                tf.compat.v1.test.mock.ANY, tf.compat.v1.test.mock.ANY,
                os.path.join(expected_dir, 'abe', 'repeat_1')),
            tf.compat.v1.test.mock.call(
                tf.compat.v1.test.mock.ANY, tf.compat.v1.test.mock.ANY,
                os.path.join(expected_dir, 'abc', 'repeat_2')),
            tf.compat.v1.test.mock.call(
                tf.compat.v1.test.mock.ANY, tf.compat.v1.test.mock.ANY,
                os.path.join(expected_dir, 'abd', 'repeat_2')),
            tf.compat.v1.test.mock.call(
                tf.compat.v1.test.mock.ANY, tf.compat.v1.test.mock.ANY,
                os.path.join(expected_dir, 'abe', 'repeat_2'))
        ])

        for model_id in self.toy_data:
            for repeat in range(2):
                self.assertTrue(
                    tf.io.gfile.exists(
                        os.path.join(expected_dir, model_id,
                                     'repeat_%d' % (repeat + 1),
                                     'results.json')))
Exemplo n.º 11
0
  def test_train_and_eval(self, use_document_interaction):
    # `create_tempdir` creates an isolated directory for this test, and
    # will be properly cleaned up by the test.
    data_dir = self.create_tempdir()
    data_file = os.path.join(data_dir, "elwc.tfrecord")
    vocab_file = os.path.join(data_dir, "vocab.txt")
    model_dir = os.path.join(data_dir, "model")

    # Save data.
    with tf.io.TFRecordWriter(data_file) as writer:
      for _ in range(10):
        writer.write(ELWC.SerializeToString())

    with tf.io.gfile.GFile(vocab_file, "w") as writer:
      writer.write("\n".join(VOCAB) + "\n")

    with flagsaver.flagsaver(
        train_input_pattern=data_file,
        eval_input_pattern=data_file,
        test_input_pattern=data_file,
        model_dir=model_dir,
        vocab_file_path=vocab_file,
        train_batch_size=2,
        eval_batch_size=2,
        num_epochs=1,
        vocab_size=len(VOCAB),
        hidden_layer_dims=[16, 8],
        num_train_steps=2,
        use_document_interaction=use_document_interaction):
      antique_kpl_din.train_and_eval()

      # Check SavedModel is exported correctly.
      saved_model_path = os.path.join(model_dir, "export/")
      self.assertTrue(tf.saved_model.contains_saved_model(saved_model_path))

    # Check SavedModel can be loaded and called correctly.
    saved_model = tf.keras.models.load_model(saved_model_path)
    elwc_predictor = saved_model.signatures[tf.saved_model.PREDICT_METHOD_NAME]
    listwise_logits = elwc_predictor(
        tf.convert_to_tensor(
            [ELWC.SerializeToString(),
             ELWC.SerializeToString()]))[tf.saved_model.PREDICT_OUTPUTS]
    self.assertAllEqual([2, 2], listwise_logits.get_shape().as_list())
    def testEndToEndSuccess(self, use_tf_data_input):
        if use_tf_data_input and (LooseVersion(tf.__version__) <
                                  LooseVersion("2.5.0")):
            return
        logging.info("Using testdata in %s", self.get_temp_dir())
        avg_model_dir = self._export_global_average_model()
        image_dir = self._write_cmy_dataset()
        saved_model_dir = os.path.join(self.get_temp_dir(),
                                       "final_saved_model")
        saved_model_expected_file = os.path.join(saved_model_dir,
                                                 "saved_model.pb")
        tflite_output_file = os.path.join(self.get_temp_dir(),
                                          "final_model.tflite")
        labels_output_file = os.path.join(self.get_temp_dir(), "labels.txt")
        # Make sure we don't test for pre-existing files.
        self.assertFalse(os.path.isfile(saved_model_expected_file))
        self.assertFalse(os.path.isfile(tflite_output_file))
        self.assertFalse(os.path.isfile(labels_output_file))

        with flagsaver.flagsaver(
                image_dir=image_dir,
                tfhub_module=avg_model_dir,
                # This dataset is expected to be fit perfectly.
                assert_accuracy_at_least=0.9,
                use_tf_data_input=use_tf_data_input,
                saved_model_dir=saved_model_dir,
                tflite_output_file=tflite_output_file,
                labels_output_file=labels_output_file,
                **self.DEFAULT_FLAGS):
            make_image_classifier.main([])

        # Test that the SavedModel was written.
        self.assertTrue(os.path.isfile(saved_model_expected_file))

        # Test that the TFLite model works.
        labels = self._load_labels(labels_output_file)
        lite_model = self._load_lite_model(tflite_output_file)
        for class_name, rgb in self.CMY_NAMES_AND_RGB_VALUES:
            input_batch = (_fill_image(rgb, self.IMAGE_SIZE)[None, ...] /
                           np.array(255., dtype=np.float32))
            output_batch = lite_model(input_batch)
            prediction = labels[np.argmax(output_batch[0])]
            self.assertEqual(class_name, prediction)
Exemplo n.º 13
0
  def test(self):
    input_path = self.create_tempdir('source').full_path
    output_path = self.create_tempdir('destination').full_path

    input_ds = test_util.dummy_era5_surface_dataset(times=90*24, freq='1H')
    input_ds.chunk({'time': 31}).to_zarr(input_path)

    expected = input_ds.groupby('time.month').apply(
        lambda x: x.groupby('time.hour').mean('time')
    )

    with flagsaver.flagsaver(
        input_path=input_path,
        output_path=output_path,
    ):
      era5_climatology.main([])

    actual = xarray.open_zarr(output_path)
    xarray.testing.assert_allclose(actual, expected)
Exemplo n.º 14
0
    def test_train_multiple_clients(self, disable_parallel,
                                    expected_num_examples):
        with flagsaver.flagsaver(
                fedjax_experimental_disable_parallel=disable_parallel):
            init_state = self.init_state()
            states = client_trainer.train_multiple_clients(
                federated_data=self._federated_data,
                client_ids=self._federated_data.client_ids,
                client_trainer=self._trainer,
                init_client_trainer_state=init_state,
                rng_seq=self._federated_algorithm._rng_seq,
                client_data_hparams=self._client_data_hparams)
            states = list(states)

            self.assertLen(states, 5)
            for s in states:
                self.assertEqual(s.num_examples, expected_num_examples)
                jax.tree_multimap(self.assertAllEqual, s.control_variate,
                                  init_state.control_variate)
Exemplo n.º 15
0
 def test_train_with_config_file(self, hparams_config_filename):
   with flagsaver.flagsaver(
       num_eval_steps=1,
       eval_batch_size=1,
       max_target_length=1,
       eval_dataset_name='mock_data',
       max_eval_target_length=1,
       max_predict_length=1,
       model_dir=FLAGS.test_tmpdir,
   ):
     hparams = os_hparams_utils.load_dataclass_from_config_dict(
         training_hparams.TrainingHParams,
         hparams_config_filename.get_config())
     with concurrent.futures.ThreadPoolExecutor(max_workers=1) as io_executor:
       train.run_training(
           datasets=self.datasets,
           hparams=hparams,
           io_executor=io_executor,
       )
Exemplo n.º 16
0
    def test(self):
        output_path = os.path.join(FLAGS.test_tmpdir, 'temp.h5')

        # run the beam job
        with flagsaver.flagsaver(output_path=output_path,
                                 equation_name='burgers',
                                 equation_kwargs='{"num_points": 400}',
                                 num_tasks=2,
                                 time_max=1.0,
                                 time_delta=0.1,
                                 warmup=0):
            create_training_data.main([])

        # verify the results
        with utils.read_h5py(output_path) as f:
            data = f['v'][...]
            metadata = dict(f.attrs)
        self.assertEqual(data.shape, (20, 400))
        self.assertEqual(metadata, {'num_points': 400})
Exemplo n.º 17
0
 def test_experiment_run(self):
     # Stores all flags defined in bert_ranking.py.
     with flagsaver.flagsaver(train_input_pattern=self._train_file,
                              eval_input_pattern=self._eval_file,
                              learning_rate=0.001,
                              train_batch_size=2,
                              eval_batch_size=2,
                              model_dir=self._model_dir,
                              num_train_steps=5,
                              num_eval_steps=2,
                              loss="softmax_loss",
                              local_training=True,
                              list_size=5,
                              dropout_rate=0.1,
                              bert_config_file=self._bert_config_file,
                              bert_init_ckpt=self._bert_init_ckpt,
                              bert_max_seq_length=self._bert_max_seq_length,
                              bert_num_warmup_steps=1):
         tfrbert_example.train_and_eval()
Exemplo n.º 18
0
    def test(self):
        output_path = os.path.join(FLAGS.test_tmpdir, 'temp.nc')

        # run the beam job
        with flagsaver.flagsaver(output_path=output_path,
                                 equation_name='burgers',
                                 equation_kwargs='{"num_points": 400}',
                                 num_samples=2,
                                 accuracy_orders=[1, 3, 5],
                                 time_max=1.0,
                                 time_delta=0.1,
                                 warmup=0):
            create_baseline_data.main([])

        # verify the results
        with xarray.open_dataset(output_path) as ds:
            self.assertEqual(ds['y'].dims,
                             ('sample', 'accuracy_order', 'time', 'x'))
            self.assertEqual(ds['y'].shape, (2, 3, 10, 400))
Exemplo n.º 19
0
 def test_get(self):
     with self.subTest('default'):
         self.assertEqual(
             self.FEDERATED_EXPERIMENT_CONFIG.get(),
             federated_experiment.FederatedExperimentConfig(
                 root_dir='foo', num_rounds=1234))
     with self.subTest('custom'):
         with flagsaver.flagsaver(root_dir='bar',
                                  num_rounds=567,
                                  checkpoint_frequency=2,
                                  num_checkpoints_to_keep=3,
                                  eval_frequency=4):
             self.assertEqual(
                 self.FEDERATED_EXPERIMENT_CONFIG.get(),
                 federated_experiment.FederatedExperimentConfig(
                     root_dir='bar',
                     num_rounds=567,
                     checkpoint_frequency=2,
                     num_checkpoints_to_keep=3,
                     eval_frequency=4))
Exemplo n.º 20
0
  def test_train_mnist(self):
    # Create the random data and write it to the disk.
    test_subdirectory = tempfile.mkdtemp(dir=FLAGS.test_tmpdir)

    # Create the model parameters.
    model_path = os.path.join(test_subdirectory, 'temp_model')

    with flagsaver.flagsaver(
        model_path=model_path,
        save_period=1,
        num_dense_units='4,4',
        epochs=1,
        learning_rate=0.1,
        dropout=0.0,
        batch_size=32):
      train_mnist.main(argv=())

    # Verify that the trained model was saved.
    self.assertTrue(gfile.Exists(os.path.join(model_path, 'test_accuracy.txt')))
    self.assertLen(gfile.Glob(os.path.join(model_path, 'weights_epoch*')), 1)
Exemplo n.º 21
0
    def testWget_LinkWorksWithOfflineClient(self):
        data_store.REL_DB.WriteClientMetadata(client_id=VfsTest.FAKE_CLIENT_ID,
                                              fleetspeak_enabled=False)

        api_client = self._get_fake_api_client()
        client = grr_colab.Client(api_client)
        vfs_obj = vfs.VFS(api_client, jobs_pb2.PathSpec.OS)

        content = b'foo bar'
        with flagsaver.flagsaver(grr_admin_ui_url=self.endpoint):
            with temp.AutoTempFilePath() as temp_file:
                with io.open(temp_file, 'wb') as filedesc:
                    filedesc.write(content)

                with client.open(temp_file):
                    pass

                link = vfs_obj.wget(temp_file)

            self.assertEqual(requests.get(link).content, content)
Exemplo n.º 22
0
    def testProgramSchedule(self,
                            dataset_list_override=None,
                            train_executions_per_eval_override=None,
                            eval_steps_per_loop_override=None,
                            decode_steps_per_loop_override=None):
        with flagsaver.flagsaver(
                executor_datasets_to_eval=dataset_list_override,
                executor_train_executions_per_eval=
                train_executions_per_eval_override,
                executor_eval_steps_per_loop=eval_steps_per_loop_override,
                executor_decode_steps_per_loop=decode_steps_per_loop_override):
            ps_params = model_registry.GetProgramSchedule('test.DummyModel')

            if dataset_list_override is not None:
                self.assertAllEqual(ps_params.dataset_names,
                                    dataset_list_override.split(';'))
            else:
                self.assertAllEqual(ps_params.dataset_names, ['Dev', 'Test'])

            if train_executions_per_eval_override is not None:
                self.assertEqual(ps_params.train_executions_per_eval,
                                 train_executions_per_eval_override)
            else:
                self.assertEqual(ps_params.train_executions_per_eval, 0)

            # Assume only Dev and Test are avaiable eval datasets.
            eval_dev, eval_test, decode_dev, decode_test = 0, 0, 0, 0
            if dataset_list_override is None or 'Dev' in dataset_list_override:
                if eval_steps_per_loop_override != 0:
                    eval_dev += 1
                if decode_steps_per_loop_override != 0:
                    decode_dev += 1
            if dataset_list_override is None or 'Test' in dataset_list_override:
                if eval_steps_per_loop_override != 0:
                    eval_test += 1
                if decode_steps_per_loop_override != 0:
                    decode_test += 1
            self.assertLen(ps_params.eval_programs,
                           eval_dev + decode_dev + eval_test + decode_test)
            self._CheckProgramParams(ps_params.eval_programs, eval_dev,
                                     eval_test, decode_dev, decode_test)
Exemplo n.º 23
0
    def test_restore_text(self):
        testdata_dir = (pathlib.Path(absltest.get_default_test_srcdir()) /
                        TESTDATA_DIR)
        input_file = testdata_dir / "example_mentions_no_text.jsonl"
        expected_file = testdata_dir / "example_mentions.jsonl"

        output_dir = self.create_tempdir()
        output_file = pathlib.Path(output_dir.full_path) / "restored.jsonl"
        with flagsaver.flagsaver(input=str(input_file),
                                 index_dir=str(testdata_dir),
                                 output=str(output_file)):
            restore_text.main([])

        self.assertTrue(output_file.exists(), msg=str(output_file))

        # Compare the dataclass representations of the output and expected files
        # rather than their string contents to be robust against fluctuations in the
        # serialization (e.g. due to dictionary order, etc).
        got = schema.load_jsonl(output_file, schema.ContextualMentions)
        expected = schema.load_jsonl(expected_file, schema.ContextualMentions)
        self.assertEqual(got, expected)
Exemplo n.º 24
0
    def test_train_and_eval(self, listwise_inference):
        tmp_dir = self.create_tempdir()
        data_file = os.path.join(tmp_dir, "elwc.tfrecord")
        if tf.io.gfile.exists(data_file):
            tf.io.gfile.remove(data_file)

        with tf.io.TFRecordWriter(data_file) as writer:
            for elwc in [ELWC] * 10:
                writer.write(elwc.SerializeToString())

        model_dir = os.path.join(tmp_dir, "model")

        with flagsaver.flagsaver(train_path=data_file,
                                 eval_path=data_file,
                                 data_format="example_list_with_context",
                                 model_dir=model_dir,
                                 num_train_steps=10,
                                 listwise_inference=listwise_inference,
                                 group_size=1,
                                 weights_feature_name="doc_weight"):
            tf_ranking_tfrecord.train_and_eval()
Exemplo n.º 25
0
    def test_train_and_eval(self):
        data_dir = tf.compat.v1.test.get_temp_dir()
        data_file = os.path.join(data_dir, "elwc.tfrecord")
        if tf.io.gfile.exists(data_file):
            tf.io.gfile.remove(data_file)

        with tf.io.TFRecordWriter(data_file) as writer:
            for elwc in [ELWC] * 10:
                writer.write(elwc.SerializeToString())

        model_dir = os.path.join(data_dir, "model")

        with flagsaver.flagsaver(train_input_pattern=data_file,
                                 eval_input_pattern=data_file,
                                 model_dir=model_dir,
                                 num_train_steps=10,
                                 list_size=2):
            pipeline_example.train_and_eval()

        if tf.io.gfile.exists(model_dir):
            tf.io.gfile.rmtree(model_dir)
Exemplo n.º 26
0
    def test_get_final_aggregated_documents(self):
        new_flags = get_updated_default_flags(num_documents_to_retrieve=3,
                                              use_aggregated_documents=1)
        with flagsaver.flagsaver(**new_flags):
            # stop_after_seeing_new_results with multiple steps in the history
            nqenv = create_env(max_history_entries=3)
            nqenv.stop_after_seeing_new_results = True

            final_documents = nqenv.get_final_document_list()
            self.assertLen(final_documents, 3)
            self.assertEqual(final_documents[0].content, '4')
            self.assertEqual(final_documents[1].content, '3')
            self.assertEqual(final_documents[2].content, '2')

            # stop_after_seeing_new_results=False with multiple steps in the history
            nqenv.stop_after_seeing_new_results = False
            final_documents = nqenv.get_final_document_list()
            self.assertLen(final_documents, 3)
            self.assertEqual(final_documents[0].content, '5')
            self.assertEqual(final_documents[1].content, '4')
            self.assertEqual(final_documents[2].content, '3')
Exemplo n.º 27
0
    def test(self):
        input_path = self.create_tempdir('source').full_path
        output_path = self.create_tempdir('destination').full_path

        input_ds = test_util.dummy_era5_surface_dataset(times=365)
        input_ds.chunk({'time': 31}).to_zarr(input_path)

        with flagsaver.flagsaver(
                input_path=input_path,
                output_path=output_path,
        ):
            era5_rechunk.main([])

        output_ds = xarray.open_zarr(output_path)
        self.assertEqual({k: v[0]
                          for k, v in output_ds.chunks.items()}, {
                              'latitude': 5,
                              'longitude': 5,
                              'time': 365
                          })
        xarray.testing.assert_identical(input_ds, output_ds)
Exemplo n.º 28
0
 def testGenerateDataSample(self, level, data_format, data_percent,
                            uniform_domain_distribution):
     temp_output = os.path.join(self.create_tempdir(), 'output')
     ref_output = os.path.join(
         FLAGS.test_srcdir, TEST_DIR,
         f'sgd_text_v2_uniform_{uniform_domain_distribution}_{data_percent}'
     )
     with flagsaver.flagsaver(
             level=level,
             delimiter='=',
             data_format=data_format,
             sgd_file=os.path.join(FLAGS.test_srcdir, TEST_DIR,
                                   'sgd_train.json'),
             schema_file=os.path.join(FLAGS.test_srcdir, TEST_DIR,
                                      'sgd_train_schema.json'),
             output_file=temp_output,
             randomize_items=False,
             data_percent=data_percent,
             uniform_domain_distribution=uniform_domain_distribution):
         slots, item_desc = create_sgd_schemaless_data.load_schema()
         create_sgd_schemaless_data.generate_data(slots, item_desc)
         self.assertTrue(filecmp.cmp(temp_output, ref_output))
Exemplo n.º 29
0
    def test_recovery(self, distribution_strategy, flag_mode):
        loss_threshold = 1.0
        model_dir = self.get_temp_dir()
        flags_dict = dict(experiment='mock',
                          mode=flag_mode,
                          model_dir=model_dir,
                          params_override=json.dumps(self._test_config))
        with flagsaver.flagsaver(**flags_dict):
            params = train_utils.parse_configuration(flags.FLAGS)
            params.trainer.loss_upper_bound = loss_threshold
            params.trainer.recovery_max_trials = 1
            train_utils.serialize_config(params, model_dir)
            with distribution_strategy.scope():
                task = task_factory.get_task(params.task,
                                             logging_dir=model_dir)

            # Saves a checkpoint for reference.
            model = task.build_model()
            checkpoint = tf.train.Checkpoint(model=model)
            checkpoint_manager = tf.train.CheckpointManager(
                checkpoint, self.get_temp_dir(), max_to_keep=2)
            checkpoint_manager.save()
            before_weights = model.get_weights()

            def build_losses(labels, model_outputs, aux_losses=None):
                del labels, model_outputs
                return tf.constant([loss_threshold], tf.float32) + aux_losses

            task.build_losses = build_losses

            model, _ = train_lib.run_experiment(
                distribution_strategy=distribution_strategy,
                task=task,
                mode=flag_mode,
                params=params,
                model_dir=model_dir)
            after_weights = model.get_weights()
            for left, right in zip(before_weights, after_weights):
                self.assertAllEqual(left, right)
Exemplo n.º 30
0
 def test_main(self):
     input_pattern = os.path.join(self.test_subdirectory, '*.pbtxt')
     output_dir = os.path.join(self.test_subdirectory, 'tables')
     with flagsaver.flagsaver(input=input_pattern,
                              output=output_dir,
                              database=True,
                              cleanup=False):
         build_database.main(())
     with open(os.path.join(output_dir, 'reactions.csv')) as f:
         df = pd.read_csv(f)
     # NOTE(kearnes): Map keys are not always serialized in the same order.
     df['deserialized'] = df.serialized.apply(
         lambda x: reaction_pb2.Reaction.FromString(bytes.fromhex(x)))
     del df['serialized']
     pd.testing.assert_frame_equal(
         df,
         pd.DataFrame({
             'reaction_id': ['test'],
             'reaction_smiles': ['reaction'],
             'deserialized': [self.dataset.reactions[0]],
         }))
     with open(os.path.join(output_dir, 'inputs.csv')) as f:
         df = pd.read_csv(f)
     pd.testing.assert_frame_equal(
         df,
         pd.DataFrame({
             'reaction_id': ['test', 'test', 'test'],
             'smiles': ['input1', 'input2a', 'input2b'],
         }))
     with open(os.path.join(output_dir, 'outputs.csv')) as f:
         df = pd.read_csv(f)
     pd.testing.assert_frame_equal(
         df,
         pd.DataFrame({
             'reaction_id': ['test'],
             'smiles': ['product'],
             'yield': [2.5]
         }))
Exemplo n.º 31
0
    def test_basic_functionality(self, mock_requests):
        # Set up the mock REST endpoint.
        def match_request_text(request):
            return 'instances' in (request.text or '')

        mock_requests.post('http://foo.com',
                           request_headers={
                               'Content-Type':
                               'application/json; charset=utf-8'
                           },
                           additional_matcher=match_request_text,
                           text='response')

        with flagsaver.flagsaver():
            FLAGS.target = 'http://foo.com'
            FLAGS.batch_size = 1
            FLAGS.duration_ms = 1000
            FLAGS.target_latency_percentile = 0.9
            FLAGS.target_latency_ns = 100000000
            FLAGS.performance_sample_count = 5
            FLAGS.query_count = 10
            FLAGS.total_sample_count = 5
            loadgen_rest_main.main(None)
    def test_train_and_eval(self, listwise_inference):
        data_dir = tf.compat.v1.test.get_temp_dir()
        data_file = os.path.join(data_dir, "elwc.tfrecord")
        if tf.io.gfile.exists(data_file):
            tf.io.gfile.remove(data_file)

        with tf.io.TFRecordWriter(data_file) as writer:
            for elwc in [ELWC] * 10:
                writer.write(elwc.SerializeToString())

        model_dir = os.path.join(data_dir, "model")

        with flagsaver.flagsaver(train_path=data_file,
                                 eval_path=data_file,
                                 data_format="example_list_with_context",
                                 model_dir=model_dir,
                                 num_train_steps=10,
                                 listwise_inference=listwise_inference,
                                 group_size=1):
            tf_ranking_tfrecord.train_and_eval()

        if tf.io.gfile.exists(model_dir):
            tf.io.gfile.rmtree(model_dir)
Exemplo n.º 33
0
 def test_get_default_benchmark_logger(self):
   with flagsaver.flagsaver(benchmark_logger_type='foo'):
     self.assertIsInstance(logger.get_benchmark_logger(),
                           logger.BaseBenchmarkLogger)
Exemplo n.º 34
0
 def test_config_benchmark_bigquery_logger(self, mock_bigquery_client):
   with flagsaver.flagsaver(benchmark_logger_type='BenchmarkBigQueryLogger'):
     logger.config_benchmark_logger()
     self.assertIsInstance(logger.get_benchmark_logger(),
                           logger.BenchmarkBigQueryLogger)
Exemplo n.º 35
0
 def test_config_base_benchmark_logger(self):
   with flagsaver.flagsaver(benchmark_logger_type='BaseBenchmarkLogger'):
     logger.config_benchmark_logger()
     self.assertIsInstance(logger.get_benchmark_logger(),
                           logger.BaseBenchmarkLogger)