def test_custom_img(self): max_pixel_value = lambda img: img.max() custom = lambda img: img.tolist() img_interface = Interface(max_pixel_value, "image", "label", interpretation=custom) result = img_interface.interpret([gradio.test_data.BASE64_IMAGE])[0] expected_result = np.asarray(decode_base64_to_image(gradio.test_data.BASE64_IMAGE).convert('RGB')).tolist() self.assertEqual(result, expected_result)
def test_state(self): def predict(input, history): if history is None: history = "" history += input return history, history io = Interface(predict, ["textbox", "state"], ["textbox", "state"]) app, _, _ = io.launch(prevent_thread_lock=True) client = TestClient(app) response = client.post( "/api/predict/", json={ "data": ["test", None], "fn_index": 0, "session_hash": "_" }, ) output = dict(response.json()) print("output", output) self.assertEqual(output["data"], ["test", None]) response = client.post( "/api/predict/", json={ "data": ["test", None], "fn_index": 0, "session_hash": "_" }, ) output = dict(response.json()) self.assertEqual(output["data"], ["testtest", None])
def test_default_image(self): max_pixel_value = lambda img: img.max() img_interface = Interface(max_pixel_value, "image", "number", interpretation="default") array = np.zeros((100,100)) array[0, 0] = 1 img = encode_array_to_base64(array) interpretation = img_interface.interpret([img])[0] self.assertGreater(interpretation[0][0], 0) # Checks to see if the top-left has >0 score.
def test_show_error(self): io = Interface(lambda x: 1 / x, "number", "number") app, _, _ = io.launch(show_error=True, prevent_thread_lock=True) client = TestClient(app) response = client.post("/api/predict/", json={"data": [0]}) self.assertEqual(response.status_code, 500) self.assertTrue("error" in response.json()) io.close()
def test_shapley_text(self): max_word_len = lambda text: max([len(word) for word in text.split(" ")]) text_interface = Interface( max_word_len, "textbox", "label", interpretation="shapley" ) interpretation = text_interface.interpret(["quickest brown fox"])[0][ "interpretation" ][0] self.assertGreater( interpretation[1], 0 ) # Checks to see if the first word has >0 score.
def test_custom_text(self): max_word_len = lambda text: max([len(word) for word in text.split(" ")]) custom = lambda text: [(char, 1) for char in text] text_interface = Interface( max_word_len, "textbox", "label", interpretation=custom ) result = text_interface.interpret(["quickest brown fox"])[0]["interpretation"][ 0 ] self.assertEqual( result[1], 1 ) # Checks to see if the first letter has score of 1.
def test_default_text(self): max_word_len = lambda text: max( [len(word) for word in text.split(" ")]) text_interface = Interface(max_word_len, "textbox", "label", interpretation="default") interpretation = text_interface.interpret(["quickest brown fox"])[0][0] self.assertGreater(interpretation[0][1], 0) # Checks to see if the first word has >0 score. self.assertEqual(interpretation[-1][1], 0) # Checks to see if the last word has 0 score.
def test_process_example(self): io = Interface(lambda x: "Hello " + x, "text", "text", examples=[["World"]]) prediction = process_examples.process_example(io, 0) self.assertEquals(prediction[0], "Hello World")
class TestRoutes(unittest.TestCase): def setUp(self) -> None: self.io = Interface(lambda x: x, "text", "text") self.app, _, _ = self.io.launch(prevent_thread_lock=True) self.client = TestClient(self.app) def test_get_main_route(self): response = self.client.get('/') self.assertEqual(response.status_code, 200) def test_get_api_route(self): response = self.client.get('/api/') self.assertEqual(response.status_code, 200) def test_static_files_served_safely(self): # Make sure things outside the static folder are not accessible response = self.client.get(r'/static/..%2findex.html') self.assertEqual(response.status_code, 404) response = self.client.get(r'/static/..%2f..%2fapi_docs.html') self.assertEqual(response.status_code, 404) def test_get_config_route(self): response = self.client.get('/config/') self.assertEqual(response.status_code, 200) def test_predict_route(self): response = self.client.post('/api/predict/', json={"data": ["test"]}) self.assertEqual(response.status_code, 200) output = dict(response.json()) self.assertEqual(output["data"], ["test"]) self.assertTrue("durations" in output) self.assertTrue("avg_durations" in output) # def test_queue_push_route(self): # networking.queue.push = mock.MagicMock(return_value=(None, None)) # response = self.client.post('/api/queue/push/', json={"data": "test", "action": "test"}) # self.assertEqual(response.status_code, 200) # def test_queue_push_route(self): # networking.queue.get_status = mock.MagicMock(return_value=(None, None)) # response = self.client.post('/api/queue/status/', json={"hash": "test"}) # self.assertEqual(response.status_code, 200) def tearDown(self) -> None: self.io.close() reset_all()
def test_default_image(self): max_pixel_value = lambda img: img.max() img_interface = Interface(max_pixel_value, "image", "label") array = np.zeros((100, 100)) array[0, 0] = 1 interpretation = self.default_method(img_interface, [array])[0] self.assertGreater(interpretation[0][0], 0) # Checks to see if the top-left has >0 score.
def test_get_classification_value(self): iface = Interface(lambda text: text, ["textbox"], ["label"]) diff = gradio.interpretation.get_regression_or_classification_value( iface, ["cat"], ["test"]) self.assertEquals(diff, 1) diff = gradio.interpretation.get_regression_or_classification_value( iface, ["test"], ["test"]) self.assertEquals(diff, 0)
def test_quantify_difference_with_label(self): iface = Interface(lambda text: len(text), ["textbox"], ["label"]) diff = gradio.interpretation.quantify_difference_in_label( iface, ["3"], ["10"]) self.assertEquals(diff, -7) diff = gradio.interpretation.quantify_difference_in_label( iface, ["0"], ["100"]) self.assertEquals(diff, -100)
def test_quantify_difference_with_textbox(self): iface = Interface(lambda text: text, ["textbox"], ["textbox"]) diff = gradio.interpretation.quantify_difference_in_label( iface, ["test"], ["test"]) self.assertEquals(diff, 0) diff = gradio.interpretation.quantify_difference_in_label( iface, ["test"], ["test_diff"]) self.assertEquals(diff, 1)
def process_example(interface: Interface, example_id: int) -> Tuple[List[Any], List[float]]: """Loads an example from the interface and returns its prediction.""" example_set = interface.examples[example_id] raw_input = [ interface.input_components[i].preprocess_example(example) for i, example in enumerate(example_set) ] prediction = interface.process(raw_input) return prediction
def test_get_regression_value(self): iface = Interface(lambda text: text, ["textbox"], ["label"]) output_1 = {"cat": 0.9, "dog": 0.1} output_2 = {"cat": float("nan"), "dog": 0.4} output_3 = {"cat": 0.1, "dog": 0.6} diff = gradio.interpretation.get_regression_or_classification_value( iface, [output_1], [output_2]) self.assertEquals(diff, 0) diff = gradio.interpretation.get_regression_or_classification_value( iface, [output_1], [output_3]) self.assertAlmostEquals(diff, 0.1)
class TestAuthenticatedRoutes(unittest.TestCase): def setUp(self) -> None: self.io = Interface(lambda x: x, "text", "text") self.app, _, _ = self.io.launch(auth=("test", "correct_password"), prevent_thread_lock=True) self.client = TestClient(self.app) def test_post_login(self): response = self.client.post("/login", data=dict(username="******", password="******")) self.assertEqual(response.status_code, 302) response = self.client.post("/login", data=dict(username="******", password="******")) self.assertEqual(response.status_code, 400) def tearDown(self) -> None: self.io.close() reset_all()
def test_quantify_difference_with_confidences(self): iface = Interface(lambda text: len(text), ["textbox"], ["label"]) output_1 = {"cat": 0.9, "dog": 0.1} output_2 = {"cat": 0.6, "dog": 0.4} output_3 = {"cat": 0.1, "dog": 0.6} diff = gradio.interpretation.quantify_difference_in_label( iface, [output_1], [output_2]) self.assertAlmostEquals(diff, 0.3) diff = gradio.interpretation.quantify_difference_in_label( iface, [output_1], [output_3]) self.assertAlmostEquals(diff, 0.8)
def test_create_tunnel(self): response = requests.get(networking.GRADIO_API_SERVER) payload = response.json()[0] io = Interface(lambda x: x, "text", "text") _, path_to_local_server, _ = io.launch(prevent_thread_lock=True, share=False) _, localhost, port = path_to_local_server.split(":") threading.Thread.start = mock.MagicMock(return_value=None) paramiko.SSHClient.connect = mock.MagicMock(return_value=None) tunneling.create_tunnel(payload, localhost, port) threading.Thread.start.assert_called_once() paramiko.SSHClient.connect.assert_called_once() io.close()
def test_flagging_analytics(self): callback = flagging.CSVLogger() callback.flag = mock.MagicMock() aiohttp.ClientSession.post = mock.MagicMock() aiohttp.ClientSession.post.__aenter__ = None aiohttp.ClientSession.post.__aexit__ = None io = Interface(lambda x: x, "text", "text", analytics_enabled=True, flagging_callback=callback) app, _, _ = io.launch(show_error=True, prevent_thread_lock=True) client = TestClient(app) response = client.post( '/api/flag/', json={"data": { "input_data": ["test"], "output_data": ["test"] }}) aiohttp.ClientSession.post.assert_called() callback.flag.assert_called_once() self.assertEqual(response.status_code, 200) io.close()
def test_caching(self): io = Interface( lambda x: "Hello " + x, "text", "text", examples=[["World"], ["Dunya"], ["Monde"]], ) io.launch(prevent_thread_lock=True) process_examples.cache_interface_examples(io) prediction = process_examples.load_from_cache(io, 1) io.close() self.assertEquals(prediction[0], "Hello Dunya")
def get_output_instance(iface: Interface): if isinstance(iface, str): shortcut = OutputComponent.get_all_shortcut_implementations()[iface] return shortcut[0](**shortcut[1]) # a dict with `name` as the output component type and other keys as parameters elif isinstance(iface, dict): name = iface.pop('name') for component in OutputComponent.__subclasses__(): if component.__name__.lower() == name: break else: raise ValueError("No such OutputComponent: {}".format(name)) return component(**iface) elif isinstance(iface, OutputComponent): return iface else: raise ValueError("Output interface must be of type `str` or `dict` or" "`OutputComponent` but is {}".format(iface))
def test_start_server(self): io = Interface(lambda x: x, "number", "number") io.favicon_path = None io.config = io.get_config_file() io.show_error = True io.flagging_callback.setup(gr.Number(), io.flagging_dir) io.auth = None port = networking.get_first_available_port( networking.INITIAL_PORT_VALUE, networking.INITIAL_PORT_VALUE + networking.TRY_NUM_PORTS, ) _, local_path, _, server = networking.start_server(io, server_port=port) url = urllib.parse.urlparse(local_path) self.assertEquals(url.scheme, "http") self.assertEquals(url.port, port) server.close()
def test_interpretation(self): io = Interface(lambda x: len(x), "text", "label", interpretation="default", analytics_enabled=True) app, _, _ = io.launch(prevent_thread_lock=True) client = TestClient(app) aiohttp.ClientSession.post = mock.MagicMock() aiohttp.ClientSession.post.__aenter__ = None aiohttp.ClientSession.post.__aexit__ = None io.interpret = mock.MagicMock(return_value=(None, None)) response = client.post('/api/interpret/', json={"data": ["test test"]}) aiohttp.ClientSession.post.assert_called() self.assertEqual(response.status_code, 200) io.close()
def setUp(self) -> None: self.io = Interface(lambda x: x + x, "text", "text") self.app, _, _ = self.io.launch(prevent_thread_lock=True) self.client = TestClient(self.app)
def start_server( interface: Interface, server_name: Optional[str] = None, server_port: Optional[int] = None, ) -> Tuple[int, str, fastapi.FastAPI, threading.Thread, None]: """Launches a local server running the provided Interface Parameters: interface: The interface object to run on the server server_name: to make app accessible on local network, set this to "0.0.0.0". Can be set by environment variable GRADIO_SERVER_NAME. server_port: will start gradio app on this port (if available). Can be set by environment variable GRADIO_SERVER_PORT. auth: If provided, username and password (or list of username-password tuples) required to access interface. Can also provide function that takes username and password and returns True if valid login. """ server_name = server_name or LOCALHOST_NAME # if port is not specified, search for first available port if server_port is None: port = get_first_available_port(INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS) else: try: s = socket.socket() s.bind((LOCALHOST_NAME, server_port)) s.close() except OSError: raise OSError( "Port {} is in use. If a gradio.Interface is running on the port, you can close() it or gradio.close_all()." .format(server_port)) port = server_port url_host_name = "localhost" if server_name == "0.0.0.0" else server_name path_to_local_server = "http://{}:{}/".format(url_host_name, port) auth = interface.auth if auth is not None: if not callable(auth): app.auth = {account[0]: account[1] for account in auth} else: app.auth = auth else: app.auth = None app.interface = interface app.cwd = os.getcwd() app.favicon_path = interface.favicon_path app.tokens = {} if app.interface.enable_queue: if auth is not None or app.interface.encrypt: raise ValueError( "Cannot queue with encryption or authentication enabled.") queueing.init() app.queue_thread = threading.Thread(target=queue_thread, args=(path_to_local_server, )) app.queue_thread.start() if interface.save_to is not None: # Used for selenium tests interface.save_to["port"] = port config = uvicorn.Config(app=app, port=port, host=server_name, log_level="warning") server = Server(config=config) server.run_in_thread() return port, path_to_local_server, app, server
raise DeprecationWarning( "This function is deprecated. To create stateful demos, pass 'state'" "as both an input and output component. Please see the getting started" "guide for more information.") def set_state(*args): raise DeprecationWarning( "This function is deprecated. To create stateful demos, pass 'state'" "as both an input and output component. Please see the getting started" "guide for more information.") if __name__ == '__main__': # Run directly for debugging: python app.py from gradio import Interface app.interface = Interface(lambda x: "Hello, " + x, "text", "text", analytics_enabled=False) app.interface.config = app.interface.get_config_file() app.interface.show_error = True app.interface.flagging_callback.setup(app.interface.flagging_dir) app.favicon_path = None app.tokens = {} auth = True if auth: app.interface.auth = ("a", "b") app.auth = {"a": "b"} app.interface.auth_message = None else: app.auth = None uvicorn.run(app)
def test_quantify_difference_with_number(self): iface = Interface(lambda text: text, ["textbox"], ["number"]) diff = gradio.interpretation.quantify_difference_in_label(iface, [4], [6]) self.assertEquals(diff, -2)
def setUp(self) -> None: self.io = Interface(lambda x: x, "text", "text") self.app, _, _ = self.io.launch(auth=("test", "correct_password"), prevent_thread_lock=True) self.client = TestClient(self.app)
class TestRoutes(unittest.TestCase): def setUp(self) -> None: self.io = Interface(lambda x: x + x, "text", "text") self.app, _, _ = self.io.launch(prevent_thread_lock=True) self.client = TestClient(self.app) def test_get_main_route(self): response = self.client.get("/") self.assertEqual(response.status_code, 200) # def test_get_api_route(self): # response = self.client.get("/api/") # self.assertEqual(response.status_code, 200) def test_static_files_served_safely(self): # Make sure things outside the static folder are not accessible response = self.client.get(r"/static/..%2findex.html") self.assertEqual(response.status_code, 404) response = self.client.get(r"/static/..%2f..%2fapi_docs.html") self.assertEqual(response.status_code, 404) def test_get_config_route(self): response = self.client.get("/config/") self.assertEqual(response.status_code, 200) def test_predict_route(self): response = self.client.post("/api/predict/", json={ "data": ["test"], "fn_index": 0 }) self.assertEqual(response.status_code, 200) output = dict(response.json()) self.assertEqual(output["data"], ["testtest"]) def test_state(self): def predict(input, history): if history is None: history = "" history += input return history, history io = Interface(predict, ["textbox", "state"], ["textbox", "state"]) app, _, _ = io.launch(prevent_thread_lock=True) client = TestClient(app) response = client.post( "/api/predict/", json={ "data": ["test", None], "fn_index": 0, "session_hash": "_" }, ) output = dict(response.json()) print("output", output) self.assertEqual(output["data"], ["test", None]) response = client.post( "/api/predict/", json={ "data": ["test", None], "fn_index": 0, "session_hash": "_" }, ) output = dict(response.json()) self.assertEqual(output["data"], ["testtest", None]) def test_queue_push_route(self): queueing.push = mock.MagicMock(return_value=(None, None)) response = self.client.post("/api/queue/push/", json={ "data": "test", "action": "test" }) self.assertEqual(response.status_code, 200) def test_queue_push_route_2(self): queueing.get_status = mock.MagicMock(return_value=(None, None)) response = self.client.post("/api/queue/status/", json={"hash": "test"}) self.assertEqual(response.status_code, 200) def tearDown(self) -> None: self.io.close() reset_all()