def test_gp_module_save_and_load(self):
        np.random.seed(0)
        X = np.random.rand(10, 3)
        Xt = np.random.rand(20, 3)
        Y = np.random.rand(10, 1)
        noise_var = np.random.rand(1)
        lengthscale = np.random.rand(3)
        variance = np.random.rand(1)
        dtype = 'float64'
        m = self.make_gpregr_model(lengthscale, variance, noise_var)

        observed = [m.X, m.Y]
        from mxfusion.inference import MAP, Inference
        infr = Inference(MAP(model=m, observed=observed), dtype=dtype)

        loss, _ = infr.run(X=mx.nd.array(X, dtype=dtype),
                           Y=mx.nd.array(Y, dtype=dtype))

        infr.save(prefix=self.PREFIX)

        m2 = self.make_gpregr_model(lengthscale, variance, noise_var)

        observed2 = [m2.X, m2.Y]
        infr2 = Inference(MAP(model=m2, observed=observed2), dtype=dtype)
        infr2.initialize(X=mx.nd.array(X, dtype=dtype),
                         Y=mx.nd.array(Y, dtype=dtype))

        # Load previous parameters
        infr2.load(
            graphs_file=self.PREFIX + '_graphs.json',
            parameters_file=self.PREFIX + '_params.json',
            inference_configuration_file=self.PREFIX + '_configuration.json',
            mxnet_constants_file=self.PREFIX + '_mxnet_constants.json',
            variable_constants_file=self.PREFIX + '_variable_constants.json')

        for original_uuid, original_param in infr.params.param_dict.items():
            original_data = original_param.data().asnumpy()
            reloaded_data = infr2.params.param_dict[
                infr2._uuid_map[original_uuid]].data().asnumpy()
            assert np.all(np.isclose(original_data, reloaded_data))

        for original_uuid, original_param in infr.params.constants.items():
            if isinstance(original_param, mx.ndarray.ndarray.NDArray):
                original_data = original_param.asnumpy()
                reloaded_data = infr2.params.constants[
                    infr2._uuid_map[original_uuid]].asnumpy()
            else:
                original_data = original_param
                reloaded_data = infr2.params.constants[
                    infr2._uuid_map[original_uuid]]

            assert np.all(np.isclose(original_data, reloaded_data))

        loss2, _ = infr2.run(X=mx.nd.array(X, dtype=dtype),
                             Y=mx.nd.array(Y, dtype=dtype))

        self.remove_saved_files(self.PREFIX)
Beispiel #2
0
    def test_gluon_func_save_and_load(self):
        m = self.make_simple_gluon_model()
        infr = Inference(ForwardSamplingAlgorithm(m, observed=[m.x]))
        infr.run(x=mx.nd.ones((1, 1)))
        infr.save(self.ZIPNAME)

        m2 = self.make_simple_gluon_model()
        infr2 = Inference(ForwardSamplingAlgorithm(m2, observed=[m2.x]))
        infr2.run(x=mx.nd.ones((1, 1)))
        infr2.load(self.ZIPNAME)
        infr2.run(x=mx.nd.ones((1, 1)))

        for n in m.f.parameter_names:
            assert np.allclose(infr.params[getattr(m.y.factor, n)].asnumpy(),
                               infr2.params[getattr(m2.y.factor, n)].asnumpy())

        os.remove(self.ZIPNAME)