Example #1
0
    def testGetDatasetsWithGetAllDatasetParams(self):
        class DummyDatasetHolder(base_model_params._BaseModelParams):
            def GetAllDatasetParams(self):
                return {'Train': None, 'Dev': None}

        self.assertAllEqual(['Dev', 'Train'],
                            datasets.GetDatasets(DummyDatasetHolder))
        self.assertAllEqual(['Dev', 'Train'],
                            datasets.GetDatasets(DummyDatasetHolder()))
Example #2
0
    def GetDatasetParams(self, dataset):
        """Convenience function that returns the param for the given dataset name.

    Args:
      dataset: A python string. Typically, 'Dev', 'Test', etc.

    Returns:
      If there is a `cls.${dataset}` method defined, call that method to
      generate a hyperparam for the input data.

    Raises:
      DatasetError: if there is not a `${dataset}` method defined under `cls`.
    """
        try:
            all_datasets = self.GetAllDatasetParams()
            if dataset not in all_datasets:
                raise DatasetError(
                    f'Dataset {dataset} not found; '
                    f'available datasets are: {all_datasets.keys()}')
            return all_datasets.get(dataset)
        except datasets.GetAllDatasetParamsNotImplementedError:
            # Fall through the legacy path.
            pass

        try:
            f = getattr(self, dataset)
        except AttributeError as e:
            raise DatasetError(
                str(e) + '; available datasets are: %s' %
                datasets.GetDatasets(type(self)))
        return f()
Example #3
0
    def testGetDatasetsRaisesErrorOnInvalidDatasetsOnInstanceVar(self):
        class DummyDatasetHolder(base_model_params._BaseModelParams):
            def Train(self):
                pass

            def BadDataset(self, any_argument):
                pass

        with self.assertRaises(datasets.DatasetFunctionError):
            datasets.GetDatasets(DummyDatasetHolder(), warn_on_error=False)
Example #4
0
    def testGetDatasetsFindsAllPublicMethodsOnInstanceVar(self):
        class DummyDatasetHolder(base_model_params._BaseModelParams):
            def Train(self):
                pass

            def UnexpectedDatasetName(self):
                pass

        found_datasets = datasets.GetDatasets(DummyDatasetHolder())

        self.assertAllEqual(['Train', 'UnexpectedDatasetName'], found_datasets)
Example #5
0
    def testGetDatasetsWarnsOnErrorOnInstanceVar(self):
        class DummyDatasetHolder(base_model_params._BaseModelParams):
            def Train(self):
                pass

            def BadDataset(self, any_argument):
                pass

        with self.assertLogs() as assert_log:
            found_datasets = datasets.GetDatasets(DummyDatasetHolder(),
                                                  warn_on_error=True)
            self.assertAllEqual(['Train'], found_datasets)
        self.assertIn('WARNING:absl:Found a public function BadDataset',
                      assert_log.output[0])
Example #6
0
    def testGetDatasetsOnClassWithPositionalArgumentInit(self):
        class DummyDatasetHolder(base_model_params._BaseModelParams):
            def __init__(self, model_spec):
                pass

            def Train(self):
                pass

            def Dev(self):
                pass

        self.assertAllEqual(['Dev', 'Train'],
                            datasets.GetDatasets(DummyDatasetHolder,
                                                 warn_on_error=True))
Example #7
0
    def _testOneModelParams(self, registry, name):
        with tf.Graph().as_default():
            model_params = registry.GetClass(name)()
            try:
                all_datasets = model_params.GetAllDatasetParams()
            except datasets.GetAllDatasetParamsNotImplementedError:
                all_datasets = {}
                for dataset_name in datasets.GetDatasets(model_params):
                    try:
                        all_datasets[dataset_name] = getattr(
                            model_params, dataset_name)()
                    except NotImplementedError:
                        pass

            p = model_params.Model()
            p.input = all_datasets['Train']
            self.assertTrue(issubclass(p.cls, base_model.BaseModel))
            self.assertIsNot(p.model, None)
            p.cluster.mode = 'sync'
            p.cluster.job = 'decoder'
            p.cluster.decoder.replicas = 1
            with p.cluster.Instantiate():
                # Instantiate the params class, to help catch errors in layer
                # constructors due to misconfigurations.
                mdl = p.Instantiate()
                self._ValidateEMA(name, mdl)
                p = mdl.params

            for dataset, input_p in all_datasets.items():
                if issubclass(p.cls, base_model.SingleTaskModel):
                    if (not isinstance(input_p, hyperparams.InstantiableParams)
                            or not issubclass(
                                input_p.cls,
                                base_input_generator.BaseInputGenerator)):
                        # Assume this function is not a dataset function but some helper.
                        continue
                    if (dataset != 'Train' and issubclass(
                            input_p.cls,
                            base_input_generator.BaseSequenceInputGenerator)
                            and input_p.num_samples != 0):
                        self.assertEqual(
                            input_p.num_batcher_threads, 1,
                            f'num_batcher_threads too large in {dataset}. Decoder or eval '
                            f'runs over this set might not span exactly one epoch.'
                        )
                else:
                    self.assertTrue(
                        issubclass(p.cls, base_model.MultiTaskModel))
Example #8
0
    def InspectParams(self):
        r"""Print out all the params.

    An example to run this mode:

    bazel-bin/lingvo/trainer --logtostderr \
      --model=image.mnist.LeNet5 --mode=inspect_params --logdir=/tmp/lenet5 \
      --run_locally=cpu
    """
        FLAGS.mode = 'sync'
        cls = self.model_registry.GetClass(self._model_name)
        tf.io.gfile.makedirs(FLAGS.logdir)
        for dataset in datasets.GetDatasets(cls):
            p = self.GetParamsForDataset('controller', dataset)
            outf = os.path.join(FLAGS.logdir, dataset.lower() + '-params.txt')
            tf.logging.info('Write all params for {} to {}'.format(
                dataset, outf))
            with tf.io.gfile.GFile(outf, 'w') as f:
                f.write(p.ToText())
Example #9
0
    def GetDatasetParams(self, dataset):
        """Convenience function that returns the param for the given dataset name.

    Args:
      dataset: A python string. Typically, 'Dev', 'Test', etc.

    Returns:
      If there is a `cls.${dataset}` method defined, call that method to
      generate a hyperparam for the input data.

    Raises:
      DatasetError: if there is not a `${dataset}` method defined under `cls`.
    """
        try:
            f = getattr(self, dataset)
        except AttributeError as e:
            raise DatasetError(
                str(e) + '; available datasets are: %s' %
                datasets.GetDatasets(type(self)))
        return f()
Example #10
0
 def InspectDatasets(self):
     """Prints out datasets configured for the model."""
     cls = self.model_registry.GetClass(self._model_name)
     print(','.join(
         [dataset.lower() for dataset in datasets.GetDatasets(cls)]))