def __determine_candidate_dataframe(self):
        """
        Determine the candidate of each tweet, add two columns, biden and trump,
        contains boolean value indicate whether the tweet related to each candidate
        """

        # Register the user-defined function
        udf_function_biden = udf(
            lambda col: ('joe' in col.lower()) or ('biden' in col.lower()),
            BooleanType())
        udf_function_trump = udf(
            lambda col: ('donald' in col.lower()) or ('trump' in col.lower()),
            BooleanType())
        # Apply the udf to the column 'text' and get the 'biden', 'trump' column
        self.tweets_dataframe = self.tweets_dataframe \
                                    .withColumn('biden', udf_function_biden(col('text'))) \
                                    .withColumn('trump', udf_function_trump(col('text')))
        return self
示例#2
0
def get_columns(df, col_type="relevant"):
    cols = []
    dfcols = list(df.columns)
    if col_type == "relevant":
        subs_to_check = ['time', 'location', 'passenger', 'distance',
                         'ratecode', 'fare', "longitude", "latitude"]
        for sub in subs_to_check:
            for col in dfcols:
                if sub.lower() in col.lower():
                    cols.append(col)

    elif col_type == "geolocation":
        subs_to_check = ["location", "longitude", "latitude"]
        for sub in subs_to_check:
            for col in dfcols:
                if sub.lower() in col.lower():
                    cols.append(col)
    return cols
 def run(self, database=None, churn_table=None):
     """
     Load data, and run classification
     :param database: string (name of Hive database)
     :param churn_table: string (name of Hive table)
     :return: pyspark.sql.DataFrame with list of partners and their level of churn risk
     """
     if self.hive:
         cust_at_risk = pd.read_csv("data/DD_CUSTATRISK_WKYEAR_V.csv")
         cust_at_risk.rename(
             columns={col: col.lower()
                      for col in cust_at_risk.columns},
             inplace=True)
         cust_at_risk = self.sql_context.createDataFrame(
             cust_at_risk.dropna())
         return self._run_spark(cust_at_risk)
     else:
         cust_at_risk = load_data_from_hive(self.sql_context, database,
                                            churn_table)
         return self._run_spark(cust_at_risk)

def get_nested_keys(a):
    key_list = []
    for i in a.keys():
        b = a[i]
        if "{" in str(b):
            key_list = key_list + [i]
    return (key_list)


ignore_list_1 = [
    'entitylogicalname', 'SinkCreatedOn', 'SinkModifiedOn', 'messageid',
    'sourcesystem', 'Id', 'entitydata'
]
ignore_list = [col.lower() for col in ignore_list_1]

sqlContext.sql("use dsc60263_fsm_tz_db")

for m in sqlContext.tables("dsc60263_fsm_tz_db").select('tableName').where(
        "tableName not like '%dup%'").where(
            "tableName not like '%bkp%'").where(
                "tableName not like '%temptz%'").collect():
    k = m.asDict().values()[0].encode('utf-8')
    v1 = sqlContext.table(k).filter(
        "to_date(lastupdatedatetime)  = date_sub(CAST(current_timestamp() as DATE), 1)"
    )
    y1 = v1.select("message__id").distinct()
    if y1.count() > 0:
        v2 = list(v1.select("message__id").toPandas()["message__id"])
        for msg_id in v2:
示例#5
0
 def _col_in_df(self, col: str, df: DataFrame):
     lower_df_cols = [c.lower() for c in df.columns]
     return col.lower() in lower_df_cols
    for i in a.keys():
        b = a[i]
        if "{" in str(b):
            key_list = key_list + [i]
    return (key_list)


date_to_investigate = datetime.strptime(sys.argv[1], '%Y-%m-%d')
today_ = datetime.strptime(datetime.today().strftime('%Y-%m-%d'), '%Y-%m-%d')
days_to_go_back = abs((today_ - date_to_investigate).days)

ignore_list_1 = [
    'entitylogicalname', 'SinkCreatedOn', 'SinkModifiedOn', 'messageid',
    'sourcesystem', 'Id', 'entitydata'
]
ignore_list = [col.lower() for col in ignore_list_1]
g1 = 0
g2 = 0
sqlContext.sql("use dsc10742_gcctmsd_lz_db")

tables_gcct_except_bad_rec = [
    'account', 'activityparty', 'annotation', 'appointment', 'businessunit',
    'calendar', 'calendarrule', 'contact', 'cxlvhlp_chatactivity',
    'cxlvhlp_chatqueuestatistic', 'cxlvhlp_surveyitem', 'email', 'fax',
    'gcct_accountresponsibleagent', 'gcct_additionalsymptomcodes',
    'gcct_addportalmessage', 'gcct_arbitrationclaimprocessing',
    'gcct_buybackevaluationmilestones', 'gcct_caseassignment',
    'gcct_caseclassification', 'gcct_casedispositiontype', 'gcct_coachback',
    'gcct_country', 'gcct_customersatisfactiontools',
    'gcct_delegationofauthority', 'gcct_demandltrtpsmclaimsprocessing',
    'gcct_doaprogramcode', 'gcct_documentcustomerrecontact', 'gcct_engine',
示例#7
0
def spot_variance(expected_dataframe, available_dataframe, primary_key):
    """
    Function to find column level variance between two dataframes.
    # Assumption:
    #  1. All columns in first dataframe should be present in second dataframe
    #  2. Such common columns must have same column names on both dataframes
    #  3. Second dataframe could contain extra columns or records, they will be ignored

    Parameters:
        Name: expected_dataframe
        Type:    pyspark.sql.dataframe.DataFrame
        Name: available_dataframe
        Type:    pyspark.sql.dataframe.DataFrame
        Name: primary_key
        Type:    tuple or list

    Return Type:
        pyspark.sql.dataframe.DataFrame
        Format:

    +------------------------+---------+------------------------+-----------------------------+
    |  <<primary_column_1>>  |  . . .  |  <<primary_column_N>>  |  EFFULGE_VARIANCE_PROVOKER  |
    +------------------------+---------+------------------------+-----------------------------+
    |                        |         |                        |                             |
    |                        |         |                        |                             |
    +------------------------+---------+------------------------+-----------------------------+

        EFFULGE_VARIANCE_PROVOKER
            Type:    List of Strings
            Values:  Column Names, Constant Message
                Constant Message can be -
                    "MISSING_PRIMARY_KEY"
                        => when 'expected_dataframe' contains a primary key value
                        => but 'available_dataframe' does not have corresponding match
                    "DUPLICATE_PRIMARY_KEY"
                        => when 'expected_dataframe' contains a unique primary key value
                        => but 'available_dataframe' has many matches for the same primary key value

    Throws Exception:
        - if any input parameter value is not of the acceptable type
        - if 'primary_key' parameter is an empty tuple
        - if any column present in 'expected_dataframe' is not present in 'available_dataframe'
        - if 'expected_dataframe' is empty
        - if 'primary_key' does not prove uniqueness for 'expected_dataframe'
    """
    # To Do list:
    #
    # validate parameter types                 :: Done
    # handle empty parameter values            :: Done
    # check all columns exist                  :: Done
    # strict check given primary key is valid  :: Done
    # spot & report missing primary_key        :: Done
    # spot & report duplicated primary_key     :: Done
    # spot & report mismatching attributes     :: Done
    # cache if necessary                       :: Done
    # repartition if necessary                 :: Done
    # have a way to clean up resource          :: Done

    # validate input parameter types
    _check_instance_type(expected_dataframe, DataFrame, 'expected_dataframe')
    _check_instance_type(available_dataframe, DataFrame, 'available_dataframe')
    _check_instance_type(primary_key, (tuple, list), 'primary_key')

    # handle empty parameter
    if not primary_key:
        raise Exception("Please provide valid non-empty tuple as Primary Key")

    # list down primary and non primary attributes
    primary_attributes = [a.lower() for a in primary_key]
    non_primary_attributes = [
        col.lower() for col in expected_dataframe.columns
        if col.lower() not in primary_attributes
    ]

    # check all columns exist
    # all columns on expected_dataframe, must be present in available_dataframe
    missing_attributes = set(
        (*primary_attributes, *non_primary_attributes)).difference(
            {c.lower()
             for c in available_dataframe.columns})
    if len(missing_attributes) > 0:
        raise Exception("Few attributes are missing in second data frame : {}"\
                            .format(missing_attributes))

    # repartition both data frames on primary keys
    expected_dataframe = _repartition_df_with_keys(expected_dataframe,
                                                   primary_attributes)
    available_dataframe = _repartition_df_with_keys(available_dataframe,
                                                    primary_attributes)

    # cache the data frames, if it is not cached before
    uncache_list = [
    ]  # to clear only those data frames that were cached by this function
    for d_f in (expected_dataframe, available_dataframe):
        if not d_f.is_cached:
            d_f.persist()
            uncache_list.append(d_f)

    # handle empty parameter
    if not expected_dataframe.count():
        raise Exception("Input 'expected_dataframe' can not be empty")

    # strict check given primary key is valid
    if expected_dataframe.groupBy(
            *primary_attributes).count().where("count > 1").count():
        raise Exception(
            "Given primary key is not unique for the first data frame. \
                         Please retry with valid primary key")

    # spot & report missing primary key
    df_missing_primary_key = _spot_missing_primary_key(expected_dataframe,
                                                       available_dataframe,
                                                       primary_attributes)

    # spot & report duplicated primary key
    df_duplicated_primary_key = _spot_duplicated_primary_key(
        expected_dataframe, available_dataframe, primary_attributes)

    temp_view_list = [
    ]  # to clear only the temp views that were created by this function

    # detect variance
    if len(non_primary_attributes) > 0:
        # when few non primary columns exists, then explicitly identify their variances
        #
        # create temporary views
        expected_dataframe.createOrReplaceTempView("effulge_expected_view")
        temp_view_list.append("effulge_expected_view")
        available_dataframe.createOrReplaceTempView("effulge_available_view")
        temp_view_list.append("effulge_available_view")

        # spot & report mismatching attributes
        df_variance = _spot_mismatch_variance("effulge_expected_view",
                                              "effulge_available_view",
                                              primary_attributes,
                                              non_primary_attributes)
    else:
        # when non primary columns do not exists,
        # i.e, we only have primary columns,
        # then the variances are caught implicity with MISSING_PRIMARY_KEY check
        #
        df_variance = _get_empty_result_df(expected_dataframe,
                                           primary_attributes)
    #
    #
    # merge all variances
    effulge_variance_dataframe = df_variance.union(df_missing_primary_key)\
                                             .union(df_duplicated_primary_key)\
                                             .repartition(*primary_attributes)

    # persist output dataframe
    effulge_variance_dataframe.persist()
    effulge_variance_dataframe.count()

    # clear cache and temp views
    _clean_cache_and_view(uncache_list, temp_view_list)

    # Return the output dataframe
    return effulge_variance_dataframe
def reNameColumns(cols):
    reNamedCols = []
    for col in cols:
        # Repalce all special symbols,spaces with _
        reNamedCols.append(re.sub("[^a-zA-Z0-9]", "_", col.lower()))
    return reNamedCols
def clean():
    #====================================================================
    # Task: Pasar a minusculas los nombres de columnas
    #====================================================================
    meta_clean = []  # arreglo para reunir tuplas de metadatos
    df = get_raw_data()

    # Inicializa clase para reunir metadatos
    MiLinaje_clean = Linaje_clean_data()

    # Recolectamos fecha, usuario IP, nobre de task para metadatos
    MiLinaje_clean.fecha = datetime.now()
    MiLinaje_clean.nombre_task = "Colnames_to_lower"
    MiLinaje_clean.usuario = getpass.getuser()
    MiLinaje_clean.ip_ec2 = str(socket.gethostbyname(socket.gethostname()))

    MiLinaje_clean.variables_limpias = "All_from_raw data"

    counting_cols = 0

    for col in df.columns:
        counting_cols = counting_cols + 1
        df = df.withColumnRenamed(col, col.lower())

    # Metadadatos de columnas o registros modificados
    MiLinaje_clean.num_columnas_modificadas = counting_cols
    MiLinaje_clean.variables_limpias = counting_cols

    # Subimos los metadatos al RDS
    #clean_metadata_rds(MiLinaje_clean.to_upsert())
    meta_clean.append(MiLinaje_clean.to_upsert())

    #====================================================================
    # Task: Seleccionar columnas no vacias
    #====================================================================

    # Inicializa clase para reunir metadatos
    Mi_Linaje_clean = Linaje_clean_data()

    # Recolectamos fecha, usuario IP, nobre de task para metadatos
    MiLinaje_clean.fecha = datetime.now()
    MiLinaje_clean.nombre_task = "Colnames_selection"
    MiLinaje_clean.usuario = getpass.getuser()
    MiLinaje_clean.ip_ec2 = str(socket.gethostbyname(socket.gethostname()))

    # Seleccion de columnas
    n0 = len(df.columns)

    base = df.select(
        df.year, df.quarter, df.month, df.dayofmonth, df.dayofweek,
        df.flightdate, df.reporting_airline, df.dot_id_reporting_airline,
        df.iata_code_reporting_airline, df.tail_number,
        df.flight_number_reporting_airline, df.originairportid,
        df.originairportseqid, df.origincitymarketid, df.origin,
        df.origincityname, df.originstate, df.originstatefips,
        df.originstatename, df.originwac, df.destairportid,
        df.destairportseqid, df.destcitymarketid, df.dest, df.destcityname,
        df.deststate, df.deststatefips, df.deststatename, df.destwac,
        df.crsdeptime, df.deptime, df.depdelay, df.depdelayminutes,
        df.depdel15, df.departuredelaygroups, df.deptimeblk, df.taxiout,
        df.wheelsoff, df.wheelson, df.taxiin, df.crsarrtime, df.arrtime,
        df.arrdelay, df.arrdelayminutes, df.arrdel15, df.arrivaldelaygroups,
        df.arrtimeblk, df.cancelled, df.diverted, df.crselapsedtime,
        df.actualelapsedtime, df.airtime, df.flights, df.distance,
        df.distancegroup, df.divairportlandings)

    n1 = len(base.columns)

    # Metadadatos de columas o registros modificados
    MiLinaje_clean.num_columnas_modificadas = n1 - n0
    MiLinaje_clean.variables_limpias = "year,quarter, month, dayofmonth, dayofweek,\
     flightdate, reporting_airline, dot_id_reporting_airline, iata_code_reporting_airline,\
     tail_number, flight_number_reporting_airline, originairportid, originairportseqid,\
     origincitymarketid, origin, origincityname, originstate, originstatefips, originstatename,\
     originwac, destairportid, destairportseqid, destcitymarketid, dest, destcityname, deststate,\
     deststatefips, deststatename, destwac, crsdeptime, deptime, depdelay, depdelayminutes,\
     depdel15, departuredelaygroups, deptimeblk, taxiout, wheelsoff, wheelson, taxiin, crsarrtime,\
     arrtime, arrdelay, arrdelayminutes, arrdel15, arrivaldelaygroups, arrtimeblk, cancelled,\
     diverted, crselapsedtime, actualelapsedtime, airtime, flights, distance, distancegroup,\
     divairportlandings"

    # Subimos los metadatos al RDS
    #clean_metadata_rds(MiLinaje_clean.to_upsert())
    meta_clean.append(MiLinaje_clean.to_upsert())

    #========================================================================================================
    # agregar columna con clasificación de tiempo en horas de atraso del vuelo 0-1.5, 1.5-3.5,3.5-, cancelled
    #========================================================================================================

    # Inicializa clase para reunir metadatos
    Mi_Linaje_clean = Linaje_clean_data()

    # Recolectamos fecha, usuario IP, nobre de task para metadatos
    MiLinaje_clean.fecha = datetime.now()
    MiLinaje_clean.nombre_task = "creation_of_categories"
    MiLinaje_clean.usuario = getpass.getuser()
    MiLinaje_clean.ip_ec2 = str(socket.gethostbyname(socket.gethostname()))

    from pyspark.sql import functions as f

    # Seleccion de columnas
    n0 = len(df.columns)

    base = base.withColumn(
        'rangoatrasohoras',
        f.when(f.col('cancelled') == 1,
               "cancelled").when(f.col('depdelayminutes') < 90, "0-1.5").when(
                   (f.col('depdelayminutes') > 90) &
                   (f.col('depdelayminutes') < 210),
                   "1.5-3.5").otherwise("3.5-"))

    n1 = len(base.columns)

    # Metadadatos de columas o registros modificados
    MiLinaje_clean.num_columnas_modificadas = n1 - n0

    # Metadadatos de columas o registros modificados
    MiLinaje_clean.num_filas_modificadas = df.count()

    MiLinaje_clean.variables_limpias = "year,quarter, month, dayofmonth, dayofweek,\
         flightdate, reporting_airline, dot_id_reporting_airline, iata_code_reporting_airline,\
         tail_number, flight_number_reporting_airline, originairportid, originairportseqid,\
         origincitymarketid, origin, origincityname, originstate, originstatefips, originstatename,\
         originwac, destairportid, destairportseqid, destcitymarketid, dest, destcityname, deststate,\
         deststatefips, deststatename, destwac, crsdeptime, deptime, depdelay, depdelayminutes,\
         depdel15, departuredelaygroups, deptimeblk, taxiout, wheelsoff, wheelson, taxiin, crsarrtime,\
         arrtime, arrdelay, arrdelayminutes, arrdel15, arrivaldelaygroups, arrtimeblk, cancelled,\
         diverted, crselapsedtime, actualelapsedtime, airtime, flights, distance, distancegroup,\
         divairportlandings,rangoatrasohoras,cancelled,0-1.5,1.5-3.5,3.5-"

    # Subimos los metadatos al RDS
    #clean_metadata_rds(MiLinaje_clean.to_upsert())
    meta_clean.append(MiLinaje_clean.to_upsert())

    #===================================================================
    # Aplicación de la función limpieza texto
    #===================================================================

    # Función limpiar texto: minúsculas, espacios por guiones, split
    from pyspark.sql.functions import udf
    from pyspark.sql.types import StringType
    from pyspark.sql.functions import col, lower, regexp_replace, split

    def clean_text(c):
        c = lower(c)
        c = regexp_replace(c, " ", "_")
        c = f.split(c, '\,')[0]
        return c

    # Inicializa clase para reunir metadatos
    Mi_Linaje_clean = Linaje_clean_data()

    # Recolectamos fecha, usuario IP, nobre de task para metadatos
    MiLinaje_clean.fecha = datetime.now()
    MiLinaje_clean.nombre_task = "cleaning_text_spaces_and_others"
    MiLinaje_clean.usuario = getpass.getuser()
    MiLinaje_clean.ip_ec2 = str(socket.gethostbyname(socket.gethostname()))

    string_cols = [
        item[0] for item in base.dtypes if item[1].startswith('string')
    ]
    for x in string_cols:
        base = base.withColumn(x, clean_text(col(x)))

    # Metadadatos de columas o registros modificados
    MiLinaje_clean.num_filas_modificadas = df.count()
    MiLinaje_clean.num_columnas_modificadas = len(df.columns)

    MiLinaje_clean.variables_limpias = "year,quarter, month, dayofmonth, dayofweek,\
             flightdate, reporting_airline, dot_id_reporting_airline, iata_code_reporting_airline,\
             tail_number, flight_number_reporting_airline, originairportid, originairportseqid,\
             origincitymarketid, origin, origincityname, originstate, originstatefips, originstatename,\
             originwac, destairportid, destairportseqid, destcitymarketid, dest, destcityname, deststate,\
             deststatefips, deststatename, destwac, crsdeptime, deptime, depdelay, depdelayminutes,\
             depdel15, departuredelaygroups, deptimeblk, taxiout, wheelsoff, wheelson, taxiin, crsarrtime,\
             arrtime, arrdelay, arrdelayminutes, arrdel15, arrivaldelaygroups, arrtimeblk, cancelled,\
             diverted, crselapsedtime, actualelapsedtime, airtime, flights, distance, distancegroup,\
             divairportlandings,rangoatrasohoras,cancelled,0-1.5,1.5-3.5,3.5-"

    # Subimos los metadatos al RDS
    #clean_metadata_rds(MiLinaje_clean.to_upsert())
    meta_clean.append(MiLinaje_clean.to_upsert())

    datos_clean = pd.DataFrame(meta_clean, columns=["fecha",\
    "nombre_task","usuario","ip_ec2","num_columnas_modificadas","num_filas_modificadas",\
    "variables_limpias","task_status"])
    datos_clean.to_csv("metadata/clean_metadata.csv",
                       index=False,
                       header=False)

    base.show(2)
    print((base.count(), len(base.columns)))
    # Guardamos los DATOS
    save_rds(base, "clean.rita")
    return base
示例#10
0
def processdata():

    githubdata = readgithubdata()

    #Handle latitude column name format and inset column if not present
    if 'latitude' in githubdata.columns:
        pass
    elif 'Latitude' in githubdata.columns:
        githubdata = githubdata.withColumnRenamed('Latitude', 'latitude')
    elif 'Lat' in githubdata.columns:
        githubdata = githubdata.withColumnRenamed('Lat', 'latitude')
    else:
        githubdata = githubdata.withColumn('latitude',
                                           f.lit(None).cast(StringType()))

    #Handle longitude column name format and inset column if not present
    if 'longitude' in githubdata.columns:
        pass
    elif 'Long_' in githubdata.columns:
        githubdata = githubdata.withColumnRenamed('Long_', 'longitude')
    elif 'Longitude' in githubdata.columns:
        githubdata = githubdata.withColumnRenamed('Longitude', 'longitude')
    else:
        githubdata = githubdata.withColumn('longitude',
                                           f.lit(None).cast(StringType()))

    #Insert columns if not present
    columns = ['FIPS', 'Admin2', 'Combined_Key', 'Active']
    for i in columns:
        if has_column(githubdata, i) == False:
            githubdata = githubdata.withColumn(i, lit(None).cast(StringType()))

    #Rename Columns
    for col in githubdata.columns:
        githubdata = githubdata.withColumnRenamed(
            col,
            col.lower().replace(' ', '_').replace('/', '_'))

    #Drop Columns
    drop_lst = ['latitude', 'longitude']
    df = githubdata.drop(*drop_lst)
    df = df.select('fips', 'combined_key', 'province_state', 'country_region',
                   'last_update', 'confirmed', 'recovered', 'deaths', 'active')
    print("Showing GITHUB DATA SCHEMA after dropping columns")
    df.printSchema()
    df.show(5)

    #insert month column and format date
    df = df.withColumn("date", formatdate("last_update"))
    df = df.withColumn('date', regexp_replace('date', "00", "20"))
    df = df.withColumn('month', f.month(df.date))

    #Column value modification
    df = df.withColumn(
        'country_region',
        regexp_replace('country_region', "Mainland China", "China"))
    df = df.withColumn(
        'country_region',
        regexp_replace('country_region', "South Korea", "Korea, South"))
    df = df.withColumn('country_region',
                       regexp_replace('country_region', "Taiwan\*", "Taiwan*"))
    df = df.withColumn('combined_key',
                       regexp_replace('combined_key', "Taiwan\*", "Taiwan*"))
    #Replacing , or ,, in combined_key column
    df = df.withColumn('combined_key', regexp_replace('combined_key', "^,+",
                                                      ""))
    #Replacing strings like Hardin,Ohio,US to Hardin, Ohio, US
    df = df.withColumn('combined_key',
                       regexp_replace('combined_key', ",(?=\S)", ", "))
    #Replace strings that start with a whitespace
    df = df.withColumn('combined_key',
                       regexp_replace('combined_key', "^\s*", ""))
    #Replace Diamon Princess in combined_key
    df = df.withColumn(
        'combined_key',
        regexp_replace('combined_key', "Diamond Princess, Cruise Ship*",
                       "Diamond Princess"))

    df = df.withColumn('combined_key',
                       regexp_replace('combined_key', " County,", ","))
    df = df.withColumn(
        'combined_key',
        regexp_replace('combined_key', "^unassigned", "Unassigned"))
    df = df.withColumn('combined_key',
                       regexp_replace('combined_key', "^Unknown, ", ""))
    df = df.withColumn(
        'combined_key',
        regexp_replace('combined_key', "Doña Ana, New Mexico, US",
                       "Dona Ana, New Mexico, US"))
    df = df.withColumn(
        'combined_key',
        when(df['combined_key'].like('District of Columbia%'),
             'District of Columbia, US').otherwise(df['combined_key']))
    df = df.withColumn(
        'country_region',
        when(
            df['country_region'].like("%Diamond Princess%")
            | df['province_state'].like("Diamond%"),
            'Diamond Princess').otherwise(df['country_region']))
    df = df.withColumn(
        'province_state',
        when(
            df['country_region'].like("%Diamond Princess%")
            | df['province_state'].like("Diamond%"),
            None).otherwise(df['province_state']))
    df = df.withColumn(
        'country_region',
        when(df['country_region'] == 'Hong Kong',
             'China').otherwise(df['country_region']))
    df = df.withColumn(
        'country_region',
        when(df['country_region'] == 'Macau',
             'China').otherwise(df['country_region']))
    df = df.withColumn(
        'province_state',
        when(df['country_region'] == df['province_state'],
             None).otherwise(df['province_state']))

    print("COUNTING DF")
    print(df.count())
    return df