Example #1
0
 def test_returns_model_weights_for_model_callable(self):
     parameter_count_dict = model_utils.parameter_count_from_model(
         TestModel)
     expected_parameter_count = collections.OrderedDict(
         num_tensors=2, parameters=4, num_unspecified_tensors=0)
     self.assertEqual(expected_parameter_count, parameter_count_dict)
Example #2
0
 def test_fails_not_model(self):
     with self.assertRaises(TypeError):
         model_utils.parameter_count_from_model(0)
     with self.assertRaises(TypeError):
         model_utils.parameter_count_from_model(lambda: 0)