예제 #1
0
def test_start_schedule_manual_add_debug(
        restore_cron_tab,
        snapshot  # pylint:disable=unused-argument,redefined-outer-name
):
    with TemporaryDirectory() as tempdir:
        instance = define_scheduler_instance(tempdir)

        # Initialize scheduler
        instance.reconcile_scheduler_state(
            python_path="fake path",
            repository_path=file_relative_path(__file__,
                                               '.../repository.yaml'),
            repository=test_repository,
        )

        # Manually add the schedule from to the crontab
        instance.scheduler._start_cron_job(  # pylint: disable=protected-access
            instance,
            test_repository.name,
            instance.get_schedule_by_name(
                test_repository.name, "no_config_pipeline_every_min_schedule"),
        )

        # Check debug command
        debug_info = instance.scheduler_debug_info()
        assert len(debug_info.errors) == 1

        # Reconcile should fix error
        instance.reconcile_scheduler_state(
            python_path="fake path",
            repository_path=file_relative_path(__file__,
                                               '.../repository.yaml'),
            repository=test_repository,
        )
        debug_info = instance.scheduler_debug_info()
        assert len(debug_info.errors) == 0
예제 #2
0
def test_event_log_asset_key_migration():
    src_dir = file_relative_path(
        __file__, "snapshot_0_7_8_pre_asset_key_migration/sqlite")
    with copy_directory(src_dir) as test_dir:
        db_path = os.path.join(test_dir, "history", "runs",
                               "722183e4-119f-4a00-853f-e1257be82ddb.db")
        assert get_current_alembic_version(db_path) == "3b1e175a2be3"
        assert "asset_key" not in set(
            get_sqlite3_columns(db_path, "event_logs"))

        # Make sure the schema is migrated
        instance = DagsterInstance.from_ref(InstanceRef.from_dir(test_dir))
        instance.upgrade()

        assert "asset_key" in set(get_sqlite3_columns(db_path, "event_logs"))
예제 #3
0
def test_dagster_yaml():
    dagster_yaml_folder = file_relative_path(
        __file__, "../../docs_snippets/deploying/docker/")

    res, custom_instance_class = dagster_instance_config(
        dagster_yaml_folder, "dagster.yaml")
    assert set(res.keys()) == {
        "run_storage",
        "event_log_storage",
        "schedule_storage",
        "compute_logs",
        "local_artifact_storage",
    }

    assert custom_instance_class is None
예제 #4
0
def test_event_log_asset_partition_migration():
    src_dir = file_relative_path(__file__,
                                 "snapshot_0_9_22_pre_asset_partition/sqlite")
    with copy_directory(src_dir) as test_dir:
        db_path = os.path.join(test_dir, "history", "runs",
                               "1a1d3c4b-1284-4c74-830c-c8988bd4d779.db")
        assert get_current_alembic_version(db_path) == "c34498c29964"
        assert "partition" not in set(
            get_sqlite3_columns(db_path, "event_logs"))

        # Make sure the schema is migrated
        instance = DagsterInstance.from_ref(InstanceRef.from_dir(test_dir))
        instance.upgrade()

        assert "partition" in set(get_sqlite3_columns(db_path, "event_logs"))
예제 #5
0
def test_0_6_4():
    test_dir = file_relative_path(__file__, 'snapshot_0_6_4')
    with restore_directory(test_dir):
        instance = DagsterInstance.from_ref(InstanceRef.from_dir(test_dir))

        runs = instance.get_runs()
        with pytest.raises(
                DagsterInstanceMigrationRequired,
                match=re.escape(
                    'Instance is out of date and must be migrated (SqliteEventLogStorage for run '
                    'c7a6c4d7-6c88-46d0-8baa-d4937c3cefe5). Database is at revision None, head is '
                    '567bc23fd1ac. Please run `dagster instance migrate`.'),
        ):
            for run in runs:
                instance.all_logs(run.run_id)
예제 #6
0
def test_snapshot_command_pipeline():
    runner = CliRunner()
    result = runner.invoke(
        pipeline_snapshot_command,
        ['-y', file_relative_path(__file__, 'repository_file.yaml'), 'foo'],
    )
    assert result.exit_code == 0
    # Now that we have the snapshot make sure that it can be properly deserialized
    external_pipeline_data = deserialize_json_to_dagster_namedtuple(result.output)
    assert isinstance(external_pipeline_data, ExternalPipelineData)
    assert external_pipeline_data.name == 'foo'
    assert (
        len(external_pipeline_data.pipeline_snapshot.solid_definitions_snapshot.solid_def_snaps)
        == 2
    )
예제 #7
0
def test_pipelines_success(file_path, run_config_path):

    with pushd(
            file_relative_path(__file__,
                               '../../../docs_snippets/legacy/data_science/')):
        instance = DagsterInstance.local_temp()
        run_config = load_yaml_from_path(
            run_config_path) if run_config_path else None
        recon_pipeline = ReconstructablePipeline.for_file(
            file_path, 'iris_pipeline')

        pipeline_result = execute_pipeline(recon_pipeline,
                                           run_config=run_config,
                                           instance=instance)
        assert pipeline_result.success
예제 #8
0
파일: pipelines.py 프로젝트: cs947/dagster
def define_airline_demo_ingest_pipeline():
    solids = [process_on_time_data, sfo_weather_data, s3_to_dw_table]
    dependencies = {
        SolidInstance('s3_to_dw_table', alias='process_q2_coupon_data'): {},
        SolidInstance('s3_to_dw_table', alias='process_q2_market_data'): {},
        SolidInstance('s3_to_dw_table', alias='process_q2_ticket_data'): {},
    }

    return PipelineDefinition(
        name="airline_demo_ingest_pipeline",
        solids=solids,
        dependencies=dependencies,
        mode_definitions=[test_mode, local_mode, prod_mode],
        preset_definitions=[
            PresetDefinition(
                name='local_fast',
                mode='local',
                environment_files=[
                    file_relative_path(__file__,
                                       'environments/local_base.yaml'),
                    file_relative_path(__file__,
                                       'environments/local_fast_ingest.yaml'),
                ],
            ),
            PresetDefinition(
                name='local_full',
                mode='local',
                environment_files=[
                    file_relative_path(__file__,
                                       'environments/local_base.yaml'),
                    file_relative_path(__file__,
                                       'environments/local_full_ingest.yaml'),
                ],
            ),
        ],
    )
예제 #9
0
def test_run_always_finishes():  # pylint: disable=redefined-outer-name
    with instance_for_test() as instance:
        pipeline_run = instance.create_run_for_pipeline(
            pipeline_def=slow_pipeline, run_config=None)
        run_id = pipeline_run.run_id

        loadable_target_origin = LoadableTargetOrigin(
            executable_path=sys.executable,
            attribute="nope",
            python_file=file_relative_path(__file__,
                                           "test_default_run_launcher.py"),
        )
        server_process = GrpcServerProcess(
            loadable_target_origin=loadable_target_origin, max_workers=4)
        with server_process.create_ephemeral_client() as api_client:
            with GrpcServerRepositoryLocationOrigin(
                    location_name="test",
                    port=api_client.port,
                    socket=api_client.socket,
                    host=api_client.host,
            ).create_handle() as handle:
                repository_location = GrpcServerRepositoryLocation(handle)

                external_pipeline = repository_location.get_repository(
                    "nope").get_full_external_pipeline("slow_pipeline")

                assert instance.get_run_by_id(
                    run_id).status == PipelineRunStatus.NOT_STARTED

                instance.launch_run(run_id=pipeline_run.run_id,
                                    external_pipeline=external_pipeline)

        # Server process now receives shutdown event, run has not finished yet
        pipeline_run = instance.get_run_by_id(run_id)
        assert not pipeline_run.is_finished
        assert server_process.server_process.poll() is None

        # Server should wait until run finishes, then shutdown
        pipeline_run = poll_for_finished_run(instance, run_id)
        assert pipeline_run.status == PipelineRunStatus.SUCCESS

        start_time = time.time()
        while server_process.server_process.poll() is None:
            time.sleep(0.05)
            # Verify server process cleans up eventually
            assert time.time() - start_time < 5

        server_process.wait()
예제 #10
0
def test_run_partition_data_migration():
    src_dir = file_relative_path(
        __file__, "snapshot_0_9_22_post_schema_pre_data_partition/sqlite")
    with copy_directory(src_dir) as test_dir:
        from dagster.core.storage.runs.sql_run_storage import SqlRunStorage
        from dagster.core.storage.runs.migration import RUN_PARTITIONS

        # load db that has migrated schema, but not populated data for run partitions
        db_path = os.path.join(test_dir, "history", "runs.db")
        assert get_current_alembic_version(db_path) == "375e95bad550"

        # Make sure the schema is migrated
        assert "partition" in set(get_sqlite3_columns(db_path, "runs"))
        assert "partition_set" in set(get_sqlite3_columns(db_path, "runs"))

        with DagsterInstance.from_ref(
                InstanceRef.from_dir(test_dir)) as instance:
            instance._run_storage.upgrade()

        run_storage = instance._run_storage
        assert isinstance(run_storage, SqlRunStorage)

        partition_set_name = "ingest_and_train"
        partition_name = "2020-01-02"

        # ensure old tag-based reads are working
        assert not run_storage.has_built_index(RUN_PARTITIONS)
        assert len(
            run_storage._get_partition_runs(partition_set_name,
                                            partition_name)) == 2

        # turn on reads for the partition column, without migrating the data
        run_storage.mark_index_built(RUN_PARTITIONS)

        # ensure that no runs are returned because the data has not been migrated
        assert run_storage.has_built_index(RUN_PARTITIONS)
        assert len(
            run_storage._get_partition_runs(partition_set_name,
                                            partition_name)) == 0

        # actually migrate the data
        run_storage.build_missing_indexes(force_rebuild_all=True)

        # ensure that we get the same partitioned runs returned
        assert run_storage.has_built_index(RUN_PARTITIONS)
        assert len(
            run_storage._get_partition_runs(partition_set_name,
                                            partition_name)) == 2
예제 #11
0
def test_dask_terminate():
    run_config = {
        "solids": {
            "sleepy_dask_solid": {
                "inputs": {
                    "df": {
                        "csv": {
                            "path": file_relative_path(__file__, "ex*.csv")
                        }
                    }
                }
            }
        }
    }

    interrupt_thread = None
    result_types = []
    received_interrupt = False

    with instance_for_test() as instance:
        try:
            for result in execute_pipeline_iterator(
                    pipeline=ReconstructablePipeline.for_file(
                        __file__, sleepy_dask_pipeline.name),
                    run_config=run_config,
                    instance=instance,
            ):
                # Interrupt once the first step starts
                if result.event_type == DagsterEventType.STEP_START and not interrupt_thread:
                    interrupt_thread = Thread(target=send_interrupt, args=())
                    interrupt_thread.start()

                if result.event_type == DagsterEventType.STEP_FAILURE:
                    assert ("DagsterExecutionInterruptedError"
                            in result.event_specific_data.error.message)

                result_types.append(result.event_type)

            assert False
        except DagsterExecutionInterruptedError:
            received_interrupt = True

        assert received_interrupt

        interrupt_thread.join()

        assert DagsterEventType.STEP_FAILURE in result_types
        assert DagsterEventType.PIPELINE_FAILURE in result_types
예제 #12
0
def test_pandas_dask():
    environment_dict = {
        'solids': {
            'pandas_solid': {
                'inputs': {'df': {'csv': {'path': file_relative_path(__file__, 'ex.csv')}}}
            }
        }
    }

    result = execute_on_dask(
        ExecutionTargetHandle.for_pipeline_python_file(__file__, pandas_pipeline.name),
        env_config={'storage': {'filesystem': {}}, **environment_dict},
        dask_config=DaskConfig(timeout=30),
    )

    assert result.success
예제 #13
0
        def _mgr_fn(recon_repo, instance, read_only):
            """Goes out of process but same process as host process"""
            check.inst_param(recon_repo, "recon_repo", ReconstructableRepository)

            with WorkspaceProcessContext(
                instance,
                PythonFileTarget(
                    python_file=file_relative_path(__file__, "setup.py"),
                    attribute="test_dict_repo",
                    working_directory=None,
                    location_name="test",
                ),
                version="",
                read_only=read_only,
            ) as workspace:
                yield workspace
예제 #14
0
def test_snapshot_command_handle_repository():
    runner = CliRunner()
    with safe_tempfile_path() as fp:
        result = runner.invoke(
            snapshot_command,
            [fp, '-y',
             file_relative_path(__file__, 'repository_file.yaml')],
        )
        assert result.exit_code == 0
        # Now that we have the snapshot make sure that it can be properly deserialized
        with open(fp) as buffer:
            active_repository_data = deserialize_json_to_dagster_namedtuple(
                buffer.read())
        assert isinstance(active_repository_data, ActiveRepositoryData)
        assert active_repository_data.name == 'bar'
        assert len(active_repository_data.active_pipeline_datas) == 2
예제 #15
0
파일: many_events.py 프로젝트: keyz/dagster
def many_table_materializations(_context):
    with open(file_relative_path(__file__, MARKDOWN_EXAMPLE), "r") as f:
        md_str = f.read()
        for table in raw_tables:
            yield AssetMaterialization(
                asset_key="table_info",
                metadata={
                    "table_name": table,
                    "table_path": EventMetadata.path(f"/path/to/{table}"),
                    "table_data": {"name": table},
                    "table_name_big": EventMetadata.url(f"https://bigty.pe/{table}"),
                    "table_blurb": EventMetadata.md(md_str),
                    "big_int": 29119888133298982934829348,
                    "float_nan": float("nan"),
                },
            )
예제 #16
0
def get_external_pipeline_from_python_location(pipeline_name):
    repository_location_handle = RepositoryLocationHandle.create_python_env_location(
        loadable_target_origin=LoadableTargetOrigin(
            executable_path=sys.executable,
            attribute="nope",
            python_file=file_relative_path(__file__, "test_default_run_launcher.py"),
        ),
        location_name="nope",
        user_process_api=UserProcessApi.CLI,
    )

    yield (
        RepositoryLocation.from_handle(repository_location_handle)
        .get_repository("nope")
        .get_full_external_pipeline(pipeline_name)
    )
예제 #17
0
def get_external_pipeline_from_managed_grpc_python_env_repository(
        pipeline_name):
    with RepositoryLocationHandle.create_from_repository_location_origin(
            ManagedGrpcPythonEnvRepositoryLocationOrigin(
                loadable_target_origin=LoadableTargetOrigin(
                    executable_path=sys.executable,
                    attribute="nope",
                    python_file=file_relative_path(
                        __file__, "test_default_run_launcher.py"),
                ),
                location_name="nope",
            )) as repository_location_handle:
        repository_location = GrpcServerRepositoryLocation(
            repository_location_handle)
        yield repository_location.get_repository(
            "nope").get_full_external_pipeline(pipeline_name)
예제 #18
0
def many_table_materializations(_context):
    with open(file_relative_path(__file__, MARKDOWN_EXAMPLE), 'r') as f:
        md_str = f.read()
        for table in raw_tables:
            yield Materialization(
                label='table_info',
                metadata_entries=[
                    EventMetadataEntry.text(text=table, label='table_name'),
                    EventMetadataEntry.fspath(path='/path/to/{}'.format(table), label='table_path'),
                    EventMetadataEntry.json(data={'name': table}, label='table_data'),
                    EventMetadataEntry.url(
                        url='https://bigty.pe/{}'.format(table), label='table_name_big'
                    ),
                    EventMetadataEntry.md(md_str=md_str, label='table_blurb'),
                ],
            )
def test_snapshot_command_error_on_pipeline_definition():
    runner = CliRunner()
    with pytest.raises(ParameterCheckError):
        with safe_tempfile_path() as fp:
            result = runner.invoke(
                snapshot_command,
                [
                    fp,
                    '-f',
                    file_relative_path(__file__, 'test_cli_commands.py'),
                    '-n',
                    'baz_pipeline',
                ],
            )
            assert result.exit_code == 1
            raise result.exception
예제 #20
0
def test_pipelines_success(file_path, run_config_path):

    with pushd(file_relative_path(__file__, "../../../docs_snippets/legacy/data_science/")):
        with instance_for_test() as instance:
            run_config = load_yaml_from_path(run_config_path) if run_config_path else {}
            recon_pipeline = ReconstructablePipeline.for_file(file_path, "iris_pipeline")

            with tempfile.TemporaryDirectory() as temp_dir:
                run_config["resources"] = {"io_manager": {"config": {"base_dir": temp_dir}}}
                pipeline_result = execute_pipeline(
                    recon_pipeline,
                    run_config=run_config,
                    instance=instance,
                    solid_selection=["k_means_iris"],  # skip download_file in tests
                )
                assert pipeline_result.success
예제 #21
0
def test_event_log_asset_key_migration():
    test_dir = file_relative_path(
        __file__, 'snapshot_0_7_8_pre_asset_key_migration/sqlite')
    with restore_directory(test_dir):
        db_path = os.path.join(test_dir, 'history', 'runs',
                               '722183e4-119f-4a00-853f-e1257be82ddb.db')
        assert get_current_alembic_version(db_path) == '3b1e175a2be3'
        assert 'asset_key' not in set(get_sqlite3_columns(
            db_path, 'event_log'))

        # Make sure the schema is migrated
        instance = DagsterInstance.from_ref(InstanceRef.from_dir(test_dir))
        instance.upgrade()

        assert get_current_alembic_version(db_path) == 'c39c047fa021'
        assert 'asset_key' in set(get_sqlite3_columns(db_path, 'event_logs'))
예제 #22
0
def test_schedule_namedtuple_job_instigator_backcompat():
    src_dir = file_relative_path(__file__, "snapshot_0_13_19_instigator_named_tuples/sqlite")
    with copy_directory(src_dir) as test_dir:
        with DagsterInstance.from_ref(InstanceRef.from_dir(test_dir)) as instance:
            states = instance.all_instigator_state()
            assert len(states) == 2
            check.is_list(states, of_type=InstigatorState)
            for state in states:
                assert state.instigator_type
                assert state.instigator_data
                ticks = instance.get_ticks(state.instigator_origin_id, state.selector_id)
                check.is_list(ticks, of_type=InstigatorTick)
                for tick in ticks:
                    assert tick.tick_data
                    assert tick.instigator_type
                    assert tick.instigator_name
예제 #23
0
def test_downgrade_and_upgrade():
    test_dir = file_relative_path(__file__, 'snapshot_0_7_6_pre_add_pipeline_snapshot/sqlite')
    with restore_directory(test_dir):
        # invariant check to make sure migration has not been run yet

        db_path = os.path.join(test_dir, 'history', 'runs.db')

        assert get_current_alembic_version(db_path) == '9fe9e746268c'

        assert 'snapshots' not in get_sqlite3_tables(db_path)

        instance = DagsterInstance.from_ref(InstanceRef.from_dir(test_dir))

        assert len(instance.get_runs()) == 1

        # Make sure the schema is migrated
        instance.upgrade()

        assert get_current_alembic_version(db_path) == 'c63a27054f08'

        assert 'snapshots' in get_sqlite3_tables(db_path)
        assert {'id', 'snapshot_id', 'snapshot_body', 'snapshot_type'} == set(
            get_sqlite3_columns(db_path, 'snapshots')
        )

        assert len(instance.get_runs()) == 1

        instance._run_storage._alembic_downgrade(rev='9fe9e746268c')

        assert get_current_alembic_version(db_path) == '9fe9e746268c'

        assert 'snapshots' not in get_sqlite3_tables(db_path)

        instance = DagsterInstance.from_ref(InstanceRef.from_dir(test_dir))

        assert len(instance.get_runs()) == 1

        instance.upgrade()

        assert get_current_alembic_version(db_path) == 'c63a27054f08'

        assert 'snapshots' in get_sqlite3_tables(db_path)
        assert {'id', 'snapshot_id', 'snapshot_body', 'snapshot_type'} == set(
            get_sqlite3_columns(db_path, 'snapshots')
        )

        assert len(instance.get_runs()) == 1
def test_terminate_after_shutdown():
    with instance_for_test() as instance:
        with RepositoryLocationHandle.create_from_repository_location_origin(
            ManagedGrpcPythonEnvRepositoryLocationOrigin(
                loadable_target_origin=LoadableTargetOrigin(
                    executable_path=sys.executable,
                    attribute="nope",
                    python_file=file_relative_path(__file__, "test_default_run_launcher.py"),
                ),
                location_name="nope",
            )
        ) as repository_location_handle:
            repository_location = GrpcServerRepositoryLocation(repository_location_handle)

            external_pipeline = repository_location.get_repository(
                "nope"
            ).get_full_external_pipeline("sleepy_pipeline")

            pipeline_run = instance.create_run_for_pipeline(
                pipeline_def=sleepy_pipeline, run_config=None
            )

            instance.launch_run(pipeline_run.run_id, external_pipeline)

            poll_for_step_start(instance, pipeline_run.run_id)

            # Tell the server to shut down once executions finish
            repository_location_handle.client.cleanup_server()

            # Trying to start another run fails
            doomed_to_fail_external_pipeline = repository_location.get_repository(
                "nope"
            ).get_full_external_pipeline("math_diamond")
            doomed_to_fail_pipeline_run = instance.create_run_for_pipeline(
                pipeline_def=math_diamond, run_config=None
            )

            with pytest.raises(DagsterLaunchFailedError):
                instance.launch_run(
                    doomed_to_fail_pipeline_run.run_id, doomed_to_fail_external_pipeline
                )

            launcher = instance.run_launcher

            # Can terminate the run even after the shutdown event has been received
            assert launcher.can_terminate(pipeline_run.run_id)
            assert launcher.terminate(pipeline_run.run_id)
예제 #25
0
def test_normalize_column_names():
    path = file_relative_path(__file__, "canada.csv")

    input_df = dd.read_csv(path)
    assert all(col in input_df.columns for col in ("ID", "provinceOrTerritory", "country"))

    # Set normalize_column_names=False to not modify the column names
    run_config = generate_config(path, normalize_column_names=False)
    result = execute_solid(passthrough, run_config=run_config)
    output_df = result.output_value()
    assert all(col in output_df.columns for col in ("ID", "provinceOrTerritory", "country"))

    # Set normalize_column_names=True to modify the column names
    run_config = generate_config(path, normalize_column_names=True)
    result = execute_solid(passthrough, run_config=run_config)
    output_df = result.output_value()
    assert all(col in output_df.columns for col in ("id", "province_or_territory", "country"))
예제 #26
0
파일: step_four.py 프로젝트: yuhan/dagster
def get_in_repo_preset_definition():
    return PresetDefinition(
        "in_repo",
        run_config={
            "solids": {
                "add_sugar_per_cup": {
                    "inputs": {
                        "cereals": {
                            "csv": {"path": file_relative_path(__file__, "data/cereal.csv")}
                        }
                    }
                }
            },
            "execution": {"multiprocess": {}},
            "storage": {"filesystem": {}},
        },
    )
def get_external_pipeline_from_managed_grpc_python_env_repository(
        pipeline_name):
    repository_location_handle = RepositoryLocationHandle.create_process_bound_grpc_server_location(
        loadable_target_origin=LoadableTargetOrigin(
            attribute="nope",
            python_file=file_relative_path(__file__,
                                           "test_cli_api_run_launcher.py"),
        ),
        location_name="nope",
    )
    repository_location = GrpcServerRepositoryLocation(
        repository_location_handle)
    try:
        yield repository_location.get_repository(
            "nope").get_full_external_pipeline(pipeline_name)
    finally:
        repository_location_handle.cleanup()
예제 #28
0
def test_spark_dataframe_output_csv():
    spark = SparkSession.builder.getOrCreate()
    num_df = (spark.read.format('csv').options(
        header='true',
        inferSchema='true').load(file_relative_path(__file__, 'num.csv')))

    assert num_df.collect() == [Row(num1=1, num2=2)]

    @solid
    def emit(_):
        return num_df

    @solid(input_defs=[InputDefinition('df', DataFrame)],
           output_defs=[OutputDefinition(DataFrame)])
    def passthrough_df(_context, df):
        return df

    @pipeline
    def passthrough():
        passthrough_df(emit())

    with seven.TemporaryDirectory() as tempdir:
        file_name = os.path.join(tempdir, 'output.csv')
        result = execute_pipeline(
            passthrough,
            run_config={
                'solids': {
                    'passthrough_df': {
                        'outputs': [{
                            'result': {
                                'csv': {
                                    'path': file_name,
                                    'header': True
                                }
                            }
                        }]
                    }
                },
            },
        )

        from_file_df = (spark.read.format('csv').options(
            header='true', inferSchema='true').load(file_name))

        assert (result.result_for_solid('passthrough_df').output_value().
                collect() == from_file_df.collect())
예제 #29
0
파일: test_types.py 프로젝트: yuhan/dagster
def test_spark_dataframe_output_csv():
    spark = SparkSession.builder.getOrCreate()
    num_df = (spark.read.format("csv").options(
        header="true",
        inferSchema="true").load(file_relative_path(__file__, "num.csv")))

    assert num_df.collect() == [Row(num1=1, num2=2)]

    @solid
    def emit(_):
        return num_df

    @solid(input_defs=[InputDefinition("df", DataFrame)],
           output_defs=[OutputDefinition(DataFrame)])
    def passthrough_df(_context, df):
        return df

    @pipeline
    def passthrough():
        passthrough_df(emit())

    with seven.TemporaryDirectory() as tempdir:
        file_name = os.path.join(tempdir, "output.csv")
        result = execute_pipeline(
            passthrough,
            run_config={
                "solids": {
                    "passthrough_df": {
                        "outputs": [{
                            "result": {
                                "csv": {
                                    "path": file_name,
                                    "header": True
                                }
                            }
                        }]
                    }
                },
            },
        )

        from_file_df = (spark.read.format("csv").options(
            header="true", inferSchema="true").load(file_name))

        assert (result.result_for_solid("passthrough_df").output_value().
                collect() == from_file_df.collect())
예제 #30
0
def test_downgrade_and_upgrade():
    test_dir = file_relative_path(
        __file__, "snapshot_0_7_6_pre_add_pipeline_snapshot/sqlite")
    with restore_directory(test_dir):
        # invariant check to make sure migration has not been run yet

        db_path = os.path.join(test_dir, "history", "runs.db")

        assert get_current_alembic_version(db_path) == "9fe9e746268c"

        assert "snapshots" not in get_sqlite3_tables(db_path)

        instance = DagsterInstance.from_ref(InstanceRef.from_dir(test_dir))

        assert len(instance.get_runs()) == 1

        # Make sure the schema is migrated
        instance.upgrade()

        assert get_current_alembic_version(db_path) == "c63a27054f08"

        assert "snapshots" in get_sqlite3_tables(db_path)
        assert {"id", "snapshot_id", "snapshot_body", "snapshot_type"
                } == set(get_sqlite3_columns(db_path, "snapshots"))

        assert len(instance.get_runs()) == 1

        instance._run_storage._alembic_downgrade(rev="9fe9e746268c")

        assert get_current_alembic_version(db_path) == "9fe9e746268c"

        assert "snapshots" not in get_sqlite3_tables(db_path)

        instance = DagsterInstance.from_ref(InstanceRef.from_dir(test_dir))

        assert len(instance.get_runs()) == 1

        instance.upgrade()

        assert get_current_alembic_version(db_path) == "c63a27054f08"

        assert "snapshots" in get_sqlite3_tables(db_path)
        assert {"id", "snapshot_id", "snapshot_body", "snapshot_type"
                } == set(get_sqlite3_columns(db_path, "snapshots"))

        assert len(instance.get_runs()) == 1