Example #1
0
def test_formdata():
    stream = io.BytesIO(b"data")
    upload = UploadFile(filename="file", file=stream)
    form = FormData([("a", "123"), ("a", "456"), ("b", upload)])
    assert "a" in form
    assert "A" not in form
    assert "c" not in form
    assert form["a"] == "456"
    assert form.get("a") == "456"
    assert form.get("nope", default=None) is None
    assert form.getlist("a") == ["123", "456"]
    assert list(form.keys()) == ["a", "b"]
    assert list(form.values()) == ["456", upload]
    assert list(form.items()) == [("a", "456"), ("b", upload)]
    assert len(form) == 2
    assert list(form) == ["a", "b"]
    assert dict(form) == {"a": "456", "b": upload}
    assert (repr(form) == "FormData([('a', '123'), ('a', '456'), ('b', " +
            repr(upload) + ")])")
    assert FormData(form) == form
    assert FormData({
        "a": "123",
        "b": "789"
    }) == FormData([("a", "123"), ("b", "789")])
    assert FormData({"a": "123", "b": "789"}) != {"a": "123", "b": "789"}
Example #2
0
    def validate_file(cls, v):
        temp = []
        mime = MimeTypes()

        for file in v:
            if isinstance(file, str):

                if os.path.isfile(file) and os.access(
                        file, os.R_OK) and validate_path(file):
                    mime_type = mime.guess_type(file)
                    with open(file, mode="rb") as f:
                        u = UploadFile(f.name,
                                       f.read(),
                                       content_type=mime_type[0])
                        temp.append(u)
                else:
                    raise WrongFile(
                        "incorrect file path for attachment or not readable")
            elif isinstance(file, UploadFile):
                temp.append(file)
            else:
                raise WrongFile(
                    "attachments field type incorrect, must be UploadFile or path"
                )
        return temp
Example #3
0
async def test_upload_file_file_input():
    """Test passing file/stream into the UploadFile constructor"""
    stream = io.BytesIO(b"data")
    file = UploadFile(filename="file", file=stream)
    assert await file.read() == b"data"
    await file.write(b" and more data!")
    assert await file.read() == b""
    await file.seek(0)
    assert await file.read() == b"data and more data!"
Example #4
0
async def test_resolver_args_filter_handles_uploaded_files_from_asgi(mocker):
    def arg_filter(args, _):
        return args

    file_size = 1024 * 1024
    extension = OpenTracingExtension(arg_filter=arg_filter)
    file_ = UploadFile(filename="test", content_type="text/plain")
    await file_.write(b"\0" * file_size)
    kwargs = {"0": file_}
    info = mocker.Mock()

    copied_kwargs = extension.filter_resolver_args(kwargs, info)
    assert (
        f"<class 'starlette.datastructures.UploadFile'>"
        f"(mime_type={file_.content_type}, size={file_size}, filename={file_.filename})"
    ) == copied_kwargs["0"]
Example #5
0
    async def parse(self) -> FormData:
        # Parse the Content-Type header to get the multipart boundary.
        content_type, params = parse_options_header(
            self.headers["Content-Type"])
        charset = params.get(b"charset", "utf-8")
        if type(charset) == bytes:
            charset = charset.decode("latin-1")
        boundary = params[b"boundary"]

        # Callbacks dictionary.
        callbacks = {
            "on_part_begin": self.on_part_begin,
            "on_part_data": self.on_part_data,
            "on_part_end": self.on_part_end,
            "on_header_field": self.on_header_field,
            "on_header_value": self.on_header_value,
            "on_header_end": self.on_header_end,
            "on_headers_finished": self.on_headers_finished,
            "on_end": self.on_end,
        }

        # Create the parser.
        parser = multipart.MultipartParser(boundary, callbacks)
        header_field = b""
        header_value = b""
        content_disposition = None
        content_type = b""
        field_name = ""
        data = b""
        file: typing.Optional[UploadFile] = None

        items: typing.List[typing.Tuple[str, typing.Union[str,
                                                          UploadFile]]] = []
        item_headers: typing.List[typing.Tuple[bytes, bytes]] = []

        # Feed the parser with data from the request.
        async for chunk in self.stream:
            parser.write(chunk)
            messages = list(self.messages)
            self.messages.clear()
            for message_type, message_bytes in messages:
                if message_type == MultiPartMessage.PART_BEGIN:
                    content_disposition = None
                    content_type = b""
                    data = b""
                    item_headers = []
                elif message_type == MultiPartMessage.HEADER_FIELD:
                    header_field += message_bytes
                elif message_type == MultiPartMessage.HEADER_VALUE:
                    header_value += message_bytes
                elif message_type == MultiPartMessage.HEADER_END:
                    field = header_field.lower()
                    if field == b"content-disposition":
                        content_disposition = header_value
                    elif field == b"content-type":
                        content_type = header_value
                    item_headers.append((field, header_value))
                    header_field = b""
                    header_value = b""
                elif message_type == MultiPartMessage.HEADERS_FINISHED:
                    disposition, options = parse_options_header(
                        content_disposition)
                    field_name = _user_safe_decode(options[b"name"], charset)
                    if b"filename" in options:
                        filename = _user_safe_decode(options[b"filename"],
                                                     charset)
                        file = UploadFile(
                            filename=filename,
                            content_type=content_type.decode("latin-1"),
                            headers=Headers(raw=item_headers),
                        )
                    else:
                        file = None
                elif message_type == MultiPartMessage.PART_DATA:
                    if file is None:
                        data += message_bytes
                    else:
                        await file.write(message_bytes)
                elif message_type == MultiPartMessage.PART_END:
                    if file is None:
                        items.append(
                            (field_name, _user_safe_decode(data, charset)))
                    else:
                        await file.seek(0)
                        items.append((field_name, file))

        parser.finalize()
        return FormData(items)
    async def parse(self) -> FormData:
        # Parse the Content-Type header to get the multipart boundary.
        content_type, params = parse_options_header(self.headers["Content-Type"])
        boundary = params.get(b"boundary")

        # Callbacks dictionary.
        callbacks = {
            "on_part_begin": self.on_part_begin,
            "on_part_data": self.on_part_data,
            "on_part_end": self.on_part_end,
            "on_header_field": self.on_header_field,
            "on_header_value": self.on_header_value,
            "on_header_end": self.on_header_end,
            "on_headers_finished": self.on_headers_finished,
            "on_end": self.on_end,
        }

        # Create the parser.
        parser = multipart.MultipartParser(boundary, callbacks)
        header_field = b""
        header_value = b""
        raw_headers = []  # type: typing.List[typing.Tuple[bytes, bytes]]
        field_name = ""
        data = b""
        file = None  # type: typing.Optional[UploadFile]

        items = (
            []
        )  # type: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]]

        # Feed the parser with data from the request.
        async for chunk in self.stream:
            parser.write(chunk)
            messages = list(self.messages)
            self.messages.clear()
            for message_type, message_bytes in messages:
                if message_type == MultiPartMessage.PART_BEGIN:
                    raw_headers = []
                    data = b""
                elif message_type == MultiPartMessage.HEADER_FIELD:
                    header_field += message_bytes
                elif message_type == MultiPartMessage.HEADER_VALUE:
                    header_value += message_bytes
                elif message_type == MultiPartMessage.HEADER_END:
                    raw_headers.append((header_field.lower(), header_value))
                    header_field = b""
                    header_value = b""
                elif message_type == MultiPartMessage.HEADERS_FINISHED:
                    headers = Headers(raw=raw_headers)
                    content_disposition = headers.get("Content-Disposition")
                    content_type = headers.get("Content-Type", "")
                    disposition, options = parse_options_header(content_disposition)
                    field_name = options[b"name"].decode("latin-1")
                    if b"filename" in options:
                        filename = options[b"filename"].decode("latin-1")
                        file = UploadFile(filename=filename, content_type=content_type)
                    else:
                        file = None
                elif message_type == MultiPartMessage.PART_DATA:
                    if file is None:
                        data += message_bytes
                    else:
                        await file.write(message_bytes)
                elif message_type == MultiPartMessage.PART_END:
                    if file is None:
                        items.append((field_name, data.decode("latin-1")))
                    else:
                        await file.seek(0)
                        items.append((field_name, file))
                elif message_type == MultiPartMessage.END:
                    pass

        parser.finalize()
        return FormData(items)