class DataSetRepoFactory(object): def __init__(self, smvApp): self.smvApp = smvApp def createRepo(self): return DataSetRepo(self.smvApp) getCreateRepo = create_py4j_interface_method("getCreateRepo", "createRepo") class Java: implements = ['org.tresamigos.smv.IDataSetRepoFactoryPy4J']
class SmvOutput(object): """Mixin which marks an SmvModule as one of the output of its stage SmvOutputs are distinct from other SmvDataSets in that * SmvModuleLinks can *only* link to SmvOutputs * The -s and --run-app options of smv-pyrun only run SmvOutputs and their dependencies. """ IsSmvOutput = True def tableName(self): """The user-specified table name used when exporting data to Hive (optional) Returns: (string) """ return None getTableName = create_py4j_interface_method("getTableName", "tableName")
class SmvDataSet(object): """Abstract base class for all SmvDataSets """ # Python's issubclass() check does not work well with dynamically # loaded modules. In addition, there are some issues with the # check, when the `abc` module is used as a metaclass, that we # don't yet quite understand. So for a workaround we add the # typcheck in the Smv hierarchies themselves. IsSmvDataSet = True __metaclass__ = abc.ABCMeta def __init__(self, smvApp): self.smvApp = smvApp def description(self): return self.__doc__ # this doesn't need stack trace protection @abc.abstractmethod def requiresDS(self): """User-specified list of dependencies Override this method to specify the SmvDataSets needed as inputs. Returns: (list(SmvDataSet)): a list of dependencies """ # this doesn't need stacktrace protection def dqm(self): """DQM policy Override this method to define your own DQM policy (optional). Default is an empty policy. Returns: (SmvDQM): a DQM policy """ return SmvDQM() @abc.abstractmethod def doRun(self, validator, known): """Compute this dataset, and return the dataframe""" getDoRun = create_py4j_interface_method("getDoRun", "doRun") def assert_result_is_dataframe(self, result): if not isinstance(result, DataFrame): raise SmvRuntimeError(self.fqn() + " produced " + type(result).__name__ + " in place of a DataFrame") def version(self): """Version number Each SmvDataSet is versioned with a numeric string, so it and its result can be tracked together. Returns: (str): version number of this SmvDataSet """ return "0" def isOutput(self): return isinstance(self, SmvOutput) getIsOutput = create_py4j_interface_method("getIsOutput", "isOutput") # Note that the Scala SmvDataSet will combine sourceCodeHash and instanceValHash # to compute datasetHash def sourceCodeHash(self): """Hash computed based on the source code of the dataset's class """ cls = self.__class__ try: src = inspect.getsource(cls) src_no_comm = _stripComments(src) # DO NOT use the compiled byte code for the hash computation as # it doesn't change when constant values are changed. For example, # "a = 5" and "a = 6" compile to same byte code. # co_code = compile(src, inspect.getsourcefile(cls), 'exec').co_code res = _smvhash(src_no_comm) except Exception as err: # `inspect` will raise error for classes defined in the REPL # Instead of handle the case that module defined in REPL, just raise Exception here # res = _smvhash(_disassemble(cls)) traceback.print_exc() message = "{0}({1!r})".format(type(err).__name__, err.args) raise Exception(message + "\n" + "SmvDataSet " + self.urn() + " defined in shell can't be persisted") # include sourceCodeHash of parent classes for m in inspect.getmro(cls): try: if m.IsSmvDataSet and m != cls and not m.fqn().startswith( "smv."): res += m(self.smvApp).sourceCodeHash() except: pass # if module inherits from SmvRunConfig, then add hash of all config values to module hash if hasattr(self, "_smvGetRunConfigHash"): res += self._smvGetRunConfigHash() # ensure python's numeric type can fit in a java.lang.Integer return res & 0x7fffffff getSourceCodeHash = create_py4j_interface_method("getSourceCodeHash", "sourceCodeHash") def instanceValHash(self): """Hash computed based on instance values of the dataset, such as the timestamp of an input file """ return 0 getInstanceValHash = create_py4j_interface_method("getInstanceValHash", "instanceValHash") @classmethod def fqn(cls): """Returns the fully qualified name """ return cls.__module__ + "." + cls.__name__ getFqn = create_py4j_interface_method("getFqn", "fqn") @classmethod def urn(cls): return "mod:" + cls.fqn() def isEphemeral(self): """Should this SmvDataSet skip persisting its data? Returns: (bool): True if this SmvDataSet should not persist its data, false otherwise """ return False getIsEphemeral = create_py4j_interface_method("getIsEphemeral", "isEphemeral") def publishHiveSql(self): """An optional sql query to run to publish the results of this module when the --publish-hive command line is used. The DataFrame result of running this module will be available to the query as the "dftable" table. Example: >>> return "insert overwrite table mytable select * from dftable" Note: If this method is not specified, the default is to just create the table specified by tableName() with the results of the module. Returns: (string): the query to run. """ return None getPublishHiveSql = create_py4j_interface_method("getPublishHiveSql", "publishHiveSql") @abc.abstractmethod def dsType(self): """Return SmvDataSet's type""" getDsType = create_py4j_interface_method("getDsType", "dsType") def dqmWithTypeSpecificPolicy(self): return self.dqm() getDqmWithTypeSpecificPolicy = create_py4j_interface_method( "getDqmWithTypeSpecificPolicy", "dqmWithTypeSpecificPolicy") def dependencies(self): """Can be overridden when a module has non-SmvDataSet dependencies (see SmvModelExec) """ return self.requiresDS() def dependencyUrns(self): arr = [x.urn() for x in self.dependencies()] return smv_copy_array(self.smvApp.sc, *arr) getDependencyUrns = create_py4j_interface_method("getDependencyUrns", "dependencyUrns") @classmethod def df2result(cls, df): """Given a datasets's persisted DataFrame, get the result object In most cases, this is just the DataFrame itself. See SmvResultModule for the exception. """ return df class Java: implements = ['org.tresamigos.smv.ISmvModule']
class SmvDataSet(ABC): """Abstract base class for all SmvDataSets """ # Python's issubclass() check does not work well with dynamically # loaded modules. In addition, there are some issues with the # check, when the `abc` module is used as a metaclass, that we # don't yet quite understand. So for a workaround we add the # typcheck in the Smv hierarchies themselves. IsSmvDataSet = True def __init__(self, smvApp): self.smvApp = smvApp def smvGetRunConfig(self, key): """return the current user run configuration value for the given key.""" return self.smvApp.getConf(key) def smvGetRunConfigAsInt(self, key): runConfig = self.smvGetRunConfig(key); if runConfig is None: return None return int(runConfig) def smvGetRunConfigAsBool(self, key): runConfig = self.smvGetRunConfig(key); if runConfig is None: return None sval = runConfig.strip().lower() return (sval == "1" or sval == "true") def config_hash(self): """Integer value representing the SMV config's contribution to the dataset hash Only the keys declared in requiresConfig will be considered. """ kvs = [(k, self.smvGetRunConfig(k)) for k in self.requiresConfig()] # the config_hash should change IFF the config changes # sort keys to ensure config hash is independent from key order sorted_kvs = sorted(kvs) # we need a unique string representation of sorted_kvs to hash # repr should change iff sorted_kvs changes kv_str = repr(sorted_kvs) return _smvhash(kv_str) def description(self): return self.__doc__ getDescription = create_py4j_interface_method("getDescription", "description") @abc.abstractmethod def requiresDS(self): """User-specified list of dependencies Override this method to specify the SmvDataSets needed as inputs. Returns: (list(SmvDataSet)): a list of dependencies """ pass def requiresConfig(self): """User-specified list of config keys this module depends on The given keys and their values will influence the dataset hash """ return [] def requiresLib(self): """User-specified list of 'library' dependencies. These are code, other than the DataSet's run method that impact its output or behaviour. Override this method to assist in re-running this module based on changes in other python objects (functions, classes, packages). Limitations: For python modules and packages, the 'requiresLib()' method is limited to registering changes on the main file of the package (for module 'foo', that's 'foo.py', for package 'bar', that's 'bar/__init__.py'). This means that if a module or package imports other modules, the imported module's changes will not impact DataSet hashes. Returns: (list(module)): a list of library dependencies """ return [] def dqm(self): """DQM policy Override this method to define your own DQM policy (optional). Default is an empty policy. Returns: (SmvDQM): a DQM policy """ return SmvDQM() @abc.abstractmethod def doRun(self, validator, known): """Compute this dataset, and return the dataframe""" getDoRun = create_py4j_interface_method("getDoRun", "doRun") def assert_result_is_dataframe(self, result): if not isinstance(result, DataFrame): raise SmvRuntimeError( self.fqn() + " produced " + type(result).__name__ + " in place of a DataFrame" ) def version(self): """Version number Each SmvDataSet is versioned with a numeric string, so it and its result can be tracked together. Returns: (str): version number of this SmvDataSet """ return "0" def isOutput(self): return isinstance(self, SmvOutput) getIsOutput = create_py4j_interface_method("getIsOutput", "isOutput") # Note that the Scala SmvDataSet will combine sourceCodeHash and instanceValHash # to compute datasetHash def sourceCodeHash(self): """Hash computed based on the source code of the dataset's class """ cls = self.__class__ # get hash of module's source code text try: res = _sourceHash(cls) except Exception as err: # `inspect` will raise error for classes defined in the REPL # Instead of handle the case that module defined in REPL, just raise Exception here # res = _smvhash(_disassemble(cls)) traceback.print_exc() message = "{0}({1!r})".format(type(err).__name__, err.args) raise Exception( message + "\n" + "SmvDataSet " + self.urn() + " defined in shell can't be persisted" ) # incorporate source code hash of module's parent classes for m in inspect.getmro(cls): try: # TODO: it probably shouldn't matter if the upstream class is an SmvDataSet - it could be a mixin # whose behavior matters but which doesn't inherit from SmvDataSet if m.IsSmvDataSet and m != cls and not m.fqn().startswith("smv."): res += m(self.smvApp).sourceCodeHash() except: pass # NOTE: Until SmvRunConfig (now deprecated) is removed entirely, we consider 2 source code hashes, # config_hash and _smvGetRunConfigHash. The former is influenced by KVs for all keys listed in requiresConfig # while latter is influenced by KVs for all keys listed in smv.config.keys. # TODO: Is the config really a component of the "source code"? This method is called `sourceCodeHash`, after all. # incorporate hash of KVs for config keys listed in requiresConfig res += self.config_hash() # iterate through libs/modules that this DataSet depends on and use their source towards hash as well for lib in self.requiresLib(): lib_src_hash = _sourceHash(lib) res += lib_src_hash # if module inherits from SmvRunConfig, then add hash of all config values to module hash try: res += self._smvGetRunConfigHash() except: pass # if module has high order historical validation rules, add their hash to sum. # they key() of a validator should change if its parameters change. if hasattr(cls, "_smvHistoricalValidatorsList"): keys_hash = [_smvhash(v._key()) for v in cls._smvHistoricalValidatorsList] res += sum(keys_hash) # ensure python's numeric type can fit in a java.lang.Integer return res & 0x7fffffff getSourceCodeHash = create_py4j_interface_method("getSourceCodeHash", "sourceCodeHash") def instanceValHash(self): """Hash computed based on instance values of the dataset, such as the timestamp of an input file """ return 0 getInstanceValHash = create_py4j_interface_method("getInstanceValHash", "instanceValHash") @classmethod def fqn(cls): """Returns the fully qualified name """ return cls.__module__ + "." + cls.__name__ getFqn = create_py4j_interface_method("getFqn", "fqn") @classmethod def urn(cls): return "mod:" + cls.fqn() def isEphemeral(self): """Should this SmvDataSet skip persisting its data? Returns: (bool): True if this SmvDataSet should not persist its data, false otherwise """ return False getIsEphemeral = create_py4j_interface_method("getIsEphemeral", "isEphemeral") def publishHiveSql(self): """An optional sql query to run to publish the results of this module when the --publish-hive command line is used. The DataFrame result of running this module will be available to the query as the "dftable" table. Example: >>> return "insert overwrite table mytable select * from dftable" Note: If this method is not specified, the default is to just create the table specified by tableName() with the results of the module. Returns: (string): the query to run. """ return None getPublishHiveSql = create_py4j_interface_method("getPublishHiveSql", "publishHiveSql") @abc.abstractmethod def dsType(self): """Return SmvDataSet's type""" getDsType = create_py4j_interface_method("getDsType", "dsType") def dqmWithTypeSpecificPolicy(self): return self.dqm() getDqmWithTypeSpecificPolicy = create_py4j_interface_method( "getDqmWithTypeSpecificPolicy", "dqmWithTypeSpecificPolicy" ) def dependencies(self): """Can be overridden when a module has non-SmvDataSet dependencies (see SmvModelExec) """ return self.requiresDS() def dependencyUrns(self): arr = [x.urn() for x in self.dependencies()] return smv_copy_array(self.smvApp.sc, *arr) getDependencyUrns = create_py4j_interface_method("getDependencyUrns", "dependencyUrns") @classmethod def df2result(cls, df): """Given a datasets's persisted DataFrame, get the result object In most cases, this is just the DataFrame itself. See SmvResultModule for the exception. """ return df def metadata(self, df): """User-defined metadata Override this method to define metadata that will be logged with your module's results. Defaults to empty dictionary. Arguments: df (DataFrame): result of running the module, used to generate metadata Returns: (dict): dictionary of serializable metadata """ return {} def metadataJson(self, jdf): """Get user's metadata and jsonify it for py4j transport """ df = DataFrame(jdf, self.smvApp.sqlContext) metadata = self.metadata(df) if not isinstance(metadata, dict): raise SmvRuntimeError("User metadata {} is not a dict".format(repr(metadata))) return json.dumps(metadata) getMetadataJson = create_py4j_interface_method("getMetadataJson", "metadataJson") def validateMetadata(self, current, history): """User-defined metadata validation Override this method to define validation rules for metadata given the current metadata and historical metadata. Arguments: current (dict): current metadata kv history (list(dict)): list of historical metadata kv's Returns: (str): Validation failure message. Return None (or omit a return statement) if successful. """ return None def validateMetadataJson(self, currentJson, historyJson): """Load metadata (jsonified for py4j transport) and run user's validation on it """ current = json.loads(currentJson) history = [json.loads(j) for j in historyJson] res = self.validateMetadata(current, history) if res is not None and not is_string(res): raise SmvRuntimeError("Validation failure message {} is not a string".format(repr(res))) return res getValidateMetadataJson = create_py4j_interface_method("getValidateMetadataJson", "validateMetadataJson") def metadataHistorySize(self): """Override to define the maximum size of the metadata history for this module Return: (int): size """ return 5 getMetadataHistorySize = create_py4j_interface_method("getMetadataHistorySize", "metadataHistorySize") class Java: implements = ['org.tresamigos.smv.ISmvModule']
class DataSetRepo(object): def __init__(self, smvApp): self.smvApp = smvApp # Remove client modules from sys.modules to force reload of all client # code in the new transaction self._clear_sys_modules() def _clear_sys_modules(self): """ Clear all client modules from sys.modules If modules have names like 'stage1.stage2.file.mod', then we have to clear all of set( 'stage1', 'stage1.stage2', 'stage1.stage2.file', 'stage1.stage2.file.mod' ) from the sys.modules dictionary to avoid getting cached modules from python when we contruct a new DSR. """ # The set of all user-defined code that needs to be decached # { 'stage1' } from our example user_code_fqns = set(self.smvApp.stages()).union(self.smvApp.userLibs()) fqn_stubs_to_remove = {fqn.split('.')[0] for fqn in user_code_fqns} for loaded_mod_fqn in list(sys.modules.keys()): for stubbed_fqn in fqn_stubs_to_remove: if loaded_mod_fqn == stubbed_fqn or loaded_mod_fqn.startswith(stubbed_fqn + '.'): sys.modules.pop(loaded_mod_fqn) def _iter_submodules(self, stages): """Yield the names of all submodules of the packages corresponding to the given stages """ file_iters_by_stage = (self._iter_submodules_in_stage(stage) for stage in stages) file_iter = itertools.chain(*file_iters_by_stage) return (name for (_, name, is_pkg) in file_iter if not is_pkg) def _iter_submodules_in_stage(self, stage): """Yield info on the submodules of the package corresponding with a given stage """ try: stagemod = __import__(stage) except: self.smvApp.log.warn("Package does not exist for stage: " + stage) return [] # `walk_packages` can generate AttributeError if the system has # Gtk modules, which are not designed to use with reflection or # introspection. Best action to take in this situation is probably # to simply suppress the error. def onerror(name): self.smvApp.log.error("Skipping due to error during walk_packages: " + name) return pkgutil.walk_packages(stagemod.__path__, stagemod.__name__ + '.' , onerror=onerror) def _for_name(self, name): """Dynamically load a module in a stage by its name. Similar to Java's Class.forName, but only looks in configured stages. """ lastdot = name.rfind('.') file_name = name[ : lastdot] mod_name = name[lastdot+1 : ] mod = None # if file doesnt exist, module doesn't exist if file_name in self._iter_submodules(self.smvApp.stages()): # __import__ instantiates the module hierarchy but returns the root module f = __import__(file_name) # iterate to get the file that should contain the desired module for subname in file_name.split('.')[1:]: f = getattr(f, subname) # leave mod as None if the file exists but doesnt have an attribute with that name if hasattr(f, mod_name): mod = getattr(f, mod_name) return mod # Implementation of IDataSetRepoPy4J loadDataSet, which loads the dataset # from the most recent source. If the dataset does not exist, returns None. # However, if there is an error (such as a SyntaxError) which prevents the # user's file from being imported, the error will propagate back to the # DataSetRepoPython. def loadDataSet(self, fqn): ds = None ds_class = self._for_name(fqn) if ds_class is not None: ds = ds_class(self.smvApp) # Python issue https://bugs.python.org/issue1218234 # need to invalidate inspect.linecache to make dataset hash work srcfile = inspect.getsourcefile(ds_class) if srcfile: inspect.linecache.checkcache(srcfile) return ds getLoadDataSet = create_py4j_interface_method("getLoadDataSet", "loadDataSet") def _dataSetsForStage(self, stageName): urns = [] self.smvApp.log.debug("Searching for SmvDataSets in stage " + stageName) self.smvApp.log.debug("sys.path=" + repr(sys.path)) for pymod_name in self._iter_submodules([stageName]): # The additional "." is necessary to prevent false positive, e.g. stage_2.M1 matches stage if pymod_name.startswith(stageName + "."): # __import__('a.b.c') returns the module a, just like import a.b.c pymod = __import__(pymod_name) # After import a.b.c we got a. Now we traverse from a to b to c for c in pymod_name.split('.')[1:]: pymod = getattr(pymod, c) self.smvApp.log.debug("Searching for SmvDataSets in " + repr(pymod)) # iterate over the attributes of the module, looking for SmvDataSets for obj_name in dir(pymod): obj = getattr(pymod, obj_name) self.smvApp.log.debug("Inspecting {} ({})".format(obj_name, type(obj))) # We try to access the IsSmvDataSet attribute of the object. # if it does not exist, we will catch the the AttributeError # and skip the object, as it is not an SmvDataSet. We # specifically check that IsSmvDataSet is identical to # True, because some objects like Py4J's JavaObject override # __getattr__ to **always** return something (so IsSmvDataSet # maybe truthy even though the object is not an SmvDataSet). try: obj_is_smv_dataset = (obj.IsSmvDataSet is True) except AttributeError: obj_is_smv_dataset = False if not obj_is_smv_dataset: self.smvApp.log.debug("Ignoring {} because it is not an " "SmvDataSet".format(obj_name)) continue # Class should have an fqn which begins with the stageName. # Each package will contain all of the modules, classes, etc. # that were imported into it, and we need to exclude these # (so that we only count each module once) obj_declared_in_stage = obj.fqn().startswith(pymod_name) if not obj_declared_in_stage: self.smvApp.log.debug("Ignoring {} because it was not " "declared in {}. (Note: it may " "be collected from another stage)" .format(obj_name, pymod_name)) continue # Class should not be an ABC obj_is_abstract = inspect.isabstract(obj) if obj_is_abstract: # abc labels methods as abstract via the attribute __isabstractmethod__ is_abstract_method = lambda attr: getattr(attr, "__isabstractmethod__", False) abstract_methods = [name for name, _ in inspect.getmembers(obj, is_abstract_method)] self.smvApp.log.debug("Ignoring {} because it is abstract ({} undefined)" .format(obj_name, ", ".join(abstract_methods))) continue self.smvApp.log.debug("Collecting " + obj_name) urns.append(obj.urn()) return urns def dataSetsForStage(self, stageName): urns = self._dataSetsForStage(stageName) return smv_copy_array(self.smvApp.sc, *urns) getDataSetsForStage = create_py4j_interface_method("getDataSetsForStage", "dataSetsForStage") def notFound(self, modUrn, msg): raise ValueError("dataset [{0}] is not found in {1}: {2}".format(modUrn, self.__class__.__name__, msg)) class Java: implements = ['org.tresamigos.smv.IDataSetRepoPy4J']
class DataSetRepo(object): def __init__(self, smvApp): self.smvApp = smvApp # Remove client modules from sys.modules to force reload of all client # code in the new transaction self._clear_sys_modules() def _clear_sys_modules(self): """Clear all client modules from sys.modules """ for fqn in list(sys.modules.keys()): for stage_name in self.smvApp.stages: if fqn == stage_name or fqn.startswith(stage_name + "."): sys.modules.pop(fqn) break # Implementation of IDataSetRepoPy4J loadDataSet, which loads the dataset # from the most recent source. If the dataset does not exist, returns None. # However, if there is an error (such as a SyntaxError) which prevents the # user's file from being imported, the error will propagate back to the # DataSetRepoPython. def loadDataSet(self, fqn): ds = None ds_class = for_name(fqn, self.smvApp.stages) if ds_class is not None: ds = ds_class(self.smvApp) # Python issue https://bugs.python.org/issue1218234 # need to invalidate inspect.linecache to make dataset hash work srcfile = inspect.getsourcefile(ds_class) if srcfile: inspect.linecache.checkcache(srcfile) return ds getLoadDataSet = create_py4j_interface_method("getLoadDataSet", "loadDataSet") def dataSetsForStage(self, stageName): return self._moduleUrnsForStage(stageName, lambda obj: obj.IsSmvDataSet) getDataSetsForStage = create_py4j_interface_method("getDataSetsForStage", "dataSetsForStage") def _moduleUrnsForStage(self, stageName, fn): # `walk_packages` can generate AttributeError if the system has # Gtk modules, which are not designed to use with reflection or # introspection. Best action to take in this situation is probably # to simply suppress the error. def err(name): pass # print("Error importing module %s" % name) # t, v, tb = sys.exc_info() # print("type is {0}, value is {1}".format(t, v)) buf = [] # import the stage and only walk the packages in the path of that stage, recursively for name in iter_submodules([stageName]): # The additional "." is necessary to prevent false positive, e.g. stage_2.M1 matches stage if name.startswith(stageName + "."): pymod = __import__(name) for c in name.split('.')[1:]: pymod = getattr(pymod, c) for n in dir(pymod): obj = getattr(pymod, n) try: # Class should have an fqn which begins with the stageName. # Each package will contain among other things all of # the modules that were imported into it, and we need # to exclude these (so that we only count each module once) if fn(obj) and obj.fqn().startswith(name): buf.append(obj.urn()) except AttributeError: continue return smv_copy_array(self.smvApp.sc, *buf) def notFound(self, modUrn, msg): raise ValueError("dataset [{0}] is not found in {1}: {2}".format(modUrn, self.__class__.__name__, msg)) class Java: implements = ['org.tresamigos.smv.IDataSetRepoPy4J']