コード例 #1
0
def main():
    utils.set_random_seed(123)

    num_graph = 1000
    num_data_per_graph = 1

    n, d, s0, graph_type, sem_type = np.inf, 2, 1, 'ER', 'gauss'

    # equal variance
    w_ranges = ((-2.0, -0.5), (0.5, 2.0))
    noise_scale = [1., 1.]
    expt_name = 'equal_var'
    run_expt(num_graph, num_data_per_graph, n, d, s0, graph_type, sem_type,
             w_ranges, noise_scale, expt_name)

    # large a
    w_ranges = ((-2.0, -1.1), (1.1, 2.0))
    noise_scale = [1., 0.15]
    expt_name = 'large_a'
    run_expt(num_graph, num_data_per_graph, n, d, s0, graph_type, sem_type,
             w_ranges, noise_scale, expt_name)

    # small a
    w_ranges = ((-0.9, -0.5), (0.5, 0.9))
    noise_scale = [1, 0.15]
    expt_name = 'small_a'
    run_expt(num_graph, num_data_per_graph, n, d, s0, graph_type, sem_type,
             w_ranges, noise_scale, expt_name)
コード例 #2
0
def main():
    torch.set_default_dtype(torch.double)
    np.set_printoptions(precision=3)

    import notears.utils as ut
    ut.set_random_seed(123)

    n, d, s0, graph_type, sem_type = 200, 5, 9, 'ER', 'mim'
    B_true = ut.simulate_dag(d, s0, graph_type)
    np.savetxt('W_true.csv', B_true, delimiter=',')

    X = ut.simulate_nonlinear_sem(B_true, n, sem_type)
    np.savetxt('X.csv', X, delimiter=',')

    model = NotearsMLP(dims=[d, 10, 1], bias=True)
    W_est = notears_nonlinear(model, X, lambda1=0.01, lambda2=0.01)
    assert ut.is_dag(W_est)
    np.savetxt('W_est.csv', W_est, delimiter=',')
    acc = ut.count_accuracy(B_true, W_est != 0)
    print(acc)
コード例 #3
0
def main():
    torch.set_default_dtype(torch.double)
    np.set_printoptions(precision=3)

    import notears.utils as ut
    ut.set_random_seed(123)

    n, d, s0, graph_type, sem_type = 200, 5, 9, 'ER', 'mim'
    ensemble_size = 7

    B_true = ut.simulate_dag(d, s0, graph_type)
    np.savetxt('W_true.csv', B_true, delimiter=',')

    X = ut.simulate_nonlinear_sem(B_true, n, sem_type)
    np.savetxt('X.csv', X, delimiter=',')

    X = np.expand_dims(X, 1)
    X = np.tile(X, [1, ensemble_size, 1])

    model = EnsembleNotearsMLP(dims=[d, 10, 1],
                               ensemble_size=ensemble_size,
                               bias=True)
    W_est = ensemble_notears_mlp(model, X, lambda1=0.01, lambda2=0.01)
コード例 #4
0
            if h_new > 0.25 * h:
                rho *= 10
            else:
                break
        w_est, h = w_new, h_new
        alpha += rho * h
        if h <= h_tol or rho >= rho_max:
            break
    W_est = _adj(w_est)
    W_est[np.abs(W_est) < w_threshold] = 0
    return W_est


if __name__ == '__main__':
    from notears import utils
    utils.set_random_seed(1)

    n, d, s0, graph_type, sem_type = 100, 20, 20, 'ER', 'gauss'
    B_true = utils.simulate_dag(d, s0, graph_type)
    W_true = utils.simulate_parameter(B_true)
    np.savetxt('W_true.csv', W_true, delimiter=',')

    X = utils.simulate_linear_sem(W_true, n, sem_type)
    np.savetxt('X.csv', X, delimiter=',')

    W_est = notears_linear(X, lambda1=0.1, loss_type='l2')
    assert utils.is_dag(W_est)
    np.savetxt('W_est.csv', W_est, delimiter=',')
    acc = utils.count_accuracy(B_true, W_est != 0)
    print(acc)
コード例 #5
0
ファイル: linear.py プロジェクト: yaolezju/notears
            if h_new > 0.25 * h:
                rho *= 10
            else:
                break
        w_est, h = w_new, h_new
        alpha += rho * h
        if h <= h_tol or rho >= rho_max:
            break
    W_est = _adj(w_est)
    W_est[np.abs(W_est) < w_threshold] = 0
    return W_est


if __name__ == '__main__':
    import notears.utils as ut
    ut.set_random_seed(1)

    n, d, s0, graph_type, sem_type = 100, 20, 20, 'ER', 'gauss'
    B_true = ut.simulate_dag(d, s0, graph_type)
    W_true = ut.simulate_parameter(B_true)
    np.savetxt('W_true.csv', W_true, delimiter=',')

    X = ut.simulate_linear_sem(W_true, n, sem_type)
    np.savetxt('X.csv', X, delimiter=',')

    W_est = notears_linear(X, lambda1=0.1, loss_type='l2')
    assert ut.is_dag(W_est)
    np.savetxt('W_est.csv', W_est, delimiter=',')
    acc = ut.count_accuracy(B_true, W_est != 0)
    print(acc)