def testGetDatasetsWithGetAllDatasetParams(self): class DummyDatasetHolder(base_model_params._BaseModelParams): def GetAllDatasetParams(self): return {'Train': None, 'Dev': None} self.assertAllEqual(['Dev', 'Train'], datasets.GetDatasets(DummyDatasetHolder)) self.assertAllEqual(['Dev', 'Train'], datasets.GetDatasets(DummyDatasetHolder()))
def GetDatasetParams(self, dataset): """Convenience function that returns the param for the given dataset name. Args: dataset: A python string. Typically, 'Dev', 'Test', etc. Returns: If there is a `cls.${dataset}` method defined, call that method to generate a hyperparam for the input data. Raises: DatasetError: if there is not a `${dataset}` method defined under `cls`. """ try: all_datasets = self.GetAllDatasetParams() if dataset not in all_datasets: raise DatasetError( f'Dataset {dataset} not found; ' f'available datasets are: {all_datasets.keys()}') return all_datasets.get(dataset) except datasets.GetAllDatasetParamsNotImplementedError: # Fall through the legacy path. pass try: f = getattr(self, dataset) except AttributeError as e: raise DatasetError( str(e) + '; available datasets are: %s' % datasets.GetDatasets(type(self))) return f()
def testGetDatasetsRaisesErrorOnInvalidDatasetsOnInstanceVar(self): class DummyDatasetHolder(base_model_params._BaseModelParams): def Train(self): pass def BadDataset(self, any_argument): pass with self.assertRaises(datasets.DatasetFunctionError): datasets.GetDatasets(DummyDatasetHolder(), warn_on_error=False)
def testGetDatasetsFindsAllPublicMethodsOnInstanceVar(self): class DummyDatasetHolder(base_model_params._BaseModelParams): def Train(self): pass def UnexpectedDatasetName(self): pass found_datasets = datasets.GetDatasets(DummyDatasetHolder()) self.assertAllEqual(['Train', 'UnexpectedDatasetName'], found_datasets)
def testGetDatasetsWarnsOnErrorOnInstanceVar(self): class DummyDatasetHolder(base_model_params._BaseModelParams): def Train(self): pass def BadDataset(self, any_argument): pass with self.assertLogs() as assert_log: found_datasets = datasets.GetDatasets(DummyDatasetHolder(), warn_on_error=True) self.assertAllEqual(['Train'], found_datasets) self.assertIn('WARNING:absl:Found a public function BadDataset', assert_log.output[0])
def testGetDatasetsOnClassWithPositionalArgumentInit(self): class DummyDatasetHolder(base_model_params._BaseModelParams): def __init__(self, model_spec): pass def Train(self): pass def Dev(self): pass self.assertAllEqual(['Dev', 'Train'], datasets.GetDatasets(DummyDatasetHolder, warn_on_error=True))
def _testOneModelParams(self, registry, name): with tf.Graph().as_default(): model_params = registry.GetClass(name)() try: all_datasets = model_params.GetAllDatasetParams() except datasets.GetAllDatasetParamsNotImplementedError: all_datasets = {} for dataset_name in datasets.GetDatasets(model_params): try: all_datasets[dataset_name] = getattr( model_params, dataset_name)() except NotImplementedError: pass p = model_params.Model() p.input = all_datasets['Train'] self.assertTrue(issubclass(p.cls, base_model.BaseModel)) self.assertIsNot(p.model, None) p.cluster.mode = 'sync' p.cluster.job = 'decoder' p.cluster.decoder.replicas = 1 with p.cluster.Instantiate(): # Instantiate the params class, to help catch errors in layer # constructors due to misconfigurations. mdl = p.Instantiate() self._ValidateEMA(name, mdl) p = mdl.params for dataset, input_p in all_datasets.items(): if issubclass(p.cls, base_model.SingleTaskModel): if (not isinstance(input_p, hyperparams.InstantiableParams) or not issubclass( input_p.cls, base_input_generator.BaseInputGenerator)): # Assume this function is not a dataset function but some helper. continue if (dataset != 'Train' and issubclass( input_p.cls, base_input_generator.BaseSequenceInputGenerator) and input_p.num_samples != 0): self.assertEqual( input_p.num_batcher_threads, 1, f'num_batcher_threads too large in {dataset}. Decoder or eval ' f'runs over this set might not span exactly one epoch.' ) else: self.assertTrue( issubclass(p.cls, base_model.MultiTaskModel))
def InspectParams(self): r"""Print out all the params. An example to run this mode: bazel-bin/lingvo/trainer --logtostderr \ --model=image.mnist.LeNet5 --mode=inspect_params --logdir=/tmp/lenet5 \ --run_locally=cpu """ FLAGS.mode = 'sync' cls = self.model_registry.GetClass(self._model_name) tf.io.gfile.makedirs(FLAGS.logdir) for dataset in datasets.GetDatasets(cls): p = self.GetParamsForDataset('controller', dataset) outf = os.path.join(FLAGS.logdir, dataset.lower() + '-params.txt') tf.logging.info('Write all params for {} to {}'.format( dataset, outf)) with tf.io.gfile.GFile(outf, 'w') as f: f.write(p.ToText())
def GetDatasetParams(self, dataset): """Convenience function that returns the param for the given dataset name. Args: dataset: A python string. Typically, 'Dev', 'Test', etc. Returns: If there is a `cls.${dataset}` method defined, call that method to generate a hyperparam for the input data. Raises: DatasetError: if there is not a `${dataset}` method defined under `cls`. """ try: f = getattr(self, dataset) except AttributeError as e: raise DatasetError( str(e) + '; available datasets are: %s' % datasets.GetDatasets(type(self))) return f()
def InspectDatasets(self): """Prints out datasets configured for the model.""" cls = self.model_registry.GetClass(self._model_name) print(','.join( [dataset.lower() for dataset in datasets.GetDatasets(cls)]))