def main(args):
    spark = (sql.SparkSession.builder.appName(
        'Spark SQL Query').enableHiveSupport().getOrCreate())
    for name, (fmt, options) in args.table_metadata:
        logging.info('Loading %s', name)
        spark.read.format(fmt).options(**options).load().createTempView(name)

    results = []
    for script in args.sql_scripts:
        # Read script from object storage using rdd API
        query = '\n'.join(spark.sparkContext.textFile(script).collect())

        try:
            logging.info('Running %s', script)
            start = time.time()
            # spark-sql does not limit its output. Replicate that here by setting
            # limit to max Java Integer. Hopefully you limited the output in SQL or
            # you are going to have a bad time. Note this is not true of all TPC-DS or
            # TPC-H queries and they may crash with small JVMs.
            # pylint: disable=protected-access
            spark.sql(query).show(spark._jvm.java.lang.Integer.MAX_VALUE)
            # pylint: enable=protected-access
            duration = time.time() - start
            results.append(sql.Row(script=script, duration=duration))
        # These correspond to errors in low level Spark Excecution.
        # Let ParseException and AnalysisException fail the job.
        except (sql.utils.QueryExecutionException,
                py4j.protocol.Py4JJavaError) as e:
            logging.error('Script %s failed', script, exc_info=e)

    logging.info('Writing results to %s', args.report_dir)
    spark.createDataFrame(results).coalesce(1).write.json(args.report_dir)
示例#2
0
def _to_rows(df, by, k):
    """Convert to a list of Row objects."""
    if isinstance(by, str):
        by, k = [by], [k]
    if isinstance(df, pd.Series):
        df = df.to_frame().T
    if isinstance(df, pd.DataFrame):
        for col, val in zip(by, k):
            df[col] = val
        df.columns = map(str, df.columns)  # Else Row() will fail
        rows = (sql.Row(**OrderedDict(sorted(x.items())))
                for x in df.to_dict(orient='records'))
    else:
        # It's hopefully only a value.
        result = ((k, v) for k, v in zip(by + ['value'], k + [df]))
        rows = (sql.Row(**OrderedDict(sorted(result))), )
    return rows
示例#3
0
def rowwise_function(row,metadata_dict):
    import pyspark.sql as r
    row_dict = row.asDict()
    err_col,err_message,flag =check_function(row_dict,metadata_dict)
    row_dict.update({'Error_Column':err_col})
    row_dict.update({'Error_Message':err_message})
    row_dict.update({'Flag':flag})
    newrow = r.Row(**row_dict)
    return newrow
示例#4
0
    def test_import_twitter_data(self):
        file_name = __name__ + '.csv'

        input_df = self.ss.createDataFrame(
            [
                sql.Row(id=14580,
                        timestamp=8888888,
                        postalCode='95823',
                        lon=34.256,
                        lat=34.258,
                        tweet='something something',
                        user_id=44438487,
                        application='instagram',
                        source='instagram?')
            ],
            schema=types.StructType([
                types.StructField('id', types.LongType()),
                types.StructField('timestamp',
                                  types.LongType(),
                                  nullable=False),
                types.StructField('postalCode', types.StringType()),
                types.StructField('lon', types.DoubleType(), nullable=False),
                types.StructField('lat', types.DoubleType(), nullable=False),
                types.StructField('tweet', types.StringType(), nullable=False),
                types.StructField('user_id', types.LongType()),
                types.StructField('application', types.StringType()),
                types.StructField('source', types.StringType())
            ]))

        try:
            input_df.coalesce(1).write.option('header', 'true').csv(file_name)
            actual = run_job.import_twitter_data(self.ss,
                                                 file_name).first().asDict()
            expected = sql.Row(timestamp=8888888,
                               lon=34.256,
                               lat=34.258,
                               tweet='something something').asDict()
            self.assertDictEqual(expected, actual)
        except:
            raise
        finally:
            subprocess.call(['hadoop', 'fs', '-rm', '-r', '-f', file_name])
 def _initialize_results(self, scaffolds):
     data = [
         ps.Row(smiles=scaffold, scaffold=scaffold, decorations={}, count=1)
         for scaffold in scaffolds
     ]
     data_schema = pst.StructType([
         pst.StructField("smiles", pst.StringType()),
         pst.StructField("scaffold", pst.StringType()),
         pst.StructField("decorations",
                         pst.MapType(pst.IntegerType(), pst.StringType())),
         pst.StructField("count", pst.IntegerType())
     ])
     return SPARK.createDataFrame(data, schema=data_schema)
示例#6
0
    def test_filter_by_dates(self):
        input_tweets_schema = types.StructType(
            [types.StructField('timestamp', types.LongType())])

        output_tweets_schema = types.StructType(
            [types.StructField('date', types.DateType())])

        EPOCH = datetime.date(1970, 1, 1)
        start_date = datetime.date(2016, 3, 3)
        start_timestamp = int((start_date - EPOCH).total_seconds())
        end_date = datetime.date(2016, 3, 5)
        end_timestamp = int((end_date - EPOCH).total_seconds())

        input_timestamps = [
            start_timestamp - 1, start_timestamp, start_timestamp + 1,
            end_timestamp - 1, end_timestamp, end_timestamp + 24 * 60 * 60
        ]

        input_tweets_df = self.ss.createDataFrame(
            [sql.Row(timestamp=t) for t in input_timestamps],
            schema=input_tweets_schema)

        output_tweets_df = self.ss.createDataFrame([
            sql.Row(date=datetime.datetime.utcfromtimestamp(t).replace(
                tzinfo=None).date()) for t in input_timestamps
            if t >= start_timestamp and t < end_timestamp
        ],
                                                   schema=output_tweets_schema)

        actual_df = run_job.filter_by_dates(self.ss, input_tweets_df,
                                            start_date, end_date)
        actual_list = sorted(actual_df.collect(), key=lambda d: d['date'])
        expected_list = sorted(output_tweets_df.collect(),
                               key=lambda d: d['date'])

        self.assertListEqual(expected_list, actual_list)
示例#7
0
 def enumerate(row: ps.Row, enumerator: FragmentReactionSliceEnumerator, max_cuts: int) -> List[ps.Row]:
     attachments = AttachmentPoints()
     fields = row.split("\t")
     smiles = fields[0]
     mol = uc.to_mol(smiles)
     out_rows = []
     if mol:
         for sliced_mol in enumerator.enumerate(mol, cuts=max_cuts):
             row_dict = {
                 DataframeColumnsEnum.SCAFFOLDS:
                     attachments.remove_attachment_point_numbers(sliced_mol.scaffold_smiles),
                 DataframeColumnsEnum.DECORATIONS: sliced_mol.decorations_smiles,
                 DataframeColumnsEnum.ORIGINAL: sliced_mol.original_smiles,
                 DataframeColumnsEnum.MAX_CUTS: max_cuts}
             out_rows.append(ps.Row(**row_dict))
     return out_rows
示例#8
0
def main(args):
  builder = sql.SparkSession.builder.appName('Spark SQL Query')
  if args.enable_hive:
    builder = builder.enableHiveSupport()
  spark = builder.getOrCreate()
  if args.database:
    spark.catalog.setCurrentDatabase(args.database)
  table_metadata = []
  if args.table_metadata:
    table_metadata = json.loads(load_file(spark, args.table_metadata)).items()
  for name, (fmt, options) in table_metadata:
    logging.info('Loading %s', name)
    spark.read.format(fmt).options(**options).load().createTempView(name)
  if args.table_cache:
    # This captures both tables in args.database and views from table_metadata
    for table in spark.catalog.listTables():
      spark.sql('CACHE {lazy} TABLE {name}'.format(
          lazy='LAZY' if args.table_cache == 'lazy' else '',
          name=table.name))

  results = []
  for script in args.sql_scripts:
    # Read script from object storage using rdd API
    query = load_file(spark, script)

    try:
      logging.info('Running %s', script)
      start = time.time()
      # spark-sql does not limit its output. Replicate that here by setting
      # limit to max Java Integer. Hopefully you limited the output in SQL or
      # you are going to have a bad time. Note this is not true of all TPC-DS or
      # TPC-H queries and they may crash with small JVMs.
      # pylint: disable=protected-access
      spark.sql(query).show(spark._jvm.java.lang.Integer.MAX_VALUE)
      # pylint: enable=protected-access
      duration = time.time() - start
      results.append(sql.Row(script=script, duration=duration))
    # These correspond to errors in low level Spark Excecution.
    # Let ParseException and AnalysisException fail the job.
    except (sql.utils.QueryExecutionException,
            py4j.protocol.Py4JJavaError) as e:
      logging.error('Script %s failed', script, exc_info=e)

  logging.info('Writing results to %s', args.report_dir)
  spark.createDataFrame(results).coalesce(1).write.json(args.report_dir)
示例#9
0
 def collect_failures(
         self, row: ps.Row,
         enumerator: FailingReactionsEnumerator) -> List[ps.Row]:
     fields = row.split("\t")
     smiles = fields[0]
     mol = uc.to_mol(smiles)
     out_rows = []
     if mol:
         for failed_reaction in enumerator.enumerate(
                 mol, failures_limit=self.configuration.failures_limit):
             row_dict = {
                 self._columns.REACTION: failed_reaction.reaction_smirks,
                 self._columns.ORIGINAL: failed_reaction.molecule_smiles
             }
             print("found failed reaction")
             out_rows.append(ps.Row(**row_dict))
             if self.configuration.failures_limit <= len(out_rows):
                 break
     return out_rows
 def _enumerate(row,
                max_cuts=self.max_cuts,
                enumerator=self.enumerator):
     fields = row.split("\t")
     smiles = fields[0]
     mol = uc.to_mol(smiles)
     out_rows = []
     if mol:
         for cuts in range(1, max_cuts + 1):
             for sliced_mol in enumerator.enumerate(mol, cuts=cuts):
                 # normalize scaffold and decorations
                 scaff_smi, dec_smis = sliced_mol.to_smiles()
                 dec_smis = [
                     smi for num, smi in sorted(dec_smis.items())
                 ]
                 out_rows.append(
                     ps.Row(scaffold=scaff_smi,
                            decorations=dec_smis,
                            smiles=uc.to_smiles(mol),
                            cuts=cuts))
     return out_rows
示例#11
0
def run_sql_script(spark_session, script):
  """Runs a SQL script, returns a pyspark.sql.Row with its duration."""

  # Read script from object storage using rdd API
  query = load_file(spark_session, script)

  try:
    logging.info('Running %s', script)
    start = time.time()
    # spark-sql does not limit its output. Replicate that here by setting
    # limit to max Java Integer. Hopefully you limited the output in SQL or
    # you are going to have a bad time. Note this is not true of all TPC-DS or
    # TPC-H queries and they may crash with small JVMs.
    # pylint: disable=protected-access
    df = spark_session.sql(query)
    df.show(spark_session._jvm.java.lang.Integer.MAX_VALUE)
    # pylint: enable=protected-access
    duration = time.time() - start
    return sql.Row(script=script, duration=duration)
  # These correspond to errors in low level Spark Excecution.
  # Let ParseException and AnalysisException fail the job.
  except (sql.utils.QueryExecutionException,
          py4j.protocol.Py4JJavaError) as e:
    logging.error('Script %s failed', script, exc_info=e)
 def _read_rows(row):
     idx, _, dec = row.split("\t")
     return ps.Row(id=idx, decoration_smi=dec)
示例#13
0
def test_remote_java(sqlCtx):
    sqlCtx.createDataFrame([sql.Row(x=1), sql.Row(x=2), sql.Row(x=3)])
 def _read_rows(row):
     scaff, dec = row.split("\t")
     return ps.Row(randomized_scaffold=scaff, decoration_smi=dec)
sc = SparkContext(appName="DF spark vs RDD spark")
sc.setLogLevel('WARN')
sqlContext = SQLContext(sc)

file_csv = "file.csv"
df_spark = sqlContext.read.format("csv").options(
    header='true', inferschema='true').load(file_csv)
rdd_spark = sqlContext.read.format("csv").options(
    header='true', inferschema='true').load(file_csv).rdd

t = time.time()

# where
df_spark_tmp = df_spark.where("price > 100")
rdd_spark_tmp = rdd_spark.filter(lambda l: l.price > 100)
print(df_spark_tmp.count())
print(rdd_spark_tmp.count())

# Add column
rdd_spark_tmp = rdd_spark.map(
    lambda row: sql.Row(newPrice=row.price * 2, **row.asDict()))
print(df_spark_tmp.count())
print(rdd_spark_tmp.first())

print("sec : %s " % (time.time() - t))

rdd = rdd_spark_tmp.map(lambda (x): (
    (x.hotel_id, x.room_id, x.checkin, x.date_extract),
    (x.price, 1))).reduceByKey(lambda x, y: (x[0] + y[0], x[1] + y[1]))
print(rdd.take(20))