Пример #1
0
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for UploadedFileManager"""

import unittest

from streamlit.stats import CacheStat
from streamlit.uploaded_file_manager import UploadedFileManager
from streamlit.uploaded_file_manager import UploadedFileRec

FILE_1 = UploadedFileRec(id=0, name="file1", type="type", data=b"file1")
FILE_2 = UploadedFileRec(id=0, name="file2", type="type", data=b"file222")


class UploadedFileManagerTest(unittest.TestCase):
    def setUp(self):
        self.mgr = UploadedFileManager()
        self.filemgr_events = []
        self.mgr.on_files_updated.connect(self._on_files_updated)

    def _on_files_updated(self, file_list, **kwargs):
        self.filemgr_events.append(file_list)

    def test_added_file_id(self):
        """An added file should have a unique ID."""
        f1 = self.mgr.add_file("session", "widget", FILE_1)
    def post(self, **kwargs):
        args: Dict[str, List[bytes]] = {}
        files: Dict[str, List[Any]] = {}

        tornado.httputil.parse_body_arguments(
            content_type=self.request.headers["Content-Type"],
            body=self.request.body,
            arguments=args,
            files=files,
        )

        try:
            session_id = self._require_arg(args, "sessionId")
            widget_id = self._require_arg(args, "widgetId")
            if not self._validate_request(session_id):
                raise Exception("Session '%s' invalid" % session_id)

        except Exception as e:
            self.send_error(400, reason=str(e))
            return

        LOGGER.debug(
            f"{len(files)} file(s) received for session {session_id} widget {widget_id}"
        )

        # Create an UploadedFile object for each file.
        uploaded_files: List[UploadedFileRec] = []
        for id, flist in files.items():
            for file in flist:
                uploaded_files.append(
                    UploadedFileRec(
                        id=id,
                        name=file["filename"],
                        type=file["content_type"],
                        data=file["body"],
                    ))

        if len(uploaded_files) == 0:
            self.send_error(400, reason="Expected at least 1 file, but got 0")
            return
        replace = self.get_argument("replace", "false")

        try:
            total_files = int(self._require_arg(args, "totalFiles"))
        except Exception as e:
            total_files = 1

        self._file_mgr.update_file_count(
            session_id=session_id,
            widget_id=widget_id,
            file_count=total_files,
        )

        update_files = (self._file_mgr.replace_files
                        if replace == "true" else self._file_mgr.add_files)
        update_files(
            session_id=session_id,
            widget_id=widget_id,
            files=uploaded_files,
        )

        LOGGER.debug(
            f"{len(files)} file(s) uploaded for session {session_id} widget {widget_id}. replace {replace}"
        )

        self.set_status(200)
Пример #3
0
class HashTest(unittest.TestCase):
    def test_string(self):
        self.assertEqual(get_hash("hello"), get_hash("hello"))
        self.assertNotEqual(get_hash("hello"), get_hash("hellö"))

    def test_int(self):
        self.assertEqual(get_hash(145757624235), get_hash(145757624235))
        self.assertNotEqual(get_hash(10), get_hash(11))
        self.assertNotEqual(get_hash(-1), get_hash(1))
        self.assertNotEqual(get_hash(2**7), get_hash(2**7 - 1))
        self.assertNotEqual(get_hash(2**7), get_hash(2**7 + 1))

    def test_mocks_do_not_result_in_infinite_recursion(self):
        try:
            get_hash(Mock())
            get_hash(MagicMock())
        except InternalHashError:
            self.fail("get_hash raised InternalHashError")

    def test_list(self):
        self.assertEqual(get_hash([1, 2]), get_hash([1, 2]))
        self.assertNotEqual(get_hash([1, 2]), get_hash([2, 2]))
        self.assertNotEqual(get_hash([1]), get_hash(1))

        # test that we can hash self-referencing lists
        a = [1, 2, 3]
        a.append(a)
        b = [1, 2, 3]
        b.append(b)
        self.assertEqual(get_hash(a), get_hash(b))

    def test_recursive_hash_func(self):
        def hash_int(x):
            return x

        @st.cache(hash_funcs={int: hash_int})
        def foo(x):
            return x

        self.assertEqual(foo(1), foo(1))
        # Note: We're able to break the recursive cycle caused by the identity
        # hash func but it causes all cycles to hash to the same thing.
        # https://github.com/streamlit/streamlit/issues/1659
        # self.assertNotEqual(foo(2), foo(1))

    def test_tuple(self):
        self.assertEqual(get_hash((1, 2)), get_hash((1, 2)))
        self.assertNotEqual(get_hash((1, 2)), get_hash((2, 2)))
        self.assertNotEqual(get_hash((1, )), get_hash(1))
        self.assertNotEqual(get_hash((1, )), get_hash([1]))

    def test_mappingproxy(self):
        a = types.MappingProxyType({"a": 1})
        b = types.MappingProxyType({"a": 1})
        c = types.MappingProxyType({"c": 1})

        self.assertEqual(get_hash(a), get_hash(b))
        self.assertNotEqual(get_hash(a), get_hash(c))

    def test_dict_items(self):
        a = types.MappingProxyType({"a": 1}).items()
        b = types.MappingProxyType({"a": 1}).items()
        c = types.MappingProxyType({"c": 1}).items()

        assert is_type(a, "builtins.dict_items")
        self.assertEqual(get_hash(a), get_hash(b))
        self.assertNotEqual(get_hash(a), get_hash(c))

    def test_getset_descriptor(self):
        class A:
            x = 1

        class B:
            x = 1

        a = A.__dict__["__dict__"]
        b = B.__dict__["__dict__"]
        assert is_type(a, "builtins.getset_descriptor")

        self.assertEqual(get_hash(a), get_hash(a))
        self.assertNotEqual(get_hash(a), get_hash(b))

    def test_dict(self):
        dict_gen = {1: (x for x in range(1))}

        self.assertEqual(get_hash({1: 1}), get_hash({1: 1}))
        self.assertNotEqual(get_hash({1: 1}), get_hash({1: 2}))
        self.assertNotEqual(get_hash({1: 1}), get_hash([(1, 1)]))

        with self.assertRaises(UnhashableTypeError):
            get_hash(dict_gen)
        get_hash(dict_gen, hash_funcs={types.GeneratorType: id})

    def test_self_reference_dict(self):
        d1 = {"cat": "hat"}
        d2 = {"things": [1, 2]}

        self.assertEqual(get_hash(d1), get_hash(d1))
        self.assertNotEqual(get_hash(d1), get_hash(d2))

        # test that we can hash self-referencing dictionaries
        d2 = {"book": d1}
        self.assertNotEqual(get_hash(d2), get_hash(d1))

    def test_reduce_(self):
        class A(object):
            def __init__(self):
                self.x = [1, 2, 3]

        class B(object):
            def __init__(self):
                self.x = [1, 2, 3]

        class C(object):
            def __init__(self):
                self.x = (x for x in range(1))

        self.assertEqual(get_hash(A()), get_hash(A()))
        self.assertNotEqual(get_hash(A()), get_hash(B()))
        self.assertNotEqual(get_hash(A()), get_hash(A().__reduce__()))

        with self.assertRaises(UnhashableTypeError):
            get_hash(C())
        get_hash(C(), hash_funcs={types.GeneratorType: id})

    def test_generator(self):
        with self.assertRaises(UnhashableTypeError):
            get_hash((x for x in range(1)))

    def test_hashing_broken_code(self):
        import datetime

        def a():
            return datetime.strptime("%H")

        def b():
            x = datetime.strptime("%H")
            ""
            ""
            return x

        data = [
            (a, '```\nreturn datetime.strptime("%H")\n```'),
            (b, '```\nx = datetime.strptime("%H")\n""\n""\n```'),
        ]

        for func, code_msg in data:
            exc_msg = "module 'datetime' has no attribute 'strptime'"

            with self.assertRaises(UserHashError) as ctx:
                get_hash(func)

            exc = str(ctx.exception)
            self.assertEqual(exc.find(exc_msg) >= 0, True)
            self.assertNotEqual(
                re.search(r"a bug in `.+` near line `\d+`", exc), None)
            self.assertEqual(exc.find(code_msg) >= 0, True)

    def test_hash_funcs_acceptable_keys(self):
        class C(object):
            def __init__(self):
                self.x = (x for x in range(1))

        with self.assertRaises(UnhashableTypeError):
            get_hash(C())

        self.assertEqual(
            get_hash(C(), hash_funcs={types.GeneratorType: id}),
            get_hash(C(), hash_funcs={"builtins.generator": id}),
        )

    def test_hash_funcs_error(self):
        with self.assertRaises(UserHashError):
            get_hash(1, hash_funcs={int: lambda x: "a" + x})

    def test_internal_hashing_error(self):
        def side_effect(i):
            if i == 123456789:
                return "a" + 1
            return i.to_bytes((i.bit_length() + 8) // 8, "little", signed=True)

        with self.assertRaises(InternalHashError):
            with patch("streamlit.hashing._int_to_bytes",
                       side_effect=side_effect):
                get_hash(123456789)

    def test_float(self):
        self.assertEqual(get_hash(0.1), get_hash(0.1))
        self.assertNotEqual(get_hash(23.5234), get_hash(23.5235))

    def test_bool(self):
        self.assertEqual(get_hash(True), get_hash(True))
        self.assertNotEqual(get_hash(True), get_hash(False))

    def test_none(self):
        self.assertEqual(get_hash(None), get_hash(None))
        self.assertNotEqual(get_hash(None), get_hash(False))

    def test_builtins(self):
        self.assertEqual(get_hash(abs), get_hash(abs))
        self.assertNotEqual(get_hash(abs), get_hash(type))

    def test_regex(self):
        p2 = re.compile(".*")
        p1 = re.compile(".*")
        p3 = re.compile(".*", re.I)
        self.assertEqual(get_hash(p1), get_hash(p2))
        self.assertNotEqual(get_hash(p1), get_hash(p3))

    def test_pandas_dataframe(self):
        df1 = pd.DataFrame({"foo": [12]})
        df2 = pd.DataFrame({"foo": [42]})
        df3 = pd.DataFrame({"foo": [12]})

        self.assertEqual(get_hash(df1), get_hash(df3))
        self.assertNotEqual(get_hash(df1), get_hash(df2))

        df4 = pd.DataFrame(np.zeros((_PANDAS_ROWS_LARGE, 4)),
                           columns=list("ABCD"))
        df5 = pd.DataFrame(np.zeros((_PANDAS_ROWS_LARGE, 4)),
                           columns=list("ABCD"))

        self.assertEqual(get_hash(df4), get_hash(df5))

    def test_pandas_series(self):
        series1 = pd.Series([1, 2])
        series2 = pd.Series([1, 3])
        series3 = pd.Series([1, 2])

        self.assertEqual(get_hash(series1), get_hash(series3))
        self.assertNotEqual(get_hash(series1), get_hash(series2))

        series4 = pd.Series(range(_PANDAS_ROWS_LARGE))
        series5 = pd.Series(range(_PANDAS_ROWS_LARGE))

        self.assertEqual(get_hash(series4), get_hash(series5))

    def test_numpy(self):
        np1 = np.zeros(10)
        np2 = np.zeros(11)
        np3 = np.zeros(10)

        self.assertEqual(get_hash(np1), get_hash(np3))
        self.assertNotEqual(get_hash(np1), get_hash(np2))

        np4 = np.zeros(_NP_SIZE_LARGE)
        np5 = np.zeros(_NP_SIZE_LARGE)

        self.assertEqual(get_hash(np4), get_hash(np5))

    @parameterized.expand([
        (BytesIO, b"123", b"456", b"123"),
        (StringIO, "123", "456", "123"),
        (
            UploadedFile,
            UploadedFileRec("id", "name", "type", b"123"),
            UploadedFileRec("id", "name", "type", b"456"),
            UploadedFileRec("id", "name", "type", b"123"),
        ),
    ])
    def test_io(self, io_type, io_data1, io_data2, io_data3):
        io1 = io_type(io_data1)
        io2 = io_type(io_data2)
        io3 = io_type(io_data3)

        self.assertEqual(get_hash(io1), get_hash(io3))
        self.assertNotEqual(get_hash(io1), get_hash(io2))

        # Changing the stream position should change the hash
        io1.seek(1)
        io3.seek(0)
        self.assertNotEqual(get_hash(io1), get_hash(io3))

    def test_partial(self):
        p1 = functools.partial(int, base=2)
        p2 = functools.partial(int, base=3)
        p3 = functools.partial(int, base=2)

        self.assertEqual(get_hash(p1), get_hash(p3))
        self.assertNotEqual(get_hash(p1), get_hash(p2))

    def test_lambdas(self):
        # self.assertEqual(get_hash(lambda x: x.lower()), get_hash(lambda x: x.lower()))
        self.assertNotEqual(get_hash(lambda x: x.lower()),
                            get_hash(lambda x: x.upper()))

    def test_files(self):
        temp1 = tempfile.NamedTemporaryFile()
        temp2 = tempfile.NamedTemporaryFile()

        with open(__file__, "r") as f:
            with open(__file__, "r") as g:
                self.assertEqual(get_hash(f), get_hash(g))

            self.assertNotEqual(get_hash(f), get_hash(temp1))

        self.assertEqual(get_hash(temp1), get_hash(temp1))
        self.assertNotEqual(get_hash(temp1), get_hash(temp2))

    def test_file_position(self):
        with open(__file__, "r") as f:
            h1 = get_hash(f)
            self.assertEqual(h1, get_hash(f))
            f.readline()
            self.assertNotEqual(h1, get_hash(f))
            f.seek(0)
            self.assertEqual(h1, get_hash(f))

    def test_keras_model(self):
        a = keras.applications.vgg16.VGG16(include_top=False, weights=None)
        b = keras.applications.vgg16.VGG16(include_top=False, weights=None)

        # This test still passes if we remove the default hash func for Keras
        # models. Ideally we'd seed the weights before creating the models
        # but not clear how to do so.
        self.assertEqual(get_hash(a), get_hash(a))
        self.assertNotEqual(get_hash(a), get_hash(b))

    def test_tf_keras_model(self):
        a = tf.keras.applications.vgg16.VGG16(include_top=False, weights=None)
        b = tf.keras.applications.vgg16.VGG16(include_top=False, weights=None)

        self.assertEqual(get_hash(a), get_hash(a))
        self.assertNotEqual(get_hash(a), get_hash(b))

    def test_tf_saved_model(self):
        tempdir = tempfile.TemporaryDirectory()

        model = tf.keras.models.Sequential([
            tf.keras.layers.Dense(512, activation="relu", input_shape=(784, )),
        ])
        model.save(tempdir.name)

        a = tf.saved_model.load(tempdir.name)
        b = tf.saved_model.load(tempdir.name)

        self.assertEqual(get_hash(a), get_hash(a))
        self.assertNotEqual(get_hash(a), get_hash(b))

    def test_pytorch_model(self):
        a = torchvision.models.resnet.resnet18()
        b = torchvision.models.resnet.resnet18()

        self.assertEqual(get_hash(a), get_hash(a))
        self.assertNotEqual(get_hash(a), get_hash(b))

    def test_socket(self):
        a = socket.socket()
        b = socket.socket()

        self.assertEqual(get_hash(a), get_hash(a))
        self.assertNotEqual(get_hash(a), get_hash(b))

    def test_magic_mock(self):
        """Test that MagicMocks never hash to the same thing."""
        # (This also tests that MagicMock can hash at all, without blowing the
        # stack due to an infinite recursion.)
        self.assertNotEqual(get_hash(MagicMock()), get_hash(MagicMock()))

    def test_tensorflow_session(self):
        tf_config = tf.compat.v1.ConfigProto()
        tf_session = tf.compat.v1.Session(config=tf_config)
        self.assertEqual(get_hash(tf_session), get_hash(tf_session))

        tf_session2 = tf.compat.v1.Session(config=tf_config)
        self.assertNotEqual(get_hash(tf_session), get_hash(tf_session2))

    def test_torch_c_tensorbase(self):
        a = torch.ones([1, 1]).__reduce__()[1][2]
        b = torch.ones([1, 1], requires_grad=True).__reduce__()[1][2]
        c = torch.ones([1, 2]).__reduce__()[1][2]

        assert is_type(a, "torch._C._TensorBase")
        self.assertEqual(get_hash(a), get_hash(b))
        self.assertNotEqual(get_hash(a), get_hash(c))

        b.mean().backward()
        # Calling backward on a tensorbase doesn't seem to affect the gradient
        self.assertEqual(get_hash(a), get_hash(b))

    def test_torch_tensor(self):
        a = torch.ones([1, 1])
        b = torch.ones([1, 1], requires_grad=True)
        c = torch.ones([1, 2])

        self.assertEqual(get_hash(a), get_hash(b))
        self.assertNotEqual(get_hash(a), get_hash(c))

        b.mean().backward()

        self.assertNotEqual(get_hash(a), get_hash(b))

    def test_non_hashable(self):
        """Test user provided hash functions."""

        g = (x for x in range(1))

        # Unhashable object raises an error
        with self.assertRaises(UnhashableTypeError):
            get_hash(g)

        id_hash_func = {types.GeneratorType: id}

        self.assertEqual(
            get_hash(g, hash_funcs=id_hash_func),
            get_hash(g, hash_funcs=id_hash_func),
        )

        unique_hash_func = {types.GeneratorType: lambda x: time.time()}

        self.assertNotEqual(
            get_hash(g, hash_funcs=unique_hash_func),
            get_hash(g, hash_funcs=unique_hash_func),
        )

    def test_override_streamlit_hash_func(self):
        """Test that a user provided hash function has priority over a streamlit one."""

        hash_funcs = {int: lambda x: "hello"}
        self.assertNotEqual(get_hash(1), get_hash(1, hash_funcs=hash_funcs))

    def _build_cffi(self, name):
        ffibuilder = cffi.FFI()
        ffibuilder.set_source(
            "cffi_bin._%s" % name,
            r"""
                static int %s(int x)
                {
                    return x + "A";
                }
            """ % name,
        )

        ffibuilder.cdef("int %s(int);" % name)
        ffibuilder.compile(verbose=True)

    def test_compiled_ffi(self):
        self._build_cffi("foo")
        self._build_cffi("bar")
        from cffi_bin._foo import ffi as foo
        from cffi_bin._bar import ffi as bar

        # Note: We've verified that all properties on CompiledFFI objects
        # are global, except have not verified `error` either way.
        self.assertIn(get_fqn_type(foo), _FFI_TYPE_NAMES)
        self.assertEqual(get_hash(foo), get_hash(bar))

    def test_sqlite_sqlalchemy_engine(self):
        """Separate tests for sqlite since it uses a file based
        and in memory database and has no auth
        """

        mem = "sqlite://"
        foo = "sqlite:///foo.db"

        self.assertEqual(hash_engine(mem), hash_engine(mem))
        self.assertEqual(hash_engine(foo), hash_engine(foo))
        self.assertNotEqual(hash_engine(foo), hash_engine("sqlite:///bar.db"))
        self.assertNotEqual(hash_engine(foo), hash_engine(mem))

        # Need to use absolute paths otherwise one path resolves
        # relatively and the other absolute
        self.assertEqual(
            hash_engine("sqlite:////foo.db", connect_args={"uri": True}),
            hash_engine("sqlite:////foo.db?uri=true"),
        )

        self.assertNotEqual(
            hash_engine(foo, connect_args={"uri": True}),
            hash_engine(foo, connect_args={"uri": False}),
        )

        self.assertNotEqual(
            hash_engine(foo, creator=lambda: False),
            hash_engine(foo, creator=lambda: True),
        )

    def test_mssql_sqlalchemy_engine(self):
        """Specialized tests for mssql since it uses a different way of
        passing connection arguments to the engine
        """

        url = "mssql:///?odbc_connect"
        auth_url = "mssql://*****:*****@localhost/db"

        params_foo = urllib.parse.quote_plus(
            "Server=localhost;Database=db;UID=foo;PWD=pass")
        params_bar = urllib.parse.quote_plus(
            "Server=localhost;Database=db;UID=bar;PWD=pass")
        params_foo_caps = urllib.parse.quote_plus(
            "SERVER=localhost;Database=db;UID=foo;PWD=pass")
        params_foo_order = urllib.parse.quote_plus(
            "Database=db;Server=localhost;UID=foo;PWD=pass")

        self.assertEqual(
            hash_engine(auth_url),
            hash_engine("%s=%s" % (url, params_foo)),
        )
        self.assertNotEqual(
            hash_engine("%s=%s" % (url, params_foo)),
            hash_engine("%s=%s" % (url, params_bar)),
        )

        # Note: False negative because the ordering of the keys affects
        # the hash
        self.assertNotEqual(
            hash_engine("%s=%s" % (url, params_foo)),
            hash_engine("%s=%s" % (url, params_foo_order)),
        )

        # Note: False negative because the keys are case insensitive
        self.assertNotEqual(
            hash_engine("%s=%s" % (url, params_foo)),
            hash_engine("%s=%s" % (url, params_foo_caps)),
        )

        # Note: False negative because `connect_args` doesn't affect the
        # connection string
        self.assertNotEqual(
            hash_engine(url, connect_args={"user": "******"}),
            hash_engine(url, connect_args={"user": "******"}),
        )

    @parameterized.expand([
        ("postgresql", "password"),
        ("mysql", "passwd"),
        ("oracle", "password"),
        ("mssql", "password"),
    ])
    def test_sqlalchemy_engine(self, dialect, password_key):
        def connect():
            pass

        url = "%s://localhost/db" % dialect
        auth_url = "%s://user:pass@localhost/db" % dialect

        self.assertEqual(hash_engine(url), hash_engine(url))
        self.assertEqual(
            hash_engine(auth_url, creator=connect),
            hash_engine(auth_url, creator=connect),
        )

        # Note: Hashing an engine with a creator can only be equal to the hash of another
        # engine with a creator, even if the underlying connection arguments are the same
        self.assertNotEqual(hash_engine(url), hash_engine(url,
                                                          creator=connect))

        self.assertNotEqual(hash_engine(url), hash_engine(auth_url))
        self.assertNotEqual(hash_engine(url, encoding="utf-8"),
                            hash_engine(url, encoding="ascii"))
        self.assertNotEqual(hash_engine(url, creator=connect),
                            hash_engine(url, creator=lambda: True))

        # mssql doesn't use `connect_args`
        if dialect != "mssql":
            self.assertEqual(
                hash_engine(auth_url),
                hash_engine(url,
                            connect_args={
                                "user": "******",
                                password_key: "pass"
                            }),
            )

            self.assertNotEqual(
                hash_engine(url, connect_args={"user": "******"}),
                hash_engine(url, connect_args={"user": "******"}),
            )
Пример #4
0
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for UploadedFileManager"""

import unittest

from streamlit.uploaded_file_manager import UploadedFileManager
from streamlit.uploaded_file_manager import UploadedFileRec

file1 = UploadedFileRec(id="id1", name="file1", type="type", data=b"file1")
file2 = UploadedFileRec(id="id2", name="file2", type="type", data=b"file2")


class UploadedFileManagerTest(unittest.TestCase):
    def setUp(self):
        self.mgr = UploadedFileManager()
        self.filemgr_events = []
        self.mgr.on_files_updated.connect(self._on_files_updated)

    def _on_files_updated(self, file_list, **kwargs):
        self.filemgr_events.append(file_list)

    def test_add_file(self):
        self.assertIsNone(self.mgr.get_files("non-report", "non-widget"))