def test_resource_invocation_with_config(): @resource(config_schema={"foo": str}) def resource_reqs_config(context): assert context.resource_config["foo"] == "bar" return 5 # Ensure that error is raised when we attempt to invoke with a None context with pytest.raises( DagsterInvalidInvocationError, match= "Resource has required config schema, but no context was provided.", ): resource_reqs_config(None) # Ensure that error is raised when context does not have the required config. context = build_init_resource_context() with pytest.raises( DagsterInvalidConfigError, match="Error in config for resource", ): resource_reqs_config(context) with pytest.raises( DagsterInvalidConfigError, match="Error when applying config mapping for resource", ): resource_reqs_config.configured({"foobar": "bar"})(None) # Ensure that if you configure the respirce, you can provide a none-context. result = resource_reqs_config.configured({"foo": "bar"})(None) assert result == 5 result = resource_reqs_config( build_init_resource_context(config={"foo": "bar"})) assert result == 5
def test_build_with_cm_resource(): entered = [] @resource def foo(_): try: yield "foo" finally: entered.append("true") @resource(required_resource_keys={"foo"}) def reqs_cm_resource(context): return context.resources.foo + "bar" context = build_init_resource_context(resources={"foo": foo}) with pytest.raises(DagsterInvariantViolationError): context.resources # pylint: disable=pointless-statement del context assert entered == ["true"] with build_init_resource_context(resources={"foo": foo}) as context: assert context.resources.foo == "foo" assert reqs_cm_resource(context) == "foobar" assert entered == ["true", "true"]
def test_sync_and_poll_invalid(data, match): ft_resource = fivetran_resource( build_init_resource_context( config={ "api_key": "some_key", "api_secret": "some_secret", } ) ) with pytest.raises(Failure, match=match): with responses.RequestsMock() as rsps: rsps.add( rsps.GET, f"{ft_resource.api_base_url}{DEFAULT_CONNECTOR_ID}/schemas", json=get_complex_sample_connector_schema_config(), ) rsps.add( rsps.GET, f"{ft_resource.api_base_url}{DEFAULT_CONNECTOR_ID}", json=get_sample_connector_response(data=data), ) rsps.add( rsps.PATCH, f"{ft_resource.api_base_url}{DEFAULT_CONNECTOR_ID}", json=get_sample_update_response(), ) rsps.add( rsps.POST, f"{ft_resource.api_base_url}{DEFAULT_CONNECTOR_ID}/force", json=get_sample_sync_response(), ) ft_resource.sync_and_poll(DEFAULT_CONNECTOR_ID, poll_interval=0.1)
def test_resource_invocation_default_config(): @resource(config_schema={ "foo": Field(str, is_required=False, default_value="bar") }) def resource_requires_config(context): assert context.resource_config["foo"] == "bar" return context.resource_config["foo"] assert resource_requires_config(None) == "bar" @resource(config_schema=Field(str, is_required=False, default_value="bar")) def resource_requires_config_val(context): assert context.resource_config == "bar" return context.resource_config assert resource_requires_config_val(None) == "bar" @resource(config_schema={ "foo": Field(str, is_required=False, default_value="bar"), "baz": str, }) def resource_requires_config_partial(context): assert context.resource_config["foo"] == "bar" assert context.resource_config["baz"] == "bar" return context.resource_config["foo"] + context.resource_config["baz"] assert (resource_requires_config_partial( build_init_resource_context(config={"baz": "bar"})) == "barbar")
def test_handle_output_spark_then_load_input_pandas(): snowflake_manager = snowflake_io_manager( build_init_resource_context(config={"database": "TESTDB"})) spark = SparkSession.builder.config( "spark.jars.packages", "net.snowflake:snowflake-jdbc:3.8.0,net.snowflake:spark-snowflake_2.12:2.8.2-spark_3.0", ).getOrCreate() schema = StructType([ StructField("col1", StringType()), StructField("col2", IntegerType()) ]) contents = spark.createDataFrame([Row(col1="Thom", col2=51)], schema) with temporary_snowflake_table(PandasDataFrame([{ "col1": "a", "col2": 1 }])) as temp_table_name: output_context = mock_output_context(temp_table_name) list(snowflake_manager.handle_output(output_context, contents)) # exhaust the iterator input_context = mock_input_context(output_context) input_value = snowflake_manager.load_input(input_context) contents_pandas = contents.toPandas() assert str(input_value) == str( contents_pandas), f"{input_value}\n\n{contents_pandas}"
def test_get_connector_details_flake(max_retries, n_flakes): ft_resource = fivetran_resource( build_init_resource_context( config={ "api_key": "some_key", "api_secret": "some_secret", "request_max_retries": max_retries, "request_retry_delay": 0, } ) ) def _mock_interaction(): with responses.RequestsMock() as rsps: # first n requests fail for _ in range(n_flakes): rsps.add( rsps.GET, f"{ft_resource.api_base_url}{DEFAULT_CONNECTOR_ID}", status=500, ) rsps.add( rsps.GET, f"{ft_resource.api_base_url}{DEFAULT_CONNECTOR_ID}", json=get_sample_connector_response(), ) return ft_resource.get_connector_details(DEFAULT_CONNECTOR_ID) if n_flakes > max_retries: with pytest.raises(Failure, match="Exceeded max number of retries."): _mock_interaction() else: assert _mock_interaction() == get_sample_connector_response()["data"]
def test_resource_invocation_kitchen_sink_config(): @resource( config_schema={ "str_field": str, "int_field": int, "list_int": [int], "list_list_int": [[int]], "dict_field": {"a_string": str}, "list_dict_field": [{"an_int": int}], "selector_of_things": Selector( {"select_list_dict_field": [{"an_int": int}], "select_int": int} ), "optional_list_of_optional_string": Noneable([Noneable(str)]), } ) def kitchen_sink(context): return context.resource_config resource_config = { "str_field": "kjf", "int_field": 2, "list_int": [3], "list_list_int": [[1], [2, 3]], "dict_field": {"a_string": "kdjfkd"}, "list_dict_field": [{"an_int": 2}, {"an_int": 4}], "selector_of_things": {"select_int": 3}, "optional_list_of_optional_string": ["foo", None], } assert kitchen_sink(build_init_resource_context(config=resource_config)) == resource_config
def test_resource_invocation_dict_config(): @resource(config_schema=dict) def resource_requires_dict(context): assert context.resource_config == {"foo": "bar"} return context.resource_config assert resource_requires_dict(build_init_resource_context(config={"foo": "bar"})) == { "foo": "bar" } @resource(config_schema=Noneable(dict)) def resource_noneable_dict(context): return context.resource_config assert resource_noneable_dict(build_init_resource_context()) is None assert resource_noneable_dict(None) is None
def test_versioned_filesystem_io_manager_default_base_dir(): with TemporaryDirectory() as temp_dir: with instance_for_test(temp_dir=temp_dir) as instance: my_io_manager = versioned_filesystem_io_manager( build_init_resource_context(instance=instance)) assert my_io_manager.base_dir == os.path.join( instance.storage_directory(), "versioned_outputs")
def test_trigger_connection_fail(): ab_resource = airbyte_resource( build_init_resource_context(config={ "host": "some_host", "port": "8000" })) with pytest.raises(Failure, match="Exceeded max number of retries"): ab_resource.sync_and_poll("some_connection")
def get_dbt_cloud_resource(**kwargs): return dbt_cloud_resource( build_init_resource_context( config={ "auth_token": "some_auth_token", "account_id": SAMPLE_ACCOUNT_ID, **kwargs }))
def test_build_no_args(): context = build_init_resource_context() assert isinstance(context, InitResourceContext) @resource def basic(_): return "foo" assert basic(context) == "foo"
def common_bucket_s3_pickle_io_manager(init_context): """ A version of the s3_pickle_io_manager that gets its bucket from another resource. """ return s3_pickle_io_manager( build_init_resource_context( config={"s3_bucket": init_context.resources.s3_bucket}, resources={"s3": init_context.resources.s3}, ))
def test_resync_and_poll(n_polls, succeed_at_end): ft_resource = fivetran_resource( build_init_resource_context(config={ "api_key": "some_key", "api_secret": "some_secret", })) api_prefix = f"{ft_resource.api_base_url}{DEFAULT_CONNECTOR_ID}" final_data = ({ "succeeded_at": "2021-01-01T02:00:00.0Z" } if succeed_at_end else { "failed_at": "2021-01-01T02:00:00.0Z" }) def _mock_interaction(): with responses.RequestsMock() as rsps: rsps.add( rsps.GET, f"{ft_resource.api_base_url}{DEFAULT_CONNECTOR_ID}/schemas", json=get_complex_sample_connector_schema_config(), ) rsps.add(rsps.PATCH, api_prefix, json=get_sample_update_response()) rsps.add(rsps.POST, f"{api_prefix}/schemas/tables/resync", json=get_sample_resync_response()) # initial state rsps.add(rsps.GET, api_prefix, json=get_sample_connector_response()) # n polls before updating for _ in range(n_polls): rsps.add(rsps.GET, api_prefix, json=get_sample_connector_response()) # final state will be updated rsps.add(rsps.GET, api_prefix, json=get_sample_connector_response(data=final_data)) return ft_resource.resync_and_poll( DEFAULT_CONNECTOR_ID, resync_parameters={"xyz1": ["abc1", "abc2"]}, poll_interval=0.1, ) if succeed_at_end: assert _mock_interaction() == FivetranOutput( connector_details=get_sample_connector_response( data=final_data)["data"], schema_config=get_complex_sample_connector_schema_config()["data"], ) else: with pytest.raises(Failure, match="failed!"): _mock_interaction()
def test_trigger_connection_fail(): ab_resource = airbyte_resource( build_init_resource_context( config={"host": "some_host", "port": "8000", "request_max_retries": 1} ) ) responses.add( method=responses.POST, url=ab_resource.api_base_url + "/connections/sync", status=500 ) with pytest.raises(Failure, match="Exceeded max number of retries"): ab_resource.sync_and_poll("some_connection")
def test_handle_output_then_load_input_pandas(): snowflake_manager = snowflake_io_manager( build_init_resource_context(config={"database": "TESTDB"}) ) contents1 = PandasDataFrame([{"col1": "a", "col2": 1}]) # just to get the types right contents2 = PandasDataFrame([{"col1": "b", "col2": 2}]) # contents we will insert with temporary_snowflake_table(contents1) as temp_table_name: output_context = mock_output_context(temp_table_name) list(snowflake_manager.handle_output(output_context, contents2)) # exhaust the iterator input_context = mock_input_context(output_context) input_value = snowflake_manager.load_input(input_context) assert input_value.equals(contents2), f"{input_value}\n\n{contents2}"
def test_resource_invocation_with_resources(): @resource(required_resource_keys={"foo"}) def resource_reqs_resources(init_context): return init_context.resources.foo with pytest.raises( DagsterInvalidInvocationError, match="Resource has required resources, but no context was provided.", ): resource_reqs_resources(None) context = build_init_resource_context() with pytest.raises( DagsterInvalidInvocationError, match='Resource requires resource "foo", but no resource ' "with that key was found on the context.", ): resource_reqs_resources(context) context = build_init_resource_context(resources={"foo": "bar"}) assert resource_reqs_resources(context) == "bar"
def test_sync_and_poll_timeout(): ab_resource = airbyte_resource( build_init_resource_context(config={ "host": "some_host", "port": "8000", })) responses.add( method=responses.POST, url=ab_resource.api_base_url + "/connections/get", json={}, status=200, ) responses.add( method=responses.POST, url=ab_resource.api_base_url + "/connections/sync", json={"job": { "id": 1 }}, status=200, ) responses.add( method=responses.POST, url=ab_resource.api_base_url + "/jobs/get", json={"job": { "id": 1, "status": "pending" }}, status=200, ) responses.add( method=responses.POST, url=ab_resource.api_base_url + "/jobs/get", json={"job": { "id": 1, "status": "running" }}, status=200, ) responses.add( method=responses.POST, url=ab_resource.api_base_url + "/jobs/get", json={"job": { "id": 1, "status": "running" }}, status=200, ) poll_wait_second = 2 timeout = 1 with pytest.raises(Failure, match="Timeout: Airbyte job"): ab_resource.sync_and_poll("some_connection", poll_wait_second, timeout)
def test_sync_and_poll(state): ab_resource = airbyte_resource( build_init_resource_context(config={ "host": "some_host", "port": "8000", })) responses.add( method=responses.POST, url=ab_resource.api_base_url + "/connections/get", json=get_sample_connection_json(), status=200, ) responses.add( method=responses.POST, url=ab_resource.api_base_url + "/connections/sync", json={"job": { "id": 1 }}, status=200, ) responses.add( method=responses.POST, url=ab_resource.api_base_url + "/jobs/get", json={"job": { "id": 1, "status": state }}, status=200, ) if state == AirbyteState.ERROR: with pytest.raises(Failure, match="Job failed"): ab_resource.sync_and_poll("some_connection", 0) elif state == AirbyteState.CANCELLED: with pytest.raises(Failure, match="Job was cancelled"): ab_resource.sync_and_poll("some_connection", 0) elif state == "unrecognized": with pytest.raises(Failure, match="unexpected state"): ab_resource.sync_and_poll("some_connection", 0) else: result = ab_resource.sync_and_poll("some_connection", 0) assert result == AirbyteOutput( job_details={"job": { "id": 1, "status": state }}, connection_details=get_sample_connection_json(), )
def test_get_job_status_bad_out_fail(): ab_resource = airbyte_resource( build_init_resource_context(config={ "host": "some_host", "port": "8000", })) responses.add( method=responses.POST, url=ab_resource.api_base_url + "/jobs/get", json=None, status=204, ) with pytest.raises(check.CheckError): ab_resource.get_job_status("some_connection")
def test_build_with_resources(): @resource def foo(_): return "foo" context = build_init_resource_context(resources={"foo": foo, "bar": "bar"}) assert context.resources.foo == "foo" assert context.resources.bar == "bar" @resource(required_resource_keys={"foo", "bar"}) def reqs_resources(context): return context.resources.foo + context.resources.bar assert reqs_resources(context) == "foobar"
def test_df_to_csv_io_manager(): with tempfile.TemporaryDirectory() as temp_dir: my_io_manager = df_to_csv_io_manager( build_init_resource_context(config={"base_dir": temp_dir})) test_df = pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]}) # test handle_output output_context = build_output_context(name="abc", step_key="123") my_io_manager.handle_output(output_context, test_df) output_path = my_io_manager._get_path(output_context) # pylint:disable=protected-access assert os.path.exists(output_path) assert test_df.equals(pd.read_csv(output_path)) # test load_input input_context = build_input_context(upstream_output=output_context) assert test_df.equals(my_io_manager.load_input(input_context))
def test_trigger_connection(): ab_resource = airbyte_resource( build_init_resource_context(config={ "host": "some_host", "port": "8000", })) responses.add( method=responses.POST, url=ab_resource.api_base_url + "/connections/sync", json={"job": { "id": 1 }}, status=200, ) resp = ab_resource.start_sync("some_connection") assert resp == {"job": {"id": 1}}
def test_get_connector_sync_status(data, expected): ft_resource = fivetran_resource( build_init_resource_context(config={ "api_key": "some_key", "api_secret": "some_secret", })) with responses.RequestsMock() as rsps: rsps.add( rsps.GET, f"{ft_resource.api_base_url}{DEFAULT_CONNECTOR_ID}", json=get_sample_connector_response(data=data), ) assert ft_resource.get_connector_sync_status( DEFAULT_CONNECTOR_ID) == expected
def test_get_connector_details(): ft_resource = fivetran_resource( build_init_resource_context(config={ "api_key": "some_key", "api_secret": "some_secret", })) with responses.RequestsMock() as rsps: rsps.add( rsps.GET, f"{ft_resource.api_base_url}{DEFAULT_CONNECTOR_ID}", json=get_sample_connector_response(), ) assert (ft_resource.get_connector_details(DEFAULT_CONNECTOR_ID) == get_sample_connector_response()["data"])
def test_handle_output_then_load_input_pandas(): snowflake_manager = snowflake_io_manager( build_init_resource_context( config={"database": "TESTDB"}, resources={"partition_bounds": None} ) ) contents1 = PandasDataFrame([{"col1": "a", "col2": 1}]) # just to get the types right contents2 = PandasDataFrame([{"col1": "b", "col2": 2}]) # contents we will insert with temporary_snowflake_table(contents1) as temp_table_name: metadata = {"table": f"public.{temp_table_name}"} output_context = build_output_context(metadata=metadata) list(snowflake_manager.handle_output(output_context, contents2)) # exhaust the iterator input_context = build_input_context(upstream_output=output_context) input_value = snowflake_manager.load_input(input_context) assert input_value.equals(contents2), f"{input_value}\n\n{contents2}"
def test_handle_output_then_load_input(): snowflake_config = generate_snowflake_config() snowflake_manager = snowflake_io_manager(build_init_resource_context(config=snowflake_config)) contents1 = DataFrame([{"col1": "a", "col2": 1}]) # just to get the types right contents2 = DataFrame([{"col1": "b", "col2": 2}]) # contents we will insert with temporary_snowflake_table(contents1) as temp_table_name: metadata = { "table": f"public.{temp_table_name}", } output_context = build_output_context(metadata=metadata, resource_config=snowflake_config) list(snowflake_manager.handle_output(output_context, contents2)) # exhaust the iterator input_context = build_input_context( upstream_output=output_context, resource_config=snowflake_config ) input_value = snowflake_manager.load_input(input_context) assert input_value.equals(contents2), f"{input_value}\n\n{contents2}"
def test_assets(): ab_resource = airbyte_resource( build_init_resource_context(config={ "host": "some_host", "port": "8000", })) responses.add( method=responses.POST, url=ab_resource.api_base_url + "/connections/get", json=get_sample_connection_json(), status=200, ) responses.add( method=responses.POST, url=ab_resource.api_base_url + "/connections/sync", json={"job": { "id": 1 }}, status=200, ) responses.add( method=responses.POST, url=ab_resource.api_base_url + "/jobs/get", json=get_sample_job_json(), status=200, ) airbyte_output = ab_resource.sync_and_poll("some_connection", 0, None) materializations = list(generate_materializations(airbyte_output, [])) assert len(materializations) == 3 assert MetadataEntry("bytesEmitted", value=1234) in materializations[0].metadata_entries assert MetadataEntry("recordsCommitted", value=4321) in materializations[0].metadata_entries
def test_resource_config_example(): dbconn = db_resource(build_init_resource_context(config={"connection": "foo"})) assert dbconn.connection == "foo"
def test_my_resource_with_context(): init_context = build_init_resource_context( resources={"foo": "foo_str"}, config={"bar": "bar_str"} ) assert my_resource_requires_context(init_context) == ("foo_str", "bar_str")