Esempio n. 1
0
def train(model_name,
          X,
          true_dag,
          model_params,
          topology_matrix=None,
          plot=True):
    """run algorithm of castle

    Parameters
    ----------
    model_name: str
        algorithm name
    X: pd.DataFrame
        train data
    true_dag: array
        true directed acyclic graph
    model_params: dict
        Parameters from configuration file
    topology_matrix: array, default None
        topology graph matrix
    plot: boolean, default None
        whether show graph.

    Returns
    -------
    model: castle.algorithm
        model of castle.algorithm
    pre_dag: array
        discovered causal matrix
    """

    # Instantiation algorithm and learn dag
    if model_name == 'TTPM':
        model = INLINE_ALGORITHMS[model_name.upper()](topology_matrix,
                                                      **model_params)
        model.learn(X)
    elif model_name == 'NOTEARSLOWRANK':
        rank = model_params.get('rank')
        del model_params['rank']
        model = NotearsLowRank(**model_params)
        model.learn(X, rank=rank)
    else:
        try:
            model = INLINE_ALGORITHMS[model_name.upper()](**model_params)
            model.learn(data=X)
        except ValueError:
            raise ValueError('Invalid algorithm name: {}.'.format(model_name))

    pre_dag = model.causal_matrix
    if plot:
        if true_dag is not None:
            GraphDAG(pre_dag, true_dag, show=plot)
            m = MetricsDAG(pre_dag, true_dag)
            print(m.metrics)
        else:
            GraphDAG(pre_dag, show=plot)

    return model, pre_dag
Esempio n. 2
0
def castle_experiment(model, x, y=None, show_graph=False, **kwargs):

    model.learn(x, **kwargs)
    if y is not None:
        metrics = MetricsDAG(model.causal_matrix, y).metrics
    else:
        metrics = None
    if show_graph:
        GraphDAG(model.causal_matrix, y)

    return metrics
Esempio n. 3
0
`networkx` package, then like the following import method.

Warnings: This script is used only for demonstration and cannot be directly
          imported.
"""

from castle.common import GraphDAG
from castle.metrics import MetricsDAG
from castle.datasets import DAG, IIDSimulation
from castle.algorithms import DirectLiNGAM


#######################################
# DirectLiNGAM used simulate data
#######################################
# simulate data for DirectLiNGAM
weighted_random_dag = DAG.erdos_renyi(n_nodes=10, n_edges=20, weight_range=(0.5, 2.0), seed=1)
dataset = IIDSimulation(W=weighted_random_dag, n=2000, method='linear', sem_type='gauss')
true_dag, X = dataset.B, dataset.X

# DirectLiNGAM learn
g = DirectLiNGAM()
g.learn(X)

# plot est_dag and true_dag
GraphDAG(g.causal_matrix, true_dag)

# calculate accuracy
met = MetricsDAG(g.causal_matrix, true_dag)
print(met.metrics)
Esempio n. 4
0
        imported.
"""

from castle.common import GraphDAG
from castle.metrics import MetricsDAG
from castle.datasets import DAG, IIDSimulation
from castle.algorithms import GraN_DAG, Parameters

# load data
weighted_random_dag = DAG.erdos_renyi(n_nodes=10,
                                      n_edges=20,
                                      weight_range=(0.5, 2.0),
                                      seed=1)
dataset = IIDSimulation(W=weighted_random_dag,
                        n=2000,
                        method='nonlinear',
                        sem_type='mlp')
dag, x = dataset.B, dataset.X

# Initialize parameters for gran_dag
parameters = Parameters(input_dim=x.shape[1])

# Instantiation algorithm
gnd = GraN_DAG(params=parameters)
gnd.learn(data=x, target=dag)

# plot predict_dag and true_dag
GraphDAG(gnd.causal_matrix, dag, 'result')
mm = MetricsDAG(gnd.causal_matrix, dag)
print(mm.metrics)
Esempio n. 5
0
how to use TTPM algorithm in `castle` package for causal inference.

Warnings: This script is used only for demonstration and cannot be directly
        imported.
"""

from castle.common import GraphDAG
from castle.metrics import MetricsDAG
from castle.datasets import DAG, Topology, THPSimulation
from castle.algorithms import TTPM

# Data Simulation for TTPM
true_causal_matrix = DAG.erdos_renyi(n_nodes=10, n_edges=10)
topology_matrix = Topology.erdos_renyi(n_nodes=20, n_edges=20)
simulator = THPSimulation(true_causal_matrix,
                          topology_matrix,
                          mu_range=(0.00005, 0.0001),
                          alpha_range=(0.005, 0.007))
X = simulator.simulate(T=3600 * 24, max_hop=2)

# TTPM modeling
ttpm = TTPM(topology_matrix, max_hop=2)
ttpm.learn(X)
print(ttpm.causal_matrix)

# plot est_dag and true_dag
GraphDAG(ttpm.causal_matrix, true_causal_matrix)
# calculate accuracy
ret_metrix = MetricsDAG(ttpm.causal_matrix, true_causal_matrix)
print(ret_metrix.metrics)
Esempio n. 6
0
        imported.
"""

from castle.common import GraphDAG
from castle.metrics import MetricsDAG
from castle.datasets import DAG, IIDSimulation
from castle.algorithms import PC

# simulation for pc
weighted_random_dag = DAG.erdos_renyi(n_nodes=10,
                                      n_edges=20,
                                      weight_range=(0.5, 2.0),
                                      seed=1)
dataset = IIDSimulation(W=weighted_random_dag,
                        n=2000,
                        method='linear',
                        sem_type='gauss')
true_dag, X = dataset.B, dataset.X

# PC learn
pc = PC()
pc.learn(X)

# plot predict_dag and true_dag
GraphDAG(pc.causal_matrix)
GraphDAG(pc.causal_matrix, true_dag, 'result_pc')

# calculate accuracy
met = MetricsDAG(pc.causal_matrix, true_dag)
print(met.metrics)
Esempio n. 7
0
If you want to plot causal graph, please make sure you have already install
`networkx` package, then like the following import method.

Warnings: This script is used only for demonstration and cannot be directly
        imported.
"""

from castle.common import GraphDAG
from castle.metrics import MetricsDAG
from castle.datasets import DAG, IIDSimulation
from castle.algorithms import ANMNonlinear

weighted_random_dag = DAG.erdos_renyi(n_nodes=6,
                                      n_edges=10,
                                      weight_range=(0.5, 2.0),
                                      seed=1)
dataset = IIDSimulation(W=weighted_random_dag,
                        n=1000,
                        method='nonlinear',
                        sem_type='gp-add')
true_dag, X = dataset.B, dataset.X

anm = ANMNonlinear(alpha=0.05)
anm.learn(data=X)

# plot predict_dag and true_dag
GraphDAG(anm.causal_matrix, true_dag)
mm = MetricsDAG(anm.causal_matrix, true_dag)
print(mm.metrics)
Esempio n. 8
0
how to use TTPM algorithm in `castle` package for causal inference.

Warnings: This script is used only for demonstration and cannot be directly
        imported.
"""

from castle.common import GraphDAG
from castle.metrics import MetricsDAG
from castle.datasets import DAG, Topology, THPSimulation
from castle.algorithms import TTPM

# Data Simulation for TTPM
true_causal_matrix = DAG.erdos_renyi(n_nodes=10, n_edges=10)
topology_matrix = Topology.erdos_renyi(n_nodes=20, n_edges=20)
simulator = THPSimulation(true_causal_matrix,
                          topology_matrix,
                          mu_range=(0.00005, 0.0001),
                          alpha_range=(0.005, 0.007))
X = simulator.simulate(T=3600 * 24, max_hop=2)

# TTPM modeling
ttpm = TTPM(topology_matrix, max_hop=2)
ttpm.learn(X)
print(ttpm.causal_matrix)

# plot est_dag and true_dag
GraphDAG(ttpm.causal_matrix.values, true_causal_matrix)
# calculate accuracy
ret_metrix = MetricsDAG(ttpm.causal_matrix.values, true_causal_matrix)
print(ret_metrix.metrics)
Esempio n. 9
0
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from castle.datasets import DAG, IIDSimulation
from castle.common import GraphDAG
from castle.metrics import MetricsDAG
from castle.algorithms.ges.ges import GES

for d in [6, 8, 10, 15, 20]:
    edges = d * 2
    weighted_random_dag = DAG.erdos_renyi(n_nodes=d,
                                          n_edges=edges,
                                          weight_range=(0.5, 2.0),
                                          seed=1)
    dataset = IIDSimulation(W=weighted_random_dag,
                            n=1000,
                            method='nonlinear',
                            sem_type='gp-add')
    true_dag, X = dataset.B, dataset.X

    algo = GES(criterion='bic', method='scatter')
    algo.learn(X)

    # plot predict_dag and true_dag
    GraphDAG(algo.causal_matrix, true_dag)
    m1 = MetricsDAG(algo.causal_matrix, true_dag)
    print(m1.metrics)
    break
Esempio n. 10
0
from castle.metrics import MetricsDAG
from castle.datasets import DAG, IIDSimulation
from castle.algorithms import GraNDAG

# load data
weighted_random_dag = DAG.erdos_renyi(n_nodes=10,
                                      n_edges=20,
                                      weight_range=(0.5, 2.0),
                                      seed=1)
dataset = IIDSimulation(W=weighted_random_dag,
                        n=2000,
                        method='nonlinear',
                        sem_type='mlp')
dag, x = dataset.B, dataset.X

# Instantiation algorithm
d = {
    'model_name': 'NonLinGauss',
    'nonlinear': 'leaky-relu',
    'optimizer': 'sgd',
    'norm_prod': 'paths',
    'device_type': 'gpu'
}
gnd = GraNDAG(input_dim=x.shape[1], )
gnd.learn(data=x)

# plot predict_dag and true_dag
GraphDAG(gnd.causal_matrix, dag)
mm = MetricsDAG(gnd.causal_matrix, dag)
print(mm.metrics)
Esempio n. 11
0
    from castle.algorithms import RL

    g = RL(**params_config['model_params'])
    g.learn(data=X, dag=true_dag)

elif args.model_name == 'ttpm':
    from castle.algorithms import TTPM

    g = TTPM(topology_matrix, **params_config['model_params'])
    g.learn(X)

else:
    raise ValueError('Invalid algorithm name: {}.'.format(args.model_name))

# plot and evaluate predict_dag and true_dag
if true_dag is not None:
    if args.model_name == 'ttpm':
        GraphDAG(g.causal_matrix.values, true_dag)
        m = MetricsDAG(g.causal_matrix.values, true_dag)
        print(m.metrics)
    else:
        GraphDAG(g.causal_matrix, true_dag)
        m = MetricsDAG(g.causal_matrix, true_dag)
        print(m.metrics)

else:
    if args.model_name == 'ttpm':
        GraphDAG(g.causal_matrix.values)
    else:
        GraphDAG(g.causal_matrix)