Пример #1
0
    def test_serialization(self):
        params = parameters.Parameters()
        params.__append_config__(__rand_param_config__("param_0"))
        params.__append_config__(__rand_param_config__("param_1"))

        for name in params.names():
            param = params.get(name)
            param[:] = numpy.random.uniform(-1.0,
                                            1.0,
                                            size=params.get_shape(name))
            params.set(name, param)

        tmp_file = cStringIO.StringIO()
        params.to_tar(tmp_file)
        tmp_file.seek(0)
        params_dup = parameters.Parameters.from_tar(tmp_file)

        self.assertEqual(params_dup.names(), params.names())

        for name in params.names():
            self.assertEqual(params.get_shape(name),
                             params_dup.get_shape(name))
            p0 = params.get(name)
            p1 = params_dup.get(name)
            self.assertTrue(numpy.isclose(p0, p1).all())
Пример #2
0
 def get_param(names, size):
     p = parameters.Parameters()
     for k, v in zip(names, size):
         p.__append_config__(__rand_param_config__(k, v))
     for name in p.names():
         param = p.get(name)
         param[:] = numpy.random.uniform(
             -1.0, 1.0, size=p.get_shape(name))
         p.set(name, param)
     return p