def test_serialization(self): params = parameters.Parameters() params.__append_config__(__rand_param_config__("param_0")) params.__append_config__(__rand_param_config__("param_1")) for name in params.names(): param = params.get(name) param[:] = numpy.random.uniform(-1.0, 1.0, size=params.get_shape(name)) params.set(name, param) tmp_file = cStringIO.StringIO() params.to_tar(tmp_file) tmp_file.seek(0) params_dup = parameters.Parameters.from_tar(tmp_file) self.assertEqual(params_dup.names(), params.names()) for name in params.names(): self.assertEqual(params.get_shape(name), params_dup.get_shape(name)) p0 = params.get(name) p1 = params_dup.get(name) self.assertTrue(numpy.isclose(p0, p1).all())
def get_param(names, size): p = parameters.Parameters() for k, v in zip(names, size): p.__append_config__(__rand_param_config__(k, v)) for name in p.names(): param = p.get(name) param[:] = numpy.random.uniform( -1.0, 1.0, size=p.get_shape(name)) p.set(name, param) return p