def _get_run_args(self): """Get the args and kwargs of this Node's run() method""" positionals = OrderedDict() keywords = OrderedDict() sig = signature(self.run) for i, param_name in enumerate(sig.parameters): param = sig.parameters[param_name] if i == 0 and self.run_requires_data: # The first param is the data to process which is passed # directly in process() continue raiseif( param.name in RESERVED_ARG_NAMES, "Reserved arg name '%s' used in run()" % param.name, ) if param.kind == param.POSITIONAL_ONLY: positionals[param.name] = None elif (param.default == Parameter.empty and param.kind == param.POSITIONAL_OR_KEYWORD): positionals[param.name] = None elif param.kind == param.POSITIONAL_OR_KEYWORD: keywords[param.name] = param.default elif param.kind == param.VAR_KEYWORD: pass else: raise AssertionError("%s params are not allowed in run()" % param.kind) return positionals, keywords
def run(self, df, conn, schema=None, dry_run=False, **kwargs): """Use Pandas to_sql to output a DataFrame to a temporary table. Push a reference to the temp table forward. Parameters ---------- df : pandas.DataFrame DataFrame to load to a SQL table conn Database connection schema : str, optional schema to create the temp table in dry_run : bool, optional If true, skip actually loading the data **kwargs Keyword arguments passed to DataFrame.to_sql """ raiseifnot(pd, "Please install Pandas to use this class") raiseif( isinstance(conn, sqlite3.Connection), "sqlite3 connections not supported due to bug in Pandas' has_table()", ) table = get_temp_table(conn, df, schema=schema, create=True) if dry_run: warn("dry_run=True, skipping load in %s.run" % self.__class__.__name__) else: df.to_sql(table.name, conn, if_exists="append", **kwargs) self.push(table.name)
def __init__(self, *args, **kwargs): set_missing_key(kwargs, "global_state", GlobalState()) # Ensure our version is default self.pipeline = GlidePipeline(*args, **kwargs) node_lookup = self.get_node_lookup() for key in node_lookup: raiseif(key in RESERVED_NODE_NAMES, "Can not use reserved node name: %s" % key)
def get_bulk_statement(self, conn, stmt_type, table, rows, odku=False): """Get a bulk execution SQL statement Parameters ---------- conn A SQL database connection object stmt_type : str Type of SQL statement to use (REPLACE, INSERT, etc.) table : str name of a SQL table rows An iterable of dict rows. The first row is used to determine column names. odku : bool or list, optional If true, add ON DUPLICATE KEY UPDATE clause for all columns. If a list then only add it for the specified columns. **Note:** Backend support for this varies. Returns ------- A SQL bulk load query of the given stmt_type """ if is_sqlalchemy_conn(conn): return get_bulk_statement(stmt_type, table, rows[0].keys(), dicts=False, odku=odku) if isinstance(conn, sqlite3.Connection): raiseifnot(isinstance(rows[0], sqlite3.Row), "Only sqlite3.Row rows are supported") return get_bulk_statement( stmt_type, table, rows[0].keys(), dicts=False, value_string="?", odku=odku, ) raiseif( isinstance(rows[0], tuple), "Dict rows expected, got tuple. Please use a dict cursor.", ) return get_bulk_statement(stmt_type, table, rows[0].keys(), odku=odku)
def _convert_kwargs(self, kwargs): """Convert flat kwargs to node contexts and remaining kwargs""" nodes = self.glider.get_node_lookup() node_contexts = {} add_to_final = set() for key, value in kwargs.items(): raiseif(key in nodes, "Invalid keyword arg '%s', can not be a node name" % (key)) node_name = self._get_arg_node_name(key) if node_name not in nodes: add_to_final.add(key) continue arg_name = key[len(node_name) + 1:] node_contexts.setdefault(node_name, {})[arg_name] = value injected_node_contexts = self._get_injected_node_contexts(kwargs) for node_name, injected_args in injected_node_contexts.items(): if node_name in node_contexts: node_contexts[node_name].update(injected_args) else: node_contexts[node_name] = injected_args arg_node_map = self._get_arg_name_node_map() for custom_arg_dest in self._get_custom_arg_dests(): if custom_arg_dest not in kwargs: continue for node_name in arg_node_map.get(custom_arg_dest, []): custom_arg_dict = {custom_arg_dest: kwargs[custom_arg_dest]} if node_name in node_contexts: node_contexts[node_name].update(custom_arg_dict) else: node_contexts[node_name] = custom_arg_dict final_kwargs = dict(node_contexts=node_contexts) for key in add_to_final: final_kwargs[key] = kwargs[key] return final_kwargs
def _push(self, data): raiseifnot(delayed, "Please install dask (delayed) to use DaskDelayedPush") if self._logging == "output": self._write_log(data) raiseif( "executor_kwargs" in self.context, "%s does not currently support executor_kwargs" % self.__class__, ) lazy = [] if self.context.get("split", False): splits = np.array_split(data, len(self._downstream_nodes)) for i, downstream in enumerate(self._downstream_nodes): lazy.append(delayed(downstream._process)(splits[i])) else: for downstream in self._downstream_nodes: lazy.append(delayed(downstream._process)(data)) compute(lazy)
def __init__(self, filename=None, var=None, key=None): """Populate context values at runtime from a config file. One of filename or var must be specified. Parameters ---------- filename : str, optional Name of a file to read the config from. The config parser used will be inferred from the file extension. var : str, optional The name of an environment variable that points to a config file to read. key : str or callable, optional A key to extract from the config, or a callable that takes the config and returns an extracted value """ raiseifnot(filename or var, "Either filename or var must be specified") raiseif(filename and var, "Only one of filename or var should be specified") if var: filename = os.environ[var] ext = filename.split(".")[-1] supported = ["json", "yaml", "ini"] raiseifnot(ext in supported, "Invalid extension, only %s supported" % supported) if ext == "json": func = load_json_config elif ext == "yaml": func = load_yaml_config elif ext == "ini": func = load_ini_config super().__init__(func, filename, key=key)
def run( self, f, compression=None, open_flags="r", chunksize=None, push_lines=False, limit=None, ): """Extract raw data from a file or buffer and push contents Parameters ---------- f : file path or buffer File path or buffer to read compression : str, optional param passed to pandas get_filepath_or_buffer open_flags : str, optional Flags to pass to open() if f is not already an opened buffer chunksize : int, optional Push lines in chunks of this size push_lines : bool, optional Push each line as it's read instead of reading entire file and pushing limit : int, optional Limit to first N lines """ raiseif(chunksize and push_lines ), "Only one of chunksize and push_lines may be specified" is_text = True if "b" in open_flags: is_text = False f, _, close = open_filepath_or_buffer(f, open_flags=open_flags, compression=compression, is_text=is_text) try: data = [] count = 0 for line in f: count += 1 if push_lines: self.push(line) else: data.append(line) if chunksize and (count % chunksize == 0): if is_text: self.push("".join(data)) else: self.push(b"".join(data)) data = [] if limit and count >= limit: break if ((not push_lines) and data) or count == 0: if is_text: self.push("".join(data)) else: self.push(b"".join(data)) finally: if close: try: f.close() except ValueError: pass
def run(self, data, sql, conn, cursor=None, cursor_type=None, params=None, data_check=None, **kwargs): """Run a SQL query to check data. Parameters ---------- data Data to pass through on success sql : str SQL query to run. Should return a single row with a "assert" column to indicate success. Truthy values for "assert" will be considered successful, unless data_check is passed in which case it will be compared for equality to the result of that callable. conn SQL connection object cursor : optional SQL connection cursor object cursor_type : optional SQL connection cursor type when creating a cursor is necessary params : tuple or dict, optional A tuple or dict of params to pass to the execute method data_check : callable, optional A callable that will be passed the node and data as arguments and is expected to return a value to be compared to the SQL result. **kwargs Keyword arguments pushed to the execute method """ if not cursor: cursor = self.get_sql_executor(conn, cursor_type=cursor_type) params = params or () fetcher = self.execute(conn, cursor, sql, params=params, **kwargs) result = fetcher.fetchone() if isinstance(conn, sqlite3.Connection): raiseifnot( isinstance(result, sqlite3.Row), "Only sqlite3.Row rows are supported for sqlite3 connections", ) raiseif( isinstance(result, tuple), "Dict rows expected, got tuple. Please use a dict cursor.", ) raiseifnot("assert" in result.keys(), "Result is missing 'assert' column") result = result["assert"] if data_check: check = data_check(self, data) raiseifnot( result == check, ("SQL assertion failed\nnode: %s\nsql: %s\nvalue: %s\ndata_check: %s" % (self.name, sql, result, check)), ) else: raiseifnot( result, ("SQL assertion failed\nnode: %s\nsql: %s\nvalue: %s" % (self.name, sql, result)), ) self.push(data)
def get_results(self, futures, timeout=None): raiseif(timeout, "timeout argument is not supported for Dask Client") dfs = [] for _, result in dask_as_completed(futures, with_results=True): dfs.append(result) return pd.concat(dfs)
def _get_script_args(self): """Generate all tlbx Args for this Glider""" node_lookup = self.glider.get_node_lookup() script_args = OrderedDict() # Map of arg names to Args arg_dests = {} # Map of arg dests back to names node_arg_names = defaultdict(set) requires_data = not isinstance(self.glider.top_node, NoInputNode) if requires_data and not self.blacklisted("", SCRIPT_DATA_ARG): script_args[SCRIPT_DATA_ARG] = Arg(SCRIPT_DATA_ARG, nargs="+") def add_script_arg(node, arg_name, **kwargs): script_arg = self._get_script_arg(node, arg_name, **kwargs) if not script_arg: return script_args[script_arg.name] = script_arg arg_dests[script_arg.dest] = script_arg.name node_arg_names[arg_name].add(script_arg.name) for node in node_lookup.values(): node_help = {} if FunctionDoc: try: # Only works if run() has docs in numpydoc format docs = FunctionDoc(node.run) node_help = { v.name: "\n".join(v.desc) for v in docs["Parameters"] } except Exception as e: info("failed to parse node '%s' run() docs: %s" % (node.name, str(e))) for arg_name, _ in node.run_args.items(): add_script_arg( node, arg_name, required=True, arg_help=node_help.get(arg_name, None), ) for kwarg_name, kwarg_default in node.run_kwargs.items(): add_script_arg( node, kwarg_name, required=False, default=kwarg_default, arg_help=node_help.get(kwarg_name, None), ) def assert_arg_present(custom_arg, arg_name): raiseifnot( arg_name in script_args, ("Custom arg %s with dest=%s maps to node arg=%s " "which is not in the script arg list. Check for " "conflicting args that cover the same node arg." % (custom_arg.name, custom_arg.dest, arg_name)), ) for custom_arg in self.custom_args: raiseif( self.blacklisted("", custom_arg.name), "Blacklisted arg '%s' passed as a custom arg" % custom_arg.name, ) if custom_arg.dest in node_arg_names: # Find and delete all node-based args this will cover for arg_name in node_arg_names[custom_arg.dest]: assert_arg_present(custom_arg, arg_name) del script_args[arg_name] if custom_arg.dest in arg_dests: # Remove the original arg that this custom arg will satisfy arg_name = arg_dests[custom_arg.dest] assert_arg_present(custom_arg, arg_name) del script_args[arg_name] script_args[custom_arg.name] = custom_arg arg_dests[custom_arg.dest] = custom_arg.name return script_args.values()
def _get_script_arg(self, node, arg_name, required=False, default=None, arg_help=None): """Generate a tlbx Arg""" if self.blacklisted(node.name, arg_name): return None dest_arg_name = self._get_script_arg_name(node.name, arg_name) if arg_name in self.inject: required = False default = None elif arg_name in node.context: required = False default = node.context[arg_name] elif arg_name in self.glider.global_state: required = False default = self.glider.global_state[arg_name] elif dest_arg_name in self.glider.global_state: required = False default = self.glider.global_state[dest_arg_name] arg_type = str if default is not None: arg_type = type(default) if arg_type == bool: raiseif(required, "Required bool args don't make sense") base_arg_name = arg_name if default: action = "store_false" base_arg_name = "no_" + arg_name else: action = "store_true" dest = self._get_script_arg_name(node.name, arg_name) arg_name = self._get_script_arg_name(node.name, base_arg_name) script_arg = Arg( "--" + arg_name, required=required, action=action, default=default, help=arg_help, dest=dest, ) else: arg_name = self._get_script_arg_name(node.name, arg_name) # TODO: argparse puts required args with "--" in the "optional" # section. There are workarounds, but it's unclear how to use them # with tlbx.Arg which is based on the climax library. # https://stackoverflow.com/q/24180527/10682164 script_arg = Arg( "--" + arg_name, required=required, type=arg_type, default=default, help=arg_help, ) return script_arg
def _check_arg_conflicts(self): for dest in self._get_custom_arg_dests(): raiseif(dest in self.inject, "Arg dest '%s' conflicts with injected arg" % dest)
def run( self, data, url, data_param="data", session=None, skip_raise=False, dry_run=False, **kwargs ): """Load data to URL using requests and push response.content. The url maybe be a string (POST that url) or a dictionary of args to requests.request: http://2.python-requests.org/en/master/api/?highlight=get#requests.request Parameters ---------- data Data to load to the URL url : str or dict If str, a URL to POST to. If a dict, args to requets.request data_param : str, optional parameter to stuff data in when calling requests methods session : optional A requests Session to use to make the request skip_raise : bool, optional if False, raise exceptions for bad response status dry_run : bool, optional If true, skip actually loading the data **kwargs Keyword arguments to pass to the request method. If a dict is passed for the url parameter it overrides values here. """ requestor = requests if session: requestor = session if dry_run: warn("dry_run=True, skipping load in %s.run" % self.__class__.__name__) else: if isinstance(url, str): raiseif( "data" in kwargs or "json" in kwargs, "Overriding data/json params is not allowed", ) kwargs[data_param] = data resp = requestor.post(url, **kwargs) elif isinstance(url, dict): kwargs_copy = deepcopy(kwargs) kwargs_copy.update(url) raiseif( "data" in kwargs_copy or "json" in kwargs_copy, "Overriding data/json params is not allowed", ) kwargs_copy[data_param] = data resp = requestor.request(**kwargs_copy) else: raise AssertionError( "Input url must be a str or dict type, got %s" % type(url) ) if not skip_raise: resp.raise_for_status() self.push(data)