def testGetClass(self): mp = model_registry.GetClass('test.DummyModel') self.assertEqual('Train', mp.Train().name) self.assertEqual('Dev', mp.Dev().name) self.assertEqual('Test', mp.Test().name) self.assertIsNotNone(mp.Task()) self.assertIsNotNone(mp.Model()) with self.assertRaises(LookupError): # Not yet registered. model_registry.GetClass('something.does.not.exist')
def _testOneModelParams(self, name): cls = model_registry.GetClass(name) p = cls.Model() self.assertTrue(issubclass(p.cls, base_model.BaseModel)) self.assertTrue(p.model is not None) for dataset in ('Train', 'Dev', 'Test'): input_p = cls.GetDatasetParams(dataset) if issubclass(p.cls, base_model.SingleTaskModel): self.assertTrue( issubclass(input_p.cls, base_input_generator.BaseInputGenerator), 'Error in %s' % dataset) if (dataset != 'Train') and issubclass( input_p.cls, base_input_generator.BaseSequenceInputGenerator) and ( input_p.num_samples != 0): self.assertEquals( input_p.num_batcher_threads, 1, 'num_batcher_threads too large in %s. Decoder ' 'or eval runs over this set might not span ' 'exactly one epoch.' % dataset) else: self.assertTrue(issubclass(p.cls, base_model.MultiTaskModel)) for _, v in input_p.IterParams(): self.assertTrue( issubclass(v.cls, base_input_generator.BaseInputGenerator), 'Error in %s' % dataset)
def testGetModelParamsClass(self): cls = model_registry.GetClass('test.DummyModel') self.assertTrue(issubclass(cls, base_model_params.SingleTaskModelParams))