Esempio n. 1
0
class BayesianRegression(object):
    """Bayesian linear regression."""
    def __init__(self,
                 n_feature,
                 prior_mean=0,
                 prior_precision=1e-6,
                 prior_a=10,
                 prior_b=1):
        super().__init__()
        self.n_feature = n_feature

        if np.shape(prior_mean) != (n_feature, ):
            prior_mean = prior_mean * np.ones(n_feature)
        if np.shape(prior_precision) != (n_feature, n_feature):
            prior_precision = prior_precision * np.ones(n_feature)
        self.prior_mean = prior_mean
        self.prior_precision = prior_precision
        self.prior_a = prior_a
        self.prior_b = prior_b
        self._init_weights()
        # print("Intialize regression")
        # self.print()

    def _init_weights(self):
        self.weights = GaussianARD(self.prior_mean,
                                   self.prior_precision,
                                   shape=(self.n_feature, ))

    def fit(self, X, y):
        self._init_weights()
        # self.cost,
        # self.myopic_voc(action, state),
        # self.vpi_action(action, state),
        # self.vpi(state),
        # self.expected_term_reward(state)

        self.tau = Gamma(self.prior_a, self.prior_b)
        F = SumMultiply('i,i', self.weights, X)
        y_obs = GaussianARD(F, self.tau)
        y_obs.observe(y)

        Q = VB(y_obs, self.weights)
        Q.update(repeat=10, tol=1e-4, verbose=False)

    def predict(self, x, return_var=False):
        y = SumMultiply('i,i', self.weights, x)
        y_hat, var, *_ = y.get_moments()
        if return_var:
            return y_hat, var
        else:
            return y_hat

    def sample(self, x):
        w = self.weights.random()
        return x @ w

    def print(self, diagonal=True):
        mean, m2 = self.weights.get_moments()[:2]
        var = m2 - mean**2
        if diagonal:
            var = np.diagonal(var)
        bar = '_' * 40
        print(f'{bar}\n{mean.round(3)}\n{var.round(3)}\n{bar}')