예제 #1
0
 def get_grid_metrics(self):
     """Return the metrics that should be displayed in the tracking table.
     """
     return [
         tt.group("train", [
             tt.leaf("epoch"),
             tt.leaf("reco", ".3f"),
         ],
                  align=">"),
         tt.group("valid", [
             tt.leaf("penalty", ".1f"),
             tt.leaf("ms", ".1f"),
             tt.leaf("reco", ".2%"),
             tt.leaf("breco", ".2%"),
             tt.leaf("b_nsdr", ".2f"),
             tt.leaf("b_nsdr_drums", ".2f"),
             tt.leaf("b_nsdr_bass", ".2f"),
             tt.leaf("b_nsdr_other", ".2f"),
             tt.leaf("b_nsdr_vocals", ".2f"),
         ],
                  align=">"),
         tt.group("test",
                  [tt.leaf(name, ".2f") for name in self.test_metrics],
                  align=">")
     ]
    def check(self, trim=None, reset=False):
        to_check = []
        statuses = {}
        for job in self.jobs:
            if get_done(job.name):
                statuses[job.sid] = "done"
            elif job.sid is not None:
                to_check.append(job.sid)
        statuses.update(_check(to_check))

        if trim is not None:
            trim = len(get_metrics(self.jobs[trim].name))

        lines = []
        for index, job in enumerate(self.jobs):
            status = statuses.get(job.sid, "failed")
            if status in ["failed", "completing"] and reset:
                reset_job(job.name)
                status = "reset"

            meta = {
                'name': job.name,
                'sid': job.sid,
                'status': status[:2],
                "index": index
            }
            metrics = get_metrics(job.name)
            if trim is not None:
                metrics = metrics[:trim]
            meta["epoch"] = len(metrics)
            if metrics:
                metrics = metrics[-1]
            else:
                metrics = {}
            lines.append({'meta': meta, 'metrics': metrics})

        table = tt.table(shorten=True,
                         groups=[
                             tt.group("meta", [
                                 tt.leaf("index", align=">"),
                                 tt.leaf("name"),
                                 tt.leaf("sid", align=">"),
                                 tt.leaf("status"),
                                 tt.leaf("epoch", align=">")
                             ]),
                             tt.group("metrics", [
                                 tt.leaf("train", ".2%"),
                                 tt.leaf("valid", ".2%"),
                                 tt.leaf("best", ".2%"),
                                 tt.leaf("true_model_size", ".2f"),
                                 tt.leaf("compressed_model_size", ".2f"),
                             ])
                         ])
        print(tt.treetable(lines, table, colors=["0", "38;5;245"]))
예제 #3
0
            parts = []
        else:
            parts = [
                p.split("=") for p in name.split(" ") if "tasnet" not in p
            ]
        if not args.individual:
            parts = [(k, v) for k, v in parts if k != STD_KEY]
        name = model + " " + " ".join(f"{k}={v}" for k, v in parts)
        all_stats[name].append(metric)

metrics = [
    tt.leaf("score", ".4f"),
    tt.leaf("std", ".3f"),
    tt.leaf("count", ".2f")
]

mytable = tt.table([tt.leaf("name"), tt.group("valid", metrics)])

lines = []
for name, stats in all_stats.items():
    line = {"name": name}
    stats = np.array(stats)
    line["valid"] = {
        "score": stats.mean(),
        "std": stats.std() / stats.shape[0]**0.5,
        "count": stats.shape[0]
    }
    lines.append(line)
lines.sort(key=lambda x: x["valid"]["score"])
print(tt.treetable(lines, mytable, colors=['33', '0']))
예제 #4
0
        parts = [p.split("=") for p in name.split(" ") if "tasnet" not in p]
    if not args.individual:
        parts = [(k, v) for k, v in parts if k != STD_KEY]
    name = model + " " + " ".join(f"{k}={v}" for k, v in parts)
    stats = read(args.metric, results)
    if (not stats or len(stats["drums"]) != 50):
        print(f"Missing stats for {results}", file=sys.stderr)
    else:
        all_stats[name].append(stats)

metrics = [tt.leaf("score", ".2f"), tt.leaf("std", ".2f")]
sources = ["drums", "bass", "other", "vocals"]

mytable = tt.table(
    [tt.leaf("name"),
     tt.group("all", metrics + [tt.leaf("count")])] +
    [tt.group(source, metrics) for idx, source in enumerate(sources)])

lines = []
for name, stats in all_stats.items():
    line = {"name": name}
    if 'accompaniment' in stats:
        del stats['accompaniment']
    alls = []
    for source in sources:
        stat = [np.nanmedian(s[source]) for s in stats]
        alls.append(stat)
        line[source] = {
            "score": np.mean(stat),
            "std": np.std(stat) / len(stat)**0.5
        }
        parts = []
    else:
        parts = [p.split("=") for p in name.split(" ") if p != '--tasnet']
    if not args.individual:
        parts = [(k, v) for k, v in parts if k != STD_KEY]
    name = model + " " + " ".join(f"{k}={v}" for k, v in parts)
    stats = read(args.metric, results)
    if (not stats or len(stats["drums"]) != 50):
        print(f"Missing stats for {results}", file=sys.stderr)
    else:
        all_stats[name].append(stats)

metrics = [tt.leaf("score", ".2f"), tt.leaf("std", ".2f")]
sources = ["drums", "bass", "other", "vocals"]

mytable = tt.table([tt.leaf("name"), tt.group("all", metrics + [tt.leaf("count")])] +
                   [tt.group(source, metrics) for idx, source in enumerate(sources)])

lines = []
for name, stats in all_stats.items():
    line = {"name": name}
    if 'accompaniment' in stats:
        del stats['accompaniment']
    alls = []
    for source in sources:
        stat = [np.nanmedian(s[source]) for s in stats]
        alls.append(stat)
        line[source] = {"score": np.mean(stat), "std": np.std(stat) / len(stat)**0.5}
    alls = np.array(alls)
    line["all"] = {
        "score": alls.mean(),