예제 #1
0
 def _start_watch_thread(
         self, origin: GrpcServerRepositoryLocationOrigin) -> None:
     location_name = origin.location_name
     check.invariant(
         location_name not in self._watch_thread_shutdown_events)
     client = origin.create_client()
     shutdown_event, watch_thread = create_grpc_watch_thread(
         location_name,
         client,
         on_updated=lambda location_name, new_server_id: self.
         _send_state_event_to_subscribers(
             LocationStateChangeEvent(
                 LocationStateChangeEventType.LOCATION_UPDATED,
                 location_name=location_name,
                 message="Server has been updated.",
                 server_id=new_server_id,
             )),
         on_error=lambda location_name: self.
         _send_state_event_to_subscribers(
             LocationStateChangeEvent(
                 LocationStateChangeEventType.LOCATION_ERROR,
                 location_name=location_name,
                 message=
                 "Unable to reconnect to server. You can reload the server once it is "
                 "reachable again",
             )),
     )
     self._watch_thread_shutdown_events[location_name] = shutdown_event
     self._watch_threads[location_name] = watch_thread
     watch_thread.start()
예제 #2
0
def _location_origin_from_grpc_server_config(
        grpc_server_config: Dict,
        yaml_path: str) -> GrpcServerRepositoryLocationOrigin:
    check.dict_param(grpc_server_config, "grpc_server_config")
    check.str_param(yaml_path, "yaml_path")

    port, socket, host, location_name, use_ssl = (
        grpc_server_config.get("port"),
        grpc_server_config.get("socket"),
        grpc_server_config.get("host"),
        grpc_server_config.get("location_name"),
        grpc_server_config.get("ssl"),
    )

    check.invariant((socket or port) and not (socket and port),
                    "must supply either a socket or a port")

    if not host:
        host = "localhost"

    return GrpcServerRepositoryLocationOrigin(
        port=port,
        socket=socket,
        host=host,
        location_name=location_name,
        use_ssl=use_ssl,
    )
예제 #3
0
def test_sensor_timeout():
    port = find_free_port()
    python_file = file_relative_path(__file__, "grpc_repo.py")

    subprocess_args = [
        "dagster",
        "api",
        "grpc",
        "--port",
        str(port),
        "--python-file",
        python_file,
    ]

    process = subprocess.Popen(
        subprocess_args,
        stdout=subprocess.PIPE,
    )

    try:
        wait_for_grpc_server(
            process, DagsterGrpcClient(port=port, host="localhost"), subprocess_args
        )
        client = DagsterGrpcClient(port=port)

        with instance_for_test() as instance:
            repo_origin = ExternalRepositoryOrigin(
                repository_location_origin=GrpcServerRepositoryLocationOrigin(
                    port=port, host="localhost"
                ),
                repository_name="bar_repo",
            )
            with pytest.raises(DagsterUserCodeUnreachableError) as exc_info:
                client.external_sensor_execution(
                    sensor_execution_args=SensorExecutionArgs(
                        repository_origin=repo_origin,
                        instance_ref=instance.get_ref(),
                        sensor_name="slow_sensor",
                        last_completion_time=None,
                        last_run_key=None,
                        cursor=None,
                    ),
                    timeout=2,
                )

            assert "Deadline Exceeded" in str(exc_info.getrepr())

            # Call succeeds without the timeout
            client.external_sensor_execution(
                sensor_execution_args=SensorExecutionArgs(
                    repository_origin=repo_origin,
                    instance_ref=instance.get_ref(),
                    sensor_name="slow_sensor",
                    last_completion_time=None,
                    last_run_key=None,
                    cursor=None,
                ),
            )
    finally:
        process.terminate()
예제 #4
0
 def create_origins(self):
     return [
         GrpcServerRepositoryLocationOrigin(
             port=self.port,
             socket=self.socket,
             host=self.host,
             location_name=self.location_name,
         )
     ]
예제 #5
0
def test_sensor_timeout():
    port = find_free_port()
    python_file = file_relative_path(__file__, "grpc_repo.py")

    ipc_output_file = _get_ipc_output_file()
    process = subprocess.Popen(
        [
            "dagster",
            "api",
            "grpc",
            "--port",
            str(port),
            "--python-file",
            python_file,
            "--ipc-output-file",
            ipc_output_file,
        ],
        stdout=subprocess.PIPE,
    )

    try:
        wait_for_grpc_server(process, ipc_output_file)
        client = DagsterGrpcClient(port=port)

        with instance_for_test() as instance:
            repo_origin = ExternalRepositoryOrigin(
                repository_location_origin=GrpcServerRepositoryLocationOrigin(
                    port=port, host="localhost"
                ),
                repository_name="bar_repo",
            )
            with pytest.raises(Exception, match="Deadline Exceeded"):
                client.external_sensor_execution(
                    sensor_execution_args=SensorExecutionArgs(
                        repository_origin=repo_origin,
                        instance_ref=instance.get_ref(),
                        sensor_name="slow_sensor",
                        last_completion_time=None,
                        last_run_key=None,
                    ),
                    timeout=2,
                )

            # Call succeeds without the timeout
            client.external_sensor_execution(
                sensor_execution_args=SensorExecutionArgs(
                    repository_origin=repo_origin,
                    instance_ref=instance.get_ref(),
                    sensor_name="slow_sensor",
                    last_completion_time=None,
                    last_run_key=None,
                ),
            )
    finally:
        process.terminate()
예제 #6
0
파일: handle.py 프로젝트: M-EZZ/dagster
    def create_from_repository_origin(repository_origin, instance):
        check.inst_param(repository_origin, "repository_origin",
                         RepositoryOrigin)
        check.inst_param(instance, "instance", DagsterInstance)

        if isinstance(repository_origin, RepositoryGrpcServerOrigin):
            return RepositoryLocationHandle.create_from_repository_location_origin(
                GrpcServerRepositoryLocationOrigin(
                    port=repository_origin.port,
                    socket=repository_origin.socket,
                    host=repository_origin.host,
                ))
        elif isinstance(repository_origin, RepositoryPythonOrigin):
            loadable_target_origin = repository_origin.loadable_target_origin

            repo_location_origin = ManagedGrpcPythonEnvRepositoryLocationOrigin(
                loadable_target_origin)

            return RepositoryLocationHandle.create_from_repository_location_origin(
                repo_location_origin)
        else:
            raise DagsterInvariantViolationError(
                "Unexpected repository origin type")
예제 #7
0
def test_external_job_origin_instigator_origin():
    def build_legacy_whitelist_map():
        legacy_env = WhitelistMap.create()

        @_whitelist_for_serdes(legacy_env)
        class ExternalJobOrigin(
                namedtuple("_ExternalJobOrigin",
                           "external_repository_origin job_name")):
            def get_id(self):
                return create_snapshot_id(self)

        @_whitelist_for_serdes(legacy_env)
        class ExternalRepositoryOrigin(
                namedtuple("_ExternalRepositoryOrigin",
                           "repository_location_origin repository_name")):
            def get_id(self):
                return create_snapshot_id(self)

        class GrpcServerOriginSerializer(DefaultNamedTupleSerializer):
            @classmethod
            def skip_when_empty(cls):
                return {"use_ssl"}

        @_whitelist_for_serdes(whitelist_map=legacy_env,
                               serializer=GrpcServerOriginSerializer)
        class GrpcServerRepositoryLocationOrigin(
                namedtuple(
                    "_GrpcServerRepositoryLocationOrigin",
                    "host port socket location_name use_ssl",
                ), ):
            def __new__(cls,
                        host,
                        port=None,
                        socket=None,
                        location_name=None,
                        use_ssl=None):
                return super(GrpcServerRepositoryLocationOrigin,
                             cls).__new__(cls, host, port, socket,
                                          location_name, use_ssl)

        return (
            legacy_env,
            ExternalJobOrigin,
            ExternalRepositoryOrigin,
            GrpcServerRepositoryLocationOrigin,
        )

    legacy_env, klass, repo_klass, location_klass = build_legacy_whitelist_map(
    )

    from dagster.core.host_representation.origin import (
        ExternalInstigatorOrigin,
        ExternalRepositoryOrigin,
        GrpcServerRepositoryLocationOrigin,
    )

    # serialize from current code, compare against old code
    instigator_origin = ExternalInstigatorOrigin(
        external_repository_origin=ExternalRepositoryOrigin(
            repository_location_origin=GrpcServerRepositoryLocationOrigin(
                host="localhost", port=1234, location_name="test_location"),
            repository_name="the_repo",
        ),
        instigator_name="simple_schedule",
    )
    instigator_origin_str = serialize_dagster_namedtuple(instigator_origin)
    instigator_to_job = _deserialize_json(instigator_origin_str, legacy_env)
    assert isinstance(instigator_to_job, klass)
    # ensure that the origin id is stable
    assert instigator_to_job.get_id() == instigator_origin.get_id()

    # # serialize from old code, compare against current code
    job_origin = klass(
        external_repository_origin=repo_klass(
            repository_location_origin=location_klass(
                host="localhost", port=1234, location_name="test_location"),
            repository_name="the_repo",
        ),
        job_name="simple_schedule",
    )
    job_origin_str = serialize_value(job_origin, legacy_env)
    from dagster.serdes.serdes import _WHITELIST_MAP

    job_to_instigator = deserialize_json_to_dagster_namedtuple(job_origin_str)
    assert isinstance(job_to_instigator, ExternalInstigatorOrigin)
    # ensure that the origin id is stable
    assert job_to_instigator.get_id() == job_origin.get_id()