Beispiel #1
0
 def minute(self) -> "ks.Series":
     """
     The minutes of the datetime.
     """
     return column_op(lambda c: F.minute(c).cast(LongType()))(self._data).alias(self._data.name)
Beispiel #2
0
def make_get_step_udf(multi_steps: Optional[int]):
    """ Get step count by taking length of next_states_features array. """
    def get_step(col: List):
        return 1 if multi_steps is None else min(len(col), multi_steps)

    return udf(get_step, LongType())
    else:
        b_t = datetime.fromtimestamp(b, ho_chi_minh_timezone)
    b_t_week_id = long(b_t.strftime("%Y%W"))

    date_item = a_t
    while long(date_item.strftime("%Y%W")) < b_t_week_id:
        weeks.append(long(date_item.strftime("%Y%W")))
        date_item += timedelta(7)

    if len(weeks) == 0:
        weeks = [week_fake]

    return weeks


get_weeks = f.udf(get_weeks, ArrayType(LongType()))


def get_df_student_package(glueContext):
    dyf_student_package = glueContext.create_dynamic_frame.from_options(
        connection_type="redshift",
        connection_options={
            "url":
            REDSHIFT_DATABASE,
            "user":
            REDSHIFT_USERNAME,
            "password":
            REDSHIFT_PASSWORD,
            "dbtable":
            "ad_student_package",
            "redshiftTmpDir":
Beispiel #4
0
    if tf_name is None:
        tf_name = col_name
    # Use the python convention (None)
    shape = [x if x >= 0 else None for x in col_shape]
    if not block:
        shape = shape[1:]
    else:
        # The lead is always set to None, because otherwise it may choke on empty partitions.
        # (This happens when the dataset is too small)
        # TODO(tjh) add a test for this case.
        shape[0] = None
    return tf.placeholder(tfdtype, shape=shape, name=tf_name)

_dtypes = {DoubleType() : tf.double,
          IntegerType() : tf.int32,
          LongType() : tf.int64,
          FloatType() : tf.float32}

def _get_jgroup(grouped_data):
    """Get the JVM object that backs this grouped data, taking into account the different
    spark versions."""
    d = dir(grouped_data)
    if '_jdf' in d:
        return grouped_data._jdf
    if '_jgd' in d:
        return grouped_data._jgd
    raise ValueError('Could not find a dataframe for {}. All methods: {}'.format(grouped_data, d))

def _get_dtype(dtype):
    if isinstance(dtype, ArrayType):
        return _get_dtype(dtype.elementType)
Beispiel #5
0
        for line in f:
            ary = line.split('|')
            movie_names[int(ary[0])] = ary[1]
    return movie_names


spark = SparkSession.builder.appName("BcastMovies").getOrCreate()
spark.sparkContext.setLogLevel("ERROR")

name_dict = spark.sparkContext.broadcast(load_movie_names())

schema = StructType([
    StructField("userID", IntegerType(), True),
    StructField("movieID", IntegerType(), True),
    StructField("rating", IntegerType(), True),
    StructField("timestamp", LongType(), True)
])

# Read in the u.data into df
df = spark.read.option("sep",
                       "\t").schema(schema).csv("../data/ml-100k/u.data")

# Get counts by movie id
movie_counts = df.groupby("movieID").count()


# Create UDF to get name from id
def get_name_from_id(_id):
    aVar = 1
    return name_dict.value[_id]
Beispiel #6
0
from pyspark.sql.types import (
    DoubleType,
    LongType,
    StringType,
    StructField,
    StructType,
    BooleanType,
)

retention_schema = StructType([
    StructField("client_id", StringType(), True),
    StructField("subsession_start", StringType(), True),
    StructField("profile_creation", StringType(), True),
    StructField("days_since_creation", LongType(), True),
    StructField("channel", StringType(), True),
    StructField("app_version", StringType(), True),
    StructField("geo", StringType(), True),
    StructField("distribution_id", StringType(), True),
    StructField("is_funnelcake", BooleanType(), True),
    StructField("source", StringType(), True),
    StructField("medium", StringType(), True),
    StructField("campaign", StringType(), True),
    StructField("content", StringType(), True),
    StructField("sync_usage", StringType(), True),
    StructField("is_active", BooleanType(), True),
    StructField("usage_hours", DoubleType(), True),
    StructField("sum_squared_usage_hours", DoubleType(), True),
    StructField("total_uri_count", LongType(), True),
    StructField("unique_domains_count", LongType(), True),
])
Beispiel #7
0
def xform(df):
    multiply = pandas_udf(multiply_func, returnType=LongType())
    return df.select(multiply(col("id"), col("id"))).withColumnRenamed(
        "multiply_func(id, id)", "squared")
])

schema_world_area_codes = StructType([
    StructField('CODE', IntegerType(), True),
    StructField('NAME', StringType(), True)
])

schema_state = StructType([
    StructField('STATE_ABR', StringType(), True),
    StructField('STATE_FIPS', IntegerType(), True),
    StructField('STATE_NAME', StringType(), True),
    StructField('WAC_CODE', IntegerType(), True)
])

schema_city = StructType([
    StructField('CITY_ID', LongType(), True),
    StructField('CITY_NAME', StringType(), True),
    StructField('STATE_ABR', StringType(), True)
])

schema_airport = StructType([
    StructField('AIRPORT_CODE', IntegerType(), True),
    StructField('AIRPORT_NAME', StringType(), True),
    StructField('CITY_ID', LongType(), True)
])

schema_airline = StructType([
    StructField("AIRLINE_ID", IntegerType(), True),
    StructField("AIRLINE_NAME", StringType(), True),
    StructField("AIRLINE_CODE", StringType(), True)
])
Beispiel #9
0
def test_create_schema_view_fails_validate():
    """ Exercises code paths unischema.create_schema_view ValueError, and unischema.__str__."""
    TestSchema = Unischema('TestSchema', [
        UnischemaField('int_field', np.int8, (), ScalarCodec(IntegerType()), False),
        UnischemaField('string_field', np.string_, (), ScalarCodec(StringType()), False),
    ])
    with pytest.raises(ValueError, match='does not belong to the schema'):
        TestSchema.create_schema_view([UnischemaField('id', np.int64, (), ScalarCodec(LongType()), False)])
Beispiel #10
0
    ['EXPIRED', 5L],
    ['UNAVAILABLE', 6L]
]


def is_ls_sc_lt_success(behavior_id, duration):
    if behavior_id in [BEHAVIOR_ID_LS, BEHAVIOR_ID_SC]:
        if duration >= DURATION_LS_SC_SUCCESS:
            return 1L
    if behavior_id == BEHAVIOR_ID_LT:
        if duration >= DURATION_LT_SUCCESS:
            return 1L
    return 0L


is_ls_sc_lt_success = f.udf(is_ls_sc_lt_success, LongType())


def find(lists, key):
    for items in lists:
        if key.startswith(items[0]):
            return items[1]
    return 0L


def get_student_level_id(student_level):
    return find(STUDENT_LEVEL_DATA, student_level)


get_student_level_id = f.udf(get_student_level_id, LongType())
Beispiel #11
0
class CCSparkJob(object):
    """
    A simple Spark job definition to process Common Crawl data
    """

    name = 'CCSparkJob'

    output_schema = StructType([
        StructField("key", StringType(), True),
        StructField("val", LongType(), True)
    ])

    # description of input and output shown in --help
    input_descr = "Path to file listing input paths"
    output_descr = "Name of output table (saved in spark.sql.warehouse.dir)"

    warc_parse_http_header = True

    args = None
    records_processed = None
    warc_input_processed = None
    warc_input_failed = None
    log_level = 'INFO'
    logging.basicConfig(level=log_level, format=LOGGING_FORMAT)

    num_input_partitions = 400
    num_output_partitions = 10

    def parse_arguments(self):
        """ Returns the parsed arguments from the command line """

        description = self.name
        if self.__doc__ is not None:
            description += " - "
            description += self.__doc__
        arg_parser = argparse.ArgumentParser(prog=self.name,
                                             description=description,
                                             conflict_handler='resolve')

        arg_parser.add_argument("input", help=self.input_descr)
        arg_parser.add_argument("output", help=self.output_descr)

        arg_parser.add_argument("--num_input_partitions",
                                type=int,
                                default=self.num_input_partitions,
                                help="Number of input splits/partitions, "
                                "number of parallel tasks to process WARC "
                                "files/records")
        arg_parser.add_argument("--num_output_partitions",
                                type=int,
                                default=self.num_output_partitions,
                                help="Number of output partitions")
        arg_parser.add_argument("--output_format",
                                default="parquet",
                                help="Output format: parquet (default),"
                                " orc, json, csv")
        arg_parser.add_argument("--output_compression",
                                default="gzip",
                                help="Output compression codec: None,"
                                " gzip/zlib (default), snappy, lzo, etc.")
        arg_parser.add_argument(
            "--output_option",
            action='append',
            default=[],
            help="Additional output option pair"
            " to set (format-specific) output options, e.g.,"
            " `header=true` to add a header line to CSV files."
            " Option name and value are split at `=` and"
            " multiple options can be set by passing"
            " `--output_option <name>=<value>` multiple times")

        arg_parser.add_argument("--local_temp_dir",
                                default=None,
                                help="Local temporary directory, used to"
                                " buffer content from S3")

        arg_parser.add_argument("--log_level",
                                default=self.log_level,
                                help="Logging level")
        arg_parser.add_argument("--spark-profiler",
                                action='store_true',
                                help="Enable PySpark profiler and log"
                                " profiling metrics if job has finished,"
                                " cf. spark.python.profile")

        self.add_arguments(arg_parser)
        args = arg_parser.parse_args()
        if not self.validate_arguments(args):
            raise Exception("Arguments not valid")
        self.init_logging(args.log_level)

        return args

    def add_arguments(self, parser):
        pass

    def validate_arguments(self, args):
        if "orc" == args.output_format and "gzip" == args.output_compression:
            # gzip for Parquet, zlib for ORC
            args.output_compression = "zlib"
        return True

    def get_output_options(self):
        return {
            x[0]: x[1]
            for x in map(lambda x: x.split('=', 1), self.args.output_option)
        }

    def init_logging(self, level=None):
        if level is None:
            level = self.log_level
        else:
            self.log_level = level
        logging.basicConfig(level=level, format=LOGGING_FORMAT)

    def init_accumulators(self, sc):
        self.records_processed = sc.accumulator(0)
        self.warc_input_processed = sc.accumulator(0)
        self.warc_input_failed = sc.accumulator(0)

    def get_logger(self, spark_context=None):
        """Get logger from SparkContext or (if None) from logging module"""
        if spark_context is None:
            return logging.getLogger(self.name)
        return spark_context._jvm.org.apache.log4j.LogManager \
            .getLogger(self.name)

    def run(self):
        self.args = self.parse_arguments()

        conf = SparkConf()

        if self.args.spark_profiler:
            conf = conf.set("spark.python.profile", "true")

        sc = SparkContext(appName=self.name, conf=conf)
        sqlc = SQLContext(sparkContext=sc)

        self.init_accumulators(sc)

        self.run_job(sc, sqlc)

        if self.args.spark_profiler:
            sc.show_profiles()

        sc.stop()

    def log_aggregator(self, sc, agg, descr):
        self.get_logger(sc).info(descr.format(agg.value))

    def log_aggregators(self, sc):
        self.log_aggregator(sc, self.warc_input_processed,
                            'WARC/WAT/WET input files processed = {}')
        self.log_aggregator(sc, self.warc_input_failed,
                            'WARC/WAT/WET input files failed = {}')
        self.log_aggregator(sc, self.records_processed,
                            'WARC/WAT/WET records processed = {}')

    @staticmethod
    def reduce_by_key_func(a, b):
        return a + b

    def run_job(self, sc, sqlc):
        input_data = sc.textFile(self.args.input,
                                 minPartitions=self.args.num_input_partitions)

        output = input_data.mapPartitionsWithIndex(self.process_warcs) \
            .reduceByKey(self.reduce_by_key_func)

        sqlc.createDataFrame(output, schema=self.output_schema) \
            .coalesce(self.args.num_output_partitions) \
            .write \
            .format(self.args.output_format) \
            .option("compression", self.args.output_compression) \
            .options(**self.get_output_options()) \
            .saveAsTable(self.args.output)

        self.log_aggregators(sc)

    def process_warcs(self, id_, iterator):
        s3pattern = re.compile('^s3://([^/]+)/(.+)')
        base_dir = os.path.abspath(os.path.dirname(__file__))

        # S3 client (not thread-safe, initialize outside parallelized loop)
        no_sign_request = botocore.client.Config(
            signature_version=botocore.UNSIGNED)
        s3client = boto3.client('s3', config=no_sign_request)

        for uri in iterator:
            self.warc_input_processed.add(1)
            if uri.startswith('s3://'):
                self.get_logger().info('Reading from S3 {}'.format(uri))
                s3match = s3pattern.match(uri)
                if s3match is None:
                    self.get_logger().error("Invalid S3 URI: " + uri)
                    continue
                bucketname = s3match.group(1)
                path = s3match.group(2)
                warctemp = TemporaryFile(mode='w+b',
                                         dir=self.args.local_temp_dir)
                try:
                    s3client.download_fileobj(bucketname, path, warctemp)
                except botocore.client.ClientError as exception:
                    self.get_logger().error('Failed to download {}: {}'.format(
                        uri, exception))
                    self.warc_input_failed.add(1)
                    warctemp.close()
                    continue
                warctemp.seek(0)
                stream = warctemp
            elif uri.startswith('hdfs:/'):
                try:
                    import pydoop.hdfs as hdfs
                    self.get_logger().error("Reading from HDFS {}".format(uri))
                    stream = hdfs.open(uri)
                except RuntimeError as exception:
                    self.get_logger().error('Failed to open {}: {}'.format(
                        uri, exception))
                    self.warc_input_failed.add(1)
                    continue
            else:
                self.get_logger().info('Reading local stream {}'.format(uri))
                if uri.startswith('file:'):
                    uri = uri[5:]
                uri = os.path.join(base_dir, uri)
                try:
                    stream = open(uri, 'rb')
                except IOError as exception:
                    self.get_logger().error('Failed to open {}: {}'.format(
                        uri, exception))
                    self.warc_input_failed.add(1)
                    continue

            no_parse = (not self.warc_parse_http_header)
            try:
                archive_iterator = ArchiveIterator(stream,
                                                   no_record_parse=no_parse,
                                                   arc2warc=True)
                for res in self.iterate_records(uri, archive_iterator):
                    yield res
            except ArchiveLoadFailed as exception:
                self.warc_input_failed.add(1)
                self.get_logger().error('Invalid WARC: {} - {}'.format(
                    uri, exception))
            finally:
                stream.close()

    def process_record(self, record):
        raise NotImplementedError('Processing record needs to be customized')

    def iterate_records(self, _warc_uri, archive_iterator):
        """Iterate over all WARC records. This method can be customized
           and allows to access also values from ArchiveIterator, namely
           WARC record offset and length."""
        for record in archive_iterator:
            for res in self.process_record(record):
                yield res
            self.records_processed.add(1)
            # WARC record offset and length should be read after the record
            # has been processed, otherwise the record content is consumed
            # while offset and length are determined:
            #  warc_record_offset = archive_iterator.get_record_offset()
            #  warc_record_length = archive_iterator.get_record_length()

    @staticmethod
    def is_wet_text_record(record):
        """Return true if WARC record is a WET text/plain record"""
        return (record.rec_type == 'conversion'
                and record.content_type == 'text/plain')

    @staticmethod
    def is_wat_json_record(record):
        """Return true if WARC record is a WAT record"""
        return (record.rec_type == 'metadata'
                and record.content_type == 'application/json')

    @staticmethod
    def is_html(record):
        """Return true if (detected) MIME type of a record is HTML"""
        html_types = ['text/html', 'application/xhtml+xml']
        if (('WARC-Identified-Payload-Type' in record.rec_headers)
                and (record.rec_headers['WARC-Identified-Payload-Type']
                     in html_types)):
            return True
        content_type = record.http_headers.get_header('content-type', None)
        if content_type:
            for html_type in html_types:
                if html_type in content_type:
                    return True
        return False
Beispiel #12
0
from pyspark.sql.types import ArrayType, StringType, LongType, StructType, StructField
import re
import sys
reload(sys)
sys.setdefaultencoding('utf-8')

spark = SparkSession\
        .builder\
        .appName('こんにちは')\
        .config('master', 'yarn')\
        .getOrCreate()

comment_schema = StructType([
    StructField('command', StringType()),
    StructField('content', StringType()),
    StructField('date', LongType()),
    StructField('vpos', LongType()),
    StructField('video_id', StringType())
    ])

meta_schema = StructType([
    StructField('category', StringType()),
    StructField('comment_num', LongType()),
    StructField('description', StringType()),
    StructField('file_type', StringType()),
    StructField('length', LongType()),
    StructField('mylist_num', LongType()),
    StructField('size_high', LongType()),
    StructField('size_low', LongType()),
    StructField('tags', ArrayType(StringType())),
    StructField('title', StringType()),
# These allow us to create a schema for our data
from pyspark.sql.types import StructField, StructType, StringType, LongType

# A Spark Session is how we interact with Spark SQL to create Dataframes
from pyspark.sql import SparkSession

# This will help catch some PySpark errors
from py4j.protocol import Py4JJavaError

# Create a SparkSession under the name "reddit". Viewable via the Spark UI
spark = SparkSession.builder.appName("reddit").getOrCreate()

# Create a two column schema consisting of a string and a long integer
fields = [
    StructField("subreddit", StringType(), True),
    StructField("count", LongType(), True)
]
schema = StructType(fields)

# Create an empty DataFrame. We will continuously union our output with this
subreddit_counts = spark.createDataFrame([], schema)

# Establish a set of years and months to iterate over
years = ['2017', '2018', '2019']
months = [
    '01', '02', '03', '04', '05', '06', '07', '08', '09', '10', '11', '12'
]

# Keep track of all tables accessed via the job
tables_read = []
for year in years:
Beispiel #14
0
 def second(self) -> "ks.Series":
     """
     The seconds of the datetime.
     """
     return column_op(lambda c: F.second(c).cast(LongType()))(self._data).alias(self._data.name)
Beispiel #15
0
    def test_as_spark_type_pandas_on_spark_dtype(self):
        type_mapper = {
            # binary
            np.character: (np.character, BinaryType()),
            np.bytes_: (np.bytes_, BinaryType()),
            np.string_: (np.bytes_, BinaryType()),
            bytes: (np.bytes_, BinaryType()),
            # integer
            np.int8: (np.int8, ByteType()),
            np.byte: (np.int8, ByteType()),
            np.int16: (np.int16, ShortType()),
            np.int32: (np.int32, IntegerType()),
            np.int64: (np.int64, LongType()),
            np.int: (np.int64, LongType()),
            int: (np.int64, LongType()),
            # floating
            np.float32: (np.float32, FloatType()),
            np.float: (np.float64, DoubleType()),
            np.float64: (np.float64, DoubleType()),
            float: (np.float64, DoubleType()),
            # string
            np.str: (np.unicode_, StringType()),
            np.unicode_: (np.unicode_, StringType()),
            str: (np.unicode_, StringType()),
            # bool
            np.bool: (np.bool, BooleanType()),
            bool: (np.bool, BooleanType()),
            # datetime
            np.datetime64: (np.datetime64, TimestampType()),
            datetime.datetime: (np.dtype("datetime64[ns]"), TimestampType()),
            # DateType
            datetime.date: (np.dtype("object"), DateType()),
            # DecimalType
            decimal.Decimal: (np.dtype("object"), DecimalType(38, 18)),
            # ArrayType
            np.ndarray: (np.dtype("object"), ArrayType(StringType())),
            # CategoricalDtype
            CategoricalDtype(categories=["a", "b", "c"]): (
                CategoricalDtype(categories=["a", "b", "c"]),
                LongType(),
            ),
        }

        for numpy_or_python_type, (dtype, spark_type) in type_mapper.items():
            self.assertEqual(as_spark_type(numpy_or_python_type), spark_type)
            self.assertEqual(pandas_on_spark_type(numpy_or_python_type),
                             (dtype, spark_type))

            if isinstance(numpy_or_python_type, CategoricalDtype):
                # Nested CategoricalDtype is not yet supported.
                continue

            self.assertEqual(as_spark_type(List[numpy_or_python_type]),
                             ArrayType(spark_type))
            self.assertEqual(
                pandas_on_spark_type(List[numpy_or_python_type]),
                (np.dtype("object"), ArrayType(spark_type)),
            )

            # For NumPy typing, NumPy version should be 1.21+ and Python version should be 3.8+
            if sys.version_info >= (3, 8) and LooseVersion(
                    np.__version__) >= LooseVersion("1.21"):
                import numpy.typing as ntp

                self.assertEqual(
                    as_spark_type(ntp.NDArray[numpy_or_python_type]),
                    ArrayType(spark_type))
                self.assertEqual(
                    pandas_on_spark_type(ntp.NDArray[numpy_or_python_type]),
                    (np.dtype("object"), ArrayType(spark_type)),
                )

        with self.assertRaisesRegex(TypeError,
                                    "Type uint64 was not understood."):
            as_spark_type(np.dtype("uint64"))

        with self.assertRaisesRegex(TypeError,
                                    "Type object was not understood."):
            as_spark_type(np.dtype("object"))

        with self.assertRaisesRegex(TypeError,
                                    "Type uint64 was not understood."):
            pandas_on_spark_type(np.dtype("uint64"))

        with self.assertRaisesRegex(TypeError,
                                    "Type object was not understood."):
            pandas_on_spark_type(np.dtype("object"))
Beispiel #16
0
    StructField("movie_title_year", StringType()),
    StructField("genres", StringType())
])
    
    movies = FileLoader.FileLoader(spark)._FileLoader__getFiles(cfg['movies']['format'],
                                                                movSchema,
                                                                cfg['movies']['delimit'],
                                                                cfg['movies']['path'])
    movies.show(truncate=False)
    logger.info("Movies Data loaded")
  
    ratSchema = StructType([
    StructField("user_id", IntegerType()),
    StructField("movie_id", IntegerType()),
    StructField("rating", IntegerType()),
    StructField("rating_tmp", LongType())
])  
    ratings = FileLoader.FileLoader(spark)._FileLoader__getFiles(cfg['rating']['format'],
                                                                 ratSchema,
                                                                 cfg['rating']['delimit'],
                                                                 cfg['rating']['path'])
    ratings.show()
    logger.info("Rating Data loaded")
    
    usrSchema = StructType([
    StructField("userid", IntegerType()),
    StructField("twitterid", IntegerType())
])  
    users = FileLoader.FileLoader(spark)._FileLoader__getFiles(cfg['user']['format'],
                                                               usrSchema,
                                                               cfg['user']['delimit'],
Beispiel #17
0
    def test_infer_schema_from_pandas_instances(self):
        def func() -> pd.Series[int]:
            pass

        inferred = infer_return_type(func)
        self.assertEqual(inferred.dtype, np.int64)
        self.assertEqual(inferred.spark_type, LongType())

        def func() -> pd.Series[float]:
            pass

        inferred = infer_return_type(func)
        self.assertEqual(inferred.dtype, np.float64)
        self.assertEqual(inferred.spark_type, DoubleType())

        def func() -> "pd.DataFrame[np.float_, str]":
            pass

        expected = StructType(
            [StructField("c0", DoubleType()),
             StructField("c1", StringType())])
        inferred = infer_return_type(func)
        self.assertEqual(inferred.dtypes, [np.float64, np.unicode_])
        self.assertEqual(inferred.spark_type, expected)

        def func() -> "pandas.DataFrame[float]":
            pass

        expected = StructType([StructField("c0", DoubleType())])
        inferred = infer_return_type(func)
        self.assertEqual(inferred.dtypes, [np.float64])
        self.assertEqual(inferred.spark_type, expected)

        def func() -> "pd.Series[int]":
            pass

        inferred = infer_return_type(func)
        self.assertEqual(inferred.dtype, np.int64)
        self.assertEqual(inferred.spark_type, LongType())

        def func() -> pd.DataFrame[np.float64, str]:
            pass

        expected = StructType(
            [StructField("c0", DoubleType()),
             StructField("c1", StringType())])
        inferred = infer_return_type(func)
        self.assertEqual(inferred.dtypes, [np.float64, np.unicode_])
        self.assertEqual(inferred.spark_type, expected)

        def func() -> pd.DataFrame[np.float_]:
            pass

        expected = StructType([StructField("c0", DoubleType())])
        inferred = infer_return_type(func)
        self.assertEqual(inferred.dtypes, [np.float64])
        self.assertEqual(inferred.spark_type, expected)

        pdf = pd.DataFrame({"a": [1, 2, 3], "b": [3, 4, 5]})

        def func() -> pd.DataFrame[pdf.dtypes]:  # type: ignore[name-defined]
            pass

        expected = StructType(
            [StructField("c0", LongType()),
             StructField("c1", LongType())])
        inferred = infer_return_type(func)
        self.assertEqual(inferred.dtypes, [np.int64, np.int64])
        self.assertEqual(inferred.spark_type, expected)

        pdf = pd.DataFrame({
            "a": [1, 2, 3],
            "b": pd.Categorical(["a", "b", "c"])
        })

        def func() -> pd.Series[pdf.b.dtype]:  # type: ignore[name-defined]
            pass

        inferred = infer_return_type(func)
        self.assertEqual(inferred.dtype,
                         CategoricalDtype(categories=["a", "b", "c"]))
        self.assertEqual(inferred.spark_type, LongType())

        def func() -> pd.DataFrame[pdf.dtypes]:  # type: ignore[name-defined]
            pass

        expected = StructType(
            [StructField("c0", LongType()),
             StructField("c1", LongType())])
        inferred = infer_return_type(func)
        self.assertEqual(
            inferred.dtypes,
            [np.int64, CategoricalDtype(categories=["a", "b", "c"])])
        self.assertEqual(inferred.spark_type, expected)
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, LongType
from pyspark.sql.functions import col
"""
Tweet:
  tweet_id
  user_id
  language
  text
  created_at
"""

path = "/mnt/training/twitter/firehose/2018/01/08/18/twitterstream-1-2018-01-08-18-48-00-bcf3d615-9c04-44ec-aac9-25f966490aa4"

tweetSchema = StructType([
    StructField("id", LongType(), True),
    StructField("user", StructType([StructField("id", IntegerType(), True)]),
                True),
    StructField("lang", StringType(), True),
    StructField("text", StringType(), True),
    StructField("created_at", StringType(), True)
])
tweetDF = (spark.read.schema(tweetSchema).json(path))

display(tweetDF)

# COMMAND ----------

# TEST - Run this cell to test your solution
from pyspark.sql.functions import col
from pyspark.sql.types import StructField, StructType, StringType, LongType
from pyspark.sql import SparkSession

if __name__ == "__main__":
	spark = SparkSession.builder.master("local").appName("Enforce Schema").getOrCreate()
	myManualSchema = StructType([
		StructField("DEST_COUNTRY_NAME", StringType(), True),
		StructField("ORIGIN_COUNTRY_NAME", StringType(), True),
		StructField("count", LongType(), False)
	])

	df = spark.read.format("json").schema(myManualSchema) \
		.load("/Users/charlieohara/cbohara/spark_def_guide/data/flight-data/json/2015-summary.json")
	df.printSchema()
Beispiel #20
0
from pyspark import SparkContext, SparkConf
from tqdm import tqdm
from itertools import permutations
from collections import defaultdict
import time
import gensim

spark = SparkSession.builder.appName("user cf on spark").master(
    "local[8]").getOrCreate()

sc = spark.sparkContext

schema = StructType([
    StructField('userId', IntegerType(), True),
    StructField('movieId', IntegerType(), True),
    StructField('rating', LongType(), True),
    StructField('timestamp', IntegerType(), True)
])
ratings = spark.read.csv(
    r'D:\Users\hao.guo\比赛代码提炼\推荐系统\movielen\ml-20m\ratings.csv', header=True)

ratings = ratings.withColumn('rating', ratings['rating'].cast('int'))
ratings_rdd = ratings.select(['userId', 'movieId', 'rating']).rdd
ratings_rdd = ratings_rdd.sample(withReplacement=False,
                                 fraction=0.5,
                                 seed=2020)

print('user cf start......')
s = time.perf_counter()

createCombiner = lambda v: [v]
#
# ## 在udf中调用
# ```python
# from pyspark.sql.functions import udf
# from pyspark.sql.types import LongType
# squared_udf = udf(squared, LongType())
# df = spark.table("test")
# display(df.select("id", squared_udf("id").alias("id_squared")))
# ```
#

# In[ ]:

from pyspark.sql.types import LongType


# 使用自定义函数
def selfRpund(x):
    return round(x, 2)


# 注册自定义函数
spark.udf.register("self_round", selfRpund, LongType())

usa_flight.createOrReplaceTempView('ua')

# 使用sql方式进行调用
spark.sql("""
select 
""")
Beispiel #22
0
from json import loads
from pyspark.sql import SparkSession
from pyspark.sql.dataframe import StructType, StructField, DataFrame
from pyspark.sql.types import ArrayType, IntegerType, LongType, BooleanType
from pyspark.ml.linalg import Vectors, VectorUDT
from pyspark.sql.functions import udf, col

sc = SparkSession.builder.appName("dotingestion").getOrCreate()

schema = StructType([StructField("dire_lineup", ArrayType(IntegerType(), False), False),
                    StructField("radiant_lineup", ArrayType(IntegerType(), False), False),
                    StructField("radiant_win", BooleanType(), False),
                    StructField("match_id", LongType(), False)])

path = "data.json"
df = sc.read.json(path, schema=schema).na.drop("all").distinct()

with open("heroes.json", 'r', encoding="utf-8") as f:
    heroes_dict = {hero['id']: i for i, hero in enumerate(loads(f.read()))}

def convert_heroes_to_lineup(df: DataFrame) -> DataFrame:

    def onehot(heroes: ArrayType):
        lineup = tuple(heroes_dict[hero] for hero in heroes)
        return Vectors.dense([1 if hero_slot in lineup else 0 for hero_slot in range(len(heroes_dict))])

    heros_to_lineup_udf = udf(onehot, VectorUDT())
    return df.withColumn("dire_lineup_vec", heros_to_lineup_udf(df.dire_lineup))\
             .withColumn("radiant_lineup_vec", heros_to_lineup_udf(df.radiant_lineup))

df = convert_heroes_to_lineup(df)
Beispiel #23
0
    def attach_id_column(self, id_type: str, column: Name) -> "DataFrame":
        """
        Attach a column to be used as identifier of rows similar to the default index.

        See also `Default Index type
        <https://koalas.readthedocs.io/en/latest/user_guide/options.html#default-index-type>`_.

        Parameters
        ----------
        id_type : string
            The id type.

            - 'sequence' : a sequence that increases one by one.

              .. note:: this uses Spark's Window without specifying partition specification.
                  This leads to move all data into single partition in single machine and
                  could cause serious performance degradation.
                  Avoid this method against very large dataset.

            - 'distributed-sequence' : a sequence that increases one by one,
              by group-by and group-map approach in a distributed manner.
            - 'distributed' : a monotonically increasing sequence simply by using PySpark’s
              monotonically_increasing_id function in a fully distributed manner.

        column : string or tuple of string
            The column name.

        Returns
        -------
        DataFrame
            The DataFrame attached the column.

        Examples
        --------
        >>> df = ps.DataFrame({"x": ['a', 'b', 'c']})
        >>> df.pandas_on_spark.attach_id_column(id_type="sequence", column="id")
           x  id
        0  a   0
        1  b   1
        2  c   2

        >>> df.pandas_on_spark.attach_id_column(id_type="distributed-sequence", column=0)
           x  0
        0  a  0
        1  b  1
        2  c  2

        >>> df.pandas_on_spark.attach_id_column(id_type="distributed", column=0.0)
        ... # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
           x  0.0
        0  a  ...
        1  b  ...
        2  c  ...

        For multi-index columns:

        >>> df = ps.DataFrame({("x", "y"): ['a', 'b', 'c']})
        >>> df.pandas_on_spark.attach_id_column(id_type="sequence", column=("id-x", "id-y"))
           x id-x
           y id-y
        0  a    0
        1  b    1
        2  c    2

        >>> df.pandas_on_spark.attach_id_column(id_type="distributed-sequence", column=(0, 1.0))
           x   0
           y 1.0
        0  a   0
        1  b   1
        2  c   2
        """
        from pyspark.pandas.frame import DataFrame

        if id_type == "sequence":
            attach_func = InternalFrame.attach_sequence_column
        elif id_type == "distributed-sequence":
            attach_func = InternalFrame.attach_distributed_sequence_column
        elif id_type == "distributed":
            attach_func = InternalFrame.attach_distributed_column
        else:
            raise ValueError(
                "id_type should be one of 'sequence', 'distributed-sequence' and 'distributed'"
            )

        assert is_name_like_value(column, allow_none=False), column
        if not is_name_like_tuple(column):
            column = (column, )

        internal = self._psdf._internal

        if len(column) != internal.column_labels_level:
            raise ValueError(
                "The given column `{}` must be the same length as the existing columns."
                .format(column))
        elif column in internal.column_labels:
            raise ValueError("The given column `{}` already exists.".format(
                name_like_string(column)))

        # Make sure the underlying Spark column names are the form of
        # `name_like_string(column_label)`.
        sdf = internal.spark_frame.select([
            scol.alias(SPARK_INDEX_NAME_FORMAT(i))
            for i, scol in enumerate(internal.index_spark_columns)
        ] + [
            scol.alias(name_like_string(label)) for scol, label in zip(
                internal.data_spark_columns, internal.column_labels)
        ])
        sdf = attach_func(sdf, name_like_string(column))

        return DataFrame(
            InternalFrame(
                spark_frame=sdf,
                index_spark_columns=[
                    scol_for(sdf, SPARK_INDEX_NAME_FORMAT(i))
                    for i in range(internal.index_level)
                ],
                index_names=internal.index_names,
                index_fields=internal.index_fields,
                column_labels=internal.column_labels + [column],
                data_spark_columns=([
                    scol_for(sdf, name_like_string(label))
                    for label in internal.column_labels
                ] + [scol_for(sdf, name_like_string(column))]),
                data_fields=internal.data_fields + [
                    InternalField.from_struct_field(
                        StructField(name_like_string(column),
                                    LongType(),
                                    nullable=False))
                ],
                column_label_names=internal.column_label_names,
            ).resolved_copy)
Beispiel #24
0
def process_log_data(spark, input_data, output_data):
    """Reads in log data and songs_table, creates users_table,
       time_table, and songplays_table. Writes all three to S3
       in parquet format.

       Parameters
       ----------
       spark: SparkSession instance
       input_data: str
           root location of input data in S3
       output_data: str
           root location of output data in S3
    """

    # get filepath to log data file
    log_data = input_data + 'log_data/*/*/*.json'

    log_schema = StructType([
        StructField("artist", StringType(), True),
        StructField("auth", StringType(), True),
        StructField("firstName", StringType(), True),
        StructField("gender", StringType(), True),
        StructField("itemInSession", LongType(), True),
        StructField("lastName", StringType(), True),
        StructField("length", DoubleType(), True),
        StructField("level", StringType(), True),
        StructField("location", StringType(), True),
        StructField("method", StringType(), True),
        StructField("page", StringType(), True),
        StructField("registration", DoubleType(), True),
        StructField("sessionId", LongType(), True),
        StructField("song", StringType(), True),
        StructField("status", LongType(), True),
        StructField("ts", LongType(), True),
        StructField("userAgent", StringType(), True),
        StructField("userId", StringType(), True),
    ])

    # read log data file
    log_cols = [
        'userId', 'firstName', 'lastName', 'gender', 'ts', 'level',
        'sessionId', 'location', 'userAgent', 'song'
    ]
    df = spark.read \
        .json(log_data, schema=log_schema) \
        .select(*log_cols) \
        .filter("page = 'NextSong'")

    # extract columns for users table
    users_columns = [
        col('userId').alias('user_id'),
        col('firstName').alias('first_name'),
        col('lastName').alias('last_name'),
        col('gender'),
        col('level')
    ]
    users_table = df.select(*users_columns) \
        .filter(col('user_id').isNotNull()) \
        .filter(col('first_name').isNotNull()) \
        .filter(col('last_name').isNotNull()) \
        .filter(col('level').isNotNull()) \
        .dropDuplicates()

    # write users table to parquet files
    users_path = output_data + 'users_table/users_table.parquet'
    users_table.write \
        .mode('overwrite') \
        .parquet(users_path)

    # create timestamp column from original timestamp column
    get_timestamp = udf(lambda x: round(x / 1000))
    df = df.withColumn('timestamp', get_timestamp(df.ts))

    # create datetime column from original timestamp column
    get_datetime = udf(
        lambda x: datetime.fromtimestamp(x).strftime("%Y-%m-%d %H:%M:%S"))
    df = df.withColumn('datetime', get_datetime(df.timestamp))

    # extract columns to create time table
    # start_time, hour, day, week, month, year, weekday
    time_cols = [
        col('ts').alias('start_times'),
        hour('datetime').alias('hour'),
        dayofmonth('datetime').alias('day'),
        weekofyear('datetime').alias('week'),
        month('datetime').alias('month'),
        year('datetime').alias('year'),
        dayofweek('datetime').alias('weekday')
    ]
    time_table = df.select(*time_cols) \
        .filter(col('ts').isNotNull()) \
        .dropDuplicates()

    # write time table to parquet files partitioned by year and month
    time_path = output_data + 'time_table/time_table.parquet'
    time_table.write \
        .mode('overwrite') \
        .partitionBy('year', 'month') \
        .parquet(time_path)

    # read in song data to use for songplays table
    # 'song_id', 'title', 'artist_id', 'year', 'duration'
    songs_path = output_data + 'songs_table/songs_table.parquet'
    song_df = spark.read.parquet(songs_path)

    # get columns from joined song and log datasets to create songplays table
    # add year and month for partitioning
    songplays_table = df.alias('l').join(song_df.alias('s'),
                                         on=col('l.song') == col('s.title')) \
        .select([monotonically_increasing_id().alias('songplay_id'),
                 col('l.ts').alias('start_time'),
                 col('l.userId').alias('user_id'),
                 col('l.level'),
                 col('s.song_id'),
                 col('s.artist_id'),
                 col('l.sessionId').alias('session_id'),
                 col('l.location'),
                 col('l.userAgent').alias('user_agent'),
                 year(col('l.datetime')).alias('year'),
                 month(col('l.datetime')).alias('month')])

    # write songplays table to parquet files partitioned by year and month
    songplays_path = output_data + 'songplays_table/songplays_table.parquet'
    songplays_table.write \
        .mode('overwrite') \
        .partitionBy('year', 'month') \
        .parquet(songplays_path)
Beispiel #25
0
def process_log_data(spark, input_data, output_data):
    '''
    Reads in log data from S3, filters for songplay records, chooses the most recent record associated with a user
        in order to create an up-to-date users table, extracts datetime data to create a timestamp table,
        and joins log and song data to create a songplays table. Duplicate and null values are dropped from users and time
        tables, and each table is written to S3 as a parquet file, partitioned by "year" and "month" columns in the time
        and songplays table.
        
        Parameters:
            spark: The current SparkSession
            input_data: The path of the source data bucket on S3, formatted with s3a protocol
            output_data: The path of the target bucket for writes on S3, formatted with s3a protocol
    '''
    log_data = input_data + "log_data/*/*/*.json"

    schema = StructType([
        StructField("artist", StringType()),
        StructField("auth", StringType()),
        StructField("firstName", StringType()),
        StructField("gender", StringType()),
        StructField("itemInSession", IntegerType()),
        StructField("lastName", StringType()),
        StructField("length", DoubleType()),
        StructField("level", StringType()),
        StructField("location", StringType()),
        StructField("method", StringType()),
        StructField("page", StringType()),
        StructField("registration", DoubleType()),
        StructField("sessionId", IntegerType()),
        StructField("song", StringType()),
        StructField("status", IntegerType()),
        StructField("ts", LongType()),
        StructField("userAgent", StringType()),
        StructField("userId", StringType())
    ])
    df = spark.read.json(log_data, schema)

    df = df.where(df.page == "NextSong")

    w = Window.partitionBy("userId")
    users_table = df.dropna(subset=["userId"]) \
                    .withColumn("userid_occurrence_num", row_number() \
                    .over(w.orderBy("ts"))) \
                    .withColumn("max_occurrence_num", max("userid_occurrence_num").over(w)) \
                    .where(col("userid_occurrence_num") == col("max_occurrence_num")) \
                    .select("userId", "firstName", "lastName", "gender", "level")

    users_table.write.parquet(output_data + "users.parquet")

    get_timestamp = udf(lambda ts: ts / 1000)
    df = df.withColumn("epoch_ts", get_timestamp(df.ts))

    get_datetime = udf(lambda ts: datetime.fromtimestamp(ts), TimestampType())
    df = df.withColumn("datetime", get_datetime(df.epoch_ts))

    time_table = df.select("datetime").dropna().dropDuplicates()
    time_table = time_table.withColumn("hour", hour(col("datetime"))) \
        .withColumn("day", dayofmonth(col("datetime"))) \
        .withColumn("week", weekofyear(col("datetime"))) \
        .withColumn("month", month(col("datetime"))) \
        .withColumn("year", year(col("datetime"))) \
        .withColumn("weekday", date_format(col("datetime"), "u"))

    time_table.write.partitionBy("year",
                                 "month").parquet(output_data + "time.parquet")

    song_df = spark.read.json(
        input_data + "song_data/*/*/*/*.json",
        StructType([
            StructField("num_songs", IntegerType()),
            StructField("artist_id", StringType()),
            StructField("artist_latitude", FloatType()),
            StructField("artist_longitude", FloatType()),
            StructField("artist_location", StringType()),
            StructField("artist_name", StringType()),
            StructField("song_id", StringType()),
            StructField("title", StringType()),
            StructField("duration", FloatType()),
            StructField("year", IntegerType())
        ]))

    df = df.withColumn("songplay_id", row_number() \
                .over(Window.partitionBy("page").orderBy("ts")))

    join_condition = [df.song == song_df.title]
    songplays_table = df.join(song_df, join_condition, "left").select("songplay_id",
                                                              col("datetime").alias("start_time"),
                                                              col("userId").alias("user_id"),
                                                              "level",
                                                              "song_id",
                                                              "artist_id",
                                                              col("sessionId").alias("session_id"),
                                                              "location",
                                                              col("userAgent").alias("user_agent")) \
                                                            .withColumn("year", year(col("start_time"))) \
                                                            .withColumn("month", month(col("start_time")))

    songplays_table.write.partitionBy(
        "year", "month").parquet(output_data + "songplays.parquet")
def main():
    get_unstructured_file()
    logger.success('Downloaded Files')

    schema = StructType([
        StructField('Last Accessed Url', StringType(), True),
        StructField('Page Category', StringType(), True),
        StructField('Page Category 1', StringType(), True),
        StructField('Page Category 2', StringType(), True),
        StructField('Page Category 3', StringType(), True),
        StructField('Page Name', StringType(), True),
        StructField('at', StringType(), True),
        StructField('browser', StringType(), True),
        StructField('carrier', StringType(), True),
        StructField('city_name', StringType(), True),
        StructField('clv_total', LongType(), True),
        StructField('country', StringType(), True),
        StructField('custom_1', StringType(), True),
        StructField('custom_2', StringType(), True),
        StructField('custom_3', StringType(), True),
        StructField('custom_4', StringType(), True),
        StructField('device_new', BooleanType(), True),
        StructField('first-accessed-page', StringType(), True),
        StructField('install_uuid', StringType(), True),
        StructField('language', StringType(), True),
        StructField('library_ver', StringType(), True),
        StructField('marketing_campaign', StringType(), True),
        StructField('marketing_medium', StringType(), True),
        StructField('marketing_source', StringType(), True),
        StructField('model', StringType(), True),
        StructField('name', StringType(), True),
        StructField('nth', LongType(), True),
        StructField('os_ver', StringType(), True),
        StructField('platform', StringType(), True),
        StructField('region', StringType(), True),
        StructField('session_uuid', StringType(), True),
        StructField('studentId_clientType', StringType(), True),
        StructField('type', StringType(), True),
        StructField('user_type', StringType(), True),
        StructField('uuid', StringType(), True)
    ])

    logger.debug('Creating DataFrame...')
    df = spark.read.schema(schema).json('*.json')
    logger.success('DataFrame created with {} rows'.format(df.count()))

    df = df.select('at',
                'browser',
                'country',
                'custom_4',
                'studentId_clientType',
                'Page Name',
                'Last Accessed Url') \
            .filter(df['studentId_clientType'].isNotNull())

    df_country = df.select(df.country) \
        .filter(df.country != 'br') \
        .groupBy('country').count()
    #df_country.repartition(1).write.format('csv').mode('overwrite').option('header', 'true').save('country')



    df_users = df.filter(df.custom_4.isNotNull()) \
                .select(df.custom_4) \
                .groupBy(df.custom_4).count()
    #df_users.repartition(1).write.format('csv').mode('overwrite').option('header', 'true').save('user.csv')

    df_result = df.withColumn('id',
                              clean_studentId(df['studentId_clientType']))
    df_result = df_result.drop('studentId_clientType')

    query = """SELECT fat.id, state, city, cou.name course
                                FROM "DM_PASSEI_DIRETO".fat_students fat
                                INNER JOIN "DM_PASSEI_DIRETO".dim_courses cou
                                ON fat.course_id = cou.id
                                INNER JOIN "DM_PASSEI_DIRETO".dim_sessions ds 
                                ON fat.id = ds.student_id 
                                WHERE CAST(ds.start_time as VARCHAR) LIKE '2017-11-16%'"""
    students = dw_get_data(query)

    students_schema = StructType([
        StructField('id', StringType(), True),
        StructField('state', StringType(), True),
        StructField('city', StringType(), True),
        StructField('course', StringType(), True)
    ])

    df_dim = spark.createDataFrame(students, students_schema)

    df_result = df_result.join(df_dim, 'id', how='inner').distinct()
    #df_result.repartition(1).write.format('csv').mode('overwrite').option('header', 'true').save('full.csv')

    df_country.toPandas().to_csv('country.csv')
    df_users.toPandas().to_csv('users.csv')
    df_result.toPandas().to_csv('full.csv')
    send_files('*.csv')
Beispiel #27
0
def select_relevant_columns(df,
                            discrete_action: bool = True,
                            include_possible_actions: bool = True):
    """ Select all the relevant columns and perform type conversions. """
    if not discrete_action and include_possible_actions:
        raise NotImplementedError(
            "currently we don't support include_possible_actions")

    select_col_list = [
        # pyre-fixme[16]: Module `functions` has no attribute `col`.
        # pyre-fixme[16]: Module `functions` has no attribute `col`.
        col("reward").cast(FloatType()),
        # pyre-fixme[16]: Module `functions` has no attribute `col`.
        # pyre-fixme[16]: Module `functions` has no attribute `col`.
        col("state_features").cast(ArrayType(FloatType())),
        # pyre-fixme[16]: Module `functions` has no attribute `col`.
        # pyre-fixme[16]: Module `functions` has no attribute `col`.
        col("state_features_presence").cast(ArrayType(BooleanType())),
        # pyre-fixme[16]: Module `functions` has no attribute `col`.
        # pyre-fixme[16]: Module `functions` has no attribute `col`.
        col("next_state_features").cast(ArrayType(FloatType())),
        # pyre-fixme[16]: Module `functions` has no attribute `col`.
        # pyre-fixme[16]: Module `functions` has no attribute `col`.
        col("next_state_features_presence").cast(ArrayType(BooleanType())),
        # pyre-fixme[16]: Module `functions` has no attribute `col`.
        # pyre-fixme[16]: Module `functions` has no attribute `col`.
        col("not_terminal").cast(BooleanType()),
        # pyre-fixme[16]: Module `functions` has no attribute `col`.
        # pyre-fixme[16]: Module `functions` has no attribute `col`.
        col("action_probability").cast(FloatType()),
        # pyre-fixme[16]: Module `functions` has no attribute `col`.
        # pyre-fixme[16]: Module `functions` has no attribute `col`.
        col("mdp_id").cast(LongType()),
        # pyre-fixme[16]: Module `functions` has no attribute `col`.
        # pyre-fixme[16]: Module `functions` has no attribute `col`.
        col("sequence_number").cast(LongType()),
        # pyre-fixme[16]: Module `functions` has no attribute `col`.
        # pyre-fixme[16]: Module `functions` has no attribute `col`.
        col("step").cast(LongType()),
        # pyre-fixme[16]: Module `functions` has no attribute `col`.
        # pyre-fixme[16]: Module `functions` has no attribute `col`.
        col("time_diff").cast(LongType()),
        # pyre-fixme[16]: Module `functions` has no attribute `col`.
        # pyre-fixme[16]: Module `functions` has no attribute `col`.
        col("metrics").cast(ArrayType(FloatType())),
        # pyre-fixme[16]: Module `functions` has no attribute `col`.
        # pyre-fixme[16]: Module `functions` has no attribute `col`.
        col("metrics_presence").cast(ArrayType(BooleanType())),
    ]

    if discrete_action:
        select_col_list += [
            # pyre-fixme[16]: Module `functions` has no attribute `col`.
            # pyre-fixme[16]: Module `functions` has no attribute `col`.
            col("action").cast(LongType()),
            # pyre-fixme[16]: Module `functions` has no attribute `col`.
            # pyre-fixme[16]: Module `functions` has no attribute `col`.
            col("next_action").cast(LongType()),
        ]
    else:
        select_col_list += [
            # pyre-fixme[16]: Module `functions` has no attribute `col`.
            # pyre-fixme[16]: Module `functions` has no attribute `col`.
            col("action").cast(ArrayType(FloatType())),
            # pyre-fixme[16]: Module `functions` has no attribute `col`.
            # pyre-fixme[16]: Module `functions` has no attribute `col`.
            col("next_action").cast(ArrayType(FloatType())),
            # pyre-fixme[16]: Module `functions` has no attribute `col`.
            # pyre-fixme[16]: Module `functions` has no attribute `col`.
            col("action_presence").cast(ArrayType(BooleanType())),
            # pyre-fixme[16]: Module `functions` has no attribute `col`.
            # pyre-fixme[16]: Module `functions` has no attribute `col`.
            col("next_action_presence").cast(ArrayType(BooleanType())),
        ]

    if include_possible_actions:
        select_col_list += [
            # pyre-fixme[16]: Module `functions` has no attribute `col`.
            # pyre-fixme[16]: Module `functions` has no attribute `col`.
            col("possible_actions_mask").cast(ArrayType(LongType())),
            # pyre-fixme[16]: Module `functions` has no attribute `col`.
            # pyre-fixme[16]: Module `functions` has no attribute `col`.
            col("possible_next_actions_mask").cast(ArrayType(LongType())),
        ]

    return df.select(*select_col_list)
Beispiel #28
0
    def test_infer_schema_with_names_pandas_instances(self):
        def func() -> 'pd.DataFrame["a" : np.float_, "b":str]':  # noqa: F405
            pass

        expected = StructType(
            [StructField("a", DoubleType()),
             StructField("b", StringType())])
        inferred = infer_return_type(func)
        self.assertEqual(inferred.dtypes, [np.float64, np.unicode_])
        self.assertEqual(inferred.spark_type, expected)

        def func() -> "pd.DataFrame['a': float, 'b': int]":  # noqa: F405
            pass

        expected = StructType(
            [StructField("a", DoubleType()),
             StructField("b", LongType())])
        inferred = infer_return_type(func)
        self.assertEqual(inferred.dtypes, [np.float64, np.int64])
        self.assertEqual(inferred.spark_type, expected)

        pdf = pd.DataFrame({"a": [1, 2, 3], "b": [3, 4, 5]})

        def func() -> pd.DataFrame[zip(pdf.columns, pdf.dtypes)]:
            pass

        expected = StructType(
            [StructField("a", LongType()),
             StructField("b", LongType())])
        inferred = infer_return_type(func)
        self.assertEqual(inferred.dtypes, [np.int64, np.int64])
        self.assertEqual(inferred.spark_type, expected)

        pdf = pd.DataFrame({("x", "a"): [1, 2, 3], ("y", "b"): [3, 4, 5]})

        def func() -> pd.DataFrame[zip(pdf.columns, pdf.dtypes)]:
            pass

        expected = StructType([
            StructField("(x, a)", LongType()),
            StructField("(y, b)", LongType())
        ])
        inferred = infer_return_type(func)
        self.assertEqual(inferred.dtypes, [np.int64, np.int64])
        self.assertEqual(inferred.spark_type, expected)

        pdf = pd.DataFrame({
            "a": [1, 2, 3],
            "b": pd.Categorical(["a", "b", "c"])
        })

        def func() -> pd.DataFrame[zip(pdf.columns, pdf.dtypes)]:
            pass

        expected = StructType(
            [StructField("a", LongType()),
             StructField("b", LongType())])
        inferred = infer_return_type(func)
        self.assertEqual(
            inferred.dtypes,
            [np.int64, CategoricalDtype(categories=["a", "b", "c"])])
        self.assertEqual(inferred.spark_type, expected)
Beispiel #29
0
 def week(self) -> "ps.Series":
     """
     The week ordinal of the year.
     """
     return self._data.spark.transform(
         lambda c: F.weekofyear(c).cast(LongType()))
Beispiel #30
0
 def hour(self) -> "ks.Series":
     """
     The hours of the datetime.
     """
     return column_op(lambda c: F.hour(c).cast(LongType()))(self._data).alias(self._data.name)