Esempio n. 1
0
    def test_unique_uploaded_file_instance(self, get_file_recs_patch):
        """We should get a unique UploadedFile instance each time we access
        the file_uploader widget."""

        # Patch UploadFileManager to return two files
        file_recs = [
            UploadedFileRec(1, "file1", "type", b"123"),
            UploadedFileRec(2, "file2", "type", b"456"),
        ]

        get_file_recs_patch.return_value = file_recs

        # These file_uploaders have different labels so that we don't cause
        # a DuplicateKey error - but because we're patching the get_files
        # function, both file_uploaders will refer to the same files.
        file1: UploadedFile = st.file_uploader("a",
                                               accept_multiple_files=False)
        file2: UploadedFile = st.file_uploader("b",
                                               accept_multiple_files=False)

        self.assertNotEqual(file1, file2)

        # Seeking in one instance should not impact the position in the other.
        file1.seek(2)
        self.assertEqual(b"3", file1.read())
        self.assertEqual(b"123", file2.read())
Esempio n. 2
0
    def test_remove_orphaned_files(self, get_file_recs_patch,
                                   remove_orphaned_files_patch):
        """When file_uploader is accessed, it should call
        UploadedFileManager.remove_orphaned_files.
        """
        ctx = get_report_ctx()
        ctx.uploaded_file_mgr._file_id_counter = 101

        file_recs = [
            UploadedFileRec(1, "file1", "type", b"123"),
            UploadedFileRec(2, "file2", "type", b"456"),
        ]
        get_file_recs_patch.return_value = file_recs

        st.file_uploader("foo", accept_multiple_files=True)

        args, kwargs = remove_orphaned_files_patch.call_args
        self.assertEqual(len(args), 0)
        self.assertEqual(kwargs["session_id"], "test session id")
        self.assertEqual(kwargs["newest_file_id"], 100)
        self.assertEqual(kwargs["active_file_ids"], [1, 2])

        # Patch _get_file_recs to return [] instead. remove_orphaned_files
        # should not be called when file_uploader is accessed.
        get_file_recs_patch.return_value = []
        remove_orphaned_files_patch.reset_mock()

        st.file_uploader("foo")
        remove_orphaned_files_patch.assert_not_called()
    def test_delete_file(self):
        """File should be able to be deleted successfully"""
        file1 = UploadedFileRec("1234", "name", "type", b"1234")
        file2 = UploadedFileRec("4567", "name", "type", b"1234")

        self.file_mgr.add_files("session1", "widget1", [file1])
        self.file_mgr.add_files("session2", "widget2", [file2])

        response = self.fetch(f"/upload_file/session1/widget1/1234", method="DELETE")
        self.assertEqual(404, response.code)
    def test_delete_file_across_sessions(self):
        """Deleting file param mismatch should fail with 404 status."""
        file1 = UploadedFileRec("1234", "name", "type", b"1234")
        file2 = UploadedFileRec("4567", "name", "type", b"1234")

        self.file_mgr.add_files("session1", "widget1", [file1])
        self.file_mgr.add_files("session2", "widget2", [file2])

        response = self.fetch(f"/upload_file/session2/widget1/1234", method="DELETE")
        self.assertEqual(404, response.code)
        self.assertTrue(len(self.file_mgr.get_files("session1", "widget1")))
        self.assertTrue(len(self.file_mgr.get_files("session2", "widget2")))
Esempio n. 5
0
    def post(self, **kwargs):
        """Receive 1 or more uploaded files and add them to our
        UploadedFileManager.
        """
        args: Dict[str, List[bytes]] = {}
        files: Dict[str, List[Any]] = {}

        tornado.httputil.parse_body_arguments(
            content_type=self.request.headers["Content-Type"],
            body=self.request.body,
            arguments=args,
            files=files,
        )

        try:
            session_id = self._require_arg(args, "sessionId")
            widget_id = self._require_arg(args, "widgetId")
            if not self._is_valid_session_id(session_id):
                raise Exception(f"Invalid session_id: '{session_id}'")

        except Exception as e:
            self.send_error(400, reason=str(e))
            return

        LOGGER.debug(
            f"{len(files)} file(s) received for session {session_id} widget {widget_id}"
        )

        # Create an UploadedFile object for each file.
        uploaded_files: List[UploadedFileRec] = []
        for id, flist in files.items():
            for file in flist:
                uploaded_files.append(
                    UploadedFileRec(
                        id=id,
                        name=file["filename"],
                        type=file["content_type"],
                        data=file["body"],
                    )
                )

        if len(uploaded_files) == 0:
            self.send_error(400, reason="Expected at least 1 file, but got 0")
            return

        replace = self.get_argument("replace", "false") == "true"
        if replace:
            self._file_mgr.replace_files(
                session_id=session_id, widget_id=widget_id, files=uploaded_files
            )
        else:
            self._file_mgr.add_files(
                session_id=session_id, widget_id=widget_id, files=uploaded_files
            )

        LOGGER.debug(
            f"{len(files)} file(s) uploaded for session {session_id} widget {widget_id}. replace {replace}"
        )

        self.set_status(200)
Esempio n. 6
0
    def test_multiple_uploaded_file_triggers_one_rerun(self):
        """Uploading a file should trigger a re-run in the associated
        ReportSession."""
        with self._patch_report_session():
            yield self.start_server_loop()

            # Connect twice and get associated ReportSessions
            yield self.ws_connect()
            yield self.ws_connect()
            session_info1 = list(self.server._session_info_by_id.values())[0]
            session_info2 = list(self.server._session_info_by_id.values())[1]

            file = UploadedFileRec("id", "file.txt", "type", b"123")

            # "Upload 2 files" for Session1
            self.server._uploaded_file_mgr.update_file_count(
                session_id=session_info1.session.id,
                widget_id="widget_id",
                file_count=2,
            )
            self.server._uploaded_file_mgr.add_files(
                session_id=session_info1.session.id,
                widget_id="widget_id",
                files=[file, file],
            )

            self.assertEqual(
                self.server._uploaded_file_mgr.get_files(
                    session_info1.session.id, "widget_id"),
                [file, file],
            )

            # Session1 should have a rerun request; Session2 should not
            session_info1.session.request_rerun.assert_called_once()
            session_info2.session.request_rerun.assert_not_called()
Esempio n. 7
0
    def test_file_uploader_serde(self, get_file_recs_patch):
        file_recs = [
            UploadedFileRec(1, "file1", "type", b"123"),
        ]
        get_file_recs_patch.return_value = file_recs

        uploaded_file = st.file_uploader("file_uploader", key="file_uploader")
        check_roundtrip("file_uploader", uploaded_file)
Esempio n. 8
0
    def post(self, **kwargs):
        """Receive an uploaded file and add it to our UploadedFileManager.
        Return the file's ID, so that the client can refer to it."""
        args: Dict[str, List[bytes]] = {}
        files: Dict[str, List[Any]] = {}

        tornado.httputil.parse_body_arguments(
            content_type=self.request.headers["Content-Type"],
            body=self.request.body,
            arguments=args,
            files=files,
        )

        try:
            session_id = self._require_arg(args, "sessionId")
            widget_id = self._require_arg(args, "widgetId")
            if not self._is_valid_session_id(session_id):
                raise Exception(f"Invalid session_id: '{session_id}'")

        except Exception as e:
            self.send_error(400, reason=str(e))
            return

        LOGGER.debug(
            f"{len(files)} file(s) received for session {session_id} widget {widget_id}"
        )

        # Create an UploadedFile object for each file.
        # We assign an initial, invalid file_id to each file in this loop.
        # The file_mgr will assign unique file IDs and return in `add_file`,
        # below.
        uploaded_files: List[UploadedFileRec] = []
        for _, flist in files.items():
            for file in flist:
                uploaded_files.append(
                    UploadedFileRec(
                        id=0,
                        name=file["filename"],
                        type=file["content_type"],
                        data=file["body"],
                    )
                )

        if len(uploaded_files) != 1:
            self.send_error(
                400, reason=f"Expected 1 file, but got {len(uploaded_files)}"
            )
            return

        added_file = self._file_mgr.add_file(
            session_id=session_id, widget_id=widget_id, file=uploaded_files[0]
        )

        # Return the file_id to the client. (The client will parse
        # the string back to an int.)
        self.write(str(added_file.id))
        self.set_status(200)
Esempio n. 9
0
    def test_multiple_files(self, register_widget_patch, get_files_patch):
        """Test the accept_multiple_files flag"""
        # Patch UploadFileManager to return two files
        file_recs = [
            UploadedFileRec(1, "file1", "type", b"123"),
            UploadedFileRec(2, "file2", "type", b"456"),
        ]

        get_files_patch.return_value = file_recs

        # Patch register_widget to return the IDs of our two files
        file_ids = SInt64Array()
        file_ids.data[:] = [rec.id for rec in file_recs]
        register_widget_patch.return_value = file_ids

        for accept_multiple in [True, False]:
            return_val = st.file_uploader(
                "label", type="png", accept_multiple_files=accept_multiple)
            c = self.get_delta_from_queue().new_element.file_uploader
            self.assertEqual(accept_multiple, c.multiple_files)

            # If "accept_multiple_files" is True, then we should get a list of
            # values back. Otherwise, we should just get a single value.

            # Because file_uploader returns unique UploadedFile instances
            # each time it's called, we convert the return value back
            # from UploadedFile -> UploadedFileRec (which implements
            # equals()) to test equality.

            if accept_multiple:
                results = [
                    UploadedFileRec(file.id, file.name, file.type,
                                    file.getvalue()) for file in return_val
                ]
                self.assertEqual(file_recs, results)
            else:
                results = UploadedFileRec(
                    return_val.id,
                    return_val.name,
                    return_val.type,
                    return_val.getvalue(),
                )
                self.assertEqual(file_recs[0], results)
Esempio n. 10
0
    def test_replace_uploaded_file_triggers_one_rerun(self):
        """Uploading a file should trigger a re-run in the associated
        ReportSession."""
        with self._patch_report_session():
            yield self.start_server_loop()

            # Connect twice and get associated ReportSessions
            yield self.ws_connect()
            yield self.ws_connect()
            session_info = list(self.server._session_info_by_id.values())[0]

            file1 = UploadedFileRec("id1", "file1.txt", "type", b"123")
            file2 = UploadedFileRec("id2", "file2.txt", "type", b"456")

            self.server._uploaded_file_mgr.update_file_count(
                session_id=session_info.session.id,
                widget_id="widget_id",
                file_count=1,
            )
            self.server._uploaded_file_mgr.add_files(
                session_id=session_info.session.id,
                widget_id="widget_id",
                files=[file1],
            )

            session_info.session.request_rerun.assert_called_once()

            self.server._uploaded_file_mgr.replace_files(
                session_id=session_info.session.id,
                widget_id="widget_id",
                files=[file2],
            )

            self.assertEqual(
                self.server._uploaded_file_mgr.get_files(
                    session_info.session.id, "widget_id"),
                [file2],
            )
            self.assertEqual(session_info.session.request_rerun.call_count, 2)
Esempio n. 11
0
    def test_orphaned_upload_file_deletion(self):
        """An uploaded file with no associated ReportSession should be
        deleted."""
        with self._patch_report_session():
            yield self.start_server_loop()
            yield self.ws_connect()

            # "Upload a file" for a session that doesn't exist
            self.server._uploaded_file_mgr.add_files(
                session_id="no_such_session",
                widget_id="widget_id",
                files=[UploadedFileRec("id", "file.txt", "type", b"123")],
            )

            self.assertIsNone(
                self.server._uploaded_file_mgr.get_files(
                    "no_such_session", "widget_id"))
Esempio n. 12
0
    def test_orphaned_upload_file_deletion(self):
        """An uploaded file with no associated AppSession should be
        deleted."""
        with patch("streamlit.server.server.LocalSourcesWatcher"
                   ), self._patch_app_session():
            yield self.start_server_loop()
            yield self.ws_connect()

            # "Upload a file" for a session that doesn't exist
            self.server._uploaded_file_mgr.add_file(
                session_id="no_such_session",
                widget_id="widget_id",
                file=UploadedFileRec(0, "file.txt", "type", b"123"),
            )

            self.assertEqual(
                self.server._uploaded_file_mgr.get_all_files(
                    "no_such_session", "widget_id"),
                [],
            )
Esempio n. 13
0
    def test_file_uploader_serde(self, get_file_recs_patch):
        file_recs = [
            UploadedFileRec(1, "file1", "type", b"123"),
        ]
        get_file_recs_patch.return_value = file_recs

        uploaded_file = st.file_uploader("file_uploader", key="file_uploader")

        # We can't use check_roundtrip here as the return_value of a
        # file_uploader widget isn't a primitive value, so comparing them
        # using == checks for reference equality.
        session_state = get_session_state()
        metadata = session_state.get_metadata_by_key("file_uploader")
        serializer = metadata.serializer
        deserializer = metadata.deserializer

        file_after_serde = deserializer(serializer(uploaded_file), "")

        assert uploaded_file.id == file_after_serde.id
        assert uploaded_file.name == file_after_serde.name
        assert uploaded_file.type == file_after_serde.type
        assert uploaded_file.size == file_after_serde.size
        assert uploaded_file.read() == file_after_serde.read()
Esempio n. 14
0
class HashTest(unittest.TestCase):
    def test_string(self):
        self.assertEqual(get_hash("hello"), get_hash("hello"))
        self.assertNotEqual(get_hash("hello"), get_hash("hellö"))

    def test_int(self):
        self.assertEqual(get_hash(145757624235), get_hash(145757624235))
        self.assertNotEqual(get_hash(10), get_hash(11))
        self.assertNotEqual(get_hash(-1), get_hash(1))
        self.assertNotEqual(get_hash(2**7), get_hash(2**7 - 1))
        self.assertNotEqual(get_hash(2**7), get_hash(2**7 + 1))

    def test_mocks_do_not_result_in_infinite_recursion(self):
        try:
            get_hash(Mock())
            get_hash(MagicMock())
        except InternalHashError:
            self.fail("get_hash raised InternalHashError")

    def test_list(self):
        self.assertEqual(get_hash([1, 2]), get_hash([1, 2]))
        self.assertNotEqual(get_hash([1, 2]), get_hash([2, 2]))
        self.assertNotEqual(get_hash([1]), get_hash(1))

        # test that we can hash self-referencing lists
        a = [1, 2, 3]
        a.append(a)
        b = [1, 2, 3]
        b.append(b)
        self.assertEqual(get_hash(a), get_hash(b))

    def test_recursive_hash_func(self):
        def hash_int(x):
            return x

        @st.cache(hash_funcs={int: hash_int})
        def foo(x):
            return x

        self.assertEqual(foo(1), foo(1))
        # Note: We're able to break the recursive cycle caused by the identity
        # hash func but it causes all cycles to hash to the same thing.
        # https://github.com/streamlit/streamlit/issues/1659
        # self.assertNotEqual(foo(2), foo(1))

    def test_tuple(self):
        self.assertEqual(get_hash((1, 2)), get_hash((1, 2)))
        self.assertNotEqual(get_hash((1, 2)), get_hash((2, 2)))
        self.assertNotEqual(get_hash((1, )), get_hash(1))
        self.assertNotEqual(get_hash((1, )), get_hash([1]))

    def test_mappingproxy(self):
        a = types.MappingProxyType({"a": 1})
        b = types.MappingProxyType({"a": 1})
        c = types.MappingProxyType({"c": 1})

        self.assertEqual(get_hash(a), get_hash(b))
        self.assertNotEqual(get_hash(a), get_hash(c))

    def test_dict_items(self):
        a = types.MappingProxyType({"a": 1}).items()
        b = types.MappingProxyType({"a": 1}).items()
        c = types.MappingProxyType({"c": 1}).items()

        assert is_type(a, "builtins.dict_items")
        self.assertEqual(get_hash(a), get_hash(b))
        self.assertNotEqual(get_hash(a), get_hash(c))

    def test_getset_descriptor(self):
        class A:
            x = 1

        class B:
            x = 1

        a = A.__dict__["__dict__"]
        b = B.__dict__["__dict__"]
        assert is_type(a, "builtins.getset_descriptor")

        self.assertEqual(get_hash(a), get_hash(a))
        self.assertNotEqual(get_hash(a), get_hash(b))

    def test_dict(self):
        dict_gen = {1: (x for x in range(1))}

        self.assertEqual(get_hash({1: 1}), get_hash({1: 1}))
        self.assertNotEqual(get_hash({1: 1}), get_hash({1: 2}))
        self.assertNotEqual(get_hash({1: 1}), get_hash([(1, 1)]))

        with self.assertRaises(UnhashableTypeError):
            get_hash(dict_gen)
        get_hash(dict_gen, hash_funcs={types.GeneratorType: id})

    def test_self_reference_dict(self):
        d1 = {"cat": "hat"}
        d2 = {"things": [1, 2]}

        self.assertEqual(get_hash(d1), get_hash(d1))
        self.assertNotEqual(get_hash(d1), get_hash(d2))

        # test that we can hash self-referencing dictionaries
        d2 = {"book": d1}
        self.assertNotEqual(get_hash(d2), get_hash(d1))

    def test_reduce_(self):
        class A(object):
            def __init__(self):
                self.x = [1, 2, 3]

        class B(object):
            def __init__(self):
                self.x = [1, 2, 3]

        class C(object):
            def __init__(self):
                self.x = (x for x in range(1))

        self.assertEqual(get_hash(A()), get_hash(A()))
        self.assertNotEqual(get_hash(A()), get_hash(B()))
        self.assertNotEqual(get_hash(A()), get_hash(A().__reduce__()))

        with self.assertRaises(UnhashableTypeError):
            get_hash(C())
        get_hash(C(), hash_funcs={types.GeneratorType: id})

    def test_generator(self):
        with self.assertRaises(UnhashableTypeError):
            get_hash((x for x in range(1)))

    def test_hashing_broken_code(self):
        import datetime

        def a():
            return datetime.strptime("%H")

        def b():
            x = datetime.strptime("%H")
            ""
            ""
            return x

        data = [
            (a, '```\nreturn datetime.strptime("%H")\n```'),
            (b, '```\nx = datetime.strptime("%H")\n""\n""\n```'),
        ]

        for func, code_msg in data:
            exc_msg = "module 'datetime' has no attribute 'strptime'"

            with self.assertRaises(UserHashError) as ctx:
                get_hash(func)

            exc = str(ctx.exception)
            self.assertEqual(exc.find(exc_msg) >= 0, True)
            self.assertNotEqual(
                re.search(r"a bug in `.+` near line `\d+`", exc), None)
            self.assertEqual(exc.find(code_msg) >= 0, True)

    def test_hash_funcs_acceptable_keys(self):
        class C(object):
            def __init__(self):
                self.x = (x for x in range(1))

        with self.assertRaises(UnhashableTypeError):
            get_hash(C())

        self.assertEqual(
            get_hash(C(), hash_funcs={types.GeneratorType: id}),
            get_hash(C(), hash_funcs={"builtins.generator": id}),
        )

    def test_hash_funcs_error(self):
        with self.assertRaises(UserHashError):
            get_hash(1, hash_funcs={int: lambda x: "a" + x})

    def test_internal_hashing_error(self):
        def side_effect(i):
            if i == 123456789:
                return "a" + 1
            return i.to_bytes((i.bit_length() + 8) // 8, "little", signed=True)

        with self.assertRaises(InternalHashError):
            with patch("streamlit.hashing._int_to_bytes",
                       side_effect=side_effect):
                get_hash(123456789)

    def test_float(self):
        self.assertEqual(get_hash(0.1), get_hash(0.1))
        self.assertNotEqual(get_hash(23.5234), get_hash(23.5235))

    def test_bool(self):
        self.assertEqual(get_hash(True), get_hash(True))
        self.assertNotEqual(get_hash(True), get_hash(False))

    def test_none(self):
        self.assertEqual(get_hash(None), get_hash(None))
        self.assertNotEqual(get_hash(None), get_hash(False))

    def test_builtins(self):
        self.assertEqual(get_hash(abs), get_hash(abs))
        self.assertNotEqual(get_hash(abs), get_hash(type))

    def test_regex(self):
        p2 = re.compile(".*")
        p1 = re.compile(".*")
        p3 = re.compile(".*", re.I)
        self.assertEqual(get_hash(p1), get_hash(p2))
        self.assertNotEqual(get_hash(p1), get_hash(p3))

    def test_pandas_dataframe(self):
        df1 = pd.DataFrame({"foo": [12]})
        df2 = pd.DataFrame({"foo": [42]})
        df3 = pd.DataFrame({"foo": [12]})

        self.assertEqual(get_hash(df1), get_hash(df3))
        self.assertNotEqual(get_hash(df1), get_hash(df2))

        df4 = pd.DataFrame(np.zeros((_PANDAS_ROWS_LARGE, 4)),
                           columns=list("ABCD"))
        df5 = pd.DataFrame(np.zeros((_PANDAS_ROWS_LARGE, 4)),
                           columns=list("ABCD"))

        self.assertEqual(get_hash(df4), get_hash(df5))

    def test_pandas_series(self):
        series1 = pd.Series([1, 2])
        series2 = pd.Series([1, 3])
        series3 = pd.Series([1, 2])

        self.assertEqual(get_hash(series1), get_hash(series3))
        self.assertNotEqual(get_hash(series1), get_hash(series2))

        series4 = pd.Series(range(_PANDAS_ROWS_LARGE))
        series5 = pd.Series(range(_PANDAS_ROWS_LARGE))

        self.assertEqual(get_hash(series4), get_hash(series5))

    def test_numpy(self):
        np1 = np.zeros(10)
        np2 = np.zeros(11)
        np3 = np.zeros(10)

        self.assertEqual(get_hash(np1), get_hash(np3))
        self.assertNotEqual(get_hash(np1), get_hash(np2))

        np4 = np.zeros(_NP_SIZE_LARGE)
        np5 = np.zeros(_NP_SIZE_LARGE)

        self.assertEqual(get_hash(np4), get_hash(np5))

    @parameterized.expand([
        (BytesIO, b"123", b"456", b"123"),
        (StringIO, "123", "456", "123"),
        (
            UploadedFile,
            UploadedFileRec("id", "name", "type", b"123"),
            UploadedFileRec("id", "name", "type", b"456"),
            UploadedFileRec("id", "name", "type", b"123"),
        ),
    ])
    def test_io(self, io_type, io_data1, io_data2, io_data3):
        io1 = io_type(io_data1)
        io2 = io_type(io_data2)
        io3 = io_type(io_data3)

        self.assertEqual(get_hash(io1), get_hash(io3))
        self.assertNotEqual(get_hash(io1), get_hash(io2))

        # Changing the stream position should change the hash
        io1.seek(1)
        io3.seek(0)
        self.assertNotEqual(get_hash(io1), get_hash(io3))

    def test_partial(self):
        p1 = functools.partial(int, base=2)
        p2 = functools.partial(int, base=3)
        p3 = functools.partial(int, base=2)

        self.assertEqual(get_hash(p1), get_hash(p3))
        self.assertNotEqual(get_hash(p1), get_hash(p2))

    def test_lambdas(self):
        # self.assertEqual(get_hash(lambda x: x.lower()), get_hash(lambda x: x.lower()))
        self.assertNotEqual(get_hash(lambda x: x.lower()),
                            get_hash(lambda x: x.upper()))

    def test_files(self):
        temp1 = tempfile.NamedTemporaryFile()
        temp2 = tempfile.NamedTemporaryFile()

        with open(__file__, "r") as f:
            with open(__file__, "r") as g:
                self.assertEqual(get_hash(f), get_hash(g))

            self.assertNotEqual(get_hash(f), get_hash(temp1))

        self.assertEqual(get_hash(temp1), get_hash(temp1))
        self.assertNotEqual(get_hash(temp1), get_hash(temp2))

    def test_file_position(self):
        with open(__file__, "r") as f:
            h1 = get_hash(f)
            self.assertEqual(h1, get_hash(f))
            f.readline()
            self.assertNotEqual(h1, get_hash(f))
            f.seek(0)
            self.assertEqual(h1, get_hash(f))

    def test_keras_model(self):
        a = keras.applications.vgg16.VGG16(include_top=False, weights=None)
        b = keras.applications.vgg16.VGG16(include_top=False, weights=None)

        # This test still passes if we remove the default hash func for Keras
        # models. Ideally we'd seed the weights before creating the models
        # but not clear how to do so.
        self.assertEqual(get_hash(a), get_hash(a))
        self.assertNotEqual(get_hash(a), get_hash(b))

    def test_tf_keras_model(self):
        a = tf.keras.applications.vgg16.VGG16(include_top=False, weights=None)
        b = tf.keras.applications.vgg16.VGG16(include_top=False, weights=None)

        self.assertEqual(get_hash(a), get_hash(a))
        self.assertNotEqual(get_hash(a), get_hash(b))

    def test_tf_saved_model(self):
        tempdir = tempfile.TemporaryDirectory()

        model = tf.keras.models.Sequential([
            tf.keras.layers.Dense(512, activation="relu", input_shape=(784, )),
        ])
        model.save(tempdir.name)

        a = tf.saved_model.load(tempdir.name)
        b = tf.saved_model.load(tempdir.name)

        self.assertEqual(get_hash(a), get_hash(a))
        self.assertNotEqual(get_hash(a), get_hash(b))

    def test_pytorch_model(self):
        a = torchvision.models.resnet.resnet18()
        b = torchvision.models.resnet.resnet18()

        self.assertEqual(get_hash(a), get_hash(a))
        self.assertNotEqual(get_hash(a), get_hash(b))

    def test_socket(self):
        a = socket.socket()
        b = socket.socket()

        self.assertEqual(get_hash(a), get_hash(a))
        self.assertNotEqual(get_hash(a), get_hash(b))

    def test_magic_mock(self):
        """Test that MagicMocks never hash to the same thing."""
        # (This also tests that MagicMock can hash at all, without blowing the
        # stack due to an infinite recursion.)
        self.assertNotEqual(get_hash(MagicMock()), get_hash(MagicMock()))

    def test_tensorflow_session(self):
        tf_config = tf.compat.v1.ConfigProto()
        tf_session = tf.compat.v1.Session(config=tf_config)
        self.assertEqual(get_hash(tf_session), get_hash(tf_session))

        tf_session2 = tf.compat.v1.Session(config=tf_config)
        self.assertNotEqual(get_hash(tf_session), get_hash(tf_session2))

    def test_torch_c_tensorbase(self):
        a = torch.ones([1, 1]).__reduce__()[1][2]
        b = torch.ones([1, 1], requires_grad=True).__reduce__()[1][2]
        c = torch.ones([1, 2]).__reduce__()[1][2]

        assert is_type(a, "torch._C._TensorBase")
        self.assertEqual(get_hash(a), get_hash(b))
        self.assertNotEqual(get_hash(a), get_hash(c))

        b.mean().backward()
        # Calling backward on a tensorbase doesn't seem to affect the gradient
        self.assertEqual(get_hash(a), get_hash(b))

    def test_torch_tensor(self):
        a = torch.ones([1, 1])
        b = torch.ones([1, 1], requires_grad=True)
        c = torch.ones([1, 2])

        self.assertEqual(get_hash(a), get_hash(b))
        self.assertNotEqual(get_hash(a), get_hash(c))

        b.mean().backward()

        self.assertNotEqual(get_hash(a), get_hash(b))

    def test_non_hashable(self):
        """Test user provided hash functions."""

        g = (x for x in range(1))

        # Unhashable object raises an error
        with self.assertRaises(UnhashableTypeError):
            get_hash(g)

        id_hash_func = {types.GeneratorType: id}

        self.assertEqual(
            get_hash(g, hash_funcs=id_hash_func),
            get_hash(g, hash_funcs=id_hash_func),
        )

        unique_hash_func = {types.GeneratorType: lambda x: time.time()}

        self.assertNotEqual(
            get_hash(g, hash_funcs=unique_hash_func),
            get_hash(g, hash_funcs=unique_hash_func),
        )

    def test_override_streamlit_hash_func(self):
        """Test that a user provided hash function has priority over a streamlit one."""

        hash_funcs = {int: lambda x: "hello"}
        self.assertNotEqual(get_hash(1), get_hash(1, hash_funcs=hash_funcs))

    def _build_cffi(self, name):
        ffibuilder = cffi.FFI()
        ffibuilder.set_source(
            "cffi_bin._%s" % name,
            r"""
                static int %s(int x)
                {
                    return x + "A";
                }
            """ % name,
        )

        ffibuilder.cdef("int %s(int);" % name)
        ffibuilder.compile(verbose=True)

    def test_compiled_ffi(self):
        self._build_cffi("foo")
        self._build_cffi("bar")
        from cffi_bin._foo import ffi as foo
        from cffi_bin._bar import ffi as bar

        # Note: We've verified that all properties on CompiledFFI objects
        # are global, except have not verified `error` either way.
        self.assertIn(get_fqn_type(foo), _FFI_TYPE_NAMES)
        self.assertEqual(get_hash(foo), get_hash(bar))

    def test_sqlite_sqlalchemy_engine(self):
        """Separate tests for sqlite since it uses a file based
        and in memory database and has no auth
        """

        mem = "sqlite://"
        foo = "sqlite:///foo.db"

        self.assertEqual(hash_engine(mem), hash_engine(mem))
        self.assertEqual(hash_engine(foo), hash_engine(foo))
        self.assertNotEqual(hash_engine(foo), hash_engine("sqlite:///bar.db"))
        self.assertNotEqual(hash_engine(foo), hash_engine(mem))

        # Need to use absolute paths otherwise one path resolves
        # relatively and the other absolute
        self.assertEqual(
            hash_engine("sqlite:////foo.db", connect_args={"uri": True}),
            hash_engine("sqlite:////foo.db?uri=true"),
        )

        self.assertNotEqual(
            hash_engine(foo, connect_args={"uri": True}),
            hash_engine(foo, connect_args={"uri": False}),
        )

        self.assertNotEqual(
            hash_engine(foo, creator=lambda: False),
            hash_engine(foo, creator=lambda: True),
        )

    def test_mssql_sqlalchemy_engine(self):
        """Specialized tests for mssql since it uses a different way of
        passing connection arguments to the engine
        """

        url = "mssql:///?odbc_connect"
        auth_url = "mssql://*****:*****@localhost/db"

        params_foo = urllib.parse.quote_plus(
            "Server=localhost;Database=db;UID=foo;PWD=pass")
        params_bar = urllib.parse.quote_plus(
            "Server=localhost;Database=db;UID=bar;PWD=pass")
        params_foo_caps = urllib.parse.quote_plus(
            "SERVER=localhost;Database=db;UID=foo;PWD=pass")
        params_foo_order = urllib.parse.quote_plus(
            "Database=db;Server=localhost;UID=foo;PWD=pass")

        self.assertEqual(
            hash_engine(auth_url),
            hash_engine("%s=%s" % (url, params_foo)),
        )
        self.assertNotEqual(
            hash_engine("%s=%s" % (url, params_foo)),
            hash_engine("%s=%s" % (url, params_bar)),
        )

        # Note: False negative because the ordering of the keys affects
        # the hash
        self.assertNotEqual(
            hash_engine("%s=%s" % (url, params_foo)),
            hash_engine("%s=%s" % (url, params_foo_order)),
        )

        # Note: False negative because the keys are case insensitive
        self.assertNotEqual(
            hash_engine("%s=%s" % (url, params_foo)),
            hash_engine("%s=%s" % (url, params_foo_caps)),
        )

        # Note: False negative because `connect_args` doesn't affect the
        # connection string
        self.assertNotEqual(
            hash_engine(url, connect_args={"user": "******"}),
            hash_engine(url, connect_args={"user": "******"}),
        )

    @parameterized.expand([
        ("postgresql", "password"),
        ("mysql", "passwd"),
        ("oracle", "password"),
        ("mssql", "password"),
    ])
    def test_sqlalchemy_engine(self, dialect, password_key):
        def connect():
            pass

        url = "%s://localhost/db" % dialect
        auth_url = "%s://user:pass@localhost/db" % dialect

        self.assertEqual(hash_engine(url), hash_engine(url))
        self.assertEqual(
            hash_engine(auth_url, creator=connect),
            hash_engine(auth_url, creator=connect),
        )

        # Note: Hashing an engine with a creator can only be equal to the hash of another
        # engine with a creator, even if the underlying connection arguments are the same
        self.assertNotEqual(hash_engine(url), hash_engine(url,
                                                          creator=connect))

        self.assertNotEqual(hash_engine(url), hash_engine(auth_url))
        self.assertNotEqual(hash_engine(url, encoding="utf-8"),
                            hash_engine(url, encoding="ascii"))
        self.assertNotEqual(hash_engine(url, creator=connect),
                            hash_engine(url, creator=lambda: True))

        # mssql doesn't use `connect_args`
        if dialect != "mssql":
            self.assertEqual(
                hash_engine(auth_url),
                hash_engine(url,
                            connect_args={
                                "user": "******",
                                password_key: "pass"
                            }),
            )

            self.assertNotEqual(
                hash_engine(url, connect_args={"user": "******"}),
                hash_engine(url, connect_args={"user": "******"}),
            )
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for UploadedFileManager"""

import unittest

from streamlit.uploaded_file_manager import UploadedFileManager
from streamlit.uploaded_file_manager import UploadedFileRec

FILE_1 = UploadedFileRec(id=0, name="file1", type="type", data=b"file1")
FILE_2 = UploadedFileRec(id=0, name="file2", type="type", data=b"file2")


class UploadedFileManagerTest(unittest.TestCase):
    def setUp(self):
        self.mgr = UploadedFileManager()
        self.filemgr_events = []
        self.mgr.on_files_updated.connect(self._on_files_updated)

    def _on_files_updated(self, file_list, **kwargs):
        self.filemgr_events.append(file_list)

    def test_added_file_id(self):
        """An added file should have a unique ID."""
        f1 = self.mgr.add_file("session", "widget", FILE_1)
Esempio n. 16
0
class HashTest(unittest.TestCase):
    def test_string(self):
        self.assertEqual(get_hash("hello"), get_hash("hello"))
        self.assertNotEqual(get_hash("hello"), get_hash("hellö"))

    def test_int(self):
        self.assertEqual(get_hash(145757624235), get_hash(145757624235))
        self.assertNotEqual(get_hash(10), get_hash(11))
        self.assertNotEqual(get_hash(-1), get_hash(1))
        self.assertNotEqual(get_hash(2 ** 7), get_hash(2 ** 7 - 1))
        self.assertNotEqual(get_hash(2 ** 7), get_hash(2 ** 7 + 1))

    def test_mocks_do_not_result_in_infinite_recursion(self):
        try:
            get_hash(Mock())
            get_hash(MagicMock())
        except BaseException:
            self.fail("get_hash raised an exception")

    def test_list(self):
        self.assertEqual(get_hash([1, 2]), get_hash([1, 2]))
        self.assertNotEqual(get_hash([1, 2]), get_hash([2, 2]))
        self.assertNotEqual(get_hash([1]), get_hash(1))

        # test that we can hash self-referencing lists
        a = [1, 2, 3]
        a.append(a)
        b = [1, 2, 3]
        b.append(b)
        self.assertEqual(get_hash(a), get_hash(b))

    def test_tuple(self):
        self.assertEqual(get_hash((1, 2)), get_hash((1, 2)))
        self.assertNotEqual(get_hash((1, 2)), get_hash((2, 2)))
        self.assertNotEqual(get_hash((1,)), get_hash(1))
        self.assertNotEqual(get_hash((1,)), get_hash([1]))

    def test_mappingproxy(self):
        a = types.MappingProxyType({"a": 1})
        b = types.MappingProxyType({"a": 1})
        c = types.MappingProxyType({"c": 1})

        self.assertEqual(get_hash(a), get_hash(b))
        self.assertNotEqual(get_hash(a), get_hash(c))

    def test_dict_items(self):
        a = types.MappingProxyType({"a": 1}).items()
        b = types.MappingProxyType({"a": 1}).items()
        c = types.MappingProxyType({"c": 1}).items()

        assert is_type(a, "builtins.dict_items")
        self.assertEqual(get_hash(a), get_hash(b))
        self.assertNotEqual(get_hash(a), get_hash(c))

    def test_getset_descriptor(self):
        class A:
            x = 1

        class B:
            x = 1

        a = A.__dict__["__dict__"]
        b = B.__dict__["__dict__"]
        assert is_type(a, "builtins.getset_descriptor")

        self.assertEqual(get_hash(a), get_hash(a))
        self.assertNotEqual(get_hash(a), get_hash(b))

    def test_dict(self):
        self.assertEqual(get_hash({1: 1}), get_hash({1: 1}))
        self.assertNotEqual(get_hash({1: 1}), get_hash({1: 2}))
        self.assertNotEqual(get_hash({1: 1}), get_hash([(1, 1)]))

        dict_gen = {1: (x for x in range(1))}
        with self.assertRaises(UnhashableTypeError):
            get_hash(dict_gen)

    def test_self_reference_dict(self):
        d1 = {"cat": "hat"}
        d2 = {"things": [1, 2]}

        self.assertEqual(get_hash(d1), get_hash(d1))
        self.assertNotEqual(get_hash(d1), get_hash(d2))

        # test that we can hash self-referencing dictionaries
        d2 = {"book": d1}
        self.assertNotEqual(get_hash(d2), get_hash(d1))

    def test_float(self):
        self.assertEqual(get_hash(0.1), get_hash(0.1))
        self.assertNotEqual(get_hash(23.5234), get_hash(23.5235))

    def test_bool(self):
        self.assertEqual(get_hash(True), get_hash(True))
        self.assertNotEqual(get_hash(True), get_hash(False))

    def test_none(self):
        self.assertEqual(get_hash(None), get_hash(None))
        self.assertNotEqual(get_hash(None), get_hash(False))

    def test_builtins(self):
        self.assertEqual(get_hash(abs), get_hash(abs))
        self.assertNotEqual(get_hash(abs), get_hash(type))

    def test_regex(self):
        p2 = re.compile(".*")
        p1 = re.compile(".*")
        p3 = re.compile(".*", re.I)
        self.assertEqual(get_hash(p1), get_hash(p2))
        self.assertNotEqual(get_hash(p1), get_hash(p3))

    def test_pandas_dataframe(self):
        df1 = pd.DataFrame({"foo": [12]})
        df2 = pd.DataFrame({"foo": [42]})
        df3 = pd.DataFrame({"foo": [12]})

        self.assertEqual(get_hash(df1), get_hash(df3))
        self.assertNotEqual(get_hash(df1), get_hash(df2))

        df4 = pd.DataFrame(np.zeros((_PANDAS_ROWS_LARGE, 4)), columns=list("ABCD"))
        df5 = pd.DataFrame(np.zeros((_PANDAS_ROWS_LARGE, 4)), columns=list("ABCD"))

        self.assertEqual(get_hash(df4), get_hash(df5))

    def test_pandas_series(self):
        series1 = pd.Series([1, 2])
        series2 = pd.Series([1, 3])
        series3 = pd.Series([1, 2])

        self.assertEqual(get_hash(series1), get_hash(series3))
        self.assertNotEqual(get_hash(series1), get_hash(series2))

        series4 = pd.Series(range(_PANDAS_ROWS_LARGE))
        series5 = pd.Series(range(_PANDAS_ROWS_LARGE))

        self.assertEqual(get_hash(series4), get_hash(series5))

    def test_numpy(self):
        np1 = np.zeros(10)
        np2 = np.zeros(11)
        np3 = np.zeros(10)

        self.assertEqual(get_hash(np1), get_hash(np3))
        self.assertNotEqual(get_hash(np1), get_hash(np2))

        np4 = np.zeros(_NP_SIZE_LARGE)
        np5 = np.zeros(_NP_SIZE_LARGE)

        self.assertEqual(get_hash(np4), get_hash(np5))

    @parameterized.expand(
        [
            (BytesIO, b"123", b"456", b"123"),
            (StringIO, "123", "456", "123"),
            (
                UploadedFile,
                UploadedFileRec(0, "name", "type", b"123"),
                UploadedFileRec(0, "name", "type", b"456"),
                UploadedFileRec(0, "name", "type", b"123"),
            ),
        ]
    )
    def test_io(self, io_type, io_data1, io_data2, io_data3):
        io1 = io_type(io_data1)
        io2 = io_type(io_data2)
        io3 = io_type(io_data3)

        self.assertEqual(get_hash(io1), get_hash(io3))
        self.assertNotEqual(get_hash(io1), get_hash(io2))

        # Changing the stream position should change the hash
        io1.seek(1)
        io3.seek(0)
        self.assertNotEqual(get_hash(io1), get_hash(io3))

    def test_partial(self):
        p1 = functools.partial(int, base=2)
        p2 = functools.partial(int, base=3)
        p3 = functools.partial(int, base=2)

        self.assertEqual(get_hash(p1), get_hash(p3))
        self.assertNotEqual(get_hash(p1), get_hash(p2))

    def test_files(self):
        temp1 = tempfile.NamedTemporaryFile()
        temp2 = tempfile.NamedTemporaryFile()

        with open(__file__, "r") as f:
            with open(__file__, "r") as g:
                self.assertEqual(get_hash(f), get_hash(g))

            self.assertNotEqual(get_hash(f), get_hash(temp1))

        self.assertEqual(get_hash(temp1), get_hash(temp1))
        self.assertNotEqual(get_hash(temp1), get_hash(temp2))

    def test_file_position(self):
        with open(__file__, "r") as f:
            h1 = get_hash(f)
            self.assertEqual(h1, get_hash(f))
            f.readline()
            self.assertNotEqual(h1, get_hash(f))
            f.seek(0)
            self.assertEqual(h1, get_hash(f))

    def test_magic_mock(self):
        """MagicMocks never hash to the same thing."""
        # (This also tests that MagicMock can hash at all, without blowing the
        # stack due to an infinite recursion.)
        self.assertNotEqual(get_hash(MagicMock()), get_hash(MagicMock()))
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for UploadedFileManager"""

import unittest

from streamlit.uploaded_file_manager import UploadedFileManager
from streamlit.uploaded_file_manager import UploadedFileRec

file1 = UploadedFileRec(id="id1", name="file1", type="type", data=b"file1")
file2 = UploadedFileRec(id="id2", name="file2", type="type", data=b"file2")


class UploadedFileManagerTest(unittest.TestCase):
    def setUp(self):
        self.mgr = UploadedFileManager()
        self.filemgr_events = []
        self.mgr.on_files_updated.connect(self._on_files_updated)

    def _on_files_updated(self, file_list, **kwargs):
        self.filemgr_events.append(file_list)

    def test_add_file(self):
        self.assertIsNone(self.mgr.get_files("non-report", "non-widget"))