def start(self, parent_thread: Thread, port: int) -> None: """ Start visualization server on a new process. :param parent_thread: the parent thread that called start. :param port: Port to run visualization on. """ def wait_task(): parent_thread.join() self.stop_server() # Check if address is available. self.port = find_free_port(range(port, port + 100)) self._process = Process(target=self.start_server, name='flask', args=(self.game_state, self.port, self.child_conn)) self._process.start() thread = Thread(target=wait_task, name='waiting_on_parent_exit') thread.start() print("Viewer started at http://localhost:{}.".format(self.port)) if self.open_automatically: check_modules(["webbrowser"], missing_modules) with SupressStdStreams(): webbrowser.open("http://localhost:{}/".format(self.port))
def get_webdriver(path=None): """ Get the driver and options objects. :param path: path to browser binary. :return: driver """ check_modules(["selenium", "webdriver"], missing_modules) def chrome_driver(path=None): import urllib3 from selenium.webdriver.chrome.options import Options options = Options() options.add_argument('headless') options.add_argument('ignore-certificate-errors') options.add_argument("test-type") options.add_argument("no-sandbox") options.add_argument("disable-gpu") options.add_argument("allow-insecure-localhost") options.add_argument("allow-running-insecure-content") if path is not None: options.binary_location = path SELENIUM_RETRIES = 10 SELENIUM_DELAY = 3 # seconds for _ in range(SELENIUM_RETRIES): try: return webdriver.Chrome(chrome_options=options) except urllib3.exceptions.ProtocolError: # https://github.com/SeleniumHQ/selenium/issues/5296 time.sleep(SELENIUM_DELAY) raise ConnectionResetError( 'Cannot connect to Chrome, giving up after {SELENIUM_RETRIES} attempts.' ) def firefox_driver(path=None): from selenium.webdriver.firefox.options import Options options = Options() options.add_argument('headless') driver = webdriver.Firefox(firefox_binary=path, options=options) return driver driver_mapping = { 'geckodriver': firefox_driver, 'chromedriver': chrome_driver, 'chromium-driver': chrome_driver } for driver in sorted(driver_mapping.keys()): found = which(driver) if found is not None: return driver_mapping.get(driver, None)(path) raise WebdriverNotFoundError( "Chrome/Chromium/FireFox Webdriver not found.")
def concat_images(*images): check_modules(["pillow"], missing_modules) widths, heights = zip(*(i.size for i in images)) total_width = sum(widths) max_height = max(heights) new_im = Image.new('RGB', (total_width, max_height)) x_offset = 0 for im in images: new_im.paste(im, (x_offset, 0)) x_offset += im.size[0] return new_im
def visualize(world: Union[Game, State, GameState, World], interactive: bool = False): """ Show the current state of the world. :param world: Object representing a game state to be visualized. :param interactive: Whether or not to visualize the state in the browser. :return: Image object of the visualization. """ check_modules(["webbrowser"], missing_modules) if isinstance(world, Game): game = world state = load_state(game.world, game.infos) state["objective"] = game.objective elif isinstance(world, GameState): state = load_state_from_game_state(game_state=world) elif isinstance(world, World): state = load_state(world) elif isinstance(world, State): state = world world = World.from_facts(state.facts) state = load_state(world) else: raise ValueError("Don't know how to visualize: {!r}".format(world)) state["command"] = "" state["history"] = "" html = get_html_template(game_state=json.dumps(state)) tmpdir = maybe_mkdir(pjoin(tempfile.gettempdir(), "textworld")) fh, filename = tempfile.mkstemp(suffix=".html", dir=tmpdir, text=True) url = 'file://' + filename with open(filename, 'w') as f: f.write(html) img_graph = take_screenshot(url, id="world") img_inventory = take_screenshot(url, id="inventory") image = concat_images( img_inventory, img_graph, ) if interactive: try: webbrowser.open(url) finally: return image return image
def get_html_template(game_state=None): check_modules(["pybars"], missing_modules) # read in template compiler = pybars.Compiler() with open(pjoin(WEB_SERVER_RESOURCES, 'slideshow.handlebars'), 'r') as f: contents = f.read() template = compiler.compile(contents) if game_state is None: return template html = template({ 'game_state': game_state, 'template_path': WEB_SERVER_RESOURCES, }) return html
def __init__(self, game_state: dict, port: int): """ Note: Flask routes are defined in app.add_url_rule in order to call `self` in routes. :param game_state: game state returned from load_state_from_game_state :param port: port to run visualization on """ check_modules(["gevent", "flask"], missing_modules) super().__init__() # disabling loggers log = logging.getLogger('werkzeug') log.disabled = True self.port = port self.results = Queue() self.subscribers = [] self.game_state = game_state self.app = Flask(__name__, static_folder=pjoin(WEB_SERVER_RESOURCES, 'static')) self.app.add_url_rule('/', 'index', self.index) self.app.add_url_rule('/subscribe', 'subscribe', self.subscribe) self.slideshow_template = get_html_template()
def take_screenshot(url: str, id: str = 'world'): """ Takes a screenshot of DOM element given its id. :param url: URL of webpage to open headlessly. :param id: ID of DOM element. :return: Image object. """ check_modules(["pillow"], missing_modules) driver = get_webdriver() driver.get(url) svg = driver.find_element_by_id(id) location = svg.location size = svg.size png = driver.get_screenshot_as_png() driver.close() image = Image.open(io.BytesIO(png)) left = location['x'] top = location['y'] right = location['x'] + size['width'] bottom = location['y'] + size['height'] image = image.crop((left, top, right, bottom)) return image
def show_graph(facts: Iterable[Proposition], title: str = "Knowledge Graph", renderer: Optional[str] = None, save: Optional[str] = None) -> "plotly.graph_objs._figure.Figure": r""" Visualizes the graph made from a collection of facts. Arguments: facts: Collection of facts representing a state of a game. title: Title for the figure renderer: Which Plotly's renderer to use (e.g., 'browser'). save: If provided, path where to save a PNG version of the graph. Returns: The Plotly's figure representing the graph. Example: >>> import textworld >>> options = textworld.GameOptions() >>> options.seeds = 1234 >>> game_file, game = textworld.make(options) >>> import gym >>> import textworld.gym >>> from textworld import EnvInfos >>> request_infos = EnvInfos(facts=True) >>> env_id = textworld.gym.register_game(game_file, request_infos) >>> env = gym.make(env_id) >>> _, infos = env.reset() >>> textworld.render.show_graph(infos["facts"]) """ check_modules(["matplotlib", "plotly"], missing_modules) G = build_graph_from_facts(facts) plt.figure(figsize=(16, 9)) pos = nx.drawing.nx_pydot.pydot_layout(G, prog="fdp") edge_labels_pos = {} trace3_list = [] for edge in G.edges(data=True): trace3 = go.Scatter( x=[], y=[], mode='lines', line=dict(width=0.5, color='#888', shape='spline', smoothing=1), hoverinfo='none' ) x0, y0 = pos[edge[0]] x1, y1 = pos[edge[1]] rvec = (x0 - x1, y0 - y1) # Vector from dest -> src. length = np.sqrt(rvec[0] ** 2 + rvec[1] ** 2) mid = ((x0 + x1) / 2., (y0 + y1) / 2.) orthogonal = (rvec[1] / length, -rvec[0] / length) trace3['x'] += (x0, mid[0] + 0 * orthogonal[0], x1, None) trace3['y'] += (y0, mid[1] + 0 * orthogonal[1], y1, None) trace3_list.append(trace3) offset_ = 5 edge_labels_pos[(pos[edge[0]], pos[edge[1]])] = (mid[0] + offset_ * orthogonal[0], mid[1] + offset_ * orthogonal[1]) node_x = [] node_y = [] node_labels = [] for node, data in G.nodes(data=True): x, y = pos[node] node_x.append(x) node_y.append(y) node_labels.append("<b>{}</b>".format(data['label'].replace(" ", "<br>"))) node_trace = go.Scatter( x=node_x, y=node_y, mode='text', text=node_labels, textfont=dict( family="sans serif", size=12, color="black" ), hoverinfo='none', marker=dict( showscale=True, color=[], size=10, line_width=2 ) ) fig = go.Figure( data=[*trace3_list, node_trace], layout=go.Layout( title=title, titlefont_size=16, showlegend=False, hovermode='closest', margin=dict(b=20, l=5, r=5, t=40), xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), yaxis=dict(showgrid=False, zeroline=False, showticklabels=False) ) ) def _get_angle(p0, p1): x0, y0 = p0 x1, y1 = p1 if x1 == x0: return 0 angle = -np.rad2deg(np.arctan((y1 - y0) / (x1 - x0) / (16 / 9))) return angle def _calc_arrow_standoff(angle, label): return 5 + np.log(90 / abs(angle)) * max(map(len, label.split())) # Add relation names and relation arrows. annotations = [] for edge in G.edges(data=True): p0, p1 = pos[edge[0]], pos[edge[1]] x0, y0 = p0 x1, y1 = p1 angle = _get_angle(p0, p1) annotations.append( go.layout.Annotation( x=x1, y=y1, ax=(x0 + x1) / 2, ay=(y0 + y1) / 2, axref="x", ayref="y", showarrow=True, arrowhead=2, arrowsize=3, arrowwidth=0.5, arrowcolor="#888", standoff=_calc_arrow_standoff(angle, G.nodes[edge[1]]['label']), ) ) annotations.append( go.layout.Annotation( x=edge_labels_pos[(p0, p1)][0], y=edge_labels_pos[(p0, p1)][1], showarrow=False, text="<i>{}</i>".format(edge[2]['type']), textangle=angle, font=dict( family="sans serif", size=12, color="blue" ), ) ) fig.update_layout(annotations=annotations) if renderer: fig.show(renderer=renderer) if save: fig.write_image(save, width=1920, height=1080, scale=4) return fig