# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Unit tests for UploadedFileManager""" import unittest from streamlit.stats import CacheStat from streamlit.uploaded_file_manager import UploadedFileManager from streamlit.uploaded_file_manager import UploadedFileRec FILE_1 = UploadedFileRec(id=0, name="file1", type="type", data=b"file1") FILE_2 = UploadedFileRec(id=0, name="file2", type="type", data=b"file222") class UploadedFileManagerTest(unittest.TestCase): def setUp(self): self.mgr = UploadedFileManager() self.filemgr_events = [] self.mgr.on_files_updated.connect(self._on_files_updated) def _on_files_updated(self, file_list, **kwargs): self.filemgr_events.append(file_list) def test_added_file_id(self): """An added file should have a unique ID.""" f1 = self.mgr.add_file("session", "widget", FILE_1)
def post(self, **kwargs): args: Dict[str, List[bytes]] = {} files: Dict[str, List[Any]] = {} tornado.httputil.parse_body_arguments( content_type=self.request.headers["Content-Type"], body=self.request.body, arguments=args, files=files, ) try: session_id = self._require_arg(args, "sessionId") widget_id = self._require_arg(args, "widgetId") if not self._validate_request(session_id): raise Exception("Session '%s' invalid" % session_id) except Exception as e: self.send_error(400, reason=str(e)) return LOGGER.debug( f"{len(files)} file(s) received for session {session_id} widget {widget_id}" ) # Create an UploadedFile object for each file. uploaded_files: List[UploadedFileRec] = [] for id, flist in files.items(): for file in flist: uploaded_files.append( UploadedFileRec( id=id, name=file["filename"], type=file["content_type"], data=file["body"], )) if len(uploaded_files) == 0: self.send_error(400, reason="Expected at least 1 file, but got 0") return replace = self.get_argument("replace", "false") try: total_files = int(self._require_arg(args, "totalFiles")) except Exception as e: total_files = 1 self._file_mgr.update_file_count( session_id=session_id, widget_id=widget_id, file_count=total_files, ) update_files = (self._file_mgr.replace_files if replace == "true" else self._file_mgr.add_files) update_files( session_id=session_id, widget_id=widget_id, files=uploaded_files, ) LOGGER.debug( f"{len(files)} file(s) uploaded for session {session_id} widget {widget_id}. replace {replace}" ) self.set_status(200)
class HashTest(unittest.TestCase): def test_string(self): self.assertEqual(get_hash("hello"), get_hash("hello")) self.assertNotEqual(get_hash("hello"), get_hash("hellö")) def test_int(self): self.assertEqual(get_hash(145757624235), get_hash(145757624235)) self.assertNotEqual(get_hash(10), get_hash(11)) self.assertNotEqual(get_hash(-1), get_hash(1)) self.assertNotEqual(get_hash(2**7), get_hash(2**7 - 1)) self.assertNotEqual(get_hash(2**7), get_hash(2**7 + 1)) def test_mocks_do_not_result_in_infinite_recursion(self): try: get_hash(Mock()) get_hash(MagicMock()) except InternalHashError: self.fail("get_hash raised InternalHashError") def test_list(self): self.assertEqual(get_hash([1, 2]), get_hash([1, 2])) self.assertNotEqual(get_hash([1, 2]), get_hash([2, 2])) self.assertNotEqual(get_hash([1]), get_hash(1)) # test that we can hash self-referencing lists a = [1, 2, 3] a.append(a) b = [1, 2, 3] b.append(b) self.assertEqual(get_hash(a), get_hash(b)) def test_recursive_hash_func(self): def hash_int(x): return x @st.cache(hash_funcs={int: hash_int}) def foo(x): return x self.assertEqual(foo(1), foo(1)) # Note: We're able to break the recursive cycle caused by the identity # hash func but it causes all cycles to hash to the same thing. # https://github.com/streamlit/streamlit/issues/1659 # self.assertNotEqual(foo(2), foo(1)) def test_tuple(self): self.assertEqual(get_hash((1, 2)), get_hash((1, 2))) self.assertNotEqual(get_hash((1, 2)), get_hash((2, 2))) self.assertNotEqual(get_hash((1, )), get_hash(1)) self.assertNotEqual(get_hash((1, )), get_hash([1])) def test_mappingproxy(self): a = types.MappingProxyType({"a": 1}) b = types.MappingProxyType({"a": 1}) c = types.MappingProxyType({"c": 1}) self.assertEqual(get_hash(a), get_hash(b)) self.assertNotEqual(get_hash(a), get_hash(c)) def test_dict_items(self): a = types.MappingProxyType({"a": 1}).items() b = types.MappingProxyType({"a": 1}).items() c = types.MappingProxyType({"c": 1}).items() assert is_type(a, "builtins.dict_items") self.assertEqual(get_hash(a), get_hash(b)) self.assertNotEqual(get_hash(a), get_hash(c)) def test_getset_descriptor(self): class A: x = 1 class B: x = 1 a = A.__dict__["__dict__"] b = B.__dict__["__dict__"] assert is_type(a, "builtins.getset_descriptor") self.assertEqual(get_hash(a), get_hash(a)) self.assertNotEqual(get_hash(a), get_hash(b)) def test_dict(self): dict_gen = {1: (x for x in range(1))} self.assertEqual(get_hash({1: 1}), get_hash({1: 1})) self.assertNotEqual(get_hash({1: 1}), get_hash({1: 2})) self.assertNotEqual(get_hash({1: 1}), get_hash([(1, 1)])) with self.assertRaises(UnhashableTypeError): get_hash(dict_gen) get_hash(dict_gen, hash_funcs={types.GeneratorType: id}) def test_self_reference_dict(self): d1 = {"cat": "hat"} d2 = {"things": [1, 2]} self.assertEqual(get_hash(d1), get_hash(d1)) self.assertNotEqual(get_hash(d1), get_hash(d2)) # test that we can hash self-referencing dictionaries d2 = {"book": d1} self.assertNotEqual(get_hash(d2), get_hash(d1)) def test_reduce_(self): class A(object): def __init__(self): self.x = [1, 2, 3] class B(object): def __init__(self): self.x = [1, 2, 3] class C(object): def __init__(self): self.x = (x for x in range(1)) self.assertEqual(get_hash(A()), get_hash(A())) self.assertNotEqual(get_hash(A()), get_hash(B())) self.assertNotEqual(get_hash(A()), get_hash(A().__reduce__())) with self.assertRaises(UnhashableTypeError): get_hash(C()) get_hash(C(), hash_funcs={types.GeneratorType: id}) def test_generator(self): with self.assertRaises(UnhashableTypeError): get_hash((x for x in range(1))) def test_hashing_broken_code(self): import datetime def a(): return datetime.strptime("%H") def b(): x = datetime.strptime("%H") "" "" return x data = [ (a, '```\nreturn datetime.strptime("%H")\n```'), (b, '```\nx = datetime.strptime("%H")\n""\n""\n```'), ] for func, code_msg in data: exc_msg = "module 'datetime' has no attribute 'strptime'" with self.assertRaises(UserHashError) as ctx: get_hash(func) exc = str(ctx.exception) self.assertEqual(exc.find(exc_msg) >= 0, True) self.assertNotEqual( re.search(r"a bug in `.+` near line `\d+`", exc), None) self.assertEqual(exc.find(code_msg) >= 0, True) def test_hash_funcs_acceptable_keys(self): class C(object): def __init__(self): self.x = (x for x in range(1)) with self.assertRaises(UnhashableTypeError): get_hash(C()) self.assertEqual( get_hash(C(), hash_funcs={types.GeneratorType: id}), get_hash(C(), hash_funcs={"builtins.generator": id}), ) def test_hash_funcs_error(self): with self.assertRaises(UserHashError): get_hash(1, hash_funcs={int: lambda x: "a" + x}) def test_internal_hashing_error(self): def side_effect(i): if i == 123456789: return "a" + 1 return i.to_bytes((i.bit_length() + 8) // 8, "little", signed=True) with self.assertRaises(InternalHashError): with patch("streamlit.hashing._int_to_bytes", side_effect=side_effect): get_hash(123456789) def test_float(self): self.assertEqual(get_hash(0.1), get_hash(0.1)) self.assertNotEqual(get_hash(23.5234), get_hash(23.5235)) def test_bool(self): self.assertEqual(get_hash(True), get_hash(True)) self.assertNotEqual(get_hash(True), get_hash(False)) def test_none(self): self.assertEqual(get_hash(None), get_hash(None)) self.assertNotEqual(get_hash(None), get_hash(False)) def test_builtins(self): self.assertEqual(get_hash(abs), get_hash(abs)) self.assertNotEqual(get_hash(abs), get_hash(type)) def test_regex(self): p2 = re.compile(".*") p1 = re.compile(".*") p3 = re.compile(".*", re.I) self.assertEqual(get_hash(p1), get_hash(p2)) self.assertNotEqual(get_hash(p1), get_hash(p3)) def test_pandas_dataframe(self): df1 = pd.DataFrame({"foo": [12]}) df2 = pd.DataFrame({"foo": [42]}) df3 = pd.DataFrame({"foo": [12]}) self.assertEqual(get_hash(df1), get_hash(df3)) self.assertNotEqual(get_hash(df1), get_hash(df2)) df4 = pd.DataFrame(np.zeros((_PANDAS_ROWS_LARGE, 4)), columns=list("ABCD")) df5 = pd.DataFrame(np.zeros((_PANDAS_ROWS_LARGE, 4)), columns=list("ABCD")) self.assertEqual(get_hash(df4), get_hash(df5)) def test_pandas_series(self): series1 = pd.Series([1, 2]) series2 = pd.Series([1, 3]) series3 = pd.Series([1, 2]) self.assertEqual(get_hash(series1), get_hash(series3)) self.assertNotEqual(get_hash(series1), get_hash(series2)) series4 = pd.Series(range(_PANDAS_ROWS_LARGE)) series5 = pd.Series(range(_PANDAS_ROWS_LARGE)) self.assertEqual(get_hash(series4), get_hash(series5)) def test_numpy(self): np1 = np.zeros(10) np2 = np.zeros(11) np3 = np.zeros(10) self.assertEqual(get_hash(np1), get_hash(np3)) self.assertNotEqual(get_hash(np1), get_hash(np2)) np4 = np.zeros(_NP_SIZE_LARGE) np5 = np.zeros(_NP_SIZE_LARGE) self.assertEqual(get_hash(np4), get_hash(np5)) @parameterized.expand([ (BytesIO, b"123", b"456", b"123"), (StringIO, "123", "456", "123"), ( UploadedFile, UploadedFileRec("id", "name", "type", b"123"), UploadedFileRec("id", "name", "type", b"456"), UploadedFileRec("id", "name", "type", b"123"), ), ]) def test_io(self, io_type, io_data1, io_data2, io_data3): io1 = io_type(io_data1) io2 = io_type(io_data2) io3 = io_type(io_data3) self.assertEqual(get_hash(io1), get_hash(io3)) self.assertNotEqual(get_hash(io1), get_hash(io2)) # Changing the stream position should change the hash io1.seek(1) io3.seek(0) self.assertNotEqual(get_hash(io1), get_hash(io3)) def test_partial(self): p1 = functools.partial(int, base=2) p2 = functools.partial(int, base=3) p3 = functools.partial(int, base=2) self.assertEqual(get_hash(p1), get_hash(p3)) self.assertNotEqual(get_hash(p1), get_hash(p2)) def test_lambdas(self): # self.assertEqual(get_hash(lambda x: x.lower()), get_hash(lambda x: x.lower())) self.assertNotEqual(get_hash(lambda x: x.lower()), get_hash(lambda x: x.upper())) def test_files(self): temp1 = tempfile.NamedTemporaryFile() temp2 = tempfile.NamedTemporaryFile() with open(__file__, "r") as f: with open(__file__, "r") as g: self.assertEqual(get_hash(f), get_hash(g)) self.assertNotEqual(get_hash(f), get_hash(temp1)) self.assertEqual(get_hash(temp1), get_hash(temp1)) self.assertNotEqual(get_hash(temp1), get_hash(temp2)) def test_file_position(self): with open(__file__, "r") as f: h1 = get_hash(f) self.assertEqual(h1, get_hash(f)) f.readline() self.assertNotEqual(h1, get_hash(f)) f.seek(0) self.assertEqual(h1, get_hash(f)) def test_keras_model(self): a = keras.applications.vgg16.VGG16(include_top=False, weights=None) b = keras.applications.vgg16.VGG16(include_top=False, weights=None) # This test still passes if we remove the default hash func for Keras # models. Ideally we'd seed the weights before creating the models # but not clear how to do so. self.assertEqual(get_hash(a), get_hash(a)) self.assertNotEqual(get_hash(a), get_hash(b)) def test_tf_keras_model(self): a = tf.keras.applications.vgg16.VGG16(include_top=False, weights=None) b = tf.keras.applications.vgg16.VGG16(include_top=False, weights=None) self.assertEqual(get_hash(a), get_hash(a)) self.assertNotEqual(get_hash(a), get_hash(b)) def test_tf_saved_model(self): tempdir = tempfile.TemporaryDirectory() model = tf.keras.models.Sequential([ tf.keras.layers.Dense(512, activation="relu", input_shape=(784, )), ]) model.save(tempdir.name) a = tf.saved_model.load(tempdir.name) b = tf.saved_model.load(tempdir.name) self.assertEqual(get_hash(a), get_hash(a)) self.assertNotEqual(get_hash(a), get_hash(b)) def test_pytorch_model(self): a = torchvision.models.resnet.resnet18() b = torchvision.models.resnet.resnet18() self.assertEqual(get_hash(a), get_hash(a)) self.assertNotEqual(get_hash(a), get_hash(b)) def test_socket(self): a = socket.socket() b = socket.socket() self.assertEqual(get_hash(a), get_hash(a)) self.assertNotEqual(get_hash(a), get_hash(b)) def test_magic_mock(self): """Test that MagicMocks never hash to the same thing.""" # (This also tests that MagicMock can hash at all, without blowing the # stack due to an infinite recursion.) self.assertNotEqual(get_hash(MagicMock()), get_hash(MagicMock())) def test_tensorflow_session(self): tf_config = tf.compat.v1.ConfigProto() tf_session = tf.compat.v1.Session(config=tf_config) self.assertEqual(get_hash(tf_session), get_hash(tf_session)) tf_session2 = tf.compat.v1.Session(config=tf_config) self.assertNotEqual(get_hash(tf_session), get_hash(tf_session2)) def test_torch_c_tensorbase(self): a = torch.ones([1, 1]).__reduce__()[1][2] b = torch.ones([1, 1], requires_grad=True).__reduce__()[1][2] c = torch.ones([1, 2]).__reduce__()[1][2] assert is_type(a, "torch._C._TensorBase") self.assertEqual(get_hash(a), get_hash(b)) self.assertNotEqual(get_hash(a), get_hash(c)) b.mean().backward() # Calling backward on a tensorbase doesn't seem to affect the gradient self.assertEqual(get_hash(a), get_hash(b)) def test_torch_tensor(self): a = torch.ones([1, 1]) b = torch.ones([1, 1], requires_grad=True) c = torch.ones([1, 2]) self.assertEqual(get_hash(a), get_hash(b)) self.assertNotEqual(get_hash(a), get_hash(c)) b.mean().backward() self.assertNotEqual(get_hash(a), get_hash(b)) def test_non_hashable(self): """Test user provided hash functions.""" g = (x for x in range(1)) # Unhashable object raises an error with self.assertRaises(UnhashableTypeError): get_hash(g) id_hash_func = {types.GeneratorType: id} self.assertEqual( get_hash(g, hash_funcs=id_hash_func), get_hash(g, hash_funcs=id_hash_func), ) unique_hash_func = {types.GeneratorType: lambda x: time.time()} self.assertNotEqual( get_hash(g, hash_funcs=unique_hash_func), get_hash(g, hash_funcs=unique_hash_func), ) def test_override_streamlit_hash_func(self): """Test that a user provided hash function has priority over a streamlit one.""" hash_funcs = {int: lambda x: "hello"} self.assertNotEqual(get_hash(1), get_hash(1, hash_funcs=hash_funcs)) def _build_cffi(self, name): ffibuilder = cffi.FFI() ffibuilder.set_source( "cffi_bin._%s" % name, r""" static int %s(int x) { return x + "A"; } """ % name, ) ffibuilder.cdef("int %s(int);" % name) ffibuilder.compile(verbose=True) def test_compiled_ffi(self): self._build_cffi("foo") self._build_cffi("bar") from cffi_bin._foo import ffi as foo from cffi_bin._bar import ffi as bar # Note: We've verified that all properties on CompiledFFI objects # are global, except have not verified `error` either way. self.assertIn(get_fqn_type(foo), _FFI_TYPE_NAMES) self.assertEqual(get_hash(foo), get_hash(bar)) def test_sqlite_sqlalchemy_engine(self): """Separate tests for sqlite since it uses a file based and in memory database and has no auth """ mem = "sqlite://" foo = "sqlite:///foo.db" self.assertEqual(hash_engine(mem), hash_engine(mem)) self.assertEqual(hash_engine(foo), hash_engine(foo)) self.assertNotEqual(hash_engine(foo), hash_engine("sqlite:///bar.db")) self.assertNotEqual(hash_engine(foo), hash_engine(mem)) # Need to use absolute paths otherwise one path resolves # relatively and the other absolute self.assertEqual( hash_engine("sqlite:////foo.db", connect_args={"uri": True}), hash_engine("sqlite:////foo.db?uri=true"), ) self.assertNotEqual( hash_engine(foo, connect_args={"uri": True}), hash_engine(foo, connect_args={"uri": False}), ) self.assertNotEqual( hash_engine(foo, creator=lambda: False), hash_engine(foo, creator=lambda: True), ) def test_mssql_sqlalchemy_engine(self): """Specialized tests for mssql since it uses a different way of passing connection arguments to the engine """ url = "mssql:///?odbc_connect" auth_url = "mssql://*****:*****@localhost/db" params_foo = urllib.parse.quote_plus( "Server=localhost;Database=db;UID=foo;PWD=pass") params_bar = urllib.parse.quote_plus( "Server=localhost;Database=db;UID=bar;PWD=pass") params_foo_caps = urllib.parse.quote_plus( "SERVER=localhost;Database=db;UID=foo;PWD=pass") params_foo_order = urllib.parse.quote_plus( "Database=db;Server=localhost;UID=foo;PWD=pass") self.assertEqual( hash_engine(auth_url), hash_engine("%s=%s" % (url, params_foo)), ) self.assertNotEqual( hash_engine("%s=%s" % (url, params_foo)), hash_engine("%s=%s" % (url, params_bar)), ) # Note: False negative because the ordering of the keys affects # the hash self.assertNotEqual( hash_engine("%s=%s" % (url, params_foo)), hash_engine("%s=%s" % (url, params_foo_order)), ) # Note: False negative because the keys are case insensitive self.assertNotEqual( hash_engine("%s=%s" % (url, params_foo)), hash_engine("%s=%s" % (url, params_foo_caps)), ) # Note: False negative because `connect_args` doesn't affect the # connection string self.assertNotEqual( hash_engine(url, connect_args={"user": "******"}), hash_engine(url, connect_args={"user": "******"}), ) @parameterized.expand([ ("postgresql", "password"), ("mysql", "passwd"), ("oracle", "password"), ("mssql", "password"), ]) def test_sqlalchemy_engine(self, dialect, password_key): def connect(): pass url = "%s://localhost/db" % dialect auth_url = "%s://user:pass@localhost/db" % dialect self.assertEqual(hash_engine(url), hash_engine(url)) self.assertEqual( hash_engine(auth_url, creator=connect), hash_engine(auth_url, creator=connect), ) # Note: Hashing an engine with a creator can only be equal to the hash of another # engine with a creator, even if the underlying connection arguments are the same self.assertNotEqual(hash_engine(url), hash_engine(url, creator=connect)) self.assertNotEqual(hash_engine(url), hash_engine(auth_url)) self.assertNotEqual(hash_engine(url, encoding="utf-8"), hash_engine(url, encoding="ascii")) self.assertNotEqual(hash_engine(url, creator=connect), hash_engine(url, creator=lambda: True)) # mssql doesn't use `connect_args` if dialect != "mssql": self.assertEqual( hash_engine(auth_url), hash_engine(url, connect_args={ "user": "******", password_key: "pass" }), ) self.assertNotEqual( hash_engine(url, connect_args={"user": "******"}), hash_engine(url, connect_args={"user": "******"}), )
# # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Unit tests for UploadedFileManager""" import unittest from streamlit.uploaded_file_manager import UploadedFileManager from streamlit.uploaded_file_manager import UploadedFileRec file1 = UploadedFileRec(id="id1", name="file1", type="type", data=b"file1") file2 = UploadedFileRec(id="id2", name="file2", type="type", data=b"file2") class UploadedFileManagerTest(unittest.TestCase): def setUp(self): self.mgr = UploadedFileManager() self.filemgr_events = [] self.mgr.on_files_updated.connect(self._on_files_updated) def _on_files_updated(self, file_list, **kwargs): self.filemgr_events.append(file_list) def test_add_file(self): self.assertIsNone(self.mgr.get_files("non-report", "non-widget"))