def init_row_chain(data_matrix, num_iter=200): states, sigma_sq_D, sigma_sq_N = chains.fit_model(data_matrix, num_iter=num_iter) integ = chains.integration_matrix(data_matrix.m_orig)[data_matrix.row_ids, :] left = recursive.IntegrationNode(integ) temp = np.vstack([states[0, :][nax, :], states[1:, :] - states[:-1, :]]) right = recursive.GaussianNode(temp, 'scalar', sigma_sq_D) pred = states[data_matrix.row_ids, :] X = data_matrix.sample_latent_values(pred, sigma_sq_N) noise = recursive.GaussianNode(X - pred, 'scalar', sigma_sq_N) return recursive.SumNode([recursive.ProductNode([left, right]), noise])
def dummy(): return IntegrationTNode(chains.integration_matrix(5).T)