예제 #1
0
  def testGetClass(self):

    mp = model_registry.GetClass('test.DummyModel')
    self.assertEqual('Train', mp.Train().name)
    self.assertEqual('Dev', mp.Dev().name)
    self.assertEqual('Test', mp.Test().name)
    self.assertIsNotNone(mp.Task())
    self.assertIsNotNone(mp.Model())

    with self.assertRaises(LookupError):
      # Not yet registered.
      model_registry.GetClass('something.does.not.exist')
예제 #2
0
 def _testOneModelParams(self, name):
   cls = model_registry.GetClass(name)
   p = cls.Model()
   self.assertTrue(issubclass(p.cls, base_model.BaseModel))
   self.assertTrue(p.model is not None)
   for dataset in ('Train', 'Dev', 'Test'):
     input_p = cls.GetDatasetParams(dataset)
     if issubclass(p.cls, base_model.SingleTaskModel):
       self.assertTrue(
           issubclass(input_p.cls, base_input_generator.BaseInputGenerator),
           'Error in %s' % dataset)
       if (dataset != 'Train') and issubclass(
           input_p.cls, base_input_generator.BaseSequenceInputGenerator) and (
               input_p.num_samples != 0):
         self.assertEquals(
             input_p.num_batcher_threads, 1,
             'num_batcher_threads too large in %s. Decoder '
             'or eval runs over this set might not span '
             'exactly one epoch.' % dataset)
     else:
       self.assertTrue(issubclass(p.cls, base_model.MultiTaskModel))
       for _, v in input_p.IterParams():
         self.assertTrue(
             issubclass(v.cls, base_input_generator.BaseInputGenerator),
             'Error in %s' % dataset)
예제 #3
0
 def testGetModelParamsClass(self):
   cls = model_registry.GetClass('test.DummyModel')
   self.assertTrue(issubclass(cls, base_model_params.SingleTaskModelParams))