def test_run_bad(): class TestWriterTmp(AbstractWriter): @staticmethod def necessary_config(node_config): return set([]) def run(self, data_object): return data_object, False NodeFactory().register("TestWriterTmp", TestWriterTmp) config = { "implementation_config": { "writer_config": { "mywriter": { "class": "TestWriterTmp", "destinations": [] } } } } configuration = Configuration(None, is_dict_config=True, dict_config=config) runner = DagRunner(configuration) # unregister this class del NodeFactory().name_dict["TestWriterTmp"] with pytest.raises(Exception) as e: runner.run() assert "Issue instantiating mywriter and class TestWriterTmp" in str(e)
def _register_class(self, class_key, class_prefix): """Register a class specified in the config file. Args: class_key (str): class key to register class_prefix(str): the prefix of the class to register. Can be in `path.to.module` format, or a full path `path/to/module`. Returns: None - attempts to register the class with it's default name """ # convert to string before checking if file if class_prefix is None: class_prefix = "" if os.path.isfile(class_prefix): modulename = self._import_file(class_key, class_prefix) # loading from module else: if self.config_metadata and "class_package" in self.config_metadata: class_package = self.config_metadata["class_package"] prefix = ".".join(filter(None, [class_package, class_prefix])) else: prefix = class_prefix modulename = importlib.import_module(prefix) clz = getattr(modulename, class_key) NodeFactory().register(None, clz)
def test_run_node(self): path = "primrose.notifications.success_notification.get_notification_client" with mock.patch(path) as get_client_mock: get_client_mock.return_value = mock.Mock() NodeFactory().register("SlackDataMock", SlackDataMock) config = Configuration(None, is_dict_config=True, dict_config=config_dict_node_message) data_object = DataObject(config) reader = SlackDataMock(config, "test_node") data_object = reader.run(data_object) success_instance = ClientNotification( configuration=config, instance_name="node_notification", ) success_instance.client = get_client_mock.return_value success_instance.run(data_object) success_instance.client.post_message.assert_called_once_with( message="Node Success!")
def test_class_package(mock_env): config_path = { "metadata": {"class_package": "test"}, "implementation_config": { "reader_config": {"read_data": {"class": "TestExtNode", "destinations": []}} }, } config_full_path = { "metadata": {"class_package": "test/ext_node_example.py"}, "implementation_config": { "reader_config": {"read_data": {"class": "TestExtNode", "destinations": []}} }, } config_full_dot = { "metadata": {"class_package": "test"}, "implementation_config": { "reader_config": { "read_data": { "class": "TestExtNode", "class_prefix": "ext_node_example", "destinations": [], } } }, } for config in [config_full_path, config_path, config_full_dot]: config = Configuration( config_location=None, is_dict_config=True, dict_config=config ) assert config.config_string assert config.config_hash NodeFactory().unregister("TestExtNode")
def test_run_bad2(): class TestWriterTmp(AbstractWriter): @staticmethod def necessary_config(node_config): return set([]) def run(self, data_object): raise Exception("Deliberate error") # return data_object, False NodeFactory().register("TestWriterTmp", TestWriterTmp) config = { "implementation_config": { "writer_config": { "mywriter": { "class": "TestWriterTmp", "destinations": [] } } } } configuration = Configuration(None, is_dict_config=True, dict_config=config) runner = DagRunner(configuration) with pytest.raises(Exception) as e: runner.run() assert "Deliberate error" in str(e)
def test_register_module_classes(): with LogCapture() as l: NodeFactory().register_module_classes(__name__) l.check( ( "root", "INFO", "Discovered class CsvReader (<class " "'primrose.readers.csv_reader.CsvReader'>)", ), ( "root", "DEBUG", "Registered CsvReader : <class 'primrose.readers.csv_reader.CsvReader'>", ), ( "root", "INFO", "Discovered class DillWriter (<class " "'primrose.writers.dill_writer.DillWriter'>)", ), ( "root", "DEBUG", "Registered DillWriter : <class 'primrose.writers.dill_writer.DillWriter'>", ), )
def test_run4(): class TestWriter(AbstractFileWriter): def __init__(self, configuration, instance_name): pass @staticmethod def necessary_config(node_config): return set([]) def run(self, data_object): terminate = True return data_object, terminate NodeFactory().register("TestWriter", TestWriter) config = { "implementation_config": { "reader_config": { "csv_reader": { "class": "CsvReader", "filename": "test/minimal.csv", "destinations": ["csv_writer"], } }, "writer_config": { "csv_writer": { "class": "TestWriter" } }, } } configuration = Configuration(None, is_dict_config=True, dict_config=config) runner = DagRunner(configuration) with LogCapture() as l: runner.run(dry_run=False) l.check( ("root", "INFO", "Taking nodes to run from default"), ( "root", "INFO", "received node csv_reader of type reader_config and class CsvReader", ), ("root", "INFO", "Reading test/minimal.csv from CSV"), ( "root", "INFO", "received node csv_writer of type writer_config and class TestWriter", ), ("root", "INFO", "Terminating early due to signal from csv_writer"), ("root", "INFO", "All done. Bye bye!"), )
def test_env_override_class_package(mock_env): config = { "metadata": {"class_package": "junk"}, "implementation_config": { "reader_config": {"read_data": {"class": "TestExtNode", "destinations": []}} }, } config = Configuration( config_location=None, is_dict_config=True, dict_config=config ) assert config.config_string assert config.config_hash NodeFactory().unregister("TestExtNode")
def test_all_nodes_to_prune2(): class TestConditionalNode(AbstractConditionalPath): @staticmethod def necessary_config(node_config): return set() def destinations_to_prune(self): return ["junk"] def run(self, data_object): return data_object, False NodeFactory().register("TestConditionalNode", TestConditionalNode) config = { "implementation_config": { "reader_config": { "conditional_node": { "class": "TestConditionalNode", "destinations": ["csv_writer"], } }, "writer_config": { "csv_writer": { "class": "CsvWriter", "key": "test_data", "dir": "cache", "filename": "unittest_similar_recipes.csv", } }, } } configuration = Configuration(None, is_dict_config=True, dict_config=config) node = TestConditionalNode(configuration, "conditional_node") with pytest.raises(Exception) as e: node.all_nodes_to_prune() assert "Destination junk is not in destinations list" in str(e)
def test_run(): class TestPipeline(AbstractPipeline): def transform(self, data_object): logging.info("TRANSFORM CALLED") return data_object def fit_transform(self, data_object): logging.info("FIT_TRANSFORM CALLED") return self.transform(data_object) @staticmethod def necessary_config(node_config): return set(["is_training"]) NodeFactory().register("TestPipeline", TestPipeline) config = { "implementation_config": { "reader_config": { "myreader": { "class": "CsvReader", "filename": "test/minimal.csv", "destinations": ["mypipeline"], } }, "pipeline_config": { "mypipeline": { "class": "TestPipeline", "is_training": True } }, } } configuration = Configuration(config_location=None, is_dict_config=True, dict_config=config) reference_file_path = "test/minimal.csv" corpus = pd.read_csv(reference_file_path) reader = CsvReader(configuration, "myreader") data_object = DataObject(configuration) data_object.add(reader, corpus) pipeline = TestPipeline(configuration, "mypipeline") with LogCapture() as l: pipeline.run(data_object) l.check( ( "root", "INFO", "No upstream TransformerSequence found. Creating new TransformerSequence...", ), ("root", "INFO", "FIT_TRANSFORM CALLED"), ("root", "INFO", "TRANSFORM CALLED"), ) data_object.add(reader, TransformerSequence(), "tsequence") with LogCapture() as l: pipeline.run(data_object) l.check( ( "root", "INFO", "Upstream TransformerSequence found, initializing pipeline...", ), ("root", "INFO", "FIT_TRANSFORM CALLED"), ("root", "INFO", "TRANSFORM CALLED"), ) config["implementation_config"]["pipeline_config"]["mypipeline"][ "is_training"] = False configuration = Configuration(config_location=None, is_dict_config=True, dict_config=config) reader = CsvReader(configuration, "myreader") data_object = DataObject(configuration) data_object.add(reader, corpus) pipeline = TestPipeline(configuration, "mypipeline") with LogCapture() as l: pipeline.run(data_object) l.check( ( "root", "INFO", "No upstream TransformerSequence found. Creating new TransformerSequence...", ), ("root", "INFO", "TRANSFORM CALLED"), )
def test_execute_pipeline(): class TestTransformer(AbstractTransformer): def fit(self, data): logging.info("Transfer FIT CALLED") def transform(self, data): logging.info("Transfer TRANSFORM CALLED") return data def fit_transform(self, data): logging.info("Transfer FIT_TRANSFORM CALLED") self.fit(data) return self.transform(data) class TestPipeline2(AbstractPipeline): def transform(self, data_object): return data_object @staticmethod def necessary_config(node_config): return set(["is_training"]) NodeFactory().register("TestPipeline2", TestPipeline2) config = { "implementation_config": { "reader_config": { "myreader": { "class": "CsvReader", "filename": "test/minimal.csv", "destinations": ["mypipeline"], } }, "pipeline_config": { "mypipeline": { "class": "TestPipeline", "is_training": True } }, } } configuration = Configuration(config_location=None, is_dict_config=True, dict_config=config) reference_file_path = "test/minimal.csv" corpus = pd.read_csv(reference_file_path) reader = CsvReader(configuration, "myreader") data_object = DataObject(configuration) data_object.add(reader, corpus) sequence = TransformerSequence() sequence.add(TestTransformer()) data_object.add(reader, sequence, "tsequence") pipeline = TestPipeline2(configuration, "mypipeline") with pytest.raises(Exception) as e: pipeline.execute_pipeline(corpus, PipelineModeType.FIT) assert "run() must be called to extract/create a TransformerSequence" in str( e) pipeline.run(data_object) with pytest.raises(Exception) as e: pipeline.execute_pipeline(corpus, "JUNK") assert "mode must be of type PipelineModeType Enum object." in str(e) with LogCapture() as l: pipeline.execute_pipeline(corpus, PipelineModeType.FIT) l.check(("root", "INFO", "Transfer FIT CALLED")) with LogCapture() as l: pipeline.execute_pipeline(corpus, PipelineModeType.FIT_TRANSFORM) l.check( ("root", "INFO", "Transfer FIT_TRANSFORM CALLED"), ("root", "INFO", "Transfer FIT CALLED"), ("root", "INFO", "Transfer TRANSFORM CALLED"), )
def test_plot_dag(): class TestPostprocess(AbstractNode): @staticmethod def necessary_config(node_config): return set(["key1", "key2"]) def run(self, data_object): return data_object NodeFactory().register("TestPostprocess", TestPostprocess) class Testpipeline(AbstractNode): # def __init__(self, configuration, instance_name): # pass @staticmethod def necessary_config(node_config): return set([]) def run(self, data_object): return data_object NodeFactory().register("Testpipeline", Testpipeline) class TestCleanup(AbstractNode): @staticmethod def necessary_config(node_config): return set([]) def run(self, data_object): return data_object NodeFactory().register("TestCleanup", TestCleanup) config = { "implementation_config": { "reader_config": { "csv_reader": { "class": "CsvReader", "filename": "some/path/to/file", "destinations": ["pipeline1"], } }, "pipeline_config": { "pipeline1": { "class": "Testpipeline", "destinations": ["decision_tree_model"], } }, "model_config": { "decision_tree_model": { "class": "SklearnClassifierModel", "model_parameters": {}, "sklearn_classifier_name": "tree.DecisionTreeClassifier", "grid_search_scoring": "roc_auc", "cv_folds": 3, "mode": "predict", "destinations": ["nodename"], } }, "postprocess_config": { "nodename": { "class": "TestPostprocess", "key1": "val1", "key2": "val2", "destinations": ["write_output"], } }, "writer_config": { "write_output": { "class": "CsvWriter", "key": "read_data", "dir": "cache", "filename": "some/path/to/file.csv", "destinations": ["donothingsuccess"], } }, "cleanup_config": { "donothingsuccess": { "class": "TestCleanup", } }, } } cfilename = "test/test_dag_plotting.json" with open(cfilename, "w") as f: jstyleson.dump(config, f) config = Configuration(config_location=cfilename) dag = config.dag dag.create_dag() filename = "test/test_dag_plotting.png" if os.path.exists(filename): os.remove(filename) dag.plot_dag(filename, traverser=ConfigLayerTraverser(config)) assert os.path.exists(filename) if os.path.exists(cfilename): os.remove(cfilename) if os.path.exists(filename): os.remove(filename)
def test_cosine_similarity_matrix(): class Testpipeline(AbstractNode): def __init__(self, configuration, instance_name): self.configuration = configuration self.instance_name = instance_name @staticmethod def necessary_config(node_config): return set([]) def run(self, data_object): return data_object, False NodeFactory().register("Testpipeline", Testpipeline) class TestSimpleSearchEngine(AbstractSearchEngine): """ simple TFIDF search engine """ def __init__(self, configuration, instance_name): AbstractSearchEngine.__init__(self, configuration, instance_name) def tokenize(self, s, stopwords=[], add_ngrams=True): q = s.lower() tokens = (q.replace("-", " ").replace(",", "").replace("(", "").replace( ")", "").split(" ")) tokens = [w for w in tokens if w not in stopwords] if add_ngrams: bigrams = list(ngrams(tokens, 2)) strbigrams = ["_".join(t) for t in bigrams] tokens.extend(strbigrams) return tokens def eval_model(data_object): return data_object NodeFactory().register("TestSimpleSearchEngine", TestSimpleSearchEngine) config = { "implementation_config": { "pipeline_config": { "pipeline1": { "class": "Testpipeline", "destinations": ["recipe_name_model"], } }, "model_config": { "recipe_name_model": { "class": "TestSimpleSearchEngine", "id_key": "id", "doc_key": "name", "mode": "precict", "destinations": [], } }, } } configuration = Configuration(None, is_dict_config=True, dict_config=config) # set that pipeline provided the corpus corpus = [ { "id": 1, "name": "spinach omelet" }, { "id": 2, "name": "kale omelet" }, { "id": 3, "name": "cherry pie" }, ] data_object = DataObject(configuration) pipeline = Testpipeline(configuration, "pipeline1") data_object.add(pipeline, pd.DataFrame(corpus)) engine = TestSimpleSearchEngine(configuration, "recipe_name_model") engine.predict(data_object) m = engine.cosine_similarity_matrix() assert math.isclose(m[0, 0], 1.0, abs_tol=0.001) assert math.isclose(m[0, 1], 0.224325, abs_tol=0.001) assert math.isclose(m[0, 2], 0.0, abs_tol=0.001) assert math.isclose(m[1, 0], 0.224325, abs_tol=0.001) assert math.isclose(m[1, 1], 1.0, abs_tol=0.001) assert math.isclose(m[1, 2], 0.0, abs_tol=0.001) assert math.isclose(m[2, 0], 0.0, abs_tol=0.001) assert math.isclose(m[2, 1], 0.0, abs_tol=0.001) assert math.isclose(m[2, 2], 1.0, abs_tol=0.001) assert engine.ids == [1, 2, 3] assert engine.docs == ["spinach omelet", "kale omelet", "cherry pie"] assert engine.tfidf is not None
def test_init_ok(): config = { "implementation_config": { "postprocess_config": { "nodename": { "class": "TestPostprocess", "key1": "val1", "key2": "val2", "destinations": ["recipe_s3_writer"], } }, "writer_config": { "recipe_s3_writer": { "class": "S3Writer", "dir": "cache", "key": DataObject.DATA_KEY, "bucket_name": "does_not_exist_bucket_name", "bucket_filename": "does_not_exist.csv", } }, } } class TestPostprocess(AbstractNode): @staticmethod def necessary_config(node_config): return set(["key1", "key2"]) def run(self, data_object): return data_object NodeFactory().register("TestPostprocess", TestPostprocess) # this is to mock out the boto connection os.environ["AWS_ACCESS_KEY_ID"] = "fake" os.environ["AWS_SECRET_ACCESS_KEY"] = "fake" conn = boto3.resource("s3") # We need to create the bucket since this is all in Moto's 'virtual' AWS account conn.create_bucket(Bucket="does_not_exist_bucket_name") reference_file_path = "test/minimal.csv" corpus = pd.read_csv(reference_file_path) configuration = Configuration(None, is_dict_config=True, dict_config=config) data_object = DataObject(configuration) requestor = TestPostprocess(configuration, "nodename") data_object.add(requestor, corpus) writer = S3Writer(configuration, "recipe_s3_writer") node_config = { "class": "S3Writer", "dir": "cache", "key": DataObject.DATA_KEY, "bucket_name": "does_not_exist_bucket_name", "bucket_filename": "does_not_exist.csv", } keys = writer.necessary_config(node_config) assert keys is not None assert isinstance(keys, set) assert len(keys) > 0 # write to file filename = writer._write_locally(data_object) assert os.path.exists(filename) # check it is same data as expected reference = pd.read_csv(reference_file_path) just_written = pd.read_csv(filename) assert reference.equals(just_written) os.remove(filename) data_object = writer.run(data_object) body = (conn.Object( "does_not_exist_bucket_name", "does_not_exist.csv").get()["Body"].read().decode("utf-8")) assert body == open(reference_file_path).read()
def test_run(): class TestModel(AbstractModel): @staticmethod def necessary_config(node_config): return set(["mode"]) def train_model(self, data_object): logging.info("TRAIN called") return data_object def eval_model(self, data_object): logging.info("EVAL called") return data_object def predict(self, data_object): logging.info("PREDICT called") return data_object NodeFactory().register("TestModel", TestModel) config = { "implementation_config": { "reader_config": { "myreader": { "class": "CsvReader", "filename": "test/minimal.csv", "destinations": ["mymodel"], } }, "model_config": { "mymodel": { "class": "TestModel", "mode": "train" } }, } } configuration = Configuration(config_location=None, is_dict_config=True, dict_config=config) data_object = DataObject(configuration) reader = CsvReader(configuration, "myreader") df = pd.read_csv("test/minimal.csv") data_object.add(reader, df) model = TestModel(configuration, "mymodel") with LogCapture() as l: model.run(data_object) l.check( ("root", "INFO", "TRAIN called"), ("root", "INFO", "EVAL called"), ("root", "INFO", "PREDICT called"), ) config = { "implementation_config": { "reader_config": { "myreader": { "class": "CsvReader", "filename": "test/minimal.csv", "destinations": ["mymodel"], } }, "model_config": { "mymodel": { "class": "TestModel", "mode": "eval" } }, } } configuration = Configuration(config_location=None, is_dict_config=True, dict_config=config) data_object = DataObject(configuration) reader = CsvReader(configuration, "myreader") data_object.add(reader, df) model = TestModel(configuration, "mymodel") with LogCapture() as l: model.run(data_object) l.check(("root", "INFO", "EVAL called"), ("root", "INFO", "PREDICT called"))
# Step 3: # ======== # *Importantly*, this factory registration has to occur *before* Configuration is instantiated # in the run_primrose script. # To that end, we suggest putting this registration code below into something # like `src/__init__.py` in your project. Wherever you put it, you will need to reference # it in the run_primrose script. # # That is, if you put this code into `src/__init__.py`, you will need to add # # from src.__init__ import * # # at the head of the run_primrose script. # ######################################################################################### import logging logging.basicConfig( format="%(asctime)s %(levelname)s %(filename)s %(funcName)s: %(message)s", level=logging.INFO, ) from primrose.node_factory import NodeFactory # Add your imports here from src.yourpackage.awesome_reader import AwesomeReader from src.yourpackage.awesome_model import AwesomeModel NodeFactory().register_module_classes(__name__)
def create_dag(self): """Create the DAG Returns: nothing. Side effect is to set up graphs and node map """ logging.info("Checking configuration DAG") G = nx.Graph() G2 = nx.DiGraph() node_names = set() cleanup_nodes = set() some_postprocess_node = None # key to section type node_map = {} self.conditional_nodes = set() # add the nodes to the graph: for section_key in self.config.keys(): for key in self.config[section_key].keys(): logging.debug("Adding node '%s'" % key) G.add_node(key) G2.add_node(key) node_names.add(key) node_map[key] = section_key # root out conditional nodes... node_config = self.config[section_key][key] node_class = node_config["class"] class_obj = NodeFactory().name_dict[node_class] if issubclass(class_obj, AbstractConditionalPath): self.conditional_nodes.add(key) # cleanup section can be disconnected from rest of graph so let's keep track of these nodes if section_key == OperationType.cleanup.value: cleanup_nodes.add(key) # hack: we are going to add an edge from a postprcocess step (any one) to cleanup nodes so they # are not a separate connected component if (section_key == OperationType.postprocess.value and some_postprocess_node is None): some_postprocess_node = key # add the edges for section_key in self.config.keys(): for key in self.config[section_key].keys(): d = self.config[section_key][key] if "destinations" in d: for destination in d["destinations"]: if not isinstance(destination, str): raise ConfigurationError( "Unrecognized destination type: %s" % destination) if destination in node_map: ConfigurationDag.add_edge(G, G2, node_names, key, destination) else: raise ConfigurationError( "Did not find %s destination in %s.%s" % (destination, section_key, key)) logging.info("OK: good referential integrity") self.G = G self.G2 = G2 self.node_map = node_map
def check_config(self): """check the configuration as much as we can as early as we can Raises: various exceptions if any checks fail """ self.check_metadata() self.check_sections() self.nodename_to_classname = {} unique_class_keys = set() self.instance_to_config = {} # check that all child nodes of each section have a Factory.CLASS_KEY field for section_key in self.config.keys(): for child_key in self.config[section_key].keys(): child = self.config[section_key][child_key] self.instance_to_config[child_key] = child if not NodeFactory.CLASS_KEY in child: raise ConfigurationError("No class key found in %s.%s" % (section_key, child_key)) self.nodename_to_classname[child_key] = child[ NodeFactory.CLASS_KEY] unique_class_keys.add((child[NodeFactory.CLASS_KEY], child.get(NodeFactory.CLASS_PREFIX))) for k in [ "destination_pipeline", "destination_models", "destination_postprocesses", "destination_writer", ]: if k in child: raise Exception( "Do you have a old config file? You have %s. Nodes just have 'destinations':[] now", k, ) logging.info("OK: all class keys are present") # get class_prefixes by traversing node package unique_class_keys = self._traverse_node_package(unique_class_keys) unique_nodes = set([x[0] for x in unique_class_keys]) # check that each referenced class is registered in NodeFactory # Regex pattern in `_traverse_node_package` should capture the right file (class prefix) to register # node (class_key). Here we attempt to register any class that is not already registered. for class_key, class_prefix in unique_class_keys: if not NodeFactory().is_registered(class_key): try: logging.info(f"attempting to register {class_key}") self._register_class(class_key, class_prefix) except: logging.error( f"Cannot register node class {class_key} with prefix {class_prefix}" ) for class_key in unique_nodes: if not NodeFactory().is_registered(class_key): raise ConfigurationError( f"Cannot register node class {class_key}") # check necessary_configs for instance_name in self.nodename_to_classname: class_key = self.nodename_to_classname[instance_name] configuration_dict = self.instance_to_config[instance_name] instance = NodeFactory().instantiate(class_key, self, instance_name) NodeFactory().valid_configuration(instance, configuration_dict) logging.info("OK: all classes recognized") logging.info("OK: good necessary_configs") # run our DAG checks. Throws error if not OK self.dag.check_dag()
def run(self, dry_run=False): """run the whole DAG. Optonally, you can call dry_run=True which will log what would be run and in what order but not actually run it Args: dry_run: Boolean. Want to do a dry run? Returns: data_object: DataObject instance node (Node): last node run """ data_object = self.create_data_object() candidate_sequence = self.dag_traverser.traversal_list() sequence = self.filter_sequence(candidate_sequence) if len(candidate_sequence) > len(sequence): logging.info("Sequence of nodes to be run: %s", sequence) pruned_nodes = set() if self.configuration.config_metadata and "notify_on_error" in self.configuration.config_metadata: try: params = self.configuration.config_metadata["notify_on_error"] slack_exception_label = params.get("message", "Job error") client = get_notification_client(params) except Exception as error: msg = "Error trying to instantiate notification client." 'Check class name and parameters"' logging.error(error) raise (msg) else: client = None for i, node in enumerate(sequence): if node in pruned_nodes: logging.info("Skipping pruned node " + node) continue section = self.dag.node_map[node] class_name = self.configuration.nodename_to_classname[node] if dry_run: logging.info( "DRY RUN %s: would run node %s of type %s and class %s", i, node, section, class_name, ) continue else: logging.info( "received node %s of type %s and class %s", node, section, class_name, ) try: node_instance = NodeFactory().instantiate(class_name, self.configuration, node) except Exception as e: msg = "Issue instantiating %s and class %s" % (node, class_name) logging.error(msg) if client: client.post_message(f"{slack_exception_label}: {msg}") raise Exception(msg) try: data_object, terminate = node_instance.run(data_object) if isinstance(node_instance, AbstractConditionalPath): to_prune = node_instance.all_nodes_to_prune() if to_prune: pruned_nodes.update(to_prune) except Exception as e: msg = "Issue with %s" % node logging.error(msg) if client: client.post_message(f"{slack_exception_label}: {msg}\n{traceback.format_exc()}") raise e if terminate: logging.info("Terminating early due to signal from %s", node) if client: client.post_message(f"{slack_exception_label}: {msg}") break self.cache_data_object(data_object) logging.info("All done. Bye bye!") return data_object
def test_init(): class TestPostprocess(AbstractPostprocess): # def __init__(self, configuration, instance_name): # super(TestPostprocess, self).__init__(configuration, instance_name) @staticmethod def necessary_config(node_config): return set(["key1"]) def run(self, data_object): data_object.add(self, "some data") return data_object, False # def process(self, data): # return "some data" NodeFactory().register("TestPostprocess", TestPostprocess) class TestModel(AbstractNode): def __init__(self, configuration, instance_name): pass @staticmethod def necessary_config(node_config): return set(["key1"]) def run(self, data_object): return data_object, False NodeFactory().register("TestModel", TestModel) config = { "implementation_config": { "model_config": { "modelname": { "class": "TestModel", "key1": "val1", "destinations": ["nodename"], } }, "postprocess_config": { "nodename": { "class": "TestPostprocess", "key1": "val1", "key2": "val2", "destinations": [], } }, } } configuration = Configuration(config_location=None, is_dict_config=True, dict_config=config) data_object = DataObject(configuration) tp = TestPostprocess(configuration, "nodename") node_config = { "class": "TestPostprocess", "key1": "val1", "key2": "val2", "destinations": [], } assert set(["key1"]) == tp.necessary_config(node_config) # assert tp.process(None) == "some data" data_object, terminate = tp.run(data_object) assert not terminate assert (data_object.get( "nodename", rtype=DataObjectResponseType.VALUE.value) == "some data")