예제 #1
0
 def test_custom_img(self):
     max_pixel_value = lambda img: img.max()
     custom = lambda img: img.tolist()
     img_interface = Interface(max_pixel_value, "image", "label", interpretation=custom)
     result = img_interface.interpret([gradio.test_data.BASE64_IMAGE])[0]
     expected_result = np.asarray(decode_base64_to_image(gradio.test_data.BASE64_IMAGE).convert('RGB')).tolist()
     self.assertEqual(result, expected_result)
예제 #2
0
    def test_state(self):
        def predict(input, history):
            if history is None:
                history = ""
            history += input
            return history, history

        io = Interface(predict, ["textbox", "state"], ["textbox", "state"])
        app, _, _ = io.launch(prevent_thread_lock=True)
        client = TestClient(app)
        response = client.post(
            "/api/predict/",
            json={
                "data": ["test", None],
                "fn_index": 0,
                "session_hash": "_"
            },
        )
        output = dict(response.json())
        print("output", output)
        self.assertEqual(output["data"], ["test", None])
        response = client.post(
            "/api/predict/",
            json={
                "data": ["test", None],
                "fn_index": 0,
                "session_hash": "_"
            },
        )
        output = dict(response.json())
        self.assertEqual(output["data"], ["testtest", None])
예제 #3
0
 def test_default_image(self):
     max_pixel_value = lambda img: img.max()
     img_interface = Interface(max_pixel_value, "image", "number", interpretation="default")
     array = np.zeros((100,100))
     array[0, 0] = 1
     img = encode_array_to_base64(array)        
     interpretation = img_interface.interpret([img])[0]        
     self.assertGreater(interpretation[0][0], 0)  # Checks to see if the top-left has >0 score.
예제 #4
0
 def test_show_error(self):
     io = Interface(lambda x: 1 / x, "number", "number")
     app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
     client = TestClient(app)
     response = client.post("/api/predict/", json={"data": [0]})
     self.assertEqual(response.status_code, 500)
     self.assertTrue("error" in response.json())
     io.close()
예제 #5
0
 def test_shapley_text(self):
     max_word_len = lambda text: max([len(word) for word in text.split(" ")])
     text_interface = Interface(
         max_word_len, "textbox", "label", interpretation="shapley"
     )
     interpretation = text_interface.interpret(["quickest brown fox"])[0][
         "interpretation"
     ][0]
     self.assertGreater(
         interpretation[1], 0
     )  # Checks to see if the first word has >0 score.
예제 #6
0
 def test_custom_text(self):
     max_word_len = lambda text: max([len(word) for word in text.split(" ")])
     custom = lambda text: [(char, 1) for char in text]
     text_interface = Interface(
         max_word_len, "textbox", "label", interpretation=custom
     )
     result = text_interface.interpret(["quickest brown fox"])[0]["interpretation"][
         0
     ]
     self.assertEqual(
         result[1], 1
     )  # Checks to see if the first letter has score of 1.
예제 #7
0
 def test_default_text(self):
     max_word_len = lambda text: max(
         [len(word) for word in text.split(" ")])
     text_interface = Interface(max_word_len,
                                "textbox",
                                "label",
                                interpretation="default")
     interpretation = text_interface.interpret(["quickest brown fox"])[0][0]
     self.assertGreater(interpretation[0][1],
                        0)  # Checks to see if the first word has >0 score.
     self.assertEqual(interpretation[-1][1],
                      0)  # Checks to see if the last word has 0 score.
예제 #8
0
 def test_process_example(self):
     io = Interface(lambda x: "Hello " + x,
                    "text",
                    "text",
                    examples=[["World"]])
     prediction = process_examples.process_example(io, 0)
     self.assertEquals(prediction[0], "Hello World")
예제 #9
0
class TestRoutes(unittest.TestCase):
    def setUp(self) -> None:
        self.io = Interface(lambda x: x, "text", "text")
        self.app, _, _ = self.io.launch(prevent_thread_lock=True)
        self.client = TestClient(self.app)

    def test_get_main_route(self):
        response = self.client.get('/')
        self.assertEqual(response.status_code, 200)

    def test_get_api_route(self):
        response = self.client.get('/api/')
        self.assertEqual(response.status_code, 200)

    def test_static_files_served_safely(self):
        # Make sure things outside the static folder are not accessible
        response = self.client.get(r'/static/..%2findex.html')
        self.assertEqual(response.status_code, 404)
        response = self.client.get(r'/static/..%2f..%2fapi_docs.html')
        self.assertEqual(response.status_code, 404)

    def test_get_config_route(self):
        response = self.client.get('/config/')
        self.assertEqual(response.status_code, 200)

    def test_predict_route(self):
        response = self.client.post('/api/predict/', json={"data": ["test"]})
        self.assertEqual(response.status_code, 200)
        output = dict(response.json())
        self.assertEqual(output["data"], ["test"])
        self.assertTrue("durations" in output)
        self.assertTrue("avg_durations" in output)

    # def test_queue_push_route(self):
    #     networking.queue.push = mock.MagicMock(return_value=(None, None))
    #     response = self.client.post('/api/queue/push/', json={"data": "test", "action": "test"})
    #     self.assertEqual(response.status_code, 200)

    # def test_queue_push_route(self):
    #     networking.queue.get_status = mock.MagicMock(return_value=(None, None))
    #     response = self.client.post('/api/queue/status/', json={"hash": "test"})
    #     self.assertEqual(response.status_code, 200)

    def tearDown(self) -> None:
        self.io.close()
        reset_all()
예제 #10
0
 def test_default_image(self):
     max_pixel_value = lambda img: img.max()
     img_interface = Interface(max_pixel_value, "image", "label")
     array = np.zeros((100, 100))
     array[0, 0] = 1
     interpretation = self.default_method(img_interface, [array])[0]
     self.assertGreater(interpretation[0][0],
                        0)  # Checks to see if the top-left has >0 score.
 def test_get_classification_value(self):
     iface = Interface(lambda text: text, ["textbox"], ["label"])
     diff = gradio.interpretation.get_regression_or_classification_value(
         iface, ["cat"], ["test"])
     self.assertEquals(diff, 1)
     diff = gradio.interpretation.get_regression_or_classification_value(
         iface, ["test"], ["test"])
     self.assertEquals(diff, 0)
 def test_quantify_difference_with_label(self):
     iface = Interface(lambda text: len(text), ["textbox"], ["label"])
     diff = gradio.interpretation.quantify_difference_in_label(
         iface, ["3"], ["10"])
     self.assertEquals(diff, -7)
     diff = gradio.interpretation.quantify_difference_in_label(
         iface, ["0"], ["100"])
     self.assertEquals(diff, -100)
 def test_quantify_difference_with_textbox(self):
     iface = Interface(lambda text: text, ["textbox"], ["textbox"])
     diff = gradio.interpretation.quantify_difference_in_label(
         iface, ["test"], ["test"])
     self.assertEquals(diff, 0)
     diff = gradio.interpretation.quantify_difference_in_label(
         iface, ["test"], ["test_diff"])
     self.assertEquals(diff, 1)
예제 #14
0
def process_example(interface: Interface,
                    example_id: int) -> Tuple[List[Any], List[float]]:
    """Loads an example from the interface and returns its prediction."""
    example_set = interface.examples[example_id]
    raw_input = [
        interface.input_components[i].preprocess_example(example)
        for i, example in enumerate(example_set)
    ]
    prediction = interface.process(raw_input)
    return prediction
 def test_get_regression_value(self):
     iface = Interface(lambda text: text, ["textbox"], ["label"])
     output_1 = {"cat": 0.9, "dog": 0.1}
     output_2 = {"cat": float("nan"), "dog": 0.4}
     output_3 = {"cat": 0.1, "dog": 0.6}
     diff = gradio.interpretation.get_regression_or_classification_value(
         iface, [output_1], [output_2])
     self.assertEquals(diff, 0)
     diff = gradio.interpretation.get_regression_or_classification_value(
         iface, [output_1], [output_3])
     self.assertAlmostEquals(diff, 0.1)
예제 #16
0
class TestAuthenticatedRoutes(unittest.TestCase):
    def setUp(self) -> None:
        self.io = Interface(lambda x: x, "text", "text")
        self.app, _, _ = self.io.launch(auth=("test", "correct_password"),
                                        prevent_thread_lock=True)
        self.client = TestClient(self.app)

    def test_post_login(self):
        response = self.client.post("/login",
                                    data=dict(username="******",
                                              password="******"))
        self.assertEqual(response.status_code, 302)
        response = self.client.post("/login",
                                    data=dict(username="******",
                                              password="******"))
        self.assertEqual(response.status_code, 400)

    def tearDown(self) -> None:
        self.io.close()
        reset_all()
 def test_quantify_difference_with_confidences(self):
     iface = Interface(lambda text: len(text), ["textbox"], ["label"])
     output_1 = {"cat": 0.9, "dog": 0.1}
     output_2 = {"cat": 0.6, "dog": 0.4}
     output_3 = {"cat": 0.1, "dog": 0.6}
     diff = gradio.interpretation.quantify_difference_in_label(
         iface, [output_1], [output_2])
     self.assertAlmostEquals(diff, 0.3)
     diff = gradio.interpretation.quantify_difference_in_label(
         iface, [output_1], [output_3])
     self.assertAlmostEquals(diff, 0.8)
예제 #18
0
 def test_create_tunnel(self):
     response = requests.get(networking.GRADIO_API_SERVER)
     payload = response.json()[0]
     io = Interface(lambda x: x, "text", "text")
     _, path_to_local_server, _ = io.launch(prevent_thread_lock=True,
                                            share=False)
     _, localhost, port = path_to_local_server.split(":")
     threading.Thread.start = mock.MagicMock(return_value=None)
     paramiko.SSHClient.connect = mock.MagicMock(return_value=None)
     tunneling.create_tunnel(payload, localhost, port)
     threading.Thread.start.assert_called_once()
     paramiko.SSHClient.connect.assert_called_once()
     io.close()
예제 #19
0
 def test_flagging_analytics(self):
     callback = flagging.CSVLogger()
     callback.flag = mock.MagicMock()
     aiohttp.ClientSession.post = mock.MagicMock()
     aiohttp.ClientSession.post.__aenter__ = None
     aiohttp.ClientSession.post.__aexit__ = None
     io = Interface(lambda x: x,
                    "text",
                    "text",
                    analytics_enabled=True,
                    flagging_callback=callback)
     app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
     client = TestClient(app)
     response = client.post(
         '/api/flag/',
         json={"data": {
             "input_data": ["test"],
             "output_data": ["test"]
         }})
     aiohttp.ClientSession.post.assert_called()
     callback.flag.assert_called_once()
     self.assertEqual(response.status_code, 200)
     io.close()
예제 #20
0
 def test_caching(self):
     io = Interface(
         lambda x: "Hello " + x,
         "text",
         "text",
         examples=[["World"], ["Dunya"], ["Monde"]],
     )
     io.launch(prevent_thread_lock=True)
     process_examples.cache_interface_examples(io)
     prediction = process_examples.load_from_cache(io, 1)
     io.close()
     self.assertEquals(prediction[0], "Hello Dunya")
예제 #21
0
def get_output_instance(iface: Interface):
    if isinstance(iface, str):
        shortcut = OutputComponent.get_all_shortcut_implementations()[iface]
        return shortcut[0](**shortcut[1])
    # a dict with `name` as the output component type and other keys as parameters
    elif isinstance(iface, dict):
        name = iface.pop('name')
        for component in OutputComponent.__subclasses__():
            if component.__name__.lower() == name:
                break
        else:
            raise ValueError("No such OutputComponent: {}".format(name))
        return component(**iface)
    elif isinstance(iface, OutputComponent):
        return iface
    else:
        raise ValueError("Output interface must be of type `str` or `dict` or"
                         "`OutputComponent` but is {}".format(iface))
예제 #22
0
    def test_start_server(self):
        io = Interface(lambda x: x, "number", "number")
        io.favicon_path = None
        io.config = io.get_config_file()
        io.show_error = True
        io.flagging_callback.setup(gr.Number(), io.flagging_dir)
        io.auth = None

        port = networking.get_first_available_port(
            networking.INITIAL_PORT_VALUE,
            networking.INITIAL_PORT_VALUE + networking.TRY_NUM_PORTS,
        )
        _, local_path, _, server = networking.start_server(io,
                                                           server_port=port)
        url = urllib.parse.urlparse(local_path)
        self.assertEquals(url.scheme, "http")
        self.assertEquals(url.port, port)
        server.close()
예제 #23
0
 def test_interpretation(self):
     io = Interface(lambda x: len(x),
                    "text",
                    "label",
                    interpretation="default",
                    analytics_enabled=True)
     app, _, _ = io.launch(prevent_thread_lock=True)
     client = TestClient(app)
     aiohttp.ClientSession.post = mock.MagicMock()
     aiohttp.ClientSession.post.__aenter__ = None
     aiohttp.ClientSession.post.__aexit__ = None
     io.interpret = mock.MagicMock(return_value=(None, None))
     response = client.post('/api/interpret/', json={"data": ["test test"]})
     aiohttp.ClientSession.post.assert_called()
     self.assertEqual(response.status_code, 200)
     io.close()
예제 #24
0
 def setUp(self) -> None:
     self.io = Interface(lambda x: x + x, "text", "text")
     self.app, _, _ = self.io.launch(prevent_thread_lock=True)
     self.client = TestClient(self.app)
예제 #25
0
def start_server(
    interface: Interface,
    server_name: Optional[str] = None,
    server_port: Optional[int] = None,
) -> Tuple[int, str, fastapi.FastAPI, threading.Thread, None]:
    """Launches a local server running the provided Interface
    Parameters:
    interface: The interface object to run on the server
    server_name: to make app accessible on local network, set this to "0.0.0.0". Can be set by environment variable GRADIO_SERVER_NAME.
    server_port: will start gradio app on this port (if available). Can be set by environment variable GRADIO_SERVER_PORT.
    auth: If provided, username and password (or list of username-password tuples) required to access interface. Can also provide function that takes username and password and returns True if valid login.
    """
    server_name = server_name or LOCALHOST_NAME
    # if port is not specified, search for first available port
    if server_port is None:
        port = get_first_available_port(INITIAL_PORT_VALUE,
                                        INITIAL_PORT_VALUE + TRY_NUM_PORTS)
    else:
        try:
            s = socket.socket()
            s.bind((LOCALHOST_NAME, server_port))
            s.close()
        except OSError:
            raise OSError(
                "Port {} is in use. If a gradio.Interface is running on the port, you can close() it or gradio.close_all()."
                .format(server_port))
        port = server_port

    url_host_name = "localhost" if server_name == "0.0.0.0" else server_name
    path_to_local_server = "http://{}:{}/".format(url_host_name, port)
    auth = interface.auth
    if auth is not None:
        if not callable(auth):
            app.auth = {account[0]: account[1] for account in auth}
        else:
            app.auth = auth
    else:
        app.auth = None
    app.interface = interface
    app.cwd = os.getcwd()
    app.favicon_path = interface.favicon_path
    app.tokens = {}

    if app.interface.enable_queue:
        if auth is not None or app.interface.encrypt:
            raise ValueError(
                "Cannot queue with encryption or authentication enabled.")
        queueing.init()
        app.queue_thread = threading.Thread(target=queue_thread,
                                            args=(path_to_local_server, ))
        app.queue_thread.start()
    if interface.save_to is not None:  # Used for selenium tests
        interface.save_to["port"] = port

    config = uvicorn.Config(app=app,
                            port=port,
                            host=server_name,
                            log_level="warning")
    server = Server(config=config)
    server.run_in_thread()
    return port, path_to_local_server, app, server
예제 #26
0
    raise DeprecationWarning(
        "This function is deprecated. To create stateful demos, pass 'state'"
        "as both an input and output component. Please see the getting started"
        "guide for more information.")


def set_state(*args):
    raise DeprecationWarning(
        "This function is deprecated. To create stateful demos, pass 'state'"
        "as both an input and output component. Please see the getting started"
        "guide for more information.")

    
if __name__ == '__main__': # Run directly for debugging: python app.py
    from gradio import Interface    
    app.interface = Interface(lambda x: "Hello, " + x, "text", "text",
                              analytics_enabled=False)
    app.interface.config = app.interface.get_config_file()
    app.interface.show_error = True
    app.interface.flagging_callback.setup(app.interface.flagging_dir)
    app.favicon_path = None
    app.tokens = {}
    
    auth = True
    if auth:
        app.interface.auth = ("a", "b")
        app.auth = {"a": "b"}
        app.interface.auth_message = None
    else:
        app.auth = None

    uvicorn.run(app)
예제 #27
0
 def test_quantify_difference_with_number(self):
     iface = Interface(lambda text: text, ["textbox"], ["number"])
     diff = gradio.interpretation.quantify_difference_in_label(iface, [4], [6])
     self.assertEquals(diff, -2)
예제 #28
0
 def setUp(self) -> None:
     self.io = Interface(lambda x: x, "text", "text")
     self.app, _, _ = self.io.launch(auth=("test", "correct_password"),
                                     prevent_thread_lock=True)
     self.client = TestClient(self.app)
예제 #29
0
class TestRoutes(unittest.TestCase):
    def setUp(self) -> None:
        self.io = Interface(lambda x: x + x, "text", "text")
        self.app, _, _ = self.io.launch(prevent_thread_lock=True)
        self.client = TestClient(self.app)

    def test_get_main_route(self):
        response = self.client.get("/")
        self.assertEqual(response.status_code, 200)

    # def test_get_api_route(self):
    #     response = self.client.get("/api/")
    #     self.assertEqual(response.status_code, 200)

    def test_static_files_served_safely(self):
        # Make sure things outside the static folder are not accessible
        response = self.client.get(r"/static/..%2findex.html")
        self.assertEqual(response.status_code, 404)
        response = self.client.get(r"/static/..%2f..%2fapi_docs.html")
        self.assertEqual(response.status_code, 404)

    def test_get_config_route(self):
        response = self.client.get("/config/")
        self.assertEqual(response.status_code, 200)

    def test_predict_route(self):
        response = self.client.post("/api/predict/",
                                    json={
                                        "data": ["test"],
                                        "fn_index": 0
                                    })
        self.assertEqual(response.status_code, 200)
        output = dict(response.json())
        self.assertEqual(output["data"], ["testtest"])

    def test_state(self):
        def predict(input, history):
            if history is None:
                history = ""
            history += input
            return history, history

        io = Interface(predict, ["textbox", "state"], ["textbox", "state"])
        app, _, _ = io.launch(prevent_thread_lock=True)
        client = TestClient(app)
        response = client.post(
            "/api/predict/",
            json={
                "data": ["test", None],
                "fn_index": 0,
                "session_hash": "_"
            },
        )
        output = dict(response.json())
        print("output", output)
        self.assertEqual(output["data"], ["test", None])
        response = client.post(
            "/api/predict/",
            json={
                "data": ["test", None],
                "fn_index": 0,
                "session_hash": "_"
            },
        )
        output = dict(response.json())
        self.assertEqual(output["data"], ["testtest", None])

    def test_queue_push_route(self):
        queueing.push = mock.MagicMock(return_value=(None, None))
        response = self.client.post("/api/queue/push/",
                                    json={
                                        "data": "test",
                                        "action": "test"
                                    })
        self.assertEqual(response.status_code, 200)

    def test_queue_push_route_2(self):
        queueing.get_status = mock.MagicMock(return_value=(None, None))
        response = self.client.post("/api/queue/status/",
                                    json={"hash": "test"})
        self.assertEqual(response.status_code, 200)

    def tearDown(self) -> None:
        self.io.close()
        reset_all()