def test_make_date(self): # SPARK-36554: expose make_date expression df = self.spark.createDataFrame([(2020, 6, 26)], ["Y", "M", "D"]) row_from_col = df.select(make_date(df.Y, df.M, df.D)).first() self.assertEqual(row_from_col[0], datetime.date(2020, 6, 26)) row_from_name = df.select(make_date("Y", "M", "D")).first() self.assertEqual(row_from_name[0], datetime.date(2020, 6, 26))
def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column: sql_utils = SparkContext._active_spark_context._jvm.PythonSQLUtils origin_scol = F.lit(origin) (rule_code, n) = (self._offset.rule_code, self._offset.n ) # type: ignore[attr-defined] left_closed, right_closed = (self._closed == "left", self._closed == "right") left_labeled, right_labeled = (self._label == "left", self._label == "right") if rule_code == "A-DEC": assert (origin.month == 12 and origin.day == 31 and origin.hour == 0 and origin.minute == 0 and origin.second == 0) diff = F.year(ts_scol) - F.year(origin_scol) mod = F.lit(0) if n == 1 else (diff % n) edge_cond = (mod == 0) & (F.month(ts_scol) == 12) & (F.dayofmonth(ts_scol) == 31) edge_label = F.year(ts_scol) if left_closed and right_labeled: edge_label += n elif right_closed and left_labeled: edge_label -= n if left_labeled: non_edge_label = F.when(mod == 0, F.year(ts_scol) - n).otherwise(F.year(ts_scol) - mod) else: non_edge_label = F.when( mod == 0, F.year(ts_scol)).otherwise(F.year(ts_scol) - (mod - n)) return F.to_timestamp( F.make_date( F.when(edge_cond, edge_label).otherwise(non_edge_label), F.lit(12), F.lit(31))) elif rule_code == "M": assert (origin.is_month_end and origin.hour == 0 and origin.minute == 0 and origin.second == 0) diff = ((F.year(ts_scol) - F.year(origin_scol)) * 12 + F.month(ts_scol) - F.month(origin_scol)) mod = F.lit(0) if n == 1 else (diff % n) edge_cond = (mod == 0) & (F.dayofmonth(ts_scol) == F.dayofmonth( F.last_day(ts_scol))) truncated_ts_scol = F.date_trunc("MONTH", ts_scol) edge_label = truncated_ts_scol if left_closed and right_labeled: edge_label += sql_utils.makeInterval("MONTH", F.lit(n)._jc) elif right_closed and left_labeled: edge_label -= sql_utils.makeInterval("MONTH", F.lit(n)._jc) if left_labeled: non_edge_label = F.when( mod == 0, truncated_ts_scol - sql_utils.makeInterval("MONTH", F.lit(n)._jc), ).otherwise(truncated_ts_scol - sql_utils.makeInterval("MONTH", mod._jc)) else: non_edge_label = F.when(mod == 0, truncated_ts_scol).otherwise( truncated_ts_scol - sql_utils.makeInterval("MONTH", (mod - n)._jc)) return F.to_timestamp( F.last_day( F.when(edge_cond, edge_label).otherwise(non_edge_label))) elif rule_code == "D": assert origin.hour == 0 and origin.minute == 0 and origin.second == 0 if n == 1: # NOTE: the logic to process '1D' is different from the cases with n>1, # since hour/minute/second parts are taken into account to determine edges! edge_cond = ((F.hour(ts_scol) == 0) & (F.minute(ts_scol) == 0) & (F.second(ts_scol) == 0)) if left_closed and left_labeled: return F.date_trunc("DAY", ts_scol) elif left_closed and right_labeled: return F.date_trunc("DAY", F.date_add(ts_scol, 1)) elif right_closed and left_labeled: return F.when(edge_cond, F.date_trunc("DAY", F.date_sub( ts_scol, 1))).otherwise( F.date_trunc("DAY", ts_scol)) else: return F.when(edge_cond, F.date_trunc("DAY", ts_scol)).otherwise( F.date_trunc("DAY", F.date_add(ts_scol, 1))) else: diff = F.datediff(end=ts_scol, start=origin_scol) mod = diff % n edge_cond = mod == 0 truncated_ts_scol = F.date_trunc("DAY", ts_scol) edge_label = truncated_ts_scol if left_closed and right_labeled: edge_label = F.date_add(truncated_ts_scol, n) elif right_closed and left_labeled: edge_label = F.date_sub(truncated_ts_scol, n) if left_labeled: non_edge_label = F.date_sub(truncated_ts_scol, mod) else: non_edge_label = F.date_sub(truncated_ts_scol, mod - n) return F.when(edge_cond, edge_label).otherwise(non_edge_label) elif rule_code in ["H", "T", "S"]: unit_mapping = {"H": "HOUR", "T": "MINUTE", "S": "SECOND"} unit_str = unit_mapping[rule_code] truncated_ts_scol = F.date_trunc(unit_str, ts_scol) diff = sql_utils.timestampDiff(unit_str, origin_scol._jc, truncated_ts_scol._jc) mod = F.lit(0) if n == 1 else (diff % F.lit(n)) if rule_code == "H": assert origin.minute == 0 and origin.second == 0 edge_cond = (mod == 0) & (F.minute(ts_scol) == 0) & (F.second(ts_scol) == 0) elif rule_code == "T": assert origin.second == 0 edge_cond = (mod == 0) & (F.second(ts_scol) == 0) else: edge_cond = mod == 0 edge_label = truncated_ts_scol if left_closed and right_labeled: edge_label += sql_utils.makeInterval(unit_str, F.lit(n)._jc) elif right_closed and left_labeled: edge_label -= sql_utils.makeInterval(unit_str, F.lit(n)._jc) if left_labeled: non_edge_label = F.when(mod == 0, truncated_ts_scol).otherwise( truncated_ts_scol - sql_utils.makeInterval(unit_str, mod._jc)) else: non_edge_label = F.when( mod == 0, truncated_ts_scol + sql_utils.makeInterval(unit_str, F.lit(n)._jc), ).otherwise(truncated_ts_scol - sql_utils.makeInterval(unit_str, (mod - n)._jc)) return F.when(edge_cond, edge_label).otherwise(non_edge_label) else: raise ValueError("Got the unexpected unit {}".format(rule_code))