Ejemplo n.º 1
0
    def get_model_by_name(name):
        """
        Get a model from database and reproduce it given the parameters saved
        :param name: str - name of the model
        :return: (None, None, str) | (None, dict, str) | (BaseModel, dict, None) - str is error message, dict is model's
        parameters from DB, BaseModel is the instance of the model, might be ConvolutionalNeuralNetwork,
        GaussianProcesses, SparseGaussianProcesses up to date...
        """
        model_record, err = DBManager.get_model_by_name(name)
        if model_record is None:
            return None, None, err

        if model_record.type == 'CNN':
            cnn, err = ConvolutionalNeuralNetwork.new_from_json(
                model_record.model_params, model_record.extra_params)
            return cnn, model_record, None
        elif model_record.type == 'FullGP':
            full_gp, err = GaussianProcesses.new_from_json(
                model_record.model_params, model_record.extra_params)
            return full_gp, model_record, None
        elif model_record.type == 'SparseGP':
            sparse_gp, err = SparseGaussianProcesses.new_from_json(
                model_record.model_params, model_record.extra_params)
            return sparse_gp, model_record, None

        return None, model_record, err
Ejemplo n.º 2
0
    def test_model_to_json_load_from_json(self):
        global full_gp
        model_params, extra_params = full_gp.model_to_json()
        model_params_dict = json.loads(model_params)
        extra_params_dict = json.loads(extra_params)

        assert model_params_dict['data']['kernel'] == full_gp.kernel.to_dict()
        assert model_params_dict['data']['params'] == full_gp.model.param_array.tolist()

        loaded_gp, msg = GaussianProcesses.new_from_json(model_params_dict, extra_params_dict)

        assert msg is None
        assert full_gp.stats == loaded_gp.stats
        assert full_gp.kernel.to_dict() == loaded_gp.kernel.to_dict()
        assert full_gp.model.param_array.tolist() == loaded_gp.model.param_array.tolist()