Beispiel #1
0
    def test_ModelTest_CheckGetTrainableParametersWithSubmodels(self):
        m1 = Model()
        m2 = Model()
        m3 = Model()
        p1 = Parameter()
        p2 = Parameter()
        p3 = Parameter()
        m1.add("p", p1)
        m2.add("p", p2)
        m3.add("p", p3)
        m1.add("sm", m2)
        m2.add("sm", m3)

        params1 = m1.get_trainable_parameters()
        self.assertEqual(3, len(params1))
        self.assertIsInstance(params1, dict)
        self.assertIs(p1, params1[("p",)])
        self.assertIs(p2, params1[("sm", "p",)])
        self.assertIs(p3, params1[("sm", "sm", "p",)])

        params2 = m2.get_trainable_parameters()
        self.assertEqual(2, len(params2))
        self.assertIsInstance(params2, dict)
        self.assertIs(p2, params2[("p",)])
        self.assertIs(p3, params2[("sm", "p",)])

        params3 = m3.get_trainable_parameters()
        self.assertEqual(1, len(params3))
        self.assertIsInstance(params3, dict)
        self.assertIs(p3, params3[("p",)])
Beispiel #2
0
 def test_ModelTest_CheckGetTrainableParameters(self):
     m = Model()
     p1 = Parameter()
     p2 = Parameter()
     p3 = Parameter()
     m.add("p1", p1)
     m.add("p2", p2)
     m.add("p3", p3)
     params = m.get_trainable_parameters()
     self.assertEqual(3, len(params))
     self.assertIsInstance(params, dict)
     self.assertIs(p1, params[("p1",)]);
     self.assertIs(p2, params[("p2",)]);
     self.assertIs(p3, params[("p3",)]);