コード例 #1
0
ファイル: example.py プロジェクト: kelizhang/trustworthyAI
def run_simulate(config):
    """this function used to run simulate data task

    Parameters
    ----------
    config: dict
        configuration info.

    Returns
    -------
    out: tuple
        (X, true_dag) or (X, true_dag, topology_matrix)
    """

    algo_params = config['algorithm_params']
    if config['task_params']['algorithm'] == 'EVENT':
        true_dag = DAG.erdos_renyi(n_nodes=algo_params['n_nodes'],
                                   n_edges=algo_params['n_edges'],
                                   weight_range=algo_params['weight_range'],
                                   seed=algo_params['seed'])
        topology_matrix = Topology.erdos_renyi(
            n_nodes=algo_params['Topology_n_nodes'],
            n_edges=algo_params['Topology_n_edges'],
            seed=algo_params['Topology_seed'])
        simulator = THPSimulation(true_dag,
                                  topology_matrix,
                                  mu_range=algo_params['mu_range'],
                                  alpha_range=algo_params['alpha_range'])
        X = simulator.simulate(
            T=algo_params['THPSimulation_simulate_T'],
            max_hop=algo_params['THPSimulation_simulate_max_hop'],
            beta=algo_params['THPSimulation_simulate_beta'])

        return X, true_dag, topology_matrix
    else:
        weighted_random_dag = DAG.erdos_renyi(
            n_nodes=algo_params['n_nodes'],
            n_edges=algo_params['n_edges'],
            weight_range=algo_params['weight_range'],
            seed=algo_params['seed'])
        dataset = IIDSimulation(W=weighted_random_dag,
                                n=algo_params['n'],
                                method=algo_params['method'],
                                sem_type=algo_params['sem_type'],
                                noise_scale=algo_params['noise_scale'])

        return pd.DataFrame(dataset.X), dataset.B
コード例 #2
0
def simulate_data(method='nonlinear', sem_type='mlp',
                  n_nodes=6, n_edges=15, n=1000):
    weighted_random_dag = DAG.erdos_renyi(n_nodes=n_nodes, n_edges=n_edges,
                                          weight_range=(0.5, 2.0), seed=1)
    dataset = IIDSimulation(W=weighted_random_dag, n=n, method=method,
                            sem_type=sem_type)
    true_dag, X = dataset.B, dataset.X

    return X, true_dag
コード例 #3
0
 def setUp(self) -> None:
     print(f"{'=' * 20}Testing TTPM{'=' * 20}")
     self.dag = DAG.erdos_renyi(n_nodes=10, n_edges=10)
     topology_matrix = Topology.erdos_renyi(n_nodes=20, n_edges=20)
     simulator = THPSimulation(self.dag,
                               topology_matrix,
                               mu_range=(0.00005, 0.0001),
                               alpha_range=(0.005, 0.007))
     self.x = simulator.simulate(T=3600 * 24, max_hop=2)
     self.error_params = []
コード例 #4
0
`networkx` package, then like the following import method.

Warnings: This script is used only for demonstration and cannot be directly
          imported.
"""

from castle.common import GraphDAG
from castle.metrics import MetricsDAG
from castle.datasets import DAG, IIDSimulation
from castle.algorithms import DirectLiNGAM


#######################################
# DirectLiNGAM used simulate data
#######################################
# simulate data for DirectLiNGAM
weighted_random_dag = DAG.erdos_renyi(n_nodes=10, n_edges=20, weight_range=(0.5, 2.0), seed=1)
dataset = IIDSimulation(W=weighted_random_dag, n=2000, method='linear', sem_type='gauss')
true_dag, X = dataset.B, dataset.X

# DirectLiNGAM learn
g = DirectLiNGAM()
g.learn(X)

# plot est_dag and true_dag
GraphDAG(g.causal_matrix, true_dag)

# calculate accuracy
met = MetricsDAG(g.causal_matrix, true_dag)
print(met.metrics)
コード例 #5
0
# limitations under the License.
"""
This demo script aim to demonstrate
how to use TTPM algorithm in `castle` package for causal inference.

Warnings: This script is used only for demonstration and cannot be directly
        imported.
"""

from castle.common import GraphDAG
from castle.metrics import MetricsDAG
from castle.datasets import DAG, Topology, THPSimulation
from castle.algorithms import TTPM

# Data Simulation for TTPM
true_causal_matrix = DAG.erdos_renyi(n_nodes=10, n_edges=10)
topology_matrix = Topology.erdos_renyi(n_nodes=20, n_edges=20)
simulator = THPSimulation(true_causal_matrix,
                          topology_matrix,
                          mu_range=(0.00005, 0.0001),
                          alpha_range=(0.005, 0.007))
X = simulator.simulate(T=3600 * 24, max_hop=2)

# TTPM modeling
ttpm = TTPM(topology_matrix, max_hop=2)
ttpm.learn(X)
print(ttpm.causal_matrix)

# plot est_dag and true_dag
GraphDAG(ttpm.causal_matrix, true_causal_matrix)
# calculate accuracy
コード例 #6
0
from castle.common import GraphDAG
from castle.metrics import MetricsDAG
from castle.datasets import DAG, IIDSimulation
from castle.algorithms import DAG_GNN

type = 'ER'  # or `SF`
h = 2  # ER2 when h=5 --> ER5
n_nodes = 10
n_edges = h * n_nodes
method = 'linear'
sem_type = 'gauss'

if type == 'ER':
    weighted_random_dag = DAG.erdos_renyi(n_nodes=n_nodes,
                                          n_edges=n_edges,
                                          weight_range=(0.5, 2.0),
                                          seed=300)
elif type == 'SF':
    weighted_random_dag = DAG.scale_free(n_nodes=n_nodes,
                                         n_edges=n_edges,
                                         weight_range=(0.5, 2.0),
                                         seed=300)
else:
    raise ValueError('Just supported `ER` or `SF`.')

dataset = IIDSimulation(W=weighted_random_dag,
                        n=2000,
                        method=method,
                        sem_type=sem_type)
true_dag, X = dataset.B, dataset.X
コード例 #7
0
def simulate_data(data, alg, task_id, parameters):
    """
    Simulation Data Generation Entry.

    Parameters
    ----------
    data: str
        Path for storing generated data files.
    alg: str
        Generating Operator Strings.
    task_id: int
        task key in the database.
    parameters: dict
        Data generation parameters.
    Returns
    -------
        True or False
    """
    parameters = translation_parameters(parameters)
    task_api = TaskApi()
    start_time = datetime.datetime.now()
    task_api.update_task_status(task_id, 0.1)
    task_api.update_consumed_time(task_id, start_time)
    task_api.update_update_time(task_id, start_time)

    if not os.path.exists(data):
        os.makedirs(data)
    task_name = task_api.get_task_name(task_id)
    sample_path = os.path.join(data, "datasets",
                               str(task_id) + "_" + task_name + ".csv")
    true_dag_path = os.path.join(data, "true",
                                 str(task_id) + "_" + task_name + ".npz")
    node_relationship_path = os.path.join(
        data, "node_relationship_" + str(task_id) + "_" + task_name + ".csv")
    topo_path = os.path.join(data,
                             "topo_" + str(task_id) + "_" + task_name + ".npz")
    task_api.update_task_status(task_id, 0.2)
    task_api.update_consumed_time(task_id, start_time)

    topo = None
    try:
        if alg == "EVENT":
            true_dag = DAG.erdos_renyi(n_nodes=parameters['n_nodes'],
                                       n_edges=parameters['n_edges'],
                                       weight_range=parameters['weight_range'],
                                       seed=parameters['seed'])
            topo = Topology.erdos_renyi(n_nodes=parameters['Topology_n_nodes'],
                                        n_edges=parameters['Topology_n_edges'],
                                        seed=parameters['Topology_seed'])
            simulator = THPSimulation(true_dag,
                                      topo,
                                      mu_range=parameters['mu_range'],
                                      alpha_range=parameters['alpha_range'])
            sample = simulator.simulate(
                T=parameters['THPSimulation_simulate_T'],
                max_hop=parameters['THPSimulation_simulate_max_hop'],
                beta=parameters['THPSimulation_simulate_beta'])

            task_api.update_task_status(task_id, 0.5)
            task_api.update_consumed_time(task_id, start_time)
        else:

            weighted_random_dag = DAG.erdos_renyi(
                n_nodes=parameters['n_nodes'],
                n_edges=parameters['n_edges'],
                weight_range=parameters['weight_range'],
                seed=parameters['seed'])
            dataset = IIDSimulation(W=weighted_random_dag,
                                    n=parameters['n'],
                                    method=parameters['method'],
                                    sem_type=parameters['sem_type'],
                                    noise_scale=parameters['noise_scale'])

            true_dag, sample = dataset.B, dataset.X
            sample = pd.DataFrame(sample)

            task_api.update_task_status(task_id, 0.5)
            task_api.update_consumed_time(task_id, start_time)
    except Exception as error:
        task_api.update_task_status(task_id, str(error))
        task_api.update_consumed_time(task_id, start_time)
        logger.warning('Generating simulation data failed, exp=%s' % error)
        if os.path.exists(sample_path):
            os.remove(sample_path)
        if os.path.exists(true_dag_path):
            os.remove(true_dag_path)
        if os.path.exists(node_relationship_path):
            os.remove(node_relationship_path)
        if os.path.exists(topo_path):
            os.remove(topo_path)
        return False

    if os.path.exists(topo_path):
        os.remove(topo_path)

    task_api.update_task_status(task_id, 0.6)
    task_api.update_consumed_time(task_id, start_time)

    save_to_file(sample, sample_path)
    save_to_file(true_dag, true_dag_path)
    if isinstance(topo, np.ndarray):
        save_to_file(topo, topo_path)

    # calculate accuracy
    save_gragh_edges(true_dag, node_relationship_path)
    task_api.update_task_status(task_id, 1.0)
    task_api.update_consumed_time(task_id, start_time)
    return True