def test_run_bad():
    class TestWriterTmp(AbstractWriter):
        @staticmethod
        def necessary_config(node_config):
            return set([])

        def run(self, data_object):
            return data_object, False

    NodeFactory().register("TestWriterTmp", TestWriterTmp)

    config = {
        "implementation_config": {
            "writer_config": {
                "mywriter": {
                    "class": "TestWriterTmp",
                    "destinations": []
                }
            }
        }
    }
    configuration = Configuration(None,
                                  is_dict_config=True,
                                  dict_config=config)

    runner = DagRunner(configuration)

    # unregister this class
    del NodeFactory().name_dict["TestWriterTmp"]

    with pytest.raises(Exception) as e:
        runner.run()
    assert "Issue instantiating mywriter and class TestWriterTmp" in str(e)
Beispiel #2
0
    def _register_class(self, class_key, class_prefix):
        """Register a class specified in the config file.

        Args:
            class_key (str): class key to register
            class_prefix(str): the prefix of the class to register. Can be in `path.to.module` format,
                or a full path `path/to/module`.

        Returns:
            None - attempts to register the class with it's default name
        """
        # convert to string before checking if file
        if class_prefix is None:
            class_prefix = ""

        if os.path.isfile(class_prefix):
            modulename = self._import_file(class_key, class_prefix)

        # loading from module
        else:
            if self.config_metadata and "class_package" in self.config_metadata:
                class_package = self.config_metadata["class_package"]
                prefix = ".".join(filter(None, [class_package, class_prefix]))
            else:
                prefix = class_prefix

            modulename = importlib.import_module(prefix)

        clz = getattr(modulename, class_key)
        NodeFactory().register(None, clz)
    def test_run_node(self):

        path = "primrose.notifications.success_notification.get_notification_client"
        with mock.patch(path) as get_client_mock:
            get_client_mock.return_value = mock.Mock()

            NodeFactory().register("SlackDataMock", SlackDataMock)

            config = Configuration(None,
                                   is_dict_config=True,
                                   dict_config=config_dict_node_message)
            data_object = DataObject(config)

            reader = SlackDataMock(config, "test_node")
            data_object = reader.run(data_object)

            success_instance = ClientNotification(
                configuration=config,
                instance_name="node_notification",
            )
            success_instance.client = get_client_mock.return_value

            success_instance.run(data_object)

            success_instance.client.post_message.assert_called_once_with(
                message="Node Success!")
def test_class_package(mock_env):
    config_path = {
        "metadata": {"class_package": "test"},
        "implementation_config": {
            "reader_config": {"read_data": {"class": "TestExtNode", "destinations": []}}
        },
    }
    config_full_path = {
        "metadata": {"class_package": "test/ext_node_example.py"},
        "implementation_config": {
            "reader_config": {"read_data": {"class": "TestExtNode", "destinations": []}}
        },
    }
    config_full_dot = {
        "metadata": {"class_package": "test"},
        "implementation_config": {
            "reader_config": {
                "read_data": {
                    "class": "TestExtNode",
                    "class_prefix": "ext_node_example",
                    "destinations": [],
                }
            }
        },
    }
    for config in [config_full_path, config_path, config_full_dot]:
        config = Configuration(
            config_location=None, is_dict_config=True, dict_config=config
        )
        assert config.config_string
        assert config.config_hash
        NodeFactory().unregister("TestExtNode")
def test_run_bad2():
    class TestWriterTmp(AbstractWriter):
        @staticmethod
        def necessary_config(node_config):
            return set([])

        def run(self, data_object):
            raise Exception("Deliberate error")
            # return data_object, False

    NodeFactory().register("TestWriterTmp", TestWriterTmp)

    config = {
        "implementation_config": {
            "writer_config": {
                "mywriter": {
                    "class": "TestWriterTmp",
                    "destinations": []
                }
            }
        }
    }
    configuration = Configuration(None,
                                  is_dict_config=True,
                                  dict_config=config)

    runner = DagRunner(configuration)

    with pytest.raises(Exception) as e:
        runner.run()
    assert "Deliberate error" in str(e)
def test_register_module_classes():

    with LogCapture() as l:
        NodeFactory().register_module_classes(__name__)
    l.check(
        (
            "root",
            "INFO",
            "Discovered class CsvReader (<class "
            "'primrose.readers.csv_reader.CsvReader'>)",
        ),
        (
            "root",
            "DEBUG",
            "Registered CsvReader : <class 'primrose.readers.csv_reader.CsvReader'>",
        ),
        (
            "root",
            "INFO",
            "Discovered class DillWriter (<class "
            "'primrose.writers.dill_writer.DillWriter'>)",
        ),
        (
            "root",
            "DEBUG",
            "Registered DillWriter : <class 'primrose.writers.dill_writer.DillWriter'>",
        ),
    )
def test_run4():
    class TestWriter(AbstractFileWriter):
        def __init__(self, configuration, instance_name):
            pass

        @staticmethod
        def necessary_config(node_config):
            return set([])

        def run(self, data_object):
            terminate = True
            return data_object, terminate

    NodeFactory().register("TestWriter", TestWriter)

    config = {
        "implementation_config": {
            "reader_config": {
                "csv_reader": {
                    "class": "CsvReader",
                    "filename": "test/minimal.csv",
                    "destinations": ["csv_writer"],
                }
            },
            "writer_config": {
                "csv_writer": {
                    "class": "TestWriter"
                }
            },
        }
    }

    configuration = Configuration(None,
                                  is_dict_config=True,
                                  dict_config=config)
    runner = DagRunner(configuration)

    with LogCapture() as l:
        runner.run(dry_run=False)
    l.check(
        ("root", "INFO", "Taking nodes to run from default"),
        (
            "root",
            "INFO",
            "received node csv_reader of type reader_config and class CsvReader",
        ),
        ("root", "INFO", "Reading test/minimal.csv from CSV"),
        (
            "root",
            "INFO",
            "received node csv_writer of type writer_config and class TestWriter",
        ),
        ("root", "INFO", "Terminating early due to signal from csv_writer"),
        ("root", "INFO", "All done. Bye bye!"),
    )
def test_env_override_class_package(mock_env):
    config = {
        "metadata": {"class_package": "junk"},
        "implementation_config": {
            "reader_config": {"read_data": {"class": "TestExtNode", "destinations": []}}
        },
    }
    config = Configuration(
        config_location=None, is_dict_config=True, dict_config=config
    )
    assert config.config_string
    assert config.config_hash
    NodeFactory().unregister("TestExtNode")
Beispiel #9
0
def test_all_nodes_to_prune2():
    class TestConditionalNode(AbstractConditionalPath):
        @staticmethod
        def necessary_config(node_config):
            return set()

        def destinations_to_prune(self):
            return ["junk"]

        def run(self, data_object):
            return data_object, False

    NodeFactory().register("TestConditionalNode", TestConditionalNode)

    config = {
        "implementation_config": {
            "reader_config": {
                "conditional_node": {
                    "class": "TestConditionalNode",
                    "destinations": ["csv_writer"],
                }
            },
            "writer_config": {
                "csv_writer": {
                    "class": "CsvWriter",
                    "key": "test_data",
                    "dir": "cache",
                    "filename": "unittest_similar_recipes.csv",
                }
            },
        }
    }

    configuration = Configuration(None,
                                  is_dict_config=True,
                                  dict_config=config)
    node = TestConditionalNode(configuration, "conditional_node")

    with pytest.raises(Exception) as e:
        node.all_nodes_to_prune()
    assert "Destination junk is not in destinations list" in str(e)
Beispiel #10
0
def test_run():
    class TestPipeline(AbstractPipeline):
        def transform(self, data_object):
            logging.info("TRANSFORM CALLED")
            return data_object

        def fit_transform(self, data_object):
            logging.info("FIT_TRANSFORM CALLED")
            return self.transform(data_object)

        @staticmethod
        def necessary_config(node_config):
            return set(["is_training"])

    NodeFactory().register("TestPipeline", TestPipeline)

    config = {
        "implementation_config": {
            "reader_config": {
                "myreader": {
                    "class": "CsvReader",
                    "filename": "test/minimal.csv",
                    "destinations": ["mypipeline"],
                }
            },
            "pipeline_config": {
                "mypipeline": {
                    "class": "TestPipeline",
                    "is_training": True
                }
            },
        }
    }
    configuration = Configuration(config_location=None,
                                  is_dict_config=True,
                                  dict_config=config)

    reference_file_path = "test/minimal.csv"
    corpus = pd.read_csv(reference_file_path)

    reader = CsvReader(configuration, "myreader")

    data_object = DataObject(configuration)
    data_object.add(reader, corpus)

    pipeline = TestPipeline(configuration, "mypipeline")

    with LogCapture() as l:
        pipeline.run(data_object)
    l.check(
        (
            "root",
            "INFO",
            "No upstream TransformerSequence found. Creating new TransformerSequence...",
        ),
        ("root", "INFO", "FIT_TRANSFORM CALLED"),
        ("root", "INFO", "TRANSFORM CALLED"),
    )

    data_object.add(reader, TransformerSequence(), "tsequence")
    with LogCapture() as l:
        pipeline.run(data_object)
    l.check(
        (
            "root",
            "INFO",
            "Upstream TransformerSequence found, initializing pipeline...",
        ),
        ("root", "INFO", "FIT_TRANSFORM CALLED"),
        ("root", "INFO", "TRANSFORM CALLED"),
    )

    config["implementation_config"]["pipeline_config"]["mypipeline"][
        "is_training"] = False
    configuration = Configuration(config_location=None,
                                  is_dict_config=True,
                                  dict_config=config)
    reader = CsvReader(configuration, "myreader")
    data_object = DataObject(configuration)
    data_object.add(reader, corpus)
    pipeline = TestPipeline(configuration, "mypipeline")
    with LogCapture() as l:
        pipeline.run(data_object)
    l.check(
        (
            "root",
            "INFO",
            "No upstream TransformerSequence found. Creating new TransformerSequence...",
        ),
        ("root", "INFO", "TRANSFORM CALLED"),
    )
Beispiel #11
0
def test_execute_pipeline():
    class TestTransformer(AbstractTransformer):
        def fit(self, data):
            logging.info("Transfer FIT CALLED")

        def transform(self, data):
            logging.info("Transfer TRANSFORM CALLED")
            return data

        def fit_transform(self, data):
            logging.info("Transfer FIT_TRANSFORM CALLED")
            self.fit(data)
            return self.transform(data)

    class TestPipeline2(AbstractPipeline):
        def transform(self, data_object):
            return data_object

        @staticmethod
        def necessary_config(node_config):
            return set(["is_training"])

    NodeFactory().register("TestPipeline2", TestPipeline2)

    config = {
        "implementation_config": {
            "reader_config": {
                "myreader": {
                    "class": "CsvReader",
                    "filename": "test/minimal.csv",
                    "destinations": ["mypipeline"],
                }
            },
            "pipeline_config": {
                "mypipeline": {
                    "class": "TestPipeline",
                    "is_training": True
                }
            },
        }
    }
    configuration = Configuration(config_location=None,
                                  is_dict_config=True,
                                  dict_config=config)

    reference_file_path = "test/minimal.csv"
    corpus = pd.read_csv(reference_file_path)

    reader = CsvReader(configuration, "myreader")

    data_object = DataObject(configuration)
    data_object.add(reader, corpus)

    sequence = TransformerSequence()
    sequence.add(TestTransformer())
    data_object.add(reader, sequence, "tsequence")

    pipeline = TestPipeline2(configuration, "mypipeline")

    with pytest.raises(Exception) as e:
        pipeline.execute_pipeline(corpus, PipelineModeType.FIT)
    assert "run() must be called to extract/create a TransformerSequence" in str(
        e)

    pipeline.run(data_object)

    with pytest.raises(Exception) as e:
        pipeline.execute_pipeline(corpus, "JUNK")
    assert "mode must be of type PipelineModeType Enum object." in str(e)

    with LogCapture() as l:
        pipeline.execute_pipeline(corpus, PipelineModeType.FIT)
    l.check(("root", "INFO", "Transfer FIT CALLED"))

    with LogCapture() as l:
        pipeline.execute_pipeline(corpus, PipelineModeType.FIT_TRANSFORM)
    l.check(
        ("root", "INFO", "Transfer FIT_TRANSFORM CALLED"),
        ("root", "INFO", "Transfer FIT CALLED"),
        ("root", "INFO", "Transfer TRANSFORM CALLED"),
    )
Beispiel #12
0
def test_plot_dag():
    class TestPostprocess(AbstractNode):
        @staticmethod
        def necessary_config(node_config):
            return set(["key1", "key2"])

        def run(self, data_object):
            return data_object

    NodeFactory().register("TestPostprocess", TestPostprocess)

    class Testpipeline(AbstractNode):
        # def __init__(self, configuration, instance_name):
        #    pass
        @staticmethod
        def necessary_config(node_config):
            return set([])

        def run(self, data_object):
            return data_object

    NodeFactory().register("Testpipeline", Testpipeline)

    class TestCleanup(AbstractNode):
        @staticmethod
        def necessary_config(node_config):
            return set([])

        def run(self, data_object):
            return data_object

    NodeFactory().register("TestCleanup", TestCleanup)

    config = {
        "implementation_config": {
            "reader_config": {
                "csv_reader": {
                    "class": "CsvReader",
                    "filename": "some/path/to/file",
                    "destinations": ["pipeline1"],
                }
            },
            "pipeline_config": {
                "pipeline1": {
                    "class": "Testpipeline",
                    "destinations": ["decision_tree_model"],
                }
            },
            "model_config": {
                "decision_tree_model": {
                    "class": "SklearnClassifierModel",
                    "model_parameters": {},
                    "sklearn_classifier_name": "tree.DecisionTreeClassifier",
                    "grid_search_scoring": "roc_auc",
                    "cv_folds": 3,
                    "mode": "predict",
                    "destinations": ["nodename"],
                }
            },
            "postprocess_config": {
                "nodename": {
                    "class": "TestPostprocess",
                    "key1": "val1",
                    "key2": "val2",
                    "destinations": ["write_output"],
                }
            },
            "writer_config": {
                "write_output": {
                    "class": "CsvWriter",
                    "key": "read_data",
                    "dir": "cache",
                    "filename": "some/path/to/file.csv",
                    "destinations": ["donothingsuccess"],
                }
            },
            "cleanup_config": {
                "donothingsuccess": {
                    "class": "TestCleanup",
                }
            },
        }
    }

    cfilename = "test/test_dag_plotting.json"
    with open(cfilename, "w") as f:
        jstyleson.dump(config, f)
    config = Configuration(config_location=cfilename)
    dag = config.dag
    dag.create_dag()

    filename = "test/test_dag_plotting.png"
    if os.path.exists(filename):
        os.remove(filename)

    dag.plot_dag(filename, traverser=ConfigLayerTraverser(config))

    assert os.path.exists(filename)

    if os.path.exists(cfilename):
        os.remove(cfilename)

    if os.path.exists(filename):
        os.remove(filename)
def test_cosine_similarity_matrix():
    class Testpipeline(AbstractNode):
        def __init__(self, configuration, instance_name):
            self.configuration = configuration
            self.instance_name = instance_name

        @staticmethod
        def necessary_config(node_config):
            return set([])

        def run(self, data_object):
            return data_object, False

    NodeFactory().register("Testpipeline", Testpipeline)

    class TestSimpleSearchEngine(AbstractSearchEngine):
        """
        simple TFIDF search engine
        """
        def __init__(self, configuration, instance_name):
            AbstractSearchEngine.__init__(self, configuration, instance_name)

        def tokenize(self, s, stopwords=[], add_ngrams=True):
            q = s.lower()
            tokens = (q.replace("-", " ").replace(",",
                                                  "").replace("(", "").replace(
                                                      ")", "").split(" "))
            tokens = [w for w in tokens if w not in stopwords]
            if add_ngrams:
                bigrams = list(ngrams(tokens, 2))
                strbigrams = ["_".join(t) for t in bigrams]
                tokens.extend(strbigrams)
            return tokens

        def eval_model(data_object):
            return data_object

    NodeFactory().register("TestSimpleSearchEngine", TestSimpleSearchEngine)

    config = {
        "implementation_config": {
            "pipeline_config": {
                "pipeline1": {
                    "class": "Testpipeline",
                    "destinations": ["recipe_name_model"],
                }
            },
            "model_config": {
                "recipe_name_model": {
                    "class": "TestSimpleSearchEngine",
                    "id_key": "id",
                    "doc_key": "name",
                    "mode": "precict",
                    "destinations": [],
                }
            },
        }
    }

    configuration = Configuration(None,
                                  is_dict_config=True,
                                  dict_config=config)

    # set that pipeline provided the corpus
    corpus = [
        {
            "id": 1,
            "name": "spinach omelet"
        },
        {
            "id": 2,
            "name": "kale omelet"
        },
        {
            "id": 3,
            "name": "cherry pie"
        },
    ]
    data_object = DataObject(configuration)
    pipeline = Testpipeline(configuration, "pipeline1")
    data_object.add(pipeline, pd.DataFrame(corpus))

    engine = TestSimpleSearchEngine(configuration, "recipe_name_model")
    engine.predict(data_object)

    m = engine.cosine_similarity_matrix()

    assert math.isclose(m[0, 0], 1.0, abs_tol=0.001)
    assert math.isclose(m[0, 1], 0.224325, abs_tol=0.001)
    assert math.isclose(m[0, 2], 0.0, abs_tol=0.001)

    assert math.isclose(m[1, 0], 0.224325, abs_tol=0.001)
    assert math.isclose(m[1, 1], 1.0, abs_tol=0.001)
    assert math.isclose(m[1, 2], 0.0, abs_tol=0.001)

    assert math.isclose(m[2, 0], 0.0, abs_tol=0.001)
    assert math.isclose(m[2, 1], 0.0, abs_tol=0.001)
    assert math.isclose(m[2, 2], 1.0, abs_tol=0.001)

    assert engine.ids == [1, 2, 3]
    assert engine.docs == ["spinach omelet", "kale omelet", "cherry pie"]
    assert engine.tfidf is not None
Beispiel #14
0
def test_init_ok():
    config = {
        "implementation_config": {
            "postprocess_config": {
                "nodename": {
                    "class": "TestPostprocess",
                    "key1": "val1",
                    "key2": "val2",
                    "destinations": ["recipe_s3_writer"],
                }
            },
            "writer_config": {
                "recipe_s3_writer": {
                    "class": "S3Writer",
                    "dir": "cache",
                    "key": DataObject.DATA_KEY,
                    "bucket_name": "does_not_exist_bucket_name",
                    "bucket_filename": "does_not_exist.csv",
                }
            },
        }
    }

    class TestPostprocess(AbstractNode):
        @staticmethod
        def necessary_config(node_config):
            return set(["key1", "key2"])

        def run(self, data_object):
            return data_object

    NodeFactory().register("TestPostprocess", TestPostprocess)

    # this is to mock out the boto connection
    os.environ["AWS_ACCESS_KEY_ID"] = "fake"
    os.environ["AWS_SECRET_ACCESS_KEY"] = "fake"
    conn = boto3.resource("s3")
    # We need to create the bucket since this is all in Moto's 'virtual' AWS account
    conn.create_bucket(Bucket="does_not_exist_bucket_name")

    reference_file_path = "test/minimal.csv"

    corpus = pd.read_csv(reference_file_path)

    configuration = Configuration(None,
                                  is_dict_config=True,
                                  dict_config=config)

    data_object = DataObject(configuration)

    requestor = TestPostprocess(configuration, "nodename")

    data_object.add(requestor, corpus)

    writer = S3Writer(configuration, "recipe_s3_writer")
    node_config = {
        "class": "S3Writer",
        "dir": "cache",
        "key": DataObject.DATA_KEY,
        "bucket_name": "does_not_exist_bucket_name",
        "bucket_filename": "does_not_exist.csv",
    }
    keys = writer.necessary_config(node_config)
    assert keys is not None
    assert isinstance(keys, set)
    assert len(keys) > 0

    # write to file
    filename = writer._write_locally(data_object)
    assert os.path.exists(filename)

    # check it is same data as expected
    reference = pd.read_csv(reference_file_path)
    just_written = pd.read_csv(filename)

    assert reference.equals(just_written)

    os.remove(filename)

    data_object = writer.run(data_object)
    body = (conn.Object(
        "does_not_exist_bucket_name",
        "does_not_exist.csv").get()["Body"].read().decode("utf-8"))
    assert body == open(reference_file_path).read()
def test_run():
    class TestModel(AbstractModel):
        @staticmethod
        def necessary_config(node_config):
            return set(["mode"])

        def train_model(self, data_object):
            logging.info("TRAIN called")
            return data_object

        def eval_model(self, data_object):
            logging.info("EVAL called")
            return data_object

        def predict(self, data_object):
            logging.info("PREDICT called")
            return data_object

    NodeFactory().register("TestModel", TestModel)

    config = {
        "implementation_config": {
            "reader_config": {
                "myreader": {
                    "class": "CsvReader",
                    "filename": "test/minimal.csv",
                    "destinations": ["mymodel"],
                }
            },
            "model_config": {
                "mymodel": {
                    "class": "TestModel",
                    "mode": "train"
                }
            },
        }
    }
    configuration = Configuration(config_location=None,
                                  is_dict_config=True,
                                  dict_config=config)

    data_object = DataObject(configuration)

    reader = CsvReader(configuration, "myreader")
    df = pd.read_csv("test/minimal.csv")

    data_object.add(reader, df)

    model = TestModel(configuration, "mymodel")

    with LogCapture() as l:
        model.run(data_object)
    l.check(
        ("root", "INFO", "TRAIN called"),
        ("root", "INFO", "EVAL called"),
        ("root", "INFO", "PREDICT called"),
    )

    config = {
        "implementation_config": {
            "reader_config": {
                "myreader": {
                    "class": "CsvReader",
                    "filename": "test/minimal.csv",
                    "destinations": ["mymodel"],
                }
            },
            "model_config": {
                "mymodel": {
                    "class": "TestModel",
                    "mode": "eval"
                }
            },
        }
    }
    configuration = Configuration(config_location=None,
                                  is_dict_config=True,
                                  dict_config=config)

    data_object = DataObject(configuration)

    reader = CsvReader(configuration, "myreader")

    data_object.add(reader, df)

    model = TestModel(configuration, "mymodel")

    with LogCapture() as l:
        model.run(data_object)
    l.check(("root", "INFO", "EVAL called"),
            ("root", "INFO", "PREDICT called"))
Beispiel #16
0
# Step 3:
# ========
# *Importantly*, this factory registration has to occur *before* Configuration is instantiated
# in the run_primrose script.
# To that end, we suggest putting this registration code below into something
# like `src/__init__.py` in your project. Wherever you put it, you will need to reference
# it in the run_primrose script.
#
# That is, if you put this code into `src/__init__.py`, you will need to add
#
#  from src.__init__ import *
#
# at the head of the run_primrose script.
#
#########################################################################################

import logging

logging.basicConfig(
    format="%(asctime)s %(levelname)s %(filename)s %(funcName)s: %(message)s",
    level=logging.INFO,
)

from primrose.node_factory import NodeFactory

# Add your imports here
from src.yourpackage.awesome_reader import AwesomeReader
from src.yourpackage.awesome_model import AwesomeModel

NodeFactory().register_module_classes(__name__)
Beispiel #17
0
    def create_dag(self):
        """Create the DAG

        Returns:
            nothing. Side effect is to set up graphs and node map

        """
        logging.info("Checking configuration DAG")
        G = nx.Graph()
        G2 = nx.DiGraph()
        node_names = set()
        cleanup_nodes = set()

        some_postprocess_node = None

        # key to section type
        node_map = {}

        self.conditional_nodes = set()

        # add the nodes to the graph:
        for section_key in self.config.keys():

            for key in self.config[section_key].keys():
                logging.debug("Adding node '%s'" % key)
                G.add_node(key)
                G2.add_node(key)
                node_names.add(key)

                node_map[key] = section_key

                # root out conditional nodes...
                node_config = self.config[section_key][key]
                node_class = node_config["class"]
                class_obj = NodeFactory().name_dict[node_class]
                if issubclass(class_obj, AbstractConditionalPath):
                    self.conditional_nodes.add(key)

                # cleanup section can be disconnected from rest of graph so let's keep track of these nodes
                if section_key == OperationType.cleanup.value:
                    cleanup_nodes.add(key)

                # hack: we are going to add an edge from a postprcocess step (any one) to cleanup nodes so they
                # are not a separate connected component
                if (section_key == OperationType.postprocess.value
                        and some_postprocess_node is None):
                    some_postprocess_node = key

        # add the edges
        for section_key in self.config.keys():

            for key in self.config[section_key].keys():
                d = self.config[section_key][key]

                if "destinations" in d:
                    for destination in d["destinations"]:

                        if not isinstance(destination, str):
                            raise ConfigurationError(
                                "Unrecognized destination type: %s" %
                                destination)

                        if destination in node_map:
                            ConfigurationDag.add_edge(G, G2, node_names, key,
                                                      destination)
                        else:
                            raise ConfigurationError(
                                "Did not find %s destination in %s.%s" %
                                (destination, section_key, key))

        logging.info("OK: good referential integrity")

        self.G = G
        self.G2 = G2
        self.node_map = node_map
Beispiel #18
0
    def check_config(self):
        """check the configuration as much as we can as early as we can

        Raises:
            various exceptions if any checks fail

        """
        self.check_metadata()

        self.check_sections()

        self.nodename_to_classname = {}

        unique_class_keys = set()

        self.instance_to_config = {}

        # check that all child nodes of each section have a Factory.CLASS_KEY field
        for section_key in self.config.keys():

            for child_key in self.config[section_key].keys():

                child = self.config[section_key][child_key]

                self.instance_to_config[child_key] = child

                if not NodeFactory.CLASS_KEY in child:
                    raise ConfigurationError("No class key found in %s.%s" %
                                             (section_key, child_key))

                self.nodename_to_classname[child_key] = child[
                    NodeFactory.CLASS_KEY]

                unique_class_keys.add((child[NodeFactory.CLASS_KEY],
                                       child.get(NodeFactory.CLASS_PREFIX)))

                for k in [
                        "destination_pipeline",
                        "destination_models",
                        "destination_postprocesses",
                        "destination_writer",
                ]:
                    if k in child:
                        raise Exception(
                            "Do you have a old config file? You have %s. Nodes just have 'destinations':[] now",
                            k,
                        )

        logging.info("OK: all class keys are present")

        # get class_prefixes by traversing node package
        unique_class_keys = self._traverse_node_package(unique_class_keys)
        unique_nodes = set([x[0] for x in unique_class_keys])

        # check that each referenced class is registered in NodeFactory
        # Regex pattern in `_traverse_node_package` should capture the right file (class prefix) to register
        # node (class_key). Here we attempt to register any class that is not already registered.
        for class_key, class_prefix in unique_class_keys:
            if not NodeFactory().is_registered(class_key):
                try:
                    logging.info(f"attempting to register {class_key}")
                    self._register_class(class_key, class_prefix)
                except:
                    logging.error(
                        f"Cannot register node class {class_key} with prefix {class_prefix}"
                    )
        for class_key in unique_nodes:
            if not NodeFactory().is_registered(class_key):
                raise ConfigurationError(
                    f"Cannot register node class {class_key}")

        # check necessary_configs
        for instance_name in self.nodename_to_classname:
            class_key = self.nodename_to_classname[instance_name]

            configuration_dict = self.instance_to_config[instance_name]

            instance = NodeFactory().instantiate(class_key, self,
                                                 instance_name)

            NodeFactory().valid_configuration(instance, configuration_dict)

        logging.info("OK: all classes recognized")
        logging.info("OK: good necessary_configs")

        # run our DAG checks. Throws error if not OK
        self.dag.check_dag()
Beispiel #19
0
    def run(self, dry_run=False):
        """run the whole DAG. Optonally, you can call dry_run=True
            which will log what would be run and in what order
            but not actually run it

        Args:
            dry_run: Boolean. Want to do a dry run?

        Returns:
            data_object: DataObject instance
            node (Node): last node run

        """
        data_object = self.create_data_object()

        candidate_sequence = self.dag_traverser.traversal_list()

        sequence = self.filter_sequence(candidate_sequence)

        if len(candidate_sequence) > len(sequence):
            logging.info("Sequence of nodes to be run: %s", sequence)

        pruned_nodes = set()

        if self.configuration.config_metadata and "notify_on_error" in self.configuration.config_metadata:
            try:
                params = self.configuration.config_metadata["notify_on_error"]
                slack_exception_label = params.get("message", "Job error")
                client = get_notification_client(params)

            except Exception as error:
                msg = "Error trying to instantiate notification client." 'Check class name and parameters"'
                logging.error(error)
                raise (msg)
        else:
            client = None

        for i, node in enumerate(sequence):

            if node in pruned_nodes:
                logging.info("Skipping pruned node " + node)
                continue

            section = self.dag.node_map[node]
            class_name = self.configuration.nodename_to_classname[node]

            if dry_run:
                logging.info(
                    "DRY RUN %s: would run node %s of type %s and class %s", i, node, section, class_name,
                )
                continue
            else:
                logging.info(
                    "received node %s of type %s and class %s", node, section, class_name,
                )

            try:
                node_instance = NodeFactory().instantiate(class_name, self.configuration, node)
            except Exception as e:
                msg = "Issue instantiating %s and class %s" % (node, class_name)
                logging.error(msg)
                if client:
                    client.post_message(f"{slack_exception_label}: {msg}")
                raise Exception(msg)

            try:
                data_object, terminate = node_instance.run(data_object)

                if isinstance(node_instance, AbstractConditionalPath):
                    to_prune = node_instance.all_nodes_to_prune()
                    if to_prune:
                        pruned_nodes.update(to_prune)

            except Exception as e:
                msg = "Issue with %s" % node
                logging.error(msg)
                if client:
                    client.post_message(f"{slack_exception_label}: {msg}\n{traceback.format_exc()}")
                raise e

            if terminate:
                logging.info("Terminating early due to signal from %s", node)
                if client:
                    client.post_message(f"{slack_exception_label}: {msg}")
                break

        self.cache_data_object(data_object)

        logging.info("All done. Bye bye!")
        return data_object
def test_init():
    class TestPostprocess(AbstractPostprocess):
        #        def __init__(self, configuration, instance_name):
        #            super(TestPostprocess, self).__init__(configuration, instance_name)
        @staticmethod
        def necessary_config(node_config):
            return set(["key1"])

        def run(self, data_object):
            data_object.add(self, "some data")
            return data_object, False

        # def process(self, data):
        #    return "some data"

    NodeFactory().register("TestPostprocess", TestPostprocess)

    class TestModel(AbstractNode):
        def __init__(self, configuration, instance_name):
            pass

        @staticmethod
        def necessary_config(node_config):
            return set(["key1"])

        def run(self, data_object):
            return data_object, False

    NodeFactory().register("TestModel", TestModel)

    config = {
        "implementation_config": {
            "model_config": {
                "modelname": {
                    "class": "TestModel",
                    "key1": "val1",
                    "destinations": ["nodename"],
                }
            },
            "postprocess_config": {
                "nodename": {
                    "class": "TestPostprocess",
                    "key1": "val1",
                    "key2": "val2",
                    "destinations": [],
                }
            },
        }
    }
    configuration = Configuration(config_location=None,
                                  is_dict_config=True,
                                  dict_config=config)

    data_object = DataObject(configuration)

    tp = TestPostprocess(configuration, "nodename")
    node_config = {
        "class": "TestPostprocess",
        "key1": "val1",
        "key2": "val2",
        "destinations": [],
    }
    assert set(["key1"]) == tp.necessary_config(node_config)

    # assert tp.process(None) == "some data"

    data_object, terminate = tp.run(data_object)

    assert not terminate

    assert (data_object.get(
        "nodename", rtype=DataObjectResponseType.VALUE.value) == "some data")