def test_ModelTest_CheckSaveLoad_Same(self): shape = Shape([2, 2]) values1 = [1, 2, 3, 4] values2 = [5, 6, 7, 8] tmp = tempfile.NamedTemporaryFile() m1 = Model() m2 = Model() p1 = Parameter(shape, I.Constant(0)) p1.value += tF.raw_input(shape, values1) p2 = Parameter(shape, I.Constant(0)) p2.value += tF.raw_input(shape, values2) m1.add("p", p1) m2.add("p", p2) m1.add("sm", m2) m1.save(tmp.name) m1 = Model() m2 = Model() p1 = Parameter() p2 = Parameter() m1.add("p", p1) m2.add("p", p2) m1.add("sm", m2) m1.load(tmp.name) self.assertTrue(p1.valid()) self.assertTrue(p2.valid()) self.assertEqual(shape, p1.shape()) self.assertEqual(shape, p2.shape()) self.assertEqual(values1, p1.value.to_list()) self.assertEqual(values2, p2.value.to_list())
def test_Parameter_argument(self): # no argument p = Parameter() self.assertFalse(p.valid()) # shape w/ Initializer p = Parameter(Shape([4, 3]), I.Constant(1)) self.assertEqual(p.shape(), Shape([4, 3])) self.assertEqual(p.value.to_list(), [1] * 12)
def test_ModelTest_CheckSaveLoadWithStats(self): shape = Shape([2, 2]) values1 = [1, 2, 3, 4] values2 = [5, 6, 7, 8] stats1 = [10, 20, 30, 40] stats2 = [50, 60, 70, 80] tmp = tempfile.NamedTemporaryFile() m1 = Model() m2 = Model() p1 = Parameter(shape, I.Constant(0)) p1.value += tF.raw_input(shape, values1) p2 = Parameter(shape, I.Constant(0)) p2.value += tF.raw_input(shape, values2) p1.add_stats("a", shape) p2.add_stats("b", shape) p1.stats["a"].reset_by_vector(stats1); p2.stats["b"].reset_by_vector(stats2); m1.add("p", p1) m2.add("p", p2) m1.add("sm", m2) m1.save(tmp.name) m1 = Model() m2 = Model() p1 = Parameter() p2 = Parameter() m1.add("p", p1) m2.add("p", p2) m1.add("sm", m2) m1.load(tmp.name) self.assertTrue(p1.valid()) self.assertTrue(p2.valid()) self.assertEqual(shape, p1.shape()) self.assertEqual(shape, p2.shape()) self.assertEqual(values1, p1.value.to_list()) self.assertEqual(values2, p2.value.to_list()) self.assertTrue("a" in p1.stats) self.assertTrue("b" in p2.stats) self.assertEqual(stats1, p1.stats["a"].to_list()) self.assertEqual(stats2, p2.stats["b"].to_list())