Ejemplo n.º 1
0
def get_dataset(dataset_type,
                data,
                schemas=None,
                autoinspect_func=autoinspect.columns_exist,
                caching=False):
    """For Pandas, data should be either a DataFrame or a dictionary that can
    be instantiated as a DataFrame.
    For SQL, data should have the following shape:
        {
            'table':
                'table': SqlAlchemy Table object
                named_column: [list of values]
        }

    """
    if dataset_type == 'PandasDataset':
        df = pd.DataFrame(data)
        if schemas and "pandas" in schemas:
            pandas_schema = {
                key: np.dtype(value)
                for (key, value) in schemas["pandas"].items()
            }
            df = df.astype(pandas_schema)
        return PandasDataset(df,
                             autoinspect_func=autoinspect_func,
                             caching=caching)
    elif dataset_type == 'SqlAlchemyDataset':
        # Create a new database

        # Try to use a local postgres instance (e.g. on Travis); this will allow more testing than sqlite
        try:
            engine = create_engine('postgresql://*****:*****@localhost/test_ci')
            conn = engine.connect()
        except SQLAlchemyError:
            warnings.warn("Falling back to sqlite database.")
            engine = create_engine('sqlite://')
            conn = engine.connect()

        # Add the data to the database as a new table
        df = pd.DataFrame(data)

        sql_dtypes = {}
        if schemas and "sqlite" in schemas and isinstance(
                engine.dialect, sqlitetypes.dialect):
            schema = schemas["sqlite"]
            sql_dtypes = {
                col: SQLITE_TYPES[dtype]
                for (col, dtype) in schema.items()
            }
            for col in schema:
                type = schema[col]
                if type == "int":
                    df[col] = pd.to_numeric(df[col], downcast='signed')
                elif type == "float":
                    df[col] = pd.to_numeric(df[col], downcast='float')
                elif type == "datetime":
                    df[col] = pd.to_datetime(df[col])
        elif schemas and "postgresql" in schemas and isinstance(
                engine.dialect, postgresqltypes.dialect):
            schema = schemas["postgresql"]
            sql_dtypes = {
                col: POSTGRESQL_TYPES[dtype]
                for (col, dtype) in schema.items()
            }
            for col in schema:
                type = schema[col]
                if type == "int":
                    df[col] = pd.to_numeric(df[col], downcast='signed')
                elif type == "float":
                    df[col] = pd.to_numeric(df[col], downcast='float')
                elif type == "timestamp":
                    df[col] = pd.to_datetime(df[col])

        tablename = "test_data_" + ''.join([
            random.choice(string.ascii_letters + string.digits)
            for n in range(8)
        ])
        df.to_sql(name=tablename, con=conn, index=False, dtype=sql_dtypes)

        # Build a SqlAlchemyDataset using that database
        return SqlAlchemyDataset(tablename,
                                 engine=conn,
                                 autoinspect_func=autoinspect_func,
                                 caching=caching)

    elif dataset_type == 'SparkDFDataset':
        spark = SparkSession.builder.getOrCreate()
        data_reshaped = list(zip(*[v for _, v in data.items()]))
        if schemas and 'spark' in schemas:
            schema = schemas['spark']
            # sometimes first method causes Spark to throw a TypeError
            try:
                spark_schema = sparktypes.StructType([
                    sparktypes.StructField(column,
                                           SPARK_TYPES[schema[column]]())
                    for column in schema
                ])
                spark_df = spark.createDataFrame(data_reshaped, spark_schema)
            except TypeError:
                string_schema = sparktypes.StructType([
                    sparktypes.StructField(column, sparktypes.StringType())
                    for column in schema
                ])
                spark_df = spark.createDataFrame(data_reshaped, string_schema)
                for c in spark_df.columns:
                    spark_df = spark_df.withColumn(
                        c, spark_df[c].cast(SPARK_TYPES[schema[c]]()))
        elif len(data_reshaped) == 0:
            # if we have an empty dataset and no schema, need to assign an arbitrary type
            columns = list(data.keys())
            spark_schema = sparktypes.StructType([
                sparktypes.StructField(column, sparktypes.StringType())
                for column in columns
            ])
            spark_df = spark.createDataFrame(data_reshaped, spark_schema)
        else:
            # if no schema provided, uses Spark's schema inference
            columns = list(data.keys())
            spark_df = spark.createDataFrame(data_reshaped, columns)
        return SparkDFDataset(spark_df, caching=caching)

    else:
        raise ValueError("Unknown dataset_type " + str(dataset_type))
Ejemplo n.º 2
0
def get_dataset(dataset_type,
                data,
                schemas=None,
                profiler=ColumnsExistProfiler,
                caching=True):
    """Utility to create datasets for json-formatted tests.
    """
    df = pd.DataFrame(data)
    if dataset_type == 'PandasDataset':
        if schemas and "pandas" in schemas:
            schema = schemas["pandas"]
            pandas_schema = {}
            for (key, value) in schema.items():
                # Note, these are just names used in our internal schemas to build datasets *for internal tests*
                # Further, some changes in pandas internal about how datetimes are created means to support pandas
                # pre- 0.25, we need to explicitly specify when we want timezone.

                # We will use timestamp for timezone-aware (UTC only) dates in our tests
                if value.lower() in ["timestamp", "datetime64[ns, tz]"]:
                    df[key] = pd.to_datetime(df[key], utc=True)
                    continue
                elif value.lower() in [
                        "datetime", "datetime64", "datetime64[ns]"
                ]:
                    df[key] = pd.to_datetime(df[key])
                    continue
                try:
                    type_ = np.dtype(value)
                except TypeError:
                    type_ = getattr(pd.core.dtypes.dtypes, value)
                    # If this raises AttributeError it's okay: it means someone built a bad test
                pandas_schema[key] = type_
            # pandas_schema = {key: np.dtype(value) for (key, value) in schemas["pandas"].items()}
            df = df.astype(pandas_schema)
        return PandasDataset(df, profiler=profiler, caching=caching)

    elif dataset_type == "sqlite":
        from sqlalchemy import create_engine
        engine = create_engine('sqlite://')
        conn = engine.connect()
        # Add the data to the database as a new table

        sql_dtypes = {}
        if schemas and "sqlite" in schemas and isinstance(
                engine.dialect, sqlitetypes.dialect):
            schema = schemas["sqlite"]
            sql_dtypes = {
                col: SQLITE_TYPES[dtype]
                for (col, dtype) in schema.items()
            }
            for col in schema:
                type_ = schema[col]
                if type_ in ["INTEGER", "SMALLINT", "BIGINT"]:
                    df[col] = pd.to_numeric(df[col], downcast='signed')
                elif type_ in ["FLOAT", "DOUBLE", "DOUBLE_PRECISION"]:
                    df[col] = pd.to_numeric(df[col])
                elif type_ in ["DATETIME", "TIMESTAMP"]:
                    df[col] = pd.to_datetime(df[col])

        tablename = "test_data_" + ''.join([
            random.choice(string.ascii_letters + string.digits)
            for n in range(8)
        ])
        df.to_sql(name=tablename, con=conn, index=False, dtype=sql_dtypes)

        # Build a SqlAlchemyDataset using that database
        return SqlAlchemyDataset(tablename,
                                 engine=conn,
                                 profiler=profiler,
                                 caching=caching)

    elif dataset_type == 'postgresql':
        from sqlalchemy import create_engine
        # Create a new database
        engine = create_engine('postgresql://postgres@localhost/test_ci')
        conn = engine.connect()

        sql_dtypes = {}
        if schemas and "postgresql" in schemas and isinstance(
                engine.dialect, postgresqltypes.dialect):
            schema = schemas["postgresql"]
            sql_dtypes = {
                col: POSTGRESQL_TYPES[dtype]
                for (col, dtype) in schema.items()
            }
            for col in schema:
                type_ = schema[col]
                if type_ in ["INTEGER", "SMALLINT", "BIGINT"]:
                    df[col] = pd.to_numeric(df[col], downcast='signed')
                elif type_ in ["FLOAT", "DOUBLE", "DOUBLE_PRECISION"]:
                    df[col] = pd.to_numeric(df[col])
                elif type_ in ["DATETIME", "TIMESTAMP"]:
                    df[col] = pd.to_datetime(df[col])

        tablename = "test_data_" + ''.join([
            random.choice(string.ascii_letters + string.digits)
            for n in range(8)
        ])
        df.to_sql(name=tablename, con=conn, index=False, dtype=sql_dtypes)

        # Build a SqlAlchemyDataset using that database
        return SqlAlchemyDataset(tablename,
                                 engine=conn,
                                 profiler=profiler,
                                 caching=caching)

    elif dataset_type == 'mysql':
        engine = create_engine('mysql://root@localhost/test_ci')
        conn = engine.connect()

        sql_dtypes = {}
        if schemas and "mysql" in schemas and isinstance(
                engine.dialect, mysqltypes.dialect):
            schema = schemas["mysql"]
            sql_dtypes = {
                col: MYSQL_TYPES[dtype]
                for (col, dtype) in schema.items()
            }
            for col in schema:
                type_ = schema[col]
                if type_ in ["INTEGER", "SMALLINT", "BIGINT"]:
                    df[col] = pd.to_numeric(df[col], downcast='signed')
                elif type_ in ["FLOAT", "DOUBLE", "DOUBLE_PRECISION"]:
                    df[col] = pd.to_numeric(df[col])
                elif type_ in ["DATETIME", "TIMESTAMP"]:
                    df[col] = pd.to_datetime(df[col])

        tablename = "test_data_" + ''.join([
            random.choice(string.ascii_letters + string.digits)
            for n in range(8)
        ])
        df.to_sql(name=tablename, con=conn, index=False, dtype=sql_dtypes)

        # Build a SqlAlchemyDataset using that database
        return SqlAlchemyDataset(tablename,
                                 engine=conn,
                                 profiler=profiler,
                                 caching=caching)

    elif dataset_type == 'SparkDFDataset':
        from pyspark.sql import SparkSession
        import pyspark.sql.types as sparktypes

        SPARK_TYPES = {
            "StringType": sparktypes.StringType,
            "IntegerType": sparktypes.IntegerType,
            "LongType": sparktypes.LongType,
            "DateType": sparktypes.DateType,
            "TimestampType": sparktypes.TimestampType,
            "FloatType": sparktypes.FloatType,
            "DoubleType": sparktypes.DoubleType,
            "BooleanType": sparktypes.BooleanType,
            "DataType": sparktypes.DataType,
            "NullType": sparktypes.NullType
        }

        spark = SparkSession.builder.getOrCreate()
        # We need to allow null values in some column types that do not support them natively, so we skip
        # use of df in this case.
        data_reshaped = list(
            zip(*[v for _, v in data.items()]))  # create a list of rows
        if schemas and 'spark' in schemas:
            schema = schemas['spark']
            # sometimes first method causes Spark to throw a TypeError
            try:
                spark_schema = sparktypes.StructType([
                    sparktypes.StructField(column,
                                           SPARK_TYPES[schema[column]](), True)
                    for column in schema
                ])
                # We create these every time, which is painful for testing
                # However nuance around null treatment as well as the desire
                # for real datetime support in tests makes this necessary
                data = copy.deepcopy(data)
                if "ts" in data:
                    print(data)
                    print(schema)
                for col in schema:
                    type_ = schema[col]
                    if type_ in ["IntegerType", "LongType"]:
                        # Ints cannot be None...but None can be valid in Spark (as Null)
                        vals = []
                        for val in data[col]:
                            if val is None:
                                vals.append(val)
                            else:
                                vals.append(int(val))
                        data[col] = vals
                    elif type_ in ["FloatType", "DoubleType"]:
                        vals = []
                        for val in data[col]:
                            if val is None:
                                vals.append(val)
                            else:
                                vals.append(float(val))
                        data[col] = vals
                    elif type_ in ["DateType", "TimestampType"]:
                        vals = []
                        for val in data[col]:
                            if val is None:
                                vals.append(val)
                            else:
                                vals.append(parse(val))
                        data[col] = vals
                # Do this again, now that we have done type conversion using the provided schema
                data_reshaped = list(
                    zip(*[v
                          for _, v in data.items()]))  # create a list of rows
                spark_df = spark.createDataFrame(data_reshaped,
                                                 schema=spark_schema)
            except TypeError:
                string_schema = sparktypes.StructType([
                    sparktypes.StructField(column, sparktypes.StringType())
                    for column in schema
                ])
                spark_df = spark.createDataFrame(data_reshaped, string_schema)
                for c in spark_df.columns:
                    spark_df = spark_df.withColumn(
                        c, spark_df[c].cast(SPARK_TYPES[schema[c]]()))
        elif len(data_reshaped) == 0:
            # if we have an empty dataset and no schema, need to assign an arbitrary type
            columns = list(data.keys())
            spark_schema = sparktypes.StructType([
                sparktypes.StructField(column, sparktypes.StringType())
                for column in columns
            ])
            spark_df = spark.createDataFrame(data_reshaped, spark_schema)
        else:
            # if no schema provided, uses Spark's schema inference
            columns = list(data.keys())
            spark_df = spark.createDataFrame(data_reshaped, columns)
        return SparkDFDataset(spark_df, profiler=profiler, caching=caching)

    else:
        raise ValueError("Unknown dataset_type " + str(dataset_type))
Ejemplo n.º 3
0
def get_dataset(
    dataset_type,
    data,
    schemas=None,
    profiler=ColumnsExistProfiler,
    caching=True,
    table_name=None,
    sqlite_db_path=None,
):
    """Utility to create datasets for json-formatted tests.
    """
    df = pd.DataFrame(data)
    if dataset_type == "PandasDataset":
        if schemas and "pandas" in schemas:
            schema = schemas["pandas"]
            pandas_schema = {}
            for (key, value) in schema.items():
                # Note, these are just names used in our internal schemas to build datasets *for internal tests*
                # Further, some changes in pandas internal about how datetimes are created means to support pandas
                # pre- 0.25, we need to explicitly specify when we want timezone.

                # We will use timestamp for timezone-aware (UTC only) dates in our tests
                if value.lower() in ["timestamp", "datetime64[ns, tz]"]:
                    df[key] = pd.to_datetime(df[key], utc=True)
                    continue
                elif value.lower() in [
                        "datetime", "datetime64", "datetime64[ns]"
                ]:
                    df[key] = pd.to_datetime(df[key])
                    continue
                try:
                    type_ = np.dtype(value)
                except TypeError:
                    type_ = getattr(pd.core.dtypes.dtypes, value)
                    # If this raises AttributeError it's okay: it means someone built a bad test
                pandas_schema[key] = type_
            # pandas_schema = {key: np.dtype(value) for (key, value) in schemas["pandas"].items()}
            df = df.astype(pandas_schema)
        return PandasDataset(df, profiler=profiler, caching=caching)

    elif dataset_type == "sqlite":
        if not create_engine:
            return None

        if sqlite_db_path is not None:
            engine = create_engine(f"sqlite:////{sqlite_db_path}")
        else:
            engine = create_engine("sqlite://")
        conn = engine.connect()
        # Add the data to the database as a new table

        sql_dtypes = {}
        if (schemas and "sqlite" in schemas
                and isinstance(engine.dialect, sqlitetypes.dialect)):
            schema = schemas["sqlite"]
            sql_dtypes = {
                col: SQLITE_TYPES[dtype]
                for (col, dtype) in schema.items()
            }
            for col in schema:
                type_ = schema[col]
                if type_ in ["INTEGER", "SMALLINT", "BIGINT"]:
                    df[col] = pd.to_numeric(df[col], downcast="signed")
                elif type_ in ["FLOAT", "DOUBLE", "DOUBLE_PRECISION"]:
                    df[col] = pd.to_numeric(df[col])
                    min_value_dbms = get_sql_dialect_floating_point_infinity_value(
                        schema=dataset_type, negative=True)
                    max_value_dbms = get_sql_dialect_floating_point_infinity_value(
                        schema=dataset_type, negative=False)
                    for api_schema_type in ["api_np", "api_cast"]:
                        min_value_api = get_sql_dialect_floating_point_infinity_value(
                            schema=api_schema_type, negative=True)
                        max_value_api = get_sql_dialect_floating_point_infinity_value(
                            schema=api_schema_type, negative=False)
                        df.replace(
                            to_replace=[min_value_api, max_value_api],
                            value=[min_value_dbms, max_value_dbms],
                            inplace=True,
                        )
                elif type_ in ["DATETIME", "TIMESTAMP"]:
                    df[col] = pd.to_datetime(df[col])

        if table_name is None:
            table_name = "test_data_" + "".join([
                random.choice(string.ascii_letters + string.digits)
                for _ in range(8)
            ])
        df.to_sql(
            name=table_name,
            con=conn,
            index=False,
            dtype=sql_dtypes,
            if_exists="replace",
        )

        # Build a SqlAlchemyDataset using that database
        return SqlAlchemyDataset(table_name,
                                 engine=conn,
                                 profiler=profiler,
                                 caching=caching)

    elif dataset_type == "postgresql":
        if not create_engine:
            return None

        # Create a new database
        engine = create_engine("postgresql://postgres@localhost/test_ci")
        conn = engine.connect()

        sql_dtypes = {}
        if (schemas and "postgresql" in schemas
                and isinstance(engine.dialect, postgresqltypes.dialect)):
            schema = schemas["postgresql"]
            sql_dtypes = {
                col: POSTGRESQL_TYPES[dtype]
                for (col, dtype) in schema.items()
            }
            for col in schema:
                type_ = schema[col]
                if type_ in ["INTEGER", "SMALLINT", "BIGINT"]:
                    df[col] = pd.to_numeric(df[col], downcast="signed")
                elif type_ in ["FLOAT", "DOUBLE", "DOUBLE_PRECISION"]:
                    df[col] = pd.to_numeric(df[col])
                    min_value_dbms = get_sql_dialect_floating_point_infinity_value(
                        schema=dataset_type, negative=True)
                    max_value_dbms = get_sql_dialect_floating_point_infinity_value(
                        schema=dataset_type, negative=False)
                    for api_schema_type in ["api_np", "api_cast"]:
                        min_value_api = get_sql_dialect_floating_point_infinity_value(
                            schema=api_schema_type, negative=True)
                        max_value_api = get_sql_dialect_floating_point_infinity_value(
                            schema=api_schema_type, negative=False)
                        df.replace(
                            to_replace=[min_value_api, max_value_api],
                            value=[min_value_dbms, max_value_dbms],
                            inplace=True,
                        )
                elif type_ in ["DATETIME", "TIMESTAMP"]:
                    df[col] = pd.to_datetime(df[col])

        if table_name is None:
            table_name = "test_data_" + "".join([
                random.choice(string.ascii_letters + string.digits)
                for _ in range(8)
            ])
        df.to_sql(
            name=table_name,
            con=conn,
            index=False,
            dtype=sql_dtypes,
            if_exists="replace",
        )

        # Build a SqlAlchemyDataset using that database
        return SqlAlchemyDataset(table_name,
                                 engine=conn,
                                 profiler=profiler,
                                 caching=caching)

    elif dataset_type == "mysql":
        if not create_engine:
            return None

        engine = create_engine("mysql+pymysql://root@localhost/test_ci")
        conn = engine.connect()

        sql_dtypes = {}
        if (schemas and "mysql" in schemas
                and isinstance(engine.dialect, mysqltypes.dialect)):
            schema = schemas["mysql"]
            sql_dtypes = {
                col: MYSQL_TYPES[dtype]
                for (col, dtype) in schema.items()
            }
            for col in schema:
                type_ = schema[col]
                if type_ in ["INTEGER", "SMALLINT", "BIGINT"]:
                    df[col] = pd.to_numeric(df[col], downcast="signed")
                elif type_ in ["FLOAT", "DOUBLE", "DOUBLE_PRECISION"]:
                    df[col] = pd.to_numeric(df[col])
                    min_value_dbms = get_sql_dialect_floating_point_infinity_value(
                        schema=dataset_type, negative=True)
                    max_value_dbms = get_sql_dialect_floating_point_infinity_value(
                        schema=dataset_type, negative=False)
                    for api_schema_type in ["api_np", "api_cast"]:
                        min_value_api = get_sql_dialect_floating_point_infinity_value(
                            schema=api_schema_type, negative=True)
                        max_value_api = get_sql_dialect_floating_point_infinity_value(
                            schema=api_schema_type, negative=False)
                        df.replace(
                            to_replace=[min_value_api, max_value_api],
                            value=[min_value_dbms, max_value_dbms],
                            inplace=True,
                        )
                elif type_ in ["DATETIME", "TIMESTAMP"]:
                    df[col] = pd.to_datetime(df[col])

        if table_name is None:
            table_name = "test_data_" + "".join([
                random.choice(string.ascii_letters + string.digits)
                for _ in range(8)
            ])
        df.to_sql(
            name=table_name,
            con=conn,
            index=False,
            dtype=sql_dtypes,
            if_exists="replace",
        )

        # Build a SqlAlchemyDataset using that database
        return SqlAlchemyDataset(table_name,
                                 engine=conn,
                                 profiler=profiler,
                                 caching=caching)

    elif dataset_type == "mssql":
        if not create_engine:
            return None

        engine = create_engine(
            "mssql+pyodbc://sa:ReallyStrongPwd1234%^&*@localhost:1433/test_ci?driver=ODBC Driver 17 for SQL Server&charset=utf8&autocommit=true",
            # echo=True,
        )

        # If "autocommit" is not desired to be on by default, then use the following pattern when explicit "autocommit"
        # is desired (e.g., for temporary tables, "autocommit" is off by default, so the override option may be useful).
        # engine.execute(sa.text(sql_query_string).execution_options(autocommit=True))

        conn = engine.connect()

        sql_dtypes = {}
        if (schemas and dataset_type in schemas
                and isinstance(engine.dialect, mssqltypes.dialect)):
            schema = schemas[dataset_type]
            sql_dtypes = {
                col: MSSQL_TYPES[dtype]
                for (col, dtype) in schema.items()
            }
            for col in schema:
                type_ = schema[col]
                if type_ in ["INTEGER", "SMALLINT", "BIGINT"]:
                    df[col] = pd.to_numeric(df[col], downcast="signed")
                elif type_ in ["FLOAT"]:
                    df[col] = pd.to_numeric(df[col])
                    min_value_dbms = get_sql_dialect_floating_point_infinity_value(
                        schema=dataset_type, negative=True)
                    max_value_dbms = get_sql_dialect_floating_point_infinity_value(
                        schema=dataset_type, negative=False)
                    for api_schema_type in ["api_np", "api_cast"]:
                        min_value_api = get_sql_dialect_floating_point_infinity_value(
                            schema=api_schema_type, negative=True)
                        max_value_api = get_sql_dialect_floating_point_infinity_value(
                            schema=api_schema_type, negative=False)
                        df.replace(
                            to_replace=[min_value_api, max_value_api],
                            value=[min_value_dbms, max_value_dbms],
                            inplace=True,
                        )
                elif type_ in ["DATETIME", "TIMESTAMP"]:
                    df[col] = pd.to_datetime(df[col])

        if table_name is None:
            table_name = "test_data_" + "".join([
                random.choice(string.ascii_letters + string.digits)
                for _ in range(8)
            ])
        df.to_sql(
            name=table_name,
            con=conn,
            index=False,
            dtype=sql_dtypes,
            if_exists="replace",
        )

        # Build a SqlAlchemyDataset using that database
        return SqlAlchemyDataset(table_name,
                                 engine=conn,
                                 profiler=profiler,
                                 caching=caching)

    elif dataset_type == "SparkDFDataset":
        from pyspark.sql import SparkSession
        import pyspark.sql.types as sparktypes

        SPARK_TYPES = {
            "StringType": sparktypes.StringType,
            "IntegerType": sparktypes.IntegerType,
            "LongType": sparktypes.LongType,
            "DateType": sparktypes.DateType,
            "TimestampType": sparktypes.TimestampType,
            "FloatType": sparktypes.FloatType,
            "DoubleType": sparktypes.DoubleType,
            "BooleanType": sparktypes.BooleanType,
            "DataType": sparktypes.DataType,
            "NullType": sparktypes.NullType,
        }

        spark = SparkSession.builder.getOrCreate()
        # We need to allow null values in some column types that do not support them natively, so we skip
        # use of df in this case.
        data_reshaped = list(
            zip(*[v for _, v in data.items()]))  # create a list of rows
        if schemas and "spark" in schemas:
            schema = schemas["spark"]
            # sometimes first method causes Spark to throw a TypeError
            try:
                spark_schema = sparktypes.StructType([
                    sparktypes.StructField(column,
                                           SPARK_TYPES[schema[column]](), True)
                    for column in schema
                ])
                # We create these every time, which is painful for testing
                # However nuance around null treatment as well as the desire
                # for real datetime support in tests makes this necessary
                data = copy.deepcopy(data)
                if "ts" in data:
                    print(data)
                    print(schema)
                for col in schema:
                    type_ = schema[col]
                    if type_ in ["IntegerType", "LongType"]:
                        # Ints cannot be None...but None can be valid in Spark (as Null)
                        vals = []
                        for val in data[col]:
                            if val is None:
                                vals.append(val)
                            else:
                                vals.append(int(val))
                        data[col] = vals
                    elif type_ in ["FloatType", "DoubleType"]:
                        vals = []
                        for val in data[col]:
                            if val is None:
                                vals.append(val)
                            else:
                                vals.append(float(val))
                        data[col] = vals
                    elif type_ in ["DateType", "TimestampType"]:
                        vals = []
                        for val in data[col]:
                            if val is None:
                                vals.append(val)
                            else:
                                vals.append(parse(val))
                        data[col] = vals
                # Do this again, now that we have done type conversion using the provided schema
                data_reshaped = list(
                    zip(*[v
                          for _, v in data.items()]))  # create a list of rows
                spark_df = spark.createDataFrame(data_reshaped,
                                                 schema=spark_schema)
            except TypeError:
                string_schema = sparktypes.StructType([
                    sparktypes.StructField(column, sparktypes.StringType())
                    for column in schema
                ])
                spark_df = spark.createDataFrame(data_reshaped, string_schema)
                for c in spark_df.columns:
                    spark_df = spark_df.withColumn(
                        c, spark_df[c].cast(SPARK_TYPES[schema[c]]()))
        elif len(data_reshaped) == 0:
            # if we have an empty dataset and no schema, need to assign an arbitrary type
            columns = list(data.keys())
            spark_schema = sparktypes.StructType([
                sparktypes.StructField(column, sparktypes.StringType())
                for column in columns
            ])
            spark_df = spark.createDataFrame(data_reshaped, spark_schema)
        else:
            # if no schema provided, uses Spark's schema inference
            columns = list(data.keys())
            spark_df = spark.createDataFrame(data_reshaped, columns)
        return SparkDFDataset(spark_df, profiler=profiler, caching=caching)

    else:
        raise ValueError("Unknown dataset_type " + str(dataset_type))