Esempio n. 1
0
 def test_make_date(self):
     # SPARK-36554: expose make_date expression
     df = self.spark.createDataFrame([(2020, 6, 26)], ["Y", "M", "D"])
     row_from_col = df.select(make_date(df.Y, df.M, df.D)).first()
     self.assertEqual(row_from_col[0], datetime.date(2020, 6, 26))
     row_from_name = df.select(make_date("Y", "M", "D")).first()
     self.assertEqual(row_from_name[0], datetime.date(2020, 6, 26))
Esempio n. 2
0
    def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column:
        sql_utils = SparkContext._active_spark_context._jvm.PythonSQLUtils
        origin_scol = F.lit(origin)
        (rule_code, n) = (self._offset.rule_code, self._offset.n
                          )  # type: ignore[attr-defined]
        left_closed, right_closed = (self._closed == "left",
                                     self._closed == "right")
        left_labeled, right_labeled = (self._label == "left",
                                       self._label == "right")

        if rule_code == "A-DEC":
            assert (origin.month == 12 and origin.day == 31
                    and origin.hour == 0 and origin.minute == 0
                    and origin.second == 0)

            diff = F.year(ts_scol) - F.year(origin_scol)
            mod = F.lit(0) if n == 1 else (diff % n)
            edge_cond = (mod == 0) & (F.month(ts_scol)
                                      == 12) & (F.dayofmonth(ts_scol) == 31)

            edge_label = F.year(ts_scol)
            if left_closed and right_labeled:
                edge_label += n
            elif right_closed and left_labeled:
                edge_label -= n

            if left_labeled:
                non_edge_label = F.when(mod == 0,
                                        F.year(ts_scol) -
                                        n).otherwise(F.year(ts_scol) - mod)
            else:
                non_edge_label = F.when(
                    mod == 0,
                    F.year(ts_scol)).otherwise(F.year(ts_scol) - (mod - n))

            return F.to_timestamp(
                F.make_date(
                    F.when(edge_cond, edge_label).otherwise(non_edge_label),
                    F.lit(12), F.lit(31)))

        elif rule_code == "M":
            assert (origin.is_month_end and origin.hour == 0
                    and origin.minute == 0 and origin.second == 0)

            diff = ((F.year(ts_scol) - F.year(origin_scol)) * 12 +
                    F.month(ts_scol) - F.month(origin_scol))
            mod = F.lit(0) if n == 1 else (diff % n)
            edge_cond = (mod == 0) & (F.dayofmonth(ts_scol) == F.dayofmonth(
                F.last_day(ts_scol)))

            truncated_ts_scol = F.date_trunc("MONTH", ts_scol)
            edge_label = truncated_ts_scol
            if left_closed and right_labeled:
                edge_label += sql_utils.makeInterval("MONTH", F.lit(n)._jc)
            elif right_closed and left_labeled:
                edge_label -= sql_utils.makeInterval("MONTH", F.lit(n)._jc)

            if left_labeled:
                non_edge_label = F.when(
                    mod == 0,
                    truncated_ts_scol -
                    sql_utils.makeInterval("MONTH",
                                           F.lit(n)._jc),
                ).otherwise(truncated_ts_scol -
                            sql_utils.makeInterval("MONTH", mod._jc))
            else:
                non_edge_label = F.when(mod == 0, truncated_ts_scol).otherwise(
                    truncated_ts_scol -
                    sql_utils.makeInterval("MONTH", (mod - n)._jc))

            return F.to_timestamp(
                F.last_day(
                    F.when(edge_cond, edge_label).otherwise(non_edge_label)))

        elif rule_code == "D":
            assert origin.hour == 0 and origin.minute == 0 and origin.second == 0

            if n == 1:
                # NOTE: the logic to process '1D' is different from the cases with n>1,
                # since hour/minute/second parts are taken into account to determine edges!
                edge_cond = ((F.hour(ts_scol) == 0) & (F.minute(ts_scol) == 0)
                             & (F.second(ts_scol) == 0))

                if left_closed and left_labeled:
                    return F.date_trunc("DAY", ts_scol)
                elif left_closed and right_labeled:
                    return F.date_trunc("DAY", F.date_add(ts_scol, 1))
                elif right_closed and left_labeled:
                    return F.when(edge_cond,
                                  F.date_trunc("DAY", F.date_sub(
                                      ts_scol, 1))).otherwise(
                                          F.date_trunc("DAY", ts_scol))
                else:
                    return F.when(edge_cond,
                                  F.date_trunc("DAY", ts_scol)).otherwise(
                                      F.date_trunc("DAY",
                                                   F.date_add(ts_scol, 1)))

            else:
                diff = F.datediff(end=ts_scol, start=origin_scol)
                mod = diff % n

                edge_cond = mod == 0

                truncated_ts_scol = F.date_trunc("DAY", ts_scol)
                edge_label = truncated_ts_scol
                if left_closed and right_labeled:
                    edge_label = F.date_add(truncated_ts_scol, n)
                elif right_closed and left_labeled:
                    edge_label = F.date_sub(truncated_ts_scol, n)

                if left_labeled:
                    non_edge_label = F.date_sub(truncated_ts_scol, mod)
                else:
                    non_edge_label = F.date_sub(truncated_ts_scol, mod - n)

                return F.when(edge_cond, edge_label).otherwise(non_edge_label)

        elif rule_code in ["H", "T", "S"]:
            unit_mapping = {"H": "HOUR", "T": "MINUTE", "S": "SECOND"}
            unit_str = unit_mapping[rule_code]

            truncated_ts_scol = F.date_trunc(unit_str, ts_scol)
            diff = sql_utils.timestampDiff(unit_str, origin_scol._jc,
                                           truncated_ts_scol._jc)
            mod = F.lit(0) if n == 1 else (diff % F.lit(n))

            if rule_code == "H":
                assert origin.minute == 0 and origin.second == 0
                edge_cond = (mod == 0) & (F.minute(ts_scol)
                                          == 0) & (F.second(ts_scol) == 0)
            elif rule_code == "T":
                assert origin.second == 0
                edge_cond = (mod == 0) & (F.second(ts_scol) == 0)
            else:
                edge_cond = mod == 0

            edge_label = truncated_ts_scol
            if left_closed and right_labeled:
                edge_label += sql_utils.makeInterval(unit_str, F.lit(n)._jc)
            elif right_closed and left_labeled:
                edge_label -= sql_utils.makeInterval(unit_str, F.lit(n)._jc)

            if left_labeled:
                non_edge_label = F.when(mod == 0, truncated_ts_scol).otherwise(
                    truncated_ts_scol -
                    sql_utils.makeInterval(unit_str, mod._jc))
            else:
                non_edge_label = F.when(
                    mod == 0,
                    truncated_ts_scol +
                    sql_utils.makeInterval(unit_str,
                                           F.lit(n)._jc),
                ).otherwise(truncated_ts_scol -
                            sql_utils.makeInterval(unit_str, (mod - n)._jc))

            return F.when(edge_cond, edge_label).otherwise(non_edge_label)

        else:
            raise ValueError("Got the unexpected unit {}".format(rule_code))