Exemplo n.º 1
0
    def set_object(self, key, obj, serialization_strategy=None):
        check.str_param(key, 'key')

        logging.info('Writing GCS object at: ' + self.uri_for_key(key))

        # cannot check obj since could be arbitrary Python object
        check.inst_param(serialization_strategy, 'serialization_strategy',
                         SerializationStrategy)  # cannot be none here

        if self.has_object(key):
            logging.warning('Removing existing GCS key: {key}'.format(key=key))
            backoff(self.rm_object, args=[key], retry_on=(TooManyRequests, ))

        with (BytesIO() if serialization_strategy.write_mode == 'wb'
              or sys.version_info < (3, 0) else StringIO()) as file_like:
            serialization_strategy.serialize(obj, file_like)
            file_like.seek(0)
            backoff(
                self.bucket_obj.blob(key).upload_from_file,
                args=[file_like],
                retry_on=(TooManyRequests, ),
            )

        return ObjectStoreOperation(
            op=ObjectStoreOperationType.SET_OBJECT,
            key=self.uri_for_key(key),
            dest_key=None,
            obj=obj,
            serialization_strategy_name=serialization_strategy.name,
            object_store_name=self.name,
        )
Exemplo n.º 2
0
def test_backoff():
    fn = Failer(fails=100)
    with pytest.raises(RetryableException):
        backoff(fn, retry_on=(RetryableException,), args=[3, 2, 1], kwargs={'foo': 'bar'})

    assert fn.call_count == 5
    assert all([args == (3, 2, 1) for args in fn.args])
    assert all([kwargs == {'foo': 'bar'} for kwargs in fn.kwargs])

    fn = Failer()
    assert backoff(fn, retry_on=(RetryableException,), args=[3, 2, 1], kwargs={'foo': 'bar'})
    assert fn.call_count == 1

    fn = Failer(fails=1)
    assert backoff(fn, retry_on=(RetryableException,), args=[3, 2, 1], kwargs={'foo': 'bar'})
    assert fn.call_count == 2

    fn = Failer(fails=1)
    with pytest.raises(RetryableException):
        backoff(
            fn, retry_on=(RetryableException,), args=[3, 2, 1], kwargs={'foo': 'bar'}, max_retries=0
        )
    assert fn.call_count == 1

    fn = Failer(fails=2)
    with pytest.raises(RetryableException):
        backoff(
            fn, retry_on=(RetryableException,), args=[3, 2, 1], kwargs={'foo': 'bar'}, max_retries=1
        )
    assert fn.call_count == 2
Exemplo n.º 3
0
    def handle_output(self, context, obj):
        key = self._get_path(context)
        context.log.debug(f"Writing GCS object at: {self._uri_for_key(key)}")

        if self._has_object(key):
            context.log.warning(f"Removing existing GCS key: {key}")
            self._rm_object(key)

        pickled_obj = pickle.dumps(obj, PICKLE_PROTOCOL)

        backoff(
            self.bucket_obj.blob(key).upload_from_string,
            args=[pickled_obj],
            retry_on=(TooManyRequests, Forbidden),
        )
Exemplo n.º 4
0
    def set_asset(self, context, obj):
        key = self._get_path(context)
        logging.info("Writing GCS object at: " + self._uri_for_key(key))

        if self._has_object(key):
            logging.warning("Removing existing GCS key: {key}".format(key=key))
            self._rm_object(key)

        pickled_obj = pickle.dumps(obj, PICKLE_PROTOCOL)

        backoff(
            self.bucket_obj.blob(key).upload_from_string,
            args=[pickled_obj],
            retry_on=(TooManyRequests,),
        )
Exemplo n.º 5
0
    def get_step_events(self, run_id: str, step_key: str):
        path = self._dbfs_path(run_id, step_key, PICKLED_EVENTS_FILE_NAME)

        def _get_step_records():
            serialized_records = self.databricks_runner.client.read_file(path)
            if not serialized_records:
                return []
            return deserialize_value(pickle.loads(serialized_records))

        try:
            # reading from dbfs while it writes can be flaky
            # allow for retry if we get malformed data
            return backoff(
                fn=_get_step_records,
                retry_on=(pickle.UnpicklingError, ),
                max_retries=2,
            )
        # if you poll before the Databricks process has had a chance to create the file,
        # we expect to get this error
        except HTTPError as e:
            if e.response.json().get(
                    "error_code") == "RESOURCE_DOES_NOT_EXIST":
                return []

        return []
Exemplo n.º 6
0
    def set_object(self, key, obj, serialization_strategy=None):
        check.str_param(key, "key")

        logging.info("Writing GCS object at: " + self.uri_for_key(key))

        # cannot check obj since could be arbitrary Python object
        check.inst_param(serialization_strategy, "serialization_strategy",
                         SerializationStrategy)  # cannot be none here

        if self.has_object(key):
            logging.warning("Removing existing GCS key: {key}".format(key=key))
            backoff(self.rm_object, args=[key], retry_on=(TooManyRequests, ))

        with (BytesIO() if serialization_strategy.write_mode == "wb"
              or sys.version_info < (3, 0) else StringIO()) as file_like:
            serialization_strategy.serialize(obj, file_like)
            file_like.seek(0)
            backoff(
                self.bucket_obj.blob(key).upload_from_file,
                args=[file_like],
                retry_on=(TooManyRequests, ),
            )

        return self.uri_for_key(key)
Exemplo n.º 7
0
def default_ecs_task_metadata(ec2, ecs):
    """
    ECS injects an environment variable into each Fargate task. The value
    of this environment variable is a url that can be queried to introspect
    information about the current processes's running task:

    https://docs.aws.amazon.com/AmazonECS/latest/userguide/task-metadata-endpoint-v4-fargate.html
    """
    container_metadata_uri = os.environ.get("ECS_CONTAINER_METADATA_URI_V4")
    name = requests.get(container_metadata_uri).json()["Name"]

    task_metadata_uri = container_metadata_uri + "/task"
    response = requests.get(task_metadata_uri).json()
    cluster = response.get("Cluster")
    task_arn = response.get("TaskARN")

    def describe_task_or_raise(task_arn, cluster):
        try:
            return ecs.describe_tasks(tasks=[task_arn], cluster=cluster)["tasks"][0]
        except IndexError:
            raise EcsNoTasksFound

    try:
        task = backoff(
            describe_task_or_raise,
            retry_on=(EcsNoTasksFound,),
            kwargs={"task_arn": task_arn, "cluster": cluster},
            max_retries=BACKOFF_RETRIES,
        )
    except EcsNoTasksFound:
        raise EcsEventualConsistencyTimeout

    enis = []
    subnets = []
    for attachment in task["attachments"]:
        if attachment["type"] == "ElasticNetworkInterface":
            for detail in attachment["details"]:
                if detail["name"] == "subnetId":
                    subnets.append(detail["value"])
                if detail["name"] == "networkInterfaceId":
                    enis.append(ec2.NetworkInterface(detail["value"]))

    public_ip = False
    security_groups = []
    for eni in enis:
        if (eni.association_attribute or {}).get("PublicIp"):
            public_ip = True
        for group in eni.groups:
            security_groups.append(group["GroupId"])

    task_definition_arn = task["taskDefinitionArn"]
    task_definition = ecs.describe_task_definition(taskDefinition=task_definition_arn)[
        "taskDefinition"
    ]

    container_definition = next(
        iter(
            [
                container
                for container in task_definition["containerDefinitions"]
                if container["name"] == name
            ]
        )
    )

    return TaskMetadata(
        cluster=cluster,
        subnets=subnets,
        security_groups=security_groups,
        task_definition=task_definition,
        container_definition=container_definition,
        assign_public_ip="ENABLED" if public_ip else "DISABLED",
    )