def run_SE(alpha, rho): # analytical linear channel (Marcenko Pastur) model = glm_state_evolution(alpha=alpha, prior_type="gauss_bernoulli", output_type="gaussian", prior_rho=rho, output_var=1e-10) # SE : uninformed initialization se = StateEvolution(model) se.iterate(max_iter=200) x_data = se.get_variable_data(id="x") return dict(source="SE", v=x_data["v"])
def run_se(a0, alpha, prior_rho, prior_mean): model = glm_state_evolution(alpha=alpha, prior_type="gauss_bernouilli", output_type="relu", prior_rho=prior_rho, prior_mean=prior_mean) initializer = CustomInit(a_init=[("x", "bwd", a0)]) records = run_state_evolution(x_ids=["x", "z"], model=model, max_iter=200, initializer=initializer) return records
def run_BO(alpha, rho): # analytical linear channel (Marcenko Pastur) model = glm_state_evolution(alpha=alpha, prior_type="gauss_bernoulli", output_type="abs", prior_rho=rho, prior_mean=0) # Bayes optimal : informed initialization initializer = CustomInit(a_init=[("x", "bwd", 10**3)]) se = StateEvolution(model) se.iterate(max_iter=200, initializer=initializer) x_data = se.get_variable_data(id="x") return dict(source="BO", v=x_data["v"])
def run_BO(alpha, rho): # analytical linear channel (Marcenko Pastur) model = glm_state_evolution(alpha=alpha, prior_type="gauss_bernoulli", output_type="gaussian", prior_rho=rho, output_var=1e-10) # BO : informative initialization, scaled to avoid issues at low alpha power = 3 * np.exp(alpha) initializer = CustomInit(a_init=[("x", "bwd", 10**power)]) se = StateEvolution(model) se.iterate(max_iter=200, initializer=initializer) x_data = se.get_variable_data(id="x") return dict(source="BO", v=x_data["v"])
def run_se(a0, alpha, output_width, prior_p_pos): model = glm_state_evolution(alpha=alpha, prior_type="binary", output_type="door", output_width=output_width, prior_p_pos=prior_p_pos) a_init = [("x", "bwd", a0)] initializer = CustomInit(a_init=a_init) early = EarlyStopping(max_increase=0.1) records = run_state_evolution(x_ids=["x"], model=model, max_iter=200, initializer=initializer, callback=early) return records