Beispiel #1
0
 def save_tags(self, n_clicks, tags):
     if not n_clicks or not self.s.selected or not len(self.s.selected) == 1:
         raise PreventUpdate
     es = ExperimentService(Path(self.s.selected[0]["path"]))
     es.set_tags(tags)
     self.s.selected[0]["tags"] = " ".join(tags)
     all_tags = get_all_tags()
     if any(tag not in all_tags for tag in tags):
         # Need to clear the memoized cache for `get_all_tags` now.
         cache.delete_memoized(get_all_tags)
     return True, None
Beispiel #2
0
 def re_build_database(n_clicks):
     if not n_clicks:
         raise PreventUpdate
     start_time = time.time()
     # Remove any deleted experiments, track new experiments added.
     ExperimentService.init_db_table()
     # Clear the memoized cache for `get_all_tags` now.
     cache.delete_memoized(get_all_tags)
     # Ensure this takes at least 1 second so that the spinner notification has
     # time to display (it looks cool).
     if (time.time() - start_time) < 1:
         time.sleep(1)
     return False, True, None
Beispiel #3
0
def serve(config, launch, skip_db_update):
    """
    Serve the dashboard locally.
    """
    monkey.patch_all()

    from mallennlp.app import create_app
    from mallennlp.services.db import get_db_from_cli, init_tables, Tables
    from mallennlp.services.experiment import ExperimentService

    if not skip_db_update:
        db = get_db_from_cli(config)
        click.echo("Updating experiments index...")
        init_tables(db, (Tables.EXPERIMENTS.value, ))
        entries = [
            s.get_db_fields() for s in ExperimentService.find_experiments()
        ]
        if entries:
            click.echo(
                f"Found {click.style(str(len(entries)), fg='green')} experiments"
            )
            ExperimentService.init_db_table(db=db, entries=entries)
        else:
            click.secho("No experiments found", fg="yellow")
        db.close()

    click.secho(
        f"Serving AllenNLP manager for {click.style(config.project.name, bold=True)}",
        fg="green",
    )

    application = create_app(config)

    if launch:
        url = f"http://localhost:{config.server.port}"
        click.launch(url)

    options = {
        "timeout": 300,
        "workers": config.server.workers,
        "worker_class": "gevent",
        "worker_connections": config.server.worker_connections,
        "bind": f":{config.server.port}",
        "loglevel": config.project.loglevel.lower(),
        "preload_app": True,
    }

    StandaloneApplication(application, options).run()
Beispiel #4
0
 def check_path_valid(self, attribute, value):
     try:
         path = ExperimentService.get_canonical_path(Path(value))
         if not path.exists():
             raise InvalidPageParametersError(
                 "Directory does not exist.")
         if path.is_dir() and not ExperimentService.is_experiment(path):
             raise InvalidPageParametersError(
                 "Directory is not an experiment.")
         if path.is_file():
             raise InvalidPageParametersError(
                 "Path is a file, not an experiment directory.")
     except ValueError:
         raise InvalidPageParametersError(
             "Directory outside of project.")
     return str(path)
def test_set_tags(experiment_service, db):
    experiment_service.set_tags(["copynet", "seq2seq"])
    assert experiment_service.get_tags() == ["copynet", "seq2seq"]

    # Should be written to file now.
    assert experiment_service.e.meta.path.exists()
    es2 = ExperimentService(experiment_service.get_path())
    assert es2.get_tags() == ["copynet", "seq2seq"]

    # Should be updated in database as well.
    results = list(
        db.execute(
            f"SELECT * FROM {Tables.EXPERIMENTS.value} WHERE path = ?",
            (str(experiment_service.get_path()), ),
        ))
    assert len(results) == 1
    assert results[0]["tags"] == "copynet seq2seq"
Beispiel #6
0
def reset(db, table: str):
    click.echo(f"Resetting db table {click.style(table, fg='green')}...")

    from mallennlp.services.db import init_tables, Tables
    from mallennlp.services.experiment import ExperimentService

    init_tables(db, (table, ))
    if table == Tables.EXPERIMENTS.value:
        # Recursively search project for experiments and populate database.
        entries = [
            s.get_db_fields() for s in ExperimentService.find_experiments()
        ]
        if entries:
            click.echo(
                f"Found {click.style(str(len(entries)), fg='green')} experiments"
            )
            ExperimentService.init_db_table(db=db, entries=entries)
    db.close()
    click.secho("Success!", fg="green")
Beispiel #7
0
def display_tags(es: ExperimentService):
    tags = es.get_tags()
    tag_badges = [
        dbc.Badge(
            [t, html.I(className="fas fa-times-circle", style={"margin-left": "5px"})],
            id=f"experiment-tag-{i}",
            key=t,
            color="info",
            className="mr-1",
            pill=True,
            href="#",
        )
        for i, t in enumerate(tags[:MAX_TAG_BADGES])
    ]
    tooltips = [
        dbc.Tooltip(f"""Remove "{t}" tag""", target=f"experiment-tag-{i}")
        for i, t in enumerate(tags[:MAX_TAG_BADGES])
    ]
    for i in range(len(tags), MAX_TAG_BADGES):
        tag_badges.append(
            dbc.Badge("", id=f"experiment-tag-{i}", href="#", style={"display": "none"})
        )
    if len(tags) > MAX_TAG_BADGES:
        tag_badges.append(
            dbc.Badge(
                html.I(className="fas fa-ellipsis-h"),
                id="experiment-edit-tags-modal-open",
                color="info",
                className="mr-1",
                pill=True,
                href="#",
            )
        )
        tooltips.append(
            dbc.Tooltip(
                "See all tags or edit", target="experiment-edit-tags-modal-open"
            )
        )
    else:
        tag_badges.append(
            dbc.Badge(
                html.I(className="fas fa-edit", style={"margin-left": "3px"}),
                id="experiment-edit-tags-modal-open",
                color="info",
                className="mr-1",
                pill=True,
                href="#",
            )
        )
        tooltips.append(
            dbc.Tooltip("Edit tags", target="experiment-edit-tags-modal-open")
        )
    return [html.Span(tag_badges), html.Div(children=tooltips)]
Beispiel #8
0
def display_metrics(es: ExperimentService, epoch: Optional[int] = None):
    metrics = es.get_metrics(epoch)
    if metrics is None:
        return dcc.Markdown("**No metrics to display**")
    fields: List[str] = []
    for field_name, formatter in METRIC_EPOCH_DISPLAY_FIELDS.items():
        field_value = metrics.get(field_name)
        if field_value is not None:
            fields.append(formatter % field_value)
    for other_field_name in metrics:
        if other_field_name in METRIC_EPOCH_DISPLAY_FIELDS:
            continue
        if other_field_name.startswith("best_validation_"):
            field_value = metrics[other_field_name]
            fields.append(
                f"**Best epoch validation {other_field_name[16:].replace('_', ' ')}:** `{field_value}`"
            )
    return dcc.Markdown("\n\n".join(fields))
Beispiel #9
0
class ExperimentPage(Page):
    @serde
    class Params:
        path: str = attr.ib()
        active: str = "overview"

        @path.validator
        def check_path_valid(self, attribute, value):
            try:
                path = ExperimentService.get_canonical_path(Path(value))
                if not path.exists():
                    raise InvalidPageParametersError(
                        "Directory does not exist.")
                if path.is_dir() and not ExperimentService.is_experiment(path):
                    raise InvalidPageParametersError(
                        "Directory is not an experiment.")
                if path.is_file():
                    raise InvalidPageParametersError(
                        "Path is a file, not an experiment directory.")
            except ValueError:
                raise InvalidPageParametersError(
                    "Directory outside of project.")
            return str(path)

    def __init__(self, state, params):
        super().__init__(state, params)
        self.path = ExperimentService.get_canonical_path(Path(self.p.path))
        self.es = ExperimentService(self.path)

    def get_experiment_header_elements(self):
        status = ec.get_status(self.es)
        elements = [
            html.H5([
                *ec.get_path_breadcrumbs(self.path),
                ec.get_status_badge(status)
            ]),
            html.Div(id="experiment-tags", children=ec.display_tags(self.es)),
            ec.edit_tags_modal("experiment"),
        ]
        if status not in (Status.FINISHED, Status.FAILED):
            interval = 1000 * 30
        else:
            interval = 1000 * 300
        elements.append(
            dcc.Interval(id="experiment-update-interval", interval=interval))
        return elements

    def get_overview_elements(self):
        out = [
            html.Div(id="experiment-overview",
                     children=ec.display_metrics(self.es)),
            html.Strong("Logs: "),
            dcc.Link(
                "STDOUT",
                href="/log-stream?" +
                urlparse.urlencode({"path": self.path / self.es.STDOUT_FNAME}),
            ),
            ", ",
            dcc.Link(
                "STDERR",
                href="/log-stream?" +
                urlparse.urlencode({"path": self.path / self.es.STDERR_FNAME}),
            ),
        ]
        epochs = self.es.get_epochs()
        if epochs:
            out.extend([
                html.Br(),
                html.Br(),
                dcc.Dropdown(
                    id="experiment-metric-plot-dropdown",
                    options=[{
                        "label": m,
                        "value": m
                    } for m in self.es.get_metric_names()],
                    value="loss",
                ),
                dcc.Graph(
                    id="experiment-metric-plot",
                    config={"displayModeBar": False},
                    figure=ec.get_metric_plot_figure(self.es, "loss"),
                ),
            ])
        return out

    def get_settings_elements(self):
        return ["Coming soon"]

    def wrap_elements(self, elements):
        return [
            dbc.Row(
                dbc.Col(
                    self.get_experiment_header_elements(),
                    className="dash-padded-element experiment-header-element",
                )),
            dbc.Row(
                dbc.Col(
                    elements,
                    className=
                    "dash-padded-element experiment-main-content-element",
                )),
        ]

    def get_elements(self):
        # Update database entry.
        self.es.update_db_entry()
        # Create sidebar entries.
        entries = OrderedDict([
            (
                "overview",
                SidebarEntry(
                    "Overview",
                    lambda: self.wrap_elements(self.get_overview_elements()),
                    className="",
                ),
            ),
            (
                "settings",
                SidebarEntry(
                    "Settings",
                    lambda: self.wrap_elements(self.get_settings_elements()),
                    className="",
                ),
            ),
        ])
        return SidebarLayout("Experiment", entries, self.p.active, self.p)

    def get_notifications(self):
        return [
            dbc.Toast(
                "Tags successfully updated",
                id="experiment-edit-tags-noti",
                header="Success",
                dismissable=True,
                duration=4000,
                is_open=False,
                icon="success",
            )
        ]

    @Page.callback(
        [Output("experiment-metric-plot", "figure")],
        [
            Input("experiment-metric-plot-dropdown", "value"),
            Input("experiment-update-interval", "n_intervals"),
        ],
        mutating=False,
    )
    def update_metric_plot_figure(self, metric_name, n_intervals):
        if not metric_name and not n_intervals:
            raise PreventUpdate
        return ec.get_metric_plot_figure(self.es, metric_name)

    @Page.callback(
        [Output("experiment-status-badge", "children")],
        [Input("experiment-update-interval", "n_intervals")],
        mutating=False,
    )
    def update_status_badge(self, _):
        # Update database entry.
        self.es.update_db_entry()
        status = ec.get_status(self.es)
        return ec.get_status_badge(status)

    @Page.callback(
        [Output("experiment-epoch-metrics", "children")],
        [Input("experiment-update-interval", "n_intervals")],
        mutating=False,
    )
    def update_epoch_metrics(self, _):
        return ec.display_metrics(self.es)

    @Page.callback(
        [Output("experiment-tags", "children")],
        [
            Input(f"experiment-tag-{i}", "n_clicks")
            for i in range(ec.MAX_TAG_BADGES)
        ] + [Input("experiment-edit-tags-noti", "is_open")],
        [
            State(f"experiment-tag-{i}", "key")
            for i in range(ec.MAX_TAG_BADGES)
        ],
        mutating=False,
        permissions=Permissions.READ_WRITE,
    )
    def update_tags(self, *args):
        ctx = dash.callback_context
        if not ctx.triggered:
            raise PreventUpdate
        button_id = ctx.triggered[0]["prop_id"].split(".")[0]
        if button_id.startswith("experiment-tag-"):
            # Try deleting tag.
            current_tags = self.es.get_tags()
            tag = ctx.states[f"{button_id}.key"]
            if tag in current_tags:
                self.es.set_tags([t for t in current_tags if t != tag])
                # Need to clear the memoized cache for `get_all_tags` now.
                cache.delete_memoized(ec.get_all_tags)
        return ec.display_tags(self.es)

    @Page.callback(
        [
            Output("experiment-edit-tags-modal", "is_open"),
            Output("experiment-edit-tags-dropdown", "value"),
        ],
        [
            Input("experiment-edit-tags-modal-open", "n_clicks"),
            Input("experiment-edit-tags-modal-close", "n_clicks"),
        ],
        [State("experiment-edit-tags-modal", "is_open")],
        mutating=False,
        permissions=Permissions.READ_WRITE,
    )
    def toggle_modal(self, n1, n2, is_open):
        tags: Optional[List[str]] = None
        if n1 or n2:
            will_open = not is_open
            if will_open and self.es:
                tags = self.es.get_tags()
            return will_open, tags
        return is_open, tags

    @staticmethod
    @Page.callback(
        [Output("experiment-edit-tags-dropdown", "options")],
        [Input("experiment-edit-tags-dropdown", "search_value")],
        [State("experiment-edit-tags-dropdown", "value")],
    )
    def update_tag_options(search, value):
        if not search:
            raise PreventUpdate
        options = {t for t in ec.get_all_tags() if t.startswith(search)}
        options.add(search)
        for v in value or []:
            options.add(v)
        return [{"label": t, "value": t} for t in options]

    @Page.callback(
        [Output("experiment-edit-tags-noti", "is_open")],
        [Input("experiment-edit-tags-save", "n_clicks")],
        [State("experiment-edit-tags-dropdown", "value")],
        mutating=False,
    )
    def save_tags(self, n_clicks, tags):
        if not n_clicks:
            raise PreventUpdate
        self.es.set_tags(tags)
        all_tags = ec.get_all_tags()
        if any(tag not in all_tags for tag in tags):
            # Need to clear the memoized cache for `get_all_tags` now.
            cache.delete_memoized(ec.get_all_tags)
        return True
def experiment_service(project, db):
    return ExperimentService(project / "test_experiment", db)
def test_add_experiments(db, entries):
    init_tables(db, (Tables.EXPERIMENTS.value, ))
    ExperimentService.init_db_table(db=db, entries=entries)
def test_get_canonical_path(path, root, result):
    assert str(ExperimentService.get_canonical_path(Path(path),
                                                    Path(root))) == result
def test_remove_db_entry(experiment_service, db):
    ExperimentService.remove_db_entry(experiment_service.get_path(), db=db)
    assert len(list(
        db.execute(f"SELECT * FROM {Tables.EXPERIMENTS.value}"))) == 0
def test_find_experiments(project):
    exps = list(ExperimentService.find_experiments(project))
    assert len(exps) == 1
    assert exps[0].e.path == Path("test_experiment")
Beispiel #15
0
 def __init__(self, state, params):
     super().__init__(state, params)
     self.path = ExperimentService.get_canonical_path(Path(self.p.path))
     self.es = ExperimentService(self.path)
Beispiel #16
0
def init_project(verb: str, name: str, path: Path, username: str,
                 password: str, **kwargs):
    from mallennlp.domain.config import ProjectConfig, ServerConfig
    from mallennlp.domain.user import Permissions
    from mallennlp.services.db import init_db, get_db_from_cli
    from mallennlp.services.experiment import ExperimentService
    from mallennlp.services.config import Config
    from mallennlp.services.user import UserService

    if (path / Config.CONFIG_PATH).exists():
        raise click.ClickException(
            click.style("project already exists", fg="red"))

    # Initialize config.
    project_options = {
        k: v
        for k, v in kwargs.items()
        if not k.startswith("server_") and v is not None
    }
    server_options = {
        k[7:]: v
        for k, v in kwargs.items() if k.startswith("server_") and v is not None
    }
    config = Config(
        ProjectConfig(path, name=name, **project_options),
        ServerConfig(path, **server_options),
    )

    # Ensure the server's instance path exists.
    os.makedirs(config.server.instance_path)

    # Initialize the database file.
    init_db(config)

    # Add the user to the database.
    db = get_db_from_cli(config)
    user_service = UserService(db=db)
    user_service.create(username, password, permissions=Permissions.ADMIN)

    # Find existing experiments and add to database (does nothing if new project).
    experiment_entries = [
        s.get_db_fields() for s in ExperimentService.find_experiments()
    ]
    if experiment_entries:
        click.echo(
            f"Found {click.style(str(len(experiment_entries)), fg='green')} existing experiments"
        )
        ExperimentService.init_db_table(db=db, entries=experiment_entries)

    # Save the config to the 'Project.toml' file in the project directory.
    config.to_toml(path)

    click.echo(
        f"{verb} project named {click.style(name, fg='green', bold=True)}")
    click.echo(
        f" --> To edit the project's config, run {click.style('mallennlp edit', fg='yellow')} "
        f"from within the project "
        f"directory or edit the {click.style('Project.toml', fg='green')} file directly."
    )
    click.echo(
        f" --> To serve the dashboard, run {click.style('mallennlp serve', fg='yellow')} "
        f"from within the project directory.")
    click.echo(
        f" --> You can log in to the dashboard with the username {click.style(username, bold=True, fg='green')}"
    )
Beispiel #17
0
def get_metric_plot_figure(es: ExperimentService, metric_name: str):
    epochs = es.get_epochs()

    training_metric_name = f"training_{metric_name}"
    validation_metric_name = f"validation_{metric_name}"
    x_vals = list(range(len(epochs)))

    training_metrics: List[float] = []
    validation_metrics: List[float] = []
    for epoch in epochs:
        metrics = epoch.metrics.data
        if metrics is None:
            continue
        t = metrics.get(training_metric_name)
        if t:
            training_metrics.append(t)
        v = metrics.get(validation_metric_name)
        if v:
            validation_metrics.append(v)

    data: List[Dict[str, Any]] = []
    if len(training_metrics) == len(epochs):
        data.append(
            {
                "name": training_metric_name,
                "x": x_vals,
                "y": training_metrics,
                "mode": "lines+markers",
                "hoverinfo": "text",
                "text": [
                    f"epoch {i} {training_metric_name}: {m:.4f}"
                    for i, m in enumerate(training_metrics)
                ],
            }
        )
    if len(validation_metrics) == len(epochs):
        data.append(
            {
                "name": validation_metric_name,
                "x": x_vals,
                "y": validation_metrics,
                "mode": "lines+markers",
                "hoverinfo": "text",
                "text": [
                    f"epoch {i} {validation_metric_name}: {m:.4f}"
                    for i, m in enumerate(validation_metrics)
                ],
            }
        )
    return {
        "data": data,
        "layout": {
            "clickmode": "event+select",
            "xaxis": {
                "range": [0, len(epochs) - 0.5],
                "tickvals": list(range(len(epochs))),
                "ticktext": [str(i) for i in range(0, len(epochs))],
            },
            "margin": {"l": 40, "b": 30, "t": 20, "pad": 2},
            "uirevision": True,
        },
    }
Beispiel #18
0
def get_status(es: ExperimentService) -> Status:
    return es.get_status()