def _get_filter_info(self, expr_to_parse, X) -> Tuple[str, Any, Optional[str]]: col_list = X.columns if isinstance(expr_to_parse, ast.Call): op = expr_to_parse.func # for now, we only support single argument predicates if len(expr_to_parse.args) != 1: raise ValueError( "Filter predicate functions currently only support a single argument" ) arg = expr_to_parse.args[0] if _is_ast_subscript(arg): lhs = _get_subscript_value(arg) elif _is_ast_attribute(arg): lhs = arg.attr # type: ignore else: raise ValueError( "Filter predicate functions only supports subscript or dot notation for the argument. For example, it.col_name or it['col_name']" ) if lhs not in col_list: raise ValueError( "Cannot perform filter predicate operation as {} not a column of input dataframe X.".format( lhs ) ) return lhs, op, None if _is_ast_subscript(expr_to_parse.left): lhs = _get_subscript_value(expr_to_parse.left) elif _is_ast_attribute(expr_to_parse.left): lhs = expr_to_parse.left.attr else: raise ValueError( "Filter predicate only supports subscript or dot notation for the left hand side. For example, it.col_name or it['col_name']" ) if lhs not in col_list: raise ValueError( "Cannot perform filter operation as {} not a column of input dataframe X.".format( lhs ) ) op = expr_to_parse.ops[0] if _is_ast_subscript(expr_to_parse.comparators[0]): rhs = _get_subscript_value(expr_to_parse.comparators[0]) elif _is_ast_attribute(expr_to_parse.comparators[0]): rhs = expr_to_parse.comparators[0].attr elif _is_ast_constant(expr_to_parse.comparators[0]): rhs = expr_to_parse.comparators[0].value else: raise ValueError( "Filter predicate only supports subscript or dot notation for the right hand side. For example, it.col_name or it['col_name'] or a constant value" ) if not _is_ast_constant(expr_to_parse.comparators[0]) and rhs not in col_list: raise ValueError( "Cannot perform filter operation as {} not a column of input dataframe X.".format( rhs ) ) return lhs, op, rhs
def identity(df: Any, column: Expr, new_column_name: str): if _is_ast_subscript(column._expr): # type: ignore column_name = column._expr.slice.value.s # type: ignore elif _is_ast_attribute(column._expr): # type: ignore column_name = column._expr.attr # type: ignore else: raise ValueError( "Expression type not supported. Formats supported: it.column_name or it['column_name']." ) if column_name is None or not column_name.strip(): raise ValueError( "Name of the column to be renamed cannot be None or empty.") if new_column_name is None or not new_column_name.strip(): raise ValueError( "New name of the column to be renamed cannot be None or empty.") if _is_pandas_df(df): df = df.rename(columns={column_name: new_column_name}) elif spark_installed and _is_spark_df(df): df = df.withColumnRenamed(column_name, new_column_name) else: raise ValueError( "Function identity supports only Pandas dataframes or spark dataframes." ) return new_column_name, df
def get_map_function_output(column, new_column_name): functions_module = importlib.import_module("lale.lib.lale.functions") if _is_ast_subscript(column._expr) or _is_ast_attribute(column._expr): function_name = "identity" else: function_name = column._expr.func.id map_func_to_be_called = getattr(functions_module, function_name) return map_func_to_be_called(X, column, new_column_name)
def _get_group_key(self, expr_to_parse): if _is_ast_subscript(expr_to_parse): return expr_to_parse.slice.value.s # type: ignore elif _is_ast_attribute(expr_to_parse): return expr_to_parse.attr else: raise ValueError( "GroupBy by parameter only supports subscript or dot notation for the key columns. For example, it.col_name or it['col_name']." )
def _get_join_info(cls, expr_to_parse): left_key = [] right_key = [] if _is_ast_subscript(expr_to_parse.left.value): left_name = _get_subscript_value(expr_to_parse.left.value) elif _is_ast_attribute(expr_to_parse.left.value): left_name = expr_to_parse.left.value.attr else: raise ValueError( "ERROR: Expression type not supported! Formats supported: it.table_name.column_name or it['table_name'].column_name" ) if _is_ast_subscript(expr_to_parse.left): left_key.append(_get_subscript_value(expr_to_parse.left)) elif _is_ast_attribute(expr_to_parse.left): left_key.append(expr_to_parse.left.attr) else: raise ValueError( "ERROR: Expression type not supported! Formats supported: it.table_name.column_name or it.table_name['column_name']" ) if _is_ast_subscript(expr_to_parse.comparators[0].value): right_name = _get_subscript_value( expr_to_parse.comparators[0].value) elif _is_ast_attribute(expr_to_parse.comparators[0].value): right_name = expr_to_parse.comparators[0].value.attr else: raise ValueError( "ERROR: Expression type not supported! Formats supported: it.table_name.column_name or it['table_name'].column_name" ) if _is_ast_subscript(expr_to_parse.comparators[0]): right_key.append(_get_subscript_value( expr_to_parse.comparators[0])) elif _is_ast_attribute(expr_to_parse.comparators[0]): right_key.append(expr_to_parse.comparators[0].attr) else: raise ValueError( "ERROR: Expression type not supported! Formats supported: it.table_name.column_name or it.table_name['column_name']" ) return left_name, left_key, right_name, right_key
def _get_order_key(self, expr_to_parse) -> Tuple[str, bool]: order_asc: bool = True col: str if isinstance(expr_to_parse, Expr): expr_to_parse = expr_to_parse._expr if isinstance(expr_to_parse, ast.Call): op = expr_to_parse.func if isinstance(op, ast.Name): name = op.id if name == "asc": order_asc = True elif name == "desc": order_asc = False else: raise ValueError( "OrderBy descriptor expressions must be either asc or desc" ) else: raise ValueError( "OrderBy expressions must be a string or an order descriptor (asc, desc)" ) # for now, we only support single argument predicates if len(expr_to_parse.args) != 1: raise ValueError( "OrderBy predicates do not support multiple aruguments", ) arg = expr_to_parse.args[0] else: arg = expr_to_parse else: arg = expr_to_parse if isinstance(arg, str): col = arg elif isinstance(arg, ast.Name): col = arg.id # type: ignore elif hasattr(ast, "Constant") and isinstance(arg, ast.Constant): col = arg.value # type: ignore elif hasattr(ast, "Str") and isinstance(arg, ast.Str): col = arg.s elif _is_ast_subscript(arg): col = arg.slice.value.s # type: ignore elif _is_ast_attribute(arg): col = arg.attr # type: ignore else: raise ValueError( "OrderBy parameters only support string, subscript or dot notation for the column name. For example, it.col_name or it['col_name']." ) return col, order_asc
def infer_new_name(expr): if (_is_ast_call(expr._expr) and _is_ast_name(expr._expr.func) and expr._expr.func.id in [ "replace", "day_of_month", "day_of_week", "day_of_year", "hour", "minute", "month", ] and _is_ast_attribute(expr._expr.args[0])): return expr._expr.args[0].attr else: raise ValueError( """New name of the column to be renamed cannot be None or empty. You may want to use a dictionary to specify the new column name as the key, and the expression as the value.""" )
def transform(self, X): agg_info = {} agg_expr = {} def create_spark_agg_expr(new_col_name, agg_col_func): functions_module = importlib.import_module( "lale.lib.lale.functions") def get_spark_agg_method(agg_method_name): return getattr(functions_module, "grouped_" + agg_method_name) agg_method = get_spark_agg_method( agg_col_func[1])() # type: ignore return agg_method(agg_col_func[0]).alias(new_col_name) if not isinstance(self.columns, dict): raise ValueError( "Aggregate 'columns' parameter should be of dictionary type.") for new_col_name, expr in (self.columns.items() if self.columns is not None else []): agg_func = expr._expr.func.id expr_to_parse = expr._expr.args[0] if _is_ast_subscript(expr_to_parse): agg_col = expr_to_parse.slice.value.s # type: ignore elif _is_ast_attribute(expr_to_parse): agg_col = expr_to_parse.attr else: raise ValueError( "Aggregate 'columns' parameter only supports subscript or dot notation for the key columns. For example, it.col_name or it['col_name']." ) agg_info[new_col_name] = (agg_col, agg_func) agg_info_sorted = { k: v for k, v in sorted(agg_info.items(), key=lambda item: item[1]) } if _is_pandas_grouped_df(X): for agg_col_func in agg_info_sorted.values(): if agg_col_func[0] in agg_expr: agg_expr[agg_col_func[0]].append(agg_col_func[1]) else: agg_expr[agg_col_func[0]] = [agg_col_func[1]] try: aggregated_df = X.agg(agg_expr) aggregated_df.columns = agg_info_sorted.keys() except KeyError as e: raise KeyError(e) elif _is_spark_grouped_df(X): agg_expr = [ create_spark_agg_expr(new_col_name, agg_col_func) for new_col_name, agg_col_func in agg_info_sorted.items() ] try: aggregated_df = X.agg(*agg_expr) except Exception as e: raise Exception(e) else: raise ValueError( "Only pandas and spark dataframes are supported by the Aggregate operator." ) named_aggregated_df = lale.datasets.data_schemas.add_table_name( aggregated_df, lale.datasets.data_schemas.get_table_name(X)) return named_aggregated_df